[Seq2SeqTrainer] Move import to init to make file self-contained (#8194)

* boom boom

* reverse order
This commit is contained in:
Patrick von Platen 2020-11-01 23:31:55 +01:00 committed by GitHub
parent 1f12934df4
commit 9bd30f7cf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,12 +20,6 @@ from transformers.optimization import (
from transformers.trainer_pt_utils import get_tpu_sampler
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
logger = logging.get_logger(__name__)
arg_to_scheduler = {
@ -64,6 +58,17 @@ class Seq2SeqTrainer(Trainer):
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
)
if self.args.label_smoothing == 0:
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
else:
# dynamically import label_smoothed_nll_loss
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
self.loss_fn = label_smoothed_nll_loss
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
@ -135,9 +140,7 @@ class Seq2SeqTrainer(Trainer):
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token
logits = model(**inputs, use_cache=False)[0]
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
# compute usual loss via models
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
@ -145,9 +148,7 @@ class Seq2SeqTrainer(Trainer):
# compute label smoothed loss
logits = model(**inputs, use_cache=False)[0]
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
)
loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id)
return loss, logits
def compute_loss(self, model, inputs):