Better typing for num_items_in_batch (#38728)

* fix

* style

* type checking ?

* maybe this ?

* fix

* can't be an int anymore

* fix
This commit is contained in:
Marc Sun 2025-06-11 16:26:41 +02:00 committed by GitHub
parent 84710a4291
commit 11ad9be153
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 47 additions and 18 deletions

View File

@ -187,14 +187,17 @@ from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
labels = inputs.pop("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss for 3 labels with different weights
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
reduction = "mean" if num_items_in_batch is not None else "sum"
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
if num_items_in_batch is not None:
loss = loss / num_items_in_batch
return (loss, outputs) if return_outputs else loss
```

View File

@ -28,18 +28,16 @@ from .loss_rt_detr import RTDetrForObjectDetectionLoss
def fixed_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[int] = None,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
) -> torch.Tensor:
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
if not isinstance(num_items_in_batch, torch.Tensor):
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype)
elif num_items_in_batch.device != loss.device:
# just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch
return loss
@ -48,7 +46,7 @@ def ForCausalLMLoss(
logits,
labels,
vocab_size: int,
num_items_in_batch: Optional[int] = None,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
@ -74,7 +72,7 @@ def ForMaskedLMLoss(
logits: torch.Tensor,
labels: torch.Tensor,
vocab_size: int,
num_items_in_batch: Optional[int] = None,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
):

View File

@ -34,7 +34,7 @@ import warnings
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, Union
# Integrations must be imported before ML frameworks:
@ -3714,7 +3714,10 @@ class Trainer:
return ctx_manager
def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
num_items_in_batch: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
@ -3783,7 +3786,7 @@ class Trainer:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
@ -3795,11 +3798,31 @@ class Trainer:
return loss.detach()
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
return_outputs: bool = False,
num_items_in_batch: Optional[torch.Tensor] = None,
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
Args:
model (`nn.Module`):
The model to compute the loss for.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The input data for the model.
return_outputs (`bool`, *optional*, defaults to `False`):
Whether to return the model outputs along with the loss.
num_items_in_batch (Optional[torch.Tensor], *optional*):
The number of items in the batch. If num_items_in_batch is not passed,
Returns:
The loss of the model along with its output if return_outputs was set to True
Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculationg might be slightly inacurate when performing gradient accumulation.
"""
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels")
@ -5257,7 +5280,12 @@ class Trainer:
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
)
def get_batch_samples(self, epoch_iterator, num_batches, device):
def get_batch_samples(
self, epoch_iterator: Iterator, num_batches: int, device: torch.device
) -> Tuple[list, Optional[torch.Tensor]]:
"""
Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
"""
batch_samples = []
num_items_in_batch = None

View File

@ -853,12 +853,12 @@ class LossKwargs(TypedDict, total=False):
Keyword arguments to be passed to the loss function
Attributes:
num_items_in_batch (`int`, *optional*):
num_items_in_batch (`Optional[torch.Tensor]`, *optional*):
Number of items in the batch. It is recommended to pass it when
you are doing gradient accumulation.
"""
num_items_in_batch: Optional[int]
num_items_in_batch: Optional["torch.Tensor"]
def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:

View File

@ -944,7 +944,7 @@ class ModelTesterMixin:
model = AutoModelForCausalLM.from_pretrained(
tmpdir, torch_dtype=torch.float32, device_map=torch_device
)
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
inputs_dict["num_items_in_batch"] = torch.tensor(inputs_dict["input_ids"].shape[0])
inputs_dict["labels"] = inputs_dict["input_ids"]
_ = model(**inputs_dict, return_dict=False)