Enable users to use their own loss functions + deal with prefetching for grad accum (#34198)

* bookmark

* Bookmark

* Bookmark

* Actually implement

* Pass in kwarg explicitly

* Adjust for if we do or don't have labels

* Bookmark fix for od

* bookmark

* Fin

* closer

* Negate accelerate grad accum div

* Fixup not training long enough

* Add in compute_loss to take full model output

* Document

* compute_loss -> compute_loss_fn

* Add a test

* Refactor

* Refactor

* Uncomment tests

* Update tests/trainer/test_trainer.py

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
Zach Mueller 2024-10-17 17:01:56 -04:00 committed by GitHub
parent 7a06d07e14
commit 6ba31a8a94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 326 additions and 125 deletions

View File

@ -340,12 +340,16 @@ class Trainer:
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
be able to choose different architectures according to hyper parameters (such as layer count, sizes of be able to choose different architectures according to hyper parameters (such as layer count, sizes of
inner layers, dropout probabilities etc). inner layers, dropout probabilities etc).
compute_loss_func (`Callable`, *optional*):
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, here is one using
the loss function from `transformers`
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
after the last eval batch to signal that the function needs to calculate and return the global summary after the last eval batch to signal that the function needs to calculate and return the global summary
statistics rather than accumulating the batch-level statistics. statistics rather than accumulating the batch-level statistics
callbacks (List of [`TrainerCallback`], *optional*): callbacks (List of [`TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback). detailed in [here](callback).
@ -394,6 +398,7 @@ class Trainer:
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None, ] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None, callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
@ -415,6 +420,7 @@ class Trainer:
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
) )
self.args = args self.args = args
self.compute_loss_func = compute_loss_func
# Seed must be set before instantiating the model when using model # Seed must be set before instantiating the model when using model
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
@ -2369,16 +2375,16 @@ class Trainer:
total_batched_samples = 0 total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader epoch_dataloader = train_dataloader
if hasattr(epoch_iterator, "set_epoch"): if hasattr(epoch_dataloader, "set_epoch"):
epoch_iterator.set_epoch(epoch) epoch_dataloader.set_epoch(epoch)
# Reset the past mems state at the beginning of each epoch if necessary. # Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0: if args.past_index >= 0:
self._past = None self._past = None
steps_in_epoch = ( steps_in_epoch = (
len(epoch_iterator) len(epoch_dataloader)
if len_dataloader is not None if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps else args.max_steps * args.gradient_accumulation_steps
) )
@ -2390,14 +2396,32 @@ class Trainer:
rng_to_sync = False rng_to_sync = False
steps_skipped = 0 steps_skipped = 0
if steps_trained_in_current_epoch > 0: if steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
rng_to_sync = True rng_to_sync = True
step = -1 step = -1
for step, inputs in enumerate(epoch_iterator): epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = num_examples % args.gradient_accumulation_steps
num_items_in_batch = None
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples:
step += 1
total_batched_samples += 1 total_batched_samples += 1
# Since we perform prefetching, we need to manually set sync_gradients
if total_batched_samples % args.gradient_accumulation_steps != 0:
self.accelerator.gradient_state._set_sync_gradients(False)
else:
self.accelerator.gradient_state._set_sync_gradients(True)
if self.args.include_num_input_tokens_seen: if self.args.include_num_input_tokens_seen:
main_input_name = getattr(self.model, "main_input_name", "input_ids") main_input_name = getattr(self.model, "main_input_name", "input_ids")
@ -2408,17 +2432,9 @@ class Trainer:
"a `main_input_name` attribute to the model class you are using." "a `main_input_name` attribute to the model class you are using."
) )
else: else:
self.state.num_input_tokens_seen += ( input_tokens = inputs[main_input_name].numel()
torch.sum( input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
self.accelerator.gather( self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item()
torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
)
.cpu()
.item()
)
if rng_to_sync: if rng_to_sync:
self._load_rng_state(resume_from_checkpoint) self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False rng_to_sync = False
@ -2439,7 +2455,7 @@ class Trainer:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
with self.accelerator.accumulate(model): with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs) tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if ( if (
args.logging_nan_inf_filter args.logging_nan_inf_filter
@ -2462,14 +2478,12 @@ class Trainer:
) )
if ( if (
total_batched_samples % args.gradient_accumulation_steps == 0 (total_batched_samples) % args.gradient_accumulation_steps == 0
or or
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc is_last_step_and_steps_less_than_grad_acc
): ):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered # Since we perform prefetching, we need to manually set sync_gradients to True
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc:
self.accelerator.gradient_state._set_sync_gradients(True) self.accelerator.gradient_state._set_sync_gradients(True)
# Gradient clipping # Gradient clipping
@ -2517,15 +2531,19 @@ class Trainer:
self.state.global_step += 1 self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control) self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else: else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control) self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
# PyTorch/XLA relies on the data loader to insert the mark_step for # PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually # each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here. # insert the mark_step here.
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available():
xm.mark_step()
break
# We also need to break out of the nested loop
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available(): if is_torch_xla_available():
xm.mark_step() xm.mark_step()
break break
@ -3514,7 +3532,9 @@ class Trainer:
return ctx_manager return ctx_manager
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
) -> torch.Tensor:
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
@ -3542,7 +3562,7 @@ class Trainer:
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs del inputs
if ( if (
@ -3575,20 +3595,23 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
loss *= self.args.gradient_accumulation_steps
self.accelerator.backward(loss, **kwargs) self.accelerator.backward(loss, **kwargs)
return loss.detach() / self.args.gradient_accumulation_steps return loss.detach() / self.args.gradient_accumulation_steps
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior. Subclass and override for custom behavior.
""" """
if self.label_smoother is not None and "labels" in inputs: if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels") labels = inputs.pop("labels")
else: else:
labels = None labels = None
# if num_items_in_batch is not None:
# inputs["num_items_in_batch"] = num_items_in_batch
outputs = model(**inputs) outputs = model(**inputs)
# Save past state if it exists # Save past state if it exists
# TODO: this needs to be fixed and made cleaner later. # TODO: this needs to be fixed and made cleaner later.
@ -3601,7 +3624,10 @@ class Trainer:
model_name = unwrapped_model.base_model.model._get_name() model_name = unwrapped_model.base_model.model._get_name()
else: else:
model_name = unwrapped_model._get_name() model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): # User-defined compute_loss function
if self.compute_loss_func is not None:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True) loss = self.label_smoother(outputs, labels, shift_labels=True)
else: else:
loss = self.label_smoother(outputs, labels) loss = self.label_smoother(outputs, labels)
@ -4993,3 +5019,21 @@ class Trainer:
fsdp_plugin.set_mixed_precision( fsdp_plugin.set_mixed_precision(
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
) )
def get_batch_samples(self, epoch_iterator, num_batches):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
batch_samples += [next(epoch_iterator)]
except StopIteration:
break
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
num_items_in_batch = sum(
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
)
except TypeError:
pass
return batch_samples, num_items_in_batch

View File

@ -42,6 +42,7 @@ from transformers import (
AutoImageProcessor, AutoImageProcessor,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
DataCollatorForLanguageModeling,
IntervalStrategy, IntervalStrategy,
PretrainedConfig, PretrainedConfig,
TrainerCallback, TrainerCallback,
@ -49,6 +50,7 @@ from transformers import (
get_polynomial_decay_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup,
is_torch_available, is_torch_available,
logging, logging,
set_seed,
) )
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
from transformers.testing_utils import ( from transformers.testing_utils import (
@ -153,6 +155,19 @@ if is_accelerate_available():
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
class StoreLossCallback(TrainerCallback):
"""
Simple callback to store the loss.
"""
def __init__(self):
self.losses = []
def on_log(self, args, state, control, logs=None, **kwargs):
if "loss" in logs:
self.losses.append(logs["loss"])
class MockCudaOOMCallback(TrainerCallback): class MockCudaOOMCallback(TrainerCallback):
""" """
Simple callback to simulate CUDA OOM error if Simple callback to simulate CUDA OOM error if
@ -168,6 +183,26 @@ class MockCudaOOMCallback(TrainerCallback):
raise RuntimeError("CUDA out of memory.") raise RuntimeError("CUDA out of memory.")
def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if num_items_in_batch is None or disable_num_items_in_batch:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="mean")
else:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
loss = loss / num_items_in_batch
return loss
class RegressionDataset: class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
np.random.seed(seed) np.random.seed(seed)
@ -438,6 +473,31 @@ if is_torch_available():
loss = nn.functional.mse_loss(y, labels) loss = nn.functional.mse_loss(y, labels)
return (loss, y) return (loss, y)
class BasicTextGenerationModel(nn.Module):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, input_ids, **kwargs):
embedded = self.embedding(input_ids)
lstm_out, _ = self.lstm(embedded)
logits = self.fc(lstm_out)
return logits
def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples):
import datasets
import numpy as np
# Create random input sequences
input_ids = np.random.randint(0, vocab_size, (num_samples, seq_length))
# Create a datasets.Dataset
dataset = datasets.Dataset.from_dict({"input_ids": input_ids, "labels": input_ids})
return dataset
class TstLayer(nn.Module): class TstLayer(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
super().__init__() super().__init__()
@ -676,8 +736,105 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train() trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True) self.check_trained_model(trainer.model, alternate_seed=True)
@slow
def test_gradient_accumulation_loss_alignment(self):
set_seed(42)
import datasets
model_name = "distilgpt2"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_function(examples):
return tokenizer(examples["text"])
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
return ForCausalLMLoss(
logits["logits"], labels, vocab_size, num_items_in_batch, disable_num_items_in_batch
)
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=False)
base_loss_callback = StoreLossCallback()
args_kwargs = {
"report_to": "none",
"logging_steps": 1,
"max_steps": 20,
"learning_rate": 3e-4,
"disable_tqdm": True,
}
args = TrainingArguments(
"./generation",
**args_kwargs,
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[base_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
grad_accum_loss_callback = StoreLossCallback()
args = TrainingArguments(
"./generation",
**args_kwargs,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
)
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[grad_accum_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
broken_loss_callback = StoreLossCallback()
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=True)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[broken_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# These should be quite close
for diff in diff_truth:
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
# These should be very off
for diff in diff_broken:
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
def test_gradient_accumulation(self): def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same results. # Training with half the batch size but accumulation steps as 2 should give the same training losses.
trainer = get_regression_trainer( trainer = get_regression_trainer(
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1 gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
) )