[Seq2Seq Trainer] Make sure padding is implemented for models without pad_token (#8043)

* make sure padding is implemented for non-padding tokens models as well

* add better error message

* add better warning

* remove results files

* Update examples/seq2seq/seq2seq_trainer.py

* remove unnecessary copy line

* correct usage of labels

* delete test files
This commit is contained in:
Patrick von Platen 2020-10-26 17:28:16 +01:00 committed by GitHub
parent 098ddc2244
commit 664c7ec453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 20 deletions

View File

@ -1,4 +1,3 @@
import copy
from typing import Any, Dict, Optional, Tuple, Union
import torch
@ -60,6 +59,11 @@ class Seq2SeqTrainer(Trainer):
self.config.pad_token_id is not None
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
if self.config.pad_token_id is None and self.config.eos_token_id is not None:
logger.warn(
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
)
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
@ -126,22 +130,19 @@ class Seq2SeqTrainer(Trainer):
else DistributedSampler(self.train_dataset)
)
def _compute_loss(self, model, inputs):
inputs = copy.deepcopy(inputs)
def _compute_loss(self, model, inputs, labels):
if self.args.label_smoothing == 0:
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
# compute usual loss via models
loss, logits = model(**inputs, use_cache=False)[:2]
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
else:
# compute label smoothed loss
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, _ = label_smoothed_nll_loss(
@ -150,7 +151,8 @@ class Seq2SeqTrainer(Trainer):
return loss, logits
def compute_loss(self, model, inputs):
loss, _ = self._compute_loss(model, inputs)
labels = inputs.pop("labels")
loss, _ = self._compute_loss(model, inputs, labels)
return loss
def prediction_step(
@ -178,25 +180,27 @@ class Seq2SeqTrainer(Trainer):
"""
inputs = self._prepare_inputs(inputs)
gen_kwargs = {
"max_length": self.data_args.val_max_target_length
if self.data_args is not None
else self.config.max_length,
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
}
if self.args.predict_with_generate and not self.args.prediction_loss_only:
gen_kwargs = {
"max_length": self.data_args.val_max_target_length
if self.data_args is not None
else self.config.max_length,
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
}
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if self.config.pad_token_id is not None:
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
# compute loss on predict data
labels = inputs.pop("labels")
with torch.no_grad():
loss, logits = self._compute_loss(model, inputs)
# compute loss on predict data
loss, logits = self._compute_loss(model, inputs, labels)
loss = loss.mean().detach()
if self.args.prediction_loss_only:
@ -204,14 +208,21 @@ class Seq2SeqTrainer(Trainer):
logits = generated_tokens if self.args.predict_with_generate else logits
labels = inputs["labels"]
if self.config.pad_token_id is not None:
labels = self._pad_tensors_to_max_len(labels, self.config.max_length)
if labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
return (loss, logits, labels)
def _pad_tensors_to_max_len(self, tensor, max_length):
padded_tensor = self.config.pad_token_id * torch.ones(
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id
if pad_token_id is None:
raise ValueError(
f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
)
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[-1]] = tensor

View File

@ -63,7 +63,9 @@ class TestFinetuneTrainer(TestCasePlus):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.max_length = 128
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")