mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Seq2SeqTrainer] Move import to init to make file self-contained (#8194)
* boom boom * reverse order
This commit is contained in:
parent
1f12934df4
commit
9bd30f7cf4
@ -20,12 +20,6 @@ from transformers.optimization import (
|
||||
from transformers.trainer_pt_utils import get_tpu_sampler
|
||||
|
||||
|
||||
try:
|
||||
from .utils import label_smoothed_nll_loss
|
||||
except ImportError:
|
||||
from utils import label_smoothed_nll_loss
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
arg_to_scheduler = {
|
||||
@ -64,6 +58,17 @@ class Seq2SeqTrainer(Trainer):
|
||||
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
|
||||
)
|
||||
|
||||
if self.args.label_smoothing == 0:
|
||||
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||
else:
|
||||
# dynamically import label_smoothed_nll_loss
|
||||
try:
|
||||
from .utils import label_smoothed_nll_loss
|
||||
except ImportError:
|
||||
from utils import label_smoothed_nll_loss
|
||||
|
||||
self.loss_fn = label_smoothed_nll_loss
|
||||
|
||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
@ -135,9 +140,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
|
||||
# force training to ignore pad token
|
||||
logits = model(**inputs, use_cache=False)[0]
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||
loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||
else:
|
||||
# compute usual loss via models
|
||||
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
|
||||
@ -145,9 +148,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
# compute label smoothed loss
|
||||
logits = model(**inputs, use_cache=False)[0]
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
loss, _ = label_smoothed_nll_loss(
|
||||
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
|
||||
)
|
||||
loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id)
|
||||
return loss, logits
|
||||
|
||||
def compute_loss(self, model, inputs):
|
||||
|
Loading…
Reference in New Issue
Block a user