mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Enable users to use their own loss functions + deal with prefetching for grad accum (#34198)
* bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
parent
7a06d07e14
commit
6ba31a8a94
@ -340,12 +340,16 @@ class Trainer:
|
|||||||
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
|
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
|
||||||
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
|
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
|
||||||
inner layers, dropout probabilities etc).
|
inner layers, dropout probabilities etc).
|
||||||
|
compute_loss_func (`Callable`, *optional*):
|
||||||
|
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
|
||||||
|
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, here is one using
|
||||||
|
the loss function from `transformers`
|
||||||
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
||||||
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
|
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
|
||||||
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
|
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
|
||||||
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
|
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
|
||||||
after the last eval batch to signal that the function needs to calculate and return the global summary
|
after the last eval batch to signal that the function needs to calculate and return the global summary
|
||||||
statistics rather than accumulating the batch-level statistics.
|
statistics rather than accumulating the batch-level statistics
|
||||||
callbacks (List of [`TrainerCallback`], *optional*):
|
callbacks (List of [`TrainerCallback`], *optional*):
|
||||||
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
||||||
detailed in [here](callback).
|
detailed in [here](callback).
|
||||||
@ -394,6 +398,7 @@ class Trainer:
|
|||||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||||
] = None,
|
] = None,
|
||||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||||
|
compute_loss_func: Optional[Callable] = None,
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
callbacks: Optional[List[TrainerCallback]] = None,
|
callbacks: Optional[List[TrainerCallback]] = None,
|
||||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||||
@ -415,6 +420,7 @@ class Trainer:
|
|||||||
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
|
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
|
||||||
)
|
)
|
||||||
self.args = args
|
self.args = args
|
||||||
|
self.compute_loss_func = compute_loss_func
|
||||||
# Seed must be set before instantiating the model when using model
|
# Seed must be set before instantiating the model when using model
|
||||||
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||||
|
|
||||||
@ -2369,16 +2375,16 @@ class Trainer:
|
|||||||
|
|
||||||
total_batched_samples = 0
|
total_batched_samples = 0
|
||||||
for epoch in range(epochs_trained, num_train_epochs):
|
for epoch in range(epochs_trained, num_train_epochs):
|
||||||
epoch_iterator = train_dataloader
|
epoch_dataloader = train_dataloader
|
||||||
if hasattr(epoch_iterator, "set_epoch"):
|
if hasattr(epoch_dataloader, "set_epoch"):
|
||||||
epoch_iterator.set_epoch(epoch)
|
epoch_dataloader.set_epoch(epoch)
|
||||||
|
|
||||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||||
if args.past_index >= 0:
|
if args.past_index >= 0:
|
||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
steps_in_epoch = (
|
steps_in_epoch = (
|
||||||
len(epoch_iterator)
|
len(epoch_dataloader)
|
||||||
if len_dataloader is not None
|
if len_dataloader is not None
|
||||||
else args.max_steps * args.gradient_accumulation_steps
|
else args.max_steps * args.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
@ -2390,14 +2396,32 @@ class Trainer:
|
|||||||
rng_to_sync = False
|
rng_to_sync = False
|
||||||
steps_skipped = 0
|
steps_skipped = 0
|
||||||
if steps_trained_in_current_epoch > 0:
|
if steps_trained_in_current_epoch > 0:
|
||||||
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
|
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
|
||||||
steps_skipped = steps_trained_in_current_epoch
|
steps_skipped = steps_trained_in_current_epoch
|
||||||
steps_trained_in_current_epoch = 0
|
steps_trained_in_current_epoch = 0
|
||||||
rng_to_sync = True
|
rng_to_sync = True
|
||||||
|
|
||||||
step = -1
|
step = -1
|
||||||
for step, inputs in enumerate(epoch_iterator):
|
epoch_iterator = iter(epoch_dataloader)
|
||||||
|
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
||||||
|
remainder = num_examples % args.gradient_accumulation_steps
|
||||||
|
num_items_in_batch = None
|
||||||
|
if remainder == 0:
|
||||||
|
remainder = args.gradient_accumulation_steps
|
||||||
|
update_step = -1
|
||||||
|
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
|
||||||
|
for _ in range(total_updates):
|
||||||
|
update_step += 1
|
||||||
|
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
||||||
|
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
|
||||||
|
for inputs in batch_samples:
|
||||||
|
step += 1
|
||||||
total_batched_samples += 1
|
total_batched_samples += 1
|
||||||
|
# Since we perform prefetching, we need to manually set sync_gradients
|
||||||
|
if total_batched_samples % args.gradient_accumulation_steps != 0:
|
||||||
|
self.accelerator.gradient_state._set_sync_gradients(False)
|
||||||
|
else:
|
||||||
|
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||||
|
|
||||||
if self.args.include_num_input_tokens_seen:
|
if self.args.include_num_input_tokens_seen:
|
||||||
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
||||||
@ -2408,17 +2432,9 @@ class Trainer:
|
|||||||
"a `main_input_name` attribute to the model class you are using."
|
"a `main_input_name` attribute to the model class you are using."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.state.num_input_tokens_seen += (
|
input_tokens = inputs[main_input_name].numel()
|
||||||
torch.sum(
|
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
|
||||||
self.accelerator.gather(
|
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item()
|
||||||
torch.tensor(
|
|
||||||
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.cpu()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
if rng_to_sync:
|
if rng_to_sync:
|
||||||
self._load_rng_state(resume_from_checkpoint)
|
self._load_rng_state(resume_from_checkpoint)
|
||||||
rng_to_sync = False
|
rng_to_sync = False
|
||||||
@ -2439,7 +2455,7 @@ class Trainer:
|
|||||||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
||||||
|
|
||||||
with self.accelerator.accumulate(model):
|
with self.accelerator.accumulate(model):
|
||||||
tr_loss_step = self.training_step(model, inputs)
|
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
args.logging_nan_inf_filter
|
args.logging_nan_inf_filter
|
||||||
@ -2462,14 +2478,12 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
total_batched_samples % args.gradient_accumulation_steps == 0
|
(total_batched_samples) % args.gradient_accumulation_steps == 0
|
||||||
or
|
or
|
||||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||||
is_last_step_and_steps_less_than_grad_acc
|
is_last_step_and_steps_less_than_grad_acc
|
||||||
):
|
):
|
||||||
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
|
# Since we perform prefetching, we need to manually set sync_gradients to True
|
||||||
# in accelerate. So, explicitly enable sync gradients to True in that case.
|
|
||||||
if is_last_step_and_steps_less_than_grad_acc:
|
|
||||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||||
|
|
||||||
# Gradient clipping
|
# Gradient clipping
|
||||||
@ -2517,15 +2531,19 @@ class Trainer:
|
|||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
||||||
|
|
||||||
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
|
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
|
||||||
else:
|
else:
|
||||||
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
||||||
|
|
||||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
|
||||||
# PyTorch/XLA relies on the data loader to insert the mark_step for
|
# PyTorch/XLA relies on the data loader to insert the mark_step for
|
||||||
# each step. Since we are breaking the loop early, we need to manually
|
# each step. Since we are breaking the loop early, we need to manually
|
||||||
# insert the mark_step here.
|
# insert the mark_step here.
|
||||||
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||||
|
if is_torch_xla_available():
|
||||||
|
xm.mark_step()
|
||||||
|
break
|
||||||
|
# We also need to break out of the nested loop
|
||||||
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
break
|
break
|
||||||
@ -3514,7 +3532,9 @@ class Trainer:
|
|||||||
|
|
||||||
return ctx_manager
|
return ctx_manager
|
||||||
|
|
||||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
def training_step(
|
||||||
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs.
|
Perform a training step on a batch of inputs.
|
||||||
|
|
||||||
@ -3542,7 +3562,7 @@ class Trainer:
|
|||||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||||
|
|
||||||
with self.compute_loss_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
loss = self.compute_loss(model, inputs)
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
|
||||||
del inputs
|
del inputs
|
||||||
if (
|
if (
|
||||||
@ -3575,20 +3595,23 @@ class Trainer:
|
|||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
|
loss *= self.args.gradient_accumulation_steps
|
||||||
self.accelerator.backward(loss, **kwargs)
|
self.accelerator.backward(loss, **kwargs)
|
||||||
|
|
||||||
return loss.detach() / self.args.gradient_accumulation_steps
|
return loss.detach() / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||||
"""
|
"""
|
||||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||||
|
|
||||||
Subclass and override for custom behavior.
|
Subclass and override for custom behavior.
|
||||||
"""
|
"""
|
||||||
if self.label_smoother is not None and "labels" in inputs:
|
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
else:
|
else:
|
||||||
labels = None
|
labels = None
|
||||||
|
# if num_items_in_batch is not None:
|
||||||
|
# inputs["num_items_in_batch"] = num_items_in_batch
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
# Save past state if it exists
|
# Save past state if it exists
|
||||||
# TODO: this needs to be fixed and made cleaner later.
|
# TODO: this needs to be fixed and made cleaner later.
|
||||||
@ -3601,7 +3624,10 @@ class Trainer:
|
|||||||
model_name = unwrapped_model.base_model.model._get_name()
|
model_name = unwrapped_model.base_model.model._get_name()
|
||||||
else:
|
else:
|
||||||
model_name = unwrapped_model._get_name()
|
model_name = unwrapped_model._get_name()
|
||||||
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
# User-defined compute_loss function
|
||||||
|
if self.compute_loss_func is not None:
|
||||||
|
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
|
||||||
|
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||||
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||||
else:
|
else:
|
||||||
loss = self.label_smoother(outputs, labels)
|
loss = self.label_smoother(outputs, labels)
|
||||||
@ -4993,3 +5019,21 @@ class Trainer:
|
|||||||
fsdp_plugin.set_mixed_precision(
|
fsdp_plugin.set_mixed_precision(
|
||||||
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
|
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||||
|
batch_samples = []
|
||||||
|
num_items_in_batch = None
|
||||||
|
for _ in range(num_batches):
|
||||||
|
try:
|
||||||
|
batch_samples += [next(epoch_iterator)]
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
||||||
|
# For now we don't support object detection
|
||||||
|
try:
|
||||||
|
num_items_in_batch = sum(
|
||||||
|
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
return batch_samples, num_items_in_batch
|
||||||
|
@ -42,6 +42,7 @@ from transformers import (
|
|||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
DataCollatorForLanguageModeling,
|
||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
@ -49,6 +50,7 @@ from transformers import (
|
|||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
|
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@ -153,6 +155,19 @@ if is_accelerate_available():
|
|||||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||||
|
|
||||||
|
|
||||||
|
class StoreLossCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Simple callback to store the loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.losses = []
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
if "loss" in logs:
|
||||||
|
self.losses.append(logs["loss"])
|
||||||
|
|
||||||
|
|
||||||
class MockCudaOOMCallback(TrainerCallback):
|
class MockCudaOOMCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
Simple callback to simulate CUDA OOM error if
|
Simple callback to simulate CUDA OOM error if
|
||||||
@ -168,6 +183,26 @@ class MockCudaOOMCallback(TrainerCallback):
|
|||||||
raise RuntimeError("CUDA out of memory.")
|
raise RuntimeError("CUDA out of memory.")
|
||||||
|
|
||||||
|
|
||||||
|
def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
# Flatten the tokens
|
||||||
|
shift_logits = shift_logits.view(-1, vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
if num_items_in_batch is None or disable_num_items_in_batch:
|
||||||
|
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="mean")
|
||||||
|
else:
|
||||||
|
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
|
||||||
|
loss = loss / num_items_in_batch
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class RegressionDataset:
|
class RegressionDataset:
|
||||||
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@ -438,6 +473,31 @@ if is_torch_available():
|
|||||||
loss = nn.functional.mse_loss(y, labels)
|
loss = nn.functional.mse_loss(y, labels)
|
||||||
return (loss, y)
|
return (loss, y)
|
||||||
|
|
||||||
|
class BasicTextGenerationModel(nn.Module):
|
||||||
|
def __init__(self, vocab_size, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||||||
|
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
|
||||||
|
self.fc = nn.Linear(hidden_size, vocab_size)
|
||||||
|
|
||||||
|
def forward(self, input_ids, **kwargs):
|
||||||
|
embedded = self.embedding(input_ids)
|
||||||
|
lstm_out, _ = self.lstm(embedded)
|
||||||
|
logits = self.fc(lstm_out)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples):
|
||||||
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Create random input sequences
|
||||||
|
input_ids = np.random.randint(0, vocab_size, (num_samples, seq_length))
|
||||||
|
|
||||||
|
# Create a datasets.Dataset
|
||||||
|
dataset = datasets.Dataset.from_dict({"input_ids": input_ids, "labels": input_ids})
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
class TstLayer(nn.Module):
|
class TstLayer(nn.Module):
|
||||||
def __init__(self, hidden_size):
|
def __init__(self, hidden_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -676,8 +736,105 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_gradient_accumulation_loss_alignment(self):
|
||||||
|
set_seed(42)
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
model_name = "distilgpt2"
|
||||||
|
dataset_name = "wikitext"
|
||||||
|
dataset_config = "wikitext-2-raw-v1"
|
||||||
|
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||||
|
dataset = dataset.train_test_split(test_size=0.2)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(examples["text"])
|
||||||
|
|
||||||
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||||
|
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
|
||||||
|
return ForCausalLMLoss(
|
||||||
|
logits["logits"], labels, vocab_size, num_items_in_batch, disable_num_items_in_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=False)
|
||||||
|
|
||||||
|
base_loss_callback = StoreLossCallback()
|
||||||
|
|
||||||
|
args_kwargs = {
|
||||||
|
"report_to": "none",
|
||||||
|
"logging_steps": 1,
|
||||||
|
"max_steps": 20,
|
||||||
|
"learning_rate": 3e-4,
|
||||||
|
"disable_tqdm": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
"./generation",
|
||||||
|
**args_kwargs,
|
||||||
|
)
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
train_dataset=tokenized_dataset["train"],
|
||||||
|
callbacks=[base_loss_callback],
|
||||||
|
compute_loss_func=loss_fn,
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
grad_accum_loss_callback = StoreLossCallback()
|
||||||
|
args = TrainingArguments(
|
||||||
|
"./generation",
|
||||||
|
**args_kwargs,
|
||||||
|
gradient_accumulation_steps=2,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
)
|
||||||
|
set_seed(42)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
train_dataset=tokenized_dataset["train"],
|
||||||
|
callbacks=[grad_accum_loss_callback],
|
||||||
|
compute_loss_func=loss_fn,
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
set_seed(42)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
broken_loss_callback = StoreLossCallback()
|
||||||
|
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=True)
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
train_dataset=tokenized_dataset["train"],
|
||||||
|
callbacks=[broken_loss_callback],
|
||||||
|
compute_loss_func=loss_fn,
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Calculate the difference between the base loss and the grad_accum loss
|
||||||
|
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
|
||||||
|
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||||
|
# These should be quite close
|
||||||
|
for diff in diff_truth:
|
||||||
|
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
|
||||||
|
|
||||||
|
# These should be very off
|
||||||
|
for diff in diff_broken:
|
||||||
|
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
|
||||||
|
|
||||||
def test_gradient_accumulation(self):
|
def test_gradient_accumulation(self):
|
||||||
# Training with half the batch size but accumulation steps as 2 should give the same results.
|
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
|
||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(
|
||||||
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
|
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user