mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
[s2s] adjust finetune + test to work with fsmt (#7263)
This commit is contained in:
parent
8d562a2d1a
commit
af4b98ed97
@ -61,6 +61,8 @@ class SummarizationModule(BaseTransformer):
|
|||||||
pickle_save(self.hparams, self.hparams_save_path)
|
pickle_save(self.hparams, self.hparams_save_path)
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
self.metrics = defaultdict(list)
|
self.metrics = defaultdict(list)
|
||||||
|
self.model_type = self.config.model_type
|
||||||
|
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
|
||||||
|
|
||||||
self.dataset_kwargs: dict = dict(
|
self.dataset_kwargs: dict = dict(
|
||||||
data_dir=self.hparams.data_dir,
|
data_dir=self.hparams.data_dir,
|
||||||
@ -106,14 +108,18 @@ class SummarizationModule(BaseTransformer):
|
|||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
try:
|
if self.model_type == "t5":
|
||||||
freeze_params(self.model.model.shared)
|
freeze_params(self.model.shared)
|
||||||
|
for d in [self.model.encoder, self.model.decoder]:
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
elif self.model_type == "fsmt":
|
||||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||||
freeze_params(d.embed_positions)
|
freeze_params(d.embed_positions)
|
||||||
freeze_params(d.embed_tokens)
|
freeze_params(d.embed_tokens)
|
||||||
except AttributeError:
|
else:
|
||||||
freeze_params(self.model.shared)
|
freeze_params(self.model.model.shared)
|
||||||
for d in [self.model.encoder, self.model.decoder]:
|
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||||
|
freeze_params(d.embed_positions)
|
||||||
freeze_params(d.embed_tokens)
|
freeze_params(d.embed_tokens)
|
||||||
|
|
||||||
def forward(self, input_ids, **kwargs):
|
def forward(self, input_ids, **kwargs):
|
||||||
@ -140,7 +146,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||||
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
|
|
||||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
assert lm_logits.shape[-1] == self.vocab_size
|
||||||
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||||
|
@ -103,6 +103,7 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
|
|||||||
BART_TINY = "sshleifer/bart-tiny-random"
|
BART_TINY = "sshleifer/bart-tiny-random"
|
||||||
MBART_TINY = "sshleifer/tiny-mbart"
|
MBART_TINY = "sshleifer/tiny-mbart"
|
||||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||||
|
FSMT_TINY = "stas/tiny-wmt19-en-de"
|
||||||
|
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
@ -374,11 +375,11 @@ def test_run_eval_search(model):
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY],
|
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
||||||
)
|
)
|
||||||
def test_finetune(model):
|
def test_finetune(model):
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
||||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||||
|
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir()
|
||||||
@ -407,7 +408,13 @@ def test_finetune(model):
|
|||||||
lm_head = module.model.lm_head
|
lm_head = module.model.lm_head
|
||||||
assert not lm_head.weight.requires_grad
|
assert not lm_head.weight.requires_grad
|
||||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||||
|
elif model == FSMT_TINY:
|
||||||
|
fsmt = module.model.model
|
||||||
|
embed_pos = fsmt.decoder.embed_positions
|
||||||
|
assert not embed_pos.weight.requires_grad
|
||||||
|
assert not fsmt.decoder.embed_tokens.weight.requires_grad
|
||||||
|
# check that embeds are not the same
|
||||||
|
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
|
||||||
else:
|
else:
|
||||||
bart = module.model.model
|
bart = module.model.model
|
||||||
embed_pos = bart.decoder.embed_positions
|
embed_pos = bart.decoder.embed_positions
|
||||||
|
Loading…
Reference in New Issue
Block a user