From af4b98ed97ec9d10c22c45e033f8dd2c0da3b69e Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 21 Sep 2020 12:13:19 -0700 Subject: [PATCH] [s2s] adjust finetune + test to work with fsmt (#7263) --- examples/seq2seq/finetune.py | 18 ++++++++++++------ examples/seq2seq/test_seq2seq_examples.py | 13 ++++++++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index f54f15c1d55..0da637f13b2 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -61,6 +61,8 @@ class SummarizationModule(BaseTransformer): pickle_save(self.hparams, self.hparams_save_path) self.step_count = 0 self.metrics = defaultdict(list) + self.model_type = self.config.model_type + self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size self.dataset_kwargs: dict = dict( data_dir=self.hparams.data_dir, @@ -106,14 +108,18 @@ class SummarizationModule(BaseTransformer): def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" - try: - freeze_params(self.model.model.shared) + if self.model_type == "t5": + freeze_params(self.model.shared) + for d in [self.model.encoder, self.model.decoder]: + freeze_params(d.embed_tokens) + elif self.model_type == "fsmt": for d in [self.model.model.encoder, self.model.model.decoder]: freeze_params(d.embed_positions) freeze_params(d.embed_tokens) - except AttributeError: - freeze_params(self.model.shared) - for d in [self.model.encoder, self.model.decoder]: + else: + freeze_params(self.model.model.shared) + for d in [self.model.model.encoder, self.model.model.decoder]: + freeze_params(d.embed_positions) freeze_params(d.embed_tokens) def forward(self, input_ids, **kwargs): @@ -140,7 +146,7 @@ class SummarizationModule(BaseTransformer): # Same behavior as modeling_bart.py, besides ignoring pad_token_id ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) - assert lm_logits.shape[-1] == self.model.config.vocab_size + assert lm_logits.shape[-1] == self.vocab_size loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 2772647288a..68a27f0f380 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -103,6 +103,7 @@ T5_TINY = "patrickvonplaten/t5-tiny-random" BART_TINY = "sshleifer/bart-tiny-random" MBART_TINY = "sshleifer/tiny-mbart" MARIAN_TINY = "sshleifer/tiny-marian-en-de" +FSMT_TINY = "stas/tiny-wmt19-en-de" stream_handler = logging.StreamHandler(sys.stdout) @@ -374,11 +375,11 @@ def test_run_eval_search(model): @pytest.mark.parametrize( "model", - [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY], + [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY], ) def test_finetune(model): args_d: dict = CHEAP_ARGS.copy() - task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization" + task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization" args_d["label_smoothing"] = 0.1 if task == "translation" else 0 tmp_dir = make_test_data_dir() @@ -407,7 +408,13 @@ def test_finetune(model): lm_head = module.model.lm_head assert not lm_head.weight.requires_grad assert (lm_head.weight == input_embeds.weight).all().item() - + elif model == FSMT_TINY: + fsmt = module.model.model + embed_pos = fsmt.decoder.embed_positions + assert not embed_pos.weight.requires_grad + assert not fsmt.decoder.embed_tokens.weight.requires_grad + # check that embeds are not the same + assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens else: bart = module.model.model embed_pos = bart.decoder.embed_positions