mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[Seq2Seq Trainer] Make sure padding is implemented for models without pad_token (#8043)
* make sure padding is implemented for non-padding tokens models as well * add better error message * add better warning * remove results files * Update examples/seq2seq/seq2seq_trainer.py * remove unnecessary copy line * correct usage of labels * delete test files
This commit is contained in:
parent
098ddc2244
commit
664c7ec453
@ -1,4 +1,3 @@
|
|||||||
import copy
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -60,6 +59,11 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
self.config.pad_token_id is not None
|
self.config.pad_token_id is not None
|
||||||
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
|
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
|
||||||
|
|
||||||
|
if self.config.pad_token_id is None and self.config.eos_token_id is not None:
|
||||||
|
logger.warn(
|
||||||
|
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
|
||||||
|
)
|
||||||
|
|
||||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
"""
|
"""
|
||||||
Setup the optimizer and the learning rate scheduler.
|
Setup the optimizer and the learning rate scheduler.
|
||||||
@ -126,22 +130,19 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
else DistributedSampler(self.train_dataset)
|
else DistributedSampler(self.train_dataset)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_loss(self, model, inputs):
|
def _compute_loss(self, model, inputs, labels):
|
||||||
inputs = copy.deepcopy(inputs)
|
|
||||||
if self.args.label_smoothing == 0:
|
if self.args.label_smoothing == 0:
|
||||||
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
|
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
|
||||||
# force training to ignore pad token
|
# force training to ignore pad token
|
||||||
labels = inputs.pop("labels")
|
|
||||||
logits = model(**inputs, use_cache=False)[0]
|
logits = model(**inputs, use_cache=False)[0]
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||||
else:
|
else:
|
||||||
# compute usual loss via models
|
# compute usual loss via models
|
||||||
loss, logits = model(**inputs, use_cache=False)[:2]
|
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
|
||||||
else:
|
else:
|
||||||
# compute label smoothed loss
|
# compute label smoothed loss
|
||||||
labels = inputs.pop("labels")
|
|
||||||
logits = model(**inputs, use_cache=False)[0]
|
logits = model(**inputs, use_cache=False)[0]
|
||||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
loss, _ = label_smoothed_nll_loss(
|
loss, _ = label_smoothed_nll_loss(
|
||||||
@ -150,7 +151,8 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
return loss, logits
|
return loss, logits
|
||||||
|
|
||||||
def compute_loss(self, model, inputs):
|
def compute_loss(self, model, inputs):
|
||||||
loss, _ = self._compute_loss(model, inputs)
|
labels = inputs.pop("labels")
|
||||||
|
loss, _ = self._compute_loss(model, inputs, labels)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
@ -178,25 +180,27 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
gen_kwargs = {
|
||||||
|
"max_length": self.data_args.val_max_target_length
|
||||||
|
if self.data_args is not None
|
||||||
|
else self.config.max_length,
|
||||||
|
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
|
||||||
|
}
|
||||||
|
|
||||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||||
gen_kwargs = {
|
|
||||||
"max_length": self.data_args.val_max_target_length
|
|
||||||
if self.data_args is not None
|
|
||||||
else self.config.max_length,
|
|
||||||
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
|
|
||||||
}
|
|
||||||
generated_tokens = model.generate(
|
generated_tokens = model.generate(
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
**gen_kwargs,
|
**gen_kwargs,
|
||||||
)
|
)
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
if self.config.pad_token_id is not None:
|
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
||||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||||
|
|
||||||
# compute loss on predict data
|
labels = inputs.pop("labels")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
loss, logits = self._compute_loss(model, inputs)
|
# compute loss on predict data
|
||||||
|
loss, logits = self._compute_loss(model, inputs, labels)
|
||||||
|
|
||||||
loss = loss.mean().detach()
|
loss = loss.mean().detach()
|
||||||
if self.args.prediction_loss_only:
|
if self.args.prediction_loss_only:
|
||||||
@ -204,14 +208,21 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
|
|
||||||
logits = generated_tokens if self.args.predict_with_generate else logits
|
logits = generated_tokens if self.args.predict_with_generate else logits
|
||||||
|
|
||||||
labels = inputs["labels"]
|
if labels.shape[-1] < gen_kwargs["max_length"]:
|
||||||
if self.config.pad_token_id is not None:
|
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
||||||
labels = self._pad_tensors_to_max_len(labels, self.config.max_length)
|
|
||||||
|
|
||||||
return (loss, logits, labels)
|
return (loss, logits, labels)
|
||||||
|
|
||||||
def _pad_tensors_to_max_len(self, tensor, max_length):
|
def _pad_tensors_to_max_len(self, tensor, max_length):
|
||||||
padded_tensor = self.config.pad_token_id * torch.ones(
|
# If PAD token is not defined at least EOS token has to be defined
|
||||||
|
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id
|
||||||
|
|
||||||
|
if pad_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
|
||||||
|
)
|
||||||
|
|
||||||
|
padded_tensor = pad_token_id * torch.ones(
|
||||||
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
||||||
)
|
)
|
||||||
padded_tensor[:, : tensor.shape[-1]] = tensor
|
padded_tensor[:, : tensor.shape[-1]] = tensor
|
||||||
|
@ -63,7 +63,9 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
||||||
|
bert2bert.config.eos_token_id = tokenizer.sep_token_id
|
||||||
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
||||||
|
bert2bert.config.max_length = 128
|
||||||
|
|
||||||
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
|
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
|
||||||
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
|
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
|
||||||
|
Loading…
Reference in New Issue
Block a user