diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index b83d9ed60bc..7174344487f 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -187,14 +187,17 @@ from torch import nn from transformers import Trainer class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None): labels = inputs.pop("labels") # forward pass outputs = model(**inputs) logits = outputs.get("logits") # compute custom loss for 3 labels with different weights - loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) + reduction = "mean" if num_items_in_batch is not None else "sum" + loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction)) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) + if num_items_in_batch is not None: + loss = loss / num_items_in_batch return (loss, outputs) if return_outputs else loss ``` diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index cf6d9078a94..764d28d6f34 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -28,18 +28,16 @@ from .loss_rt_detr import RTDetrForObjectDetectionLoss def fixed_cross_entropy( source: torch.Tensor, target: torch.Tensor, - num_items_in_batch: Optional[int] = None, + num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) if reduction == "sum": - if not isinstance(num_items_in_batch, torch.Tensor): - num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype) - elif num_items_in_batch.device != loss.device: + # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer + if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.to(loss.device) - loss = loss / num_items_in_batch return loss @@ -48,7 +46,7 @@ def ForCausalLMLoss( logits, labels, vocab_size: int, - num_items_in_batch: Optional[int] = None, + num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, shift_labels: Optional[torch.Tensor] = None, **kwargs, @@ -74,7 +72,7 @@ def ForMaskedLMLoss( logits: torch.Tensor, labels: torch.Tensor, vocab_size: int, - num_items_in_batch: Optional[int] = None, + num_items_in_batch: Optional[torch.Tensor] = None, ignore_index: int = -100, **kwargs, ): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 533f5f5ae6e..7ec12010a88 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -34,7 +34,7 @@ import warnings from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, Union # Integrations must be imported before ML frameworks: @@ -3714,7 +3714,10 @@ class Trainer: return ctx_manager def training_step( - self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + num_items_in_batch: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Perform a training step on a batch of inputs. @@ -3783,7 +3786,7 @@ class Trainer: scaled_loss.backward() else: # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss - if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: + if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None: loss = loss / self.args.gradient_accumulation_steps # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled @@ -3795,11 +3798,31 @@ class Trainer: return loss.detach() - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs: bool = False, + num_items_in_batch: Optional[torch.Tensor] = None, + ): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. - Subclass and override for custom behavior. + Args: + model (`nn.Module`): + The model to compute the loss for. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The input data for the model. + return_outputs (`bool`, *optional*, defaults to `False`): + Whether to return the model outputs along with the loss. + num_items_in_batch (Optional[torch.Tensor], *optional*): + The number of items in the batch. If num_items_in_batch is not passed, + + Returns: + The loss of the model along with its output if return_outputs was set to True + + Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss, + make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculationg might be slightly inacurate when performing gradient accumulation. """ if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: labels = inputs.pop("labels") @@ -5257,7 +5280,12 @@ class Trainer: self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True ) - def get_batch_samples(self, epoch_iterator, num_batches, device): + def get_batch_samples( + self, epoch_iterator: Iterator, num_batches: int, device: torch.device + ) -> Tuple[list, Optional[torch.Tensor]]: + """ + Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss. + """ batch_samples = [] num_items_in_batch = None diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 060140f31eb..067af41c2e8 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -853,12 +853,12 @@ class LossKwargs(TypedDict, total=False): Keyword arguments to be passed to the loss function Attributes: - num_items_in_batch (`int`, *optional*): + num_items_in_batch (`Optional[torch.Tensor]`, *optional*): Number of items in the batch. It is recommended to pass it when you are doing gradient accumulation. """ - num_items_in_batch: Optional[int] + num_items_in_batch: Optional["torch.Tensor"] def is_timm_config_dict(config_dict: dict[str, Any]) -> bool: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4aab7a69ebd..f185c30daec 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -944,7 +944,7 @@ class ModelTesterMixin: model = AutoModelForCausalLM.from_pretrained( tmpdir, torch_dtype=torch.float32, device_map=torch_device ) - inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0] + inputs_dict["num_items_in_batch"] = torch.tensor(inputs_dict["input_ids"].shape[0]) inputs_dict["labels"] = inputs_dict["input_ids"] _ = model(**inputs_dict, return_dict=False)