mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Better typing for num_items_in_batch (#38728)
* fix * style * type checking ? * maybe this ? * fix * can't be an int anymore * fix
This commit is contained in:
parent
84710a4291
commit
11ad9be153
@ -187,14 +187,17 @@ from torch import nn
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
def compute_losss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
# compute custom loss for 3 labels with different weights
|
# compute custom loss for 3 labels with different weights
|
||||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
|
reduction = "mean" if num_items_in_batch is not None else "sum"
|
||||||
|
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device, reduction=reduction))
|
||||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
loss = loss / num_items_in_batch
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -28,18 +28,16 @@ from .loss_rt_detr import RTDetrForObjectDetectionLoss
|
|||||||
def fixed_cross_entropy(
|
def fixed_cross_entropy(
|
||||||
source: torch.Tensor,
|
source: torch.Tensor,
|
||||||
target: torch.Tensor,
|
target: torch.Tensor,
|
||||||
num_items_in_batch: Optional[int] = None,
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||||
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
||||||
if reduction == "sum":
|
if reduction == "sum":
|
||||||
if not isinstance(num_items_in_batch, torch.Tensor):
|
# just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
|
||||||
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype)
|
if torch.is_tensor(num_items_in_batch):
|
||||||
elif num_items_in_batch.device != loss.device:
|
|
||||||
num_items_in_batch = num_items_in_batch.to(loss.device)
|
num_items_in_batch = num_items_in_batch.to(loss.device)
|
||||||
|
|
||||||
loss = loss / num_items_in_batch
|
loss = loss / num_items_in_batch
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -48,7 +46,7 @@ def ForCausalLMLoss(
|
|||||||
logits,
|
logits,
|
||||||
labels,
|
labels,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
num_items_in_batch: Optional[int] = None,
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
shift_labels: Optional[torch.Tensor] = None,
|
shift_labels: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -74,7 +72,7 @@ def ForMaskedLMLoss(
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
labels: torch.Tensor,
|
labels: torch.Tensor,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
num_items_in_batch: Optional[int] = None,
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -34,7 +34,7 @@ import warnings
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
# Integrations must be imported before ML frameworks:
|
# Integrations must be imported before ML frameworks:
|
||||||
@ -3714,7 +3714,10 @@ class Trainer:
|
|||||||
return ctx_manager
|
return ctx_manager
|
||||||
|
|
||||||
def training_step(
|
def training_step(
|
||||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||||
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs.
|
Perform a training step on a batch of inputs.
|
||||||
@ -3783,7 +3786,7 @@ class Trainer:
|
|||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
|
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
|
||||||
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
|
||||||
loss = loss / self.args.gradient_accumulation_steps
|
loss = loss / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
|
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
|
||||||
@ -3795,11 +3798,31 @@ class Trainer:
|
|||||||
|
|
||||||
return loss.detach()
|
return loss.detach()
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||||
|
return_outputs: bool = False,
|
||||||
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||||
|
|
||||||
Subclass and override for custom behavior.
|
Args:
|
||||||
|
model (`nn.Module`):
|
||||||
|
The model to compute the loss for.
|
||||||
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The input data for the model.
|
||||||
|
return_outputs (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to return the model outputs along with the loss.
|
||||||
|
num_items_in_batch (Optional[torch.Tensor], *optional*):
|
||||||
|
The number of items in the batch. If num_items_in_batch is not passed,
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loss of the model along with its output if return_outputs was set to True
|
||||||
|
|
||||||
|
Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
|
||||||
|
make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculationg might be slightly inacurate when performing gradient accumulation.
|
||||||
"""
|
"""
|
||||||
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
|
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
@ -5257,7 +5280,12 @@ class Trainer:
|
|||||||
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
|
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_batch_samples(self, epoch_iterator, num_batches, device):
|
def get_batch_samples(
|
||||||
|
self, epoch_iterator: Iterator, num_batches: int, device: torch.device
|
||||||
|
) -> Tuple[list, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
|
||||||
|
"""
|
||||||
batch_samples = []
|
batch_samples = []
|
||||||
num_items_in_batch = None
|
num_items_in_batch = None
|
||||||
|
|
||||||
|
@ -853,12 +853,12 @@ class LossKwargs(TypedDict, total=False):
|
|||||||
Keyword arguments to be passed to the loss function
|
Keyword arguments to be passed to the loss function
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
num_items_in_batch (`int`, *optional*):
|
num_items_in_batch (`Optional[torch.Tensor]`, *optional*):
|
||||||
Number of items in the batch. It is recommended to pass it when
|
Number of items in the batch. It is recommended to pass it when
|
||||||
you are doing gradient accumulation.
|
you are doing gradient accumulation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_items_in_batch: Optional[int]
|
num_items_in_batch: Optional["torch.Tensor"]
|
||||||
|
|
||||||
|
|
||||||
def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:
|
def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:
|
||||||
|
@ -944,7 +944,7 @@ class ModelTesterMixin:
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
tmpdir, torch_dtype=torch.float32, device_map=torch_device
|
tmpdir, torch_dtype=torch.float32, device_map=torch_device
|
||||||
)
|
)
|
||||||
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
inputs_dict["num_items_in_batch"] = torch.tensor(inputs_dict["input_ids"].shape[0])
|
||||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||||
_ = model(**inputs_dict, return_dict=False)
|
_ = model(**inputs_dict, return_dict=False)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user