Enable users to use their own loss functions + deal with prefetching for grad accum (#34198)

* bookmark

* Bookmark

* Bookmark

* Actually implement

* Pass in kwarg explicitly

* Adjust for if we do or don't have labels

* Bookmark fix for od

* bookmark

* Fin

* closer

* Negate accelerate grad accum div

* Fixup not training long enough

* Add in compute_loss to take full model output

* Document

* compute_loss -> compute_loss_fn

* Add a test

* Refactor

* Refactor

* Uncomment tests

* Update tests/trainer/test_trainer.py

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
Zach Mueller 2024-10-17 17:01:56 -04:00 committed by GitHub
parent 7a06d07e14
commit 6ba31a8a94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 326 additions and 125 deletions

View File

@ -340,12 +340,16 @@ class Trainer:
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
inner layers, dropout probabilities etc).
compute_loss_func (`Callable`, *optional*):
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, here is one using
the loss function from `transformers`
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
after the last eval batch to signal that the function needs to calculate and return the global summary
statistics rather than accumulating the batch-level statistics.
statistics rather than accumulating the batch-level statistics
callbacks (List of [`TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback).
@ -394,6 +398,7 @@ class Trainer:
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
@ -415,6 +420,7 @@ class Trainer:
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
)
self.args = args
self.compute_loss_func = compute_loss_func
# Seed must be set before instantiating the model when using model
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
@ -2369,16 +2375,16 @@ class Trainer:
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
epoch_iterator.set_epoch(epoch)
epoch_dataloader = train_dataloader
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
len(epoch_dataloader)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
@ -2390,14 +2396,32 @@ class Trainer:
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
for step, inputs in enumerate(epoch_iterator):
epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = num_examples % args.gradient_accumulation_steps
num_items_in_batch = None
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples:
step += 1
total_batched_samples += 1
# Since we perform prefetching, we need to manually set sync_gradients
if total_batched_samples % args.gradient_accumulation_steps != 0:
self.accelerator.gradient_state._set_sync_gradients(False)
else:
self.accelerator.gradient_state._set_sync_gradients(True)
if self.args.include_num_input_tokens_seen:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
@ -2408,17 +2432,9 @@ class Trainer:
"a `main_input_name` attribute to the model class you are using."
)
else:
self.state.num_input_tokens_seen += (
torch.sum(
self.accelerator.gather(
torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
)
.cpu()
.item()
)
input_tokens = inputs[main_input_name].numel()
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item()
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
@ -2439,7 +2455,7 @@ class Trainer:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if (
args.logging_nan_inf_filter
@ -2462,14 +2478,12 @@ class Trainer:
)
if (
total_batched_samples % args.gradient_accumulation_steps == 0
(total_batched_samples) % args.gradient_accumulation_steps == 0
or
# last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc:
# Since we perform prefetching, we need to manually set sync_gradients to True
self.accelerator.gradient_state._set_sync_gradients(True)
# Gradient clipping
@ -2517,15 +2531,19 @@ class Trainer:
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available():
xm.mark_step()
break
# We also need to break out of the nested loop
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available():
xm.mark_step()
break
@ -3514,7 +3532,9 @@ class Trainer:
return ctx_manager
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
@ -3542,7 +3562,7 @@ class Trainer:
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
@ -3575,20 +3595,23 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss *= self.args.gradient_accumulation_steps
self.accelerator.backward(loss, **kwargs)
return loss.detach() / self.args.gradient_accumulation_steps
def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
# if num_items_in_batch is not None:
# inputs["num_items_in_batch"] = num_items_in_batch
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
@ -3601,7 +3624,10 @@ class Trainer:
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
# User-defined compute_loss function
if self.compute_loss_func is not None:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
@ -4993,3 +5019,21 @@ class Trainer:
fsdp_plugin.set_mixed_precision(
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
)
def get_batch_samples(self, epoch_iterator, num_batches):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
batch_samples += [next(epoch_iterator)]
except StopIteration:
break
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
num_items_in_batch = sum(
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
)
except TypeError:
pass
return batch_samples, num_items_in_batch

View File

@ -42,6 +42,7 @@ from transformers import (
AutoImageProcessor,
AutoProcessor,
AutoTokenizer,
DataCollatorForLanguageModeling,
IntervalStrategy,
PretrainedConfig,
TrainerCallback,
@ -49,6 +50,7 @@ from transformers import (
get_polynomial_decay_schedule_with_warmup,
is_torch_available,
logging,
set_seed,
)
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
from transformers.testing_utils import (
@ -153,6 +155,19 @@ if is_accelerate_available():
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
class StoreLossCallback(TrainerCallback):
"""
Simple callback to store the loss.
"""
def __init__(self):
self.losses = []
def on_log(self, args, state, control, logs=None, **kwargs):
if "loss" in logs:
self.losses.append(logs["loss"])
class MockCudaOOMCallback(TrainerCallback):
"""
Simple callback to simulate CUDA OOM error if
@ -168,6 +183,26 @@ class MockCudaOOMCallback(TrainerCallback):
raise RuntimeError("CUDA out of memory.")
def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if num_items_in_batch is None or disable_num_items_in_batch:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="mean")
else:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
loss = loss / num_items_in_batch
return loss
class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
np.random.seed(seed)
@ -438,6 +473,31 @@ if is_torch_available():
loss = nn.functional.mse_loss(y, labels)
return (loss, y)
class BasicTextGenerationModel(nn.Module):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, input_ids, **kwargs):
embedded = self.embedding(input_ids)
lstm_out, _ = self.lstm(embedded)
logits = self.fc(lstm_out)
return logits
def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples):
import datasets
import numpy as np
# Create random input sequences
input_ids = np.random.randint(0, vocab_size, (num_samples, seq_length))
# Create a datasets.Dataset
dataset = datasets.Dataset.from_dict({"input_ids": input_ids, "labels": input_ids})
return dataset
class TstLayer(nn.Module):
def __init__(self, hidden_size):
super().__init__()
@ -676,8 +736,105 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
@slow
def test_gradient_accumulation_loss_alignment(self):
set_seed(42)
import datasets
model_name = "distilgpt2"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_function(examples):
return tokenizer(examples["text"])
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
return ForCausalLMLoss(
logits["logits"], labels, vocab_size, num_items_in_batch, disable_num_items_in_batch
)
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=False)
base_loss_callback = StoreLossCallback()
args_kwargs = {
"report_to": "none",
"logging_steps": 1,
"max_steps": 20,
"learning_rate": 3e-4,
"disable_tqdm": True,
}
args = TrainingArguments(
"./generation",
**args_kwargs,
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[base_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
grad_accum_loss_callback = StoreLossCallback()
args = TrainingArguments(
"./generation",
**args_kwargs,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
)
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[grad_accum_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
broken_loss_callback = StoreLossCallback()
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=True)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[broken_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# These should be quite close
for diff in diff_truth:
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
# These should be very off
for diff in diff_broken:
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same results.
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
trainer = get_regression_trainer(
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
)