Better typing for num_items_in_batch (#38728)

* fix

* style

* type checking ?

* maybe this ?

* fix

* can't be an int anymore

* fix
This commit is contained in:
Marc Sun 2025-06-11 16:26:41 +02:00 committed by GitHub
parent 84710a4291
commit 11ad9be153
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 47 additions and 18 deletions

View File

@ -187,14 +187,17 @@ from torch import nn
from transformers import Trainer from transformers import Trainer
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
labels = inputs.pop("labels") labels = inputs.pop("labels")
# forward pass # forward pass
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs.get("logits") logits = outputs.get("logits")
# compute custom loss for 3 labels with different weights # compute custom loss for 3 labels with different weights
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) reduction = "mean" if num_items_in_batch is not None else "sum"
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
if num_items_in_batch is not None:
loss = loss / num_items_in_batch
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss
``` ```

View File

@ -28,18 +28,16 @@ from .loss_rt_detr import RTDetrForObjectDetectionLoss
def fixed_cross_entropy( def fixed_cross_entropy(
source: torch.Tensor, source: torch.Tensor,
target: torch.Tensor, target: torch.Tensor,
num_items_in_batch: Optional[int] = None, num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100, ignore_index: int = -100,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
reduction = "sum" if num_items_in_batch is not None else "mean" reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum": if reduction == "sum":
if not isinstance(num_items_in_batch, torch.Tensor): # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype) if torch.is_tensor(num_items_in_batch):
elif num_items_in_batch.device != loss.device:
num_items_in_batch = num_items_in_batch.to(loss.device) num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch loss = loss / num_items_in_batch
return loss return loss
@ -48,7 +46,7 @@ def ForCausalLMLoss(
logits, logits,
labels, labels,
vocab_size: int, vocab_size: int,
num_items_in_batch: Optional[int] = None, num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100, ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None, shift_labels: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
@ -74,7 +72,7 @@ def ForMaskedLMLoss(
logits: torch.Tensor, logits: torch.Tensor,
labels: torch.Tensor, labels: torch.Tensor,
vocab_size: int, vocab_size: int,
num_items_in_batch: Optional[int] = None, num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100, ignore_index: int = -100,
**kwargs, **kwargs,
): ):

View File

@ -34,7 +34,7 @@ import warnings
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, Union
# Integrations must be imported before ML frameworks: # Integrations must be imported before ML frameworks:
@ -3714,7 +3714,10 @@ class Trainer:
return ctx_manager return ctx_manager
def training_step( def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
num_items_in_batch: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
@ -3783,7 +3786,7 @@ class Trainer:
scaled_loss.backward() scaled_loss.backward()
else: else:
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps loss = loss / self.args.gradient_accumulation_steps
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
@ -3795,11 +3798,31 @@ class Trainer:
return loss.detach() return loss.detach()
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
return_outputs: bool = False,
num_items_in_batch: Optional[torch.Tensor] = None,
):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior. Args:
model (`nn.Module`):
The model to compute the loss for.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The input data for the model.
return_outputs (`bool`, *optional*, defaults to `False`):
Whether to return the model outputs along with the loss.
num_items_in_batch (Optional[torch.Tensor], *optional*):
The number of items in the batch. If num_items_in_batch is not passed,
Returns:
The loss of the model along with its output if return_outputs was set to True
Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculationg might be slightly inacurate when performing gradient accumulation.
""" """
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels") labels = inputs.pop("labels")
@ -5257,7 +5280,12 @@ class Trainer:
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
) )
def get_batch_samples(self, epoch_iterator, num_batches, device): def get_batch_samples(
self, epoch_iterator: Iterator, num_batches: int, device: torch.device
) -> Tuple[list, Optional[torch.Tensor]]:
"""
Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
"""
batch_samples = [] batch_samples = []
num_items_in_batch = None num_items_in_batch = None

View File

@ -853,12 +853,12 @@ class LossKwargs(TypedDict, total=False):
Keyword arguments to be passed to the loss function Keyword arguments to be passed to the loss function
Attributes: Attributes:
num_items_in_batch (`int`, *optional*): num_items_in_batch (`Optional[torch.Tensor]`, *optional*):
Number of items in the batch. It is recommended to pass it when Number of items in the batch. It is recommended to pass it when
you are doing gradient accumulation. you are doing gradient accumulation.
""" """
num_items_in_batch: Optional[int] num_items_in_batch: Optional["torch.Tensor"]
def is_timm_config_dict(config_dict: dict[str, Any]) -> bool: def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:

View File

@ -944,7 +944,7 @@ class ModelTesterMixin:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
tmpdir, torch_dtype=torch.float32, device_map=torch_device tmpdir, torch_dtype=torch.float32, device_map=torch_device
) )
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0] inputs_dict["num_items_in_batch"] = torch.tensor(inputs_dict["input_ids"].shape[0])
inputs_dict["labels"] = inputs_dict["input_ids"] inputs_dict["labels"] = inputs_dict["input_ids"]
_ = model(**inputs_dict, return_dict=False) _ = model(**inputs_dict, return_dict=False)