Update trainer for easier handling of accumulate, compile fixes, and proper reporting (#34511)

* Update trainer for easier handling of accumulate + proper reporting

* test

* Fixup tests

* Full fix

* Fix style

* rm comment

* Fix tests

* Minimize test + remove py 311 check

* Unused import

* Forward contrib credits from discussions

* Fix reported metrics

* Refactor, good as it's going to get

* rm pad tok id check

* object detection and audio are being annoying

* Fin

* Fin x2

---------

Co-authored-by: Gyanateet Dutta <Ryukijano@users.noreply.github.com>
This commit is contained in:
Zach Mueller 2024-11-04 07:47:34 -05:00 committed by GitHub
parent 33868a057c
commit ef976a7e18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 48 deletions

View File

@ -28,7 +28,7 @@ import tempfile
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile
@ -5014,7 +5014,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return self.hf_quantizer.is_trainable
@property
@lru_cache
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type

View File

@ -233,7 +233,6 @@ if is_accelerate_available():
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
@ -601,8 +600,10 @@ class Trainer:
if not _is_peft_model(unwrapped_model)
else unwrapped_model.get_base_model().forward
)
self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters
forward_params = inspect.signature(model_forward).parameters
self.model_accepts_loss_kwargs = (
"loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD
)
self.neftune_noise_alpha = args.neftune_noise_alpha
@ -2444,7 +2445,7 @@ class Trainer:
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples:
for i, inputs in enumerate(batch_samples):
step += 1
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
# Since we perform prefetching, we need to manually set sync_gradients
@ -2484,7 +2485,13 @@ class Trainer:
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
with self.accelerator.accumulate(model):
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if (
@ -3636,15 +3643,11 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs)
return loss.detach() / self.args.gradient_accumulation_steps
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
return loss.detach() / self.args.gradient_accumulation_steps
return loss.detach()
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
@ -3656,9 +3659,6 @@ class Trainer:
labels = inputs.pop("labels")
else:
labels = None
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
@ -3692,6 +3692,9 @@ class Trainer:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss
def is_local_process_zero(self) -> bool:
@ -4946,24 +4949,21 @@ class Trainer:
self.repo.git_push()
def create_accelerator_and_postprocess(self):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
else:
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
accelerator_config = self.args.accelerator_config.to_dict()
@ -4994,7 +4994,6 @@ class Trainer:
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
@ -5090,12 +5089,18 @@ class Trainer:
batch_samples += [next(epoch_iterator)]
except StopIteration:
break
# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
num_items_in_batch = sum(
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
)
except TypeError:
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
except (TypeError, AttributeError):
pass
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
return batch_samples, num_items_in_batch

View File

@ -272,6 +272,19 @@ class RepeatDataset:
return {"input_ids": self.x, "labels": self.x}
class SequenceClassificationDataset:
def __init__(self, length=64, vocab_size=100, num_labels=5):
self.length = length
self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)]
self.labels = torch.randint(0, num_labels, (length,)).tolist()
def __len__(self):
return self.length
def __getitem__(self, i):
return {"input_ids": self.sequences[i], "label": self.labels[i]}
class DynamicShapesDataset:
def __init__(self, length=64, seed=42, batch_size=8):
self.length = length
@ -1144,6 +1157,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
def test_torch_compile_loss_func_compatibility(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
per_device_train_batch_size=2,
torch_compile=True,
max_steps=1, # compile happens on the first step
)
trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa
trainer.train()
@require_peft
@require_bitsandbytes
def test_bnb_compile(self):
@ -3676,9 +3706,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
@ -3691,8 +3718,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
"even_batches": False,
"use_seedable_sampler": False,
}
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
@ -3706,9 +3731,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
@ -3754,10 +3776,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
self.assertEqual(trainer.args.gradient_accumulation_steps, 10)
def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through