diff --git a/scripts/sample.py b/scripts/sample.py index ce3df41..a68273d 100644 --- a/scripts/sample.py +++ b/scripts/sample.py @@ -22,6 +22,7 @@ def configure_arg_parser(): parser.add_argument("seeds", nargs="*", type=int, help="Random seeds") parser.add_argument("-m", "--mode", type=str, default="ddpm", help="Sampling mode") parser.add_argument("-f", "--freq", type=int, default=1, help="Sampling step frequency") + parser.add_argument("-s", "--guidance_strength", type=float, default=0.0, help="Guidance strength") parser.add_argument("-i", "--dataset_dir", type=str, help="Input file for generation") parser.add_argument("-o", "--output_dir", type=str, help="Output directory for sampling result") parser.add_argument("-d", "--device_id", type=int, default=0, help="GPU device id") @@ -34,6 +35,7 @@ def main( seeds: list[int], mode: str, freq: int, + guidance_strength: float, dataset_dir: str, output_dir: str, device_id: int, @@ -52,7 +54,7 @@ def main( device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu" - _, _, enc_dim, dec_dim = get_components(config.base_name) + _, _, enc_dim, dec_dim = get_components(config.base_name, config.encoder.pretrained, **config.decoder) model = DiDi.load_from_checkpoint(model_path, enc_dim=enc_dim, dec_dim=dec_dim, map_location=device) model.eval() @@ -64,7 +66,7 @@ def main( context.append(utterance) joined_context, _ = preprocess(context, "") raw_context = context_tokenizer(joined_context, **tokenizer_kwargs).to(device) - reply = sample(raw_context, model, mode, freq, context_tokenizer)[0] + reply = sample(raw_context, model, mode, freq, guidance_strength, context_tokenizer)[0] context.append(reply) print("DiDi:", reply) except KeyboardInterrupt: diff --git a/scripts/train.py b/scripts/train.py index 5ae69e5..7f53e7a 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -79,6 +79,8 @@ def main(config_path: str, dataset_dir: str, ckpt_dir: str = None, resume: str = train_dataset.vocab_size, config.encoder.freeze, pad_idx=train_dataset.pad_idx, + bos_idx=train_dataset.bos_idx, + eos_idx=train_dataset.eos_idx, batch_decoder=batch_decoder, **config.didi, ) diff --git a/src/data/commonsense_dataset.py b/src/data/commonsense_dataset.py index f7737c6..338dc8e 100644 --- a/src/data/commonsense_dataset.py +++ b/src/data/commonsense_dataset.py @@ -47,6 +47,14 @@ def vocab_size(self) -> int: def pad_idx(self) -> int: return self.context_tokenizer.pad_token_id + @property + def bos_idx(self) -> int: + return self.context_tokenizer.bos_token_id or self.context_tokenizer.cls_token_id + + @property + def eos_idx(self) -> int: + return self.context_tokenizer.eos_token_id or self.context_tokenizer.sep_token_id + def __iter__(self) -> Iterator[tuple[str, str]]: n_epochs = 0 while True: diff --git a/src/data/reddit_dataset.py b/src/data/reddit_dataset.py index 00c5d34..e653f3b 100644 --- a/src/data/reddit_dataset.py +++ b/src/data/reddit_dataset.py @@ -60,6 +60,14 @@ def vocab_size(self) -> int: def pad_idx(self) -> int: return self.context_tokenizer.pad_token_id + @property + def bos_idx(self) -> int: + return self.context_tokenizer.bos_token_id or self.context_tokenizer.cls_token_id + + @property + def eos_idx(self) -> int: + return self.context_tokenizer.eos_token_id or self.context_tokenizer.sep_token_id + def __iter__(self) -> Iterator[tuple[str, str]]: for file in self.files: zero_rank_info(f"Reading file: {file}") diff --git a/src/diffusion/model.py b/src/diffusion/model.py index c232797..086f5e3 100644 --- a/src/diffusion/model.py +++ b/src/diffusion/model.py @@ -77,6 +77,10 @@ def __init__( schedule: str, step_freq: int, pad_idx: int, + bos_idx: int, + eos_idx: int, + context_dropout_prob: float = 0.0, + guidance_strength: float = 0.0, tie_weights: bool = False, lr: float = 0.0001, weight_decay: float = 0.0, @@ -91,11 +95,17 @@ def __init__( super().__init__() self.save_hyperparameters(ignore=[encoder, decoder]) self.diffusion_steps = diffusion_steps - self.pad_idx = pad_idx self.step_freq = step_freq self.encoder_dim = enc_dim self.decoder_dim = dec_dim + self.dropout_prob = context_dropout_prob + self.w = guidance_strength + + self.pad_idx = pad_idx + self.bos_idx = bos_idx + self.eos_idx = eos_idx + self.emb = nn.Embedding(vocabulary_size, dec_dim, padding_idx=pad_idx) self.time_embeds = nn.Embedding(diffusion_steps + 1, dec_dim) @@ -153,6 +163,22 @@ def _encode_context(self, encoder_input_ids, encoder_attention_mask): context = self.adapter(context) return context + def dropout_context(self, context, dropout_prob): + out_context = context.copy() + batch_size = context.input_ids.shape[0] + + empty_context = torch.full_like(context.input_ids[0], self.pad_idx) + empty_mask = torch.zeros_like(context.attention_mask[0]) + empty_context[0] = self.bos_idx + empty_context[1] = self.eos_idx + empty_mask[0] = 1 + empty_mask[1] = 1 + + condition = torch.rand((batch_size, 1), device=context.input_ids.device) < dropout_prob + out_context["input_ids"] = torch.where(condition, empty_context, context.input_ids) + out_context["attention_mask"] = torch.where(condition, empty_mask, context.attention_mask) + return out_context + def forward( self, encoder_input_ids: torch.Tensor = None, @@ -183,6 +209,10 @@ def forward( def training_step(self, batch: list, batch_idx: int): raw_context, target = batch + + if self.dropout_prob: + raw_context = self.dropout_context(raw_context, self.dropout_prob) + emb = self.emb(target.input_ids) x_0 = get_x0(emb, self.std_0) noise = torch.randn_like(x_0) @@ -224,7 +254,15 @@ def training_step(self, batch: list, batch_idx: int): def validation_step(self, batch: list, batch_idx: int): raw_context, target = batch max_trg_len = target.input_ids.shape[1] - logits = sample(raw_context, self, self.sampling_mode, self.step_freq, max_len=max_trg_len, raw_output=True) + logits = sample( + raw_context, + self, + self.sampling_mode, + self.step_freq, + guidance_strength=self.w, + max_len=max_trg_len, + raw_output=True, + ) predictions = logits.argmax(-1) self.val_ce.append(calculate_batch_ce(logits, target.input_ids, target.attention_mask).item()) diff --git a/src/sampling.py b/src/sampling.py index bfe7a7c..c3a6e18 100644 --- a/src/sampling.py +++ b/src/sampling.py @@ -4,18 +4,31 @@ @torch.no_grad() -def sample(raw_context, model, mode, step_freq, tokenizer=None, max_len=-1, raw_output=False, skip_special=True): +def sample( + raw_context, + model, + mode, + step_freq, + guidance_strength=0.0, + tokenizer=None, + max_len=-1, + raw_output=False, + skip_special=True, +): input_ids = raw_context.input_ids emb = model.emb(input_ids)[:, :max_len] x_t = torch.randn_like(emb) * model.sigmas[-1] cached_context = None + empty_cached_context = None ones = torch.ones((emb.shape[0], 1), dtype=torch.long, device=emb.device) noise = torch.empty_like(emb) if mode == "ddpm": - logits = sample_ddpm(model, x_t, raw_context, cached_context, noise, ones, step_freq) + logits = sample_ddpm( + model, x_t, raw_context, cached_context, empty_cached_context, noise, ones, step_freq, guidance_strength + ) elif mode == "euler": logits = sample_euler(model, x_t, raw_context, cached_context, noise, ones, step_freq) else: @@ -33,31 +46,46 @@ def sample(raw_context, model, mode, step_freq, tokenizer=None, max_len=-1, raw_ return select_reply(replies) -def sample_ddpm(model, x_t, raw_context, cached_context, noise, ones, step_freq): +def guided_step(model, x_t, t, raw_context, cached_context, empty_cached_context, ones, guidance_strength): + x_0, cached_context = model( + encoder_input_ids=raw_context.input_ids, + encoder_attention_mask=raw_context.attention_mask, + decoder_inputs_embeds=x_t, + time_ids=t * ones, + context=cached_context, + ) + + if guidance_strength: + empty_context = model.dropout_context(raw_context, 1) + x_0_uncond, empty_cached_context = model( + encoder_input_ids=empty_context.input_ids, + encoder_attention_mask=empty_context.attention_mask, + decoder_inputs_embeds=x_t, + time_ids=t * ones, + context=empty_cached_context, + ) + x_0 = (1 + guidance_strength) * x_0 - guidance_strength * x_0_uncond + return x_0, cached_context, empty_cached_context + + +def sample_ddpm( + model, x_t, raw_context, cached_context, empty_cached_context, noise, ones, step_freq, guidance_strength +): diffusion_steps = model.diffusion_steps timesteps = range(diffusion_steps, 1, -step_freq) x_t = scale_input(x_t, model.sigmas[-1]) for t in timesteps: - x_0, cached_context = model( - encoder_input_ids=raw_context.input_ids, - encoder_attention_mask=raw_context.attention_mask, - decoder_inputs_embeds=x_t, - time_ids=t * ones, - context=cached_context, + x_0, cached_context, empty_cached_context = guided_step( + model, x_t, t, raw_context, cached_context, empty_cached_context, ones, guidance_strength ) sigma_t = model.sigmas[max(t - step_freq, 1)] noise.normal_(0, 1) x_t = scale_input(x_0 + sigma_t * noise, sigma_t) - x_0, _ = model( - encoder_attention_mask=raw_context.attention_mask, - decoder_inputs_embeds=x_t, - time_ids=ones, - context=cached_context, - ) + x_0, *_ = guided_step(model, x_t, 1, raw_context, cached_context, empty_cached_context, ones, guidance_strength) logits = model.classifier(x_0) return logits