mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Better typing for num_items_in_batch (#38728)
* fix * style * type checking ? * maybe this ? * fix * can't be an int anymore * fix
This commit is contained in:
parent
84710a4291
commit
11ad9be153
@ -187,14 +187,17 @@ from torch import nn
|
||||
from transformers import Trainer
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits")
|
||||
# compute custom loss for 3 labels with different weights
|
||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
|
||||
reduction = "mean" if num_items_in_batch is not None else "sum"
|
||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
|
||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
```
|
||||
|
||||
|
@ -28,18 +28,16 @@ from .loss_rt_detr import RTDetrForObjectDetectionLoss
|
||||
def fixed_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
||||
if reduction == "sum":
|
||||
if not isinstance(num_items_in_batch, torch.Tensor):
|
||||
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype)
|
||||
elif num_items_in_batch.device != loss.device:
|
||||
# just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.to(loss.device)
|
||||
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
|
||||
@ -48,7 +46,7 @@ def ForCausalLMLoss(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
ignore_index: int = -100,
|
||||
shift_labels: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
@ -74,7 +72,7 @@ def ForMaskedLMLoss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
vocab_size: int,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -34,7 +34,7 @@ import warnings
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, Union
|
||||
|
||||
|
||||
# Integrations must be imported before ML frameworks:
|
||||
@ -3714,7 +3714,10 @@ class Trainer:
|
||||
return ctx_manager
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
@ -3783,7 +3786,7 @@ class Trainer:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
|
||||
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
||||
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
|
||||
@ -3795,11 +3798,31 @@ class Trainer:
|
||||
|
||||
return loss.detach()
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
return_outputs: bool = False,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
The model to compute the loss for.
|
||||
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The input data for the model.
|
||||
return_outputs (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return the model outputs along with the loss.
|
||||
num_items_in_batch (Optional[torch.Tensor], *optional*):
|
||||
The number of items in the batch. If num_items_in_batch is not passed,
|
||||
|
||||
Returns:
|
||||
The loss of the model along with its output if return_outputs was set to True
|
||||
|
||||
Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
|
||||
make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculationg might be slightly inacurate when performing gradient accumulation.
|
||||
"""
|
||||
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
@ -5257,7 +5280,12 @@ class Trainer:
|
||||
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
|
||||
)
|
||||
|
||||
def get_batch_samples(self, epoch_iterator, num_batches, device):
|
||||
def get_batch_samples(
|
||||
self, epoch_iterator: Iterator, num_batches: int, device: torch.device
|
||||
) -> Tuple[list, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
|
||||
"""
|
||||
batch_samples = []
|
||||
num_items_in_batch = None
|
||||
|
||||
|
@ -853,12 +853,12 @@ class LossKwargs(TypedDict, total=False):
|
||||
Keyword arguments to be passed to the loss function
|
||||
|
||||
Attributes:
|
||||
num_items_in_batch (`int`, *optional*):
|
||||
num_items_in_batch (`Optional[torch.Tensor]`, *optional*):
|
||||
Number of items in the batch. It is recommended to pass it when
|
||||
you are doing gradient accumulation.
|
||||
"""
|
||||
|
||||
num_items_in_batch: Optional[int]
|
||||
num_items_in_batch: Optional["torch.Tensor"]
|
||||
|
||||
|
||||
def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:
|
||||
|
@ -944,7 +944,7 @@ class ModelTesterMixin:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdir, torch_dtype=torch.float32, device_map=torch_device
|
||||
)
|
||||
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
||||
inputs_dict["num_items_in_batch"] = torch.tensor(inputs_dict["input_ids"].shape[0])
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
_ = model(**inputs_dict, return_dict=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user