mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enable users to use their own loss functions + deal with prefetching for grad accum (#34198)
* bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
parent
7a06d07e14
commit
6ba31a8a94
@ -340,12 +340,16 @@ class Trainer:
|
||||
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
|
||||
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
|
||||
inner layers, dropout probabilities etc).
|
||||
compute_loss_func (`Callable`, *optional*):
|
||||
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
|
||||
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, here is one using
|
||||
the loss function from `transformers`
|
||||
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
||||
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
|
||||
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
|
||||
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
|
||||
after the last eval batch to signal that the function needs to calculate and return the global summary
|
||||
statistics rather than accumulating the batch-level statistics.
|
||||
statistics rather than accumulating the batch-level statistics
|
||||
callbacks (List of [`TrainerCallback`], *optional*):
|
||||
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
||||
detailed in [here](callback).
|
||||
@ -394,6 +398,7 @@ class Trainer:
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_loss_func: Optional[Callable] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
@ -415,6 +420,7 @@ class Trainer:
|
||||
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
|
||||
)
|
||||
self.args = args
|
||||
self.compute_loss_func = compute_loss_func
|
||||
# Seed must be set before instantiating the model when using model
|
||||
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||
|
||||
@ -2369,16 +2375,16 @@ class Trainer:
|
||||
|
||||
total_batched_samples = 0
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
epoch_iterator = train_dataloader
|
||||
if hasattr(epoch_iterator, "set_epoch"):
|
||||
epoch_iterator.set_epoch(epoch)
|
||||
epoch_dataloader = train_dataloader
|
||||
if hasattr(epoch_dataloader, "set_epoch"):
|
||||
epoch_dataloader.set_epoch(epoch)
|
||||
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
steps_in_epoch = (
|
||||
len(epoch_iterator)
|
||||
len(epoch_dataloader)
|
||||
if len_dataloader is not None
|
||||
else args.max_steps * args.gradient_accumulation_steps
|
||||
)
|
||||
@ -2390,142 +2396,154 @@ class Trainer:
|
||||
rng_to_sync = False
|
||||
steps_skipped = 0
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
|
||||
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
|
||||
steps_skipped = steps_trained_in_current_epoch
|
||||
steps_trained_in_current_epoch = 0
|
||||
rng_to_sync = True
|
||||
|
||||
step = -1
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
total_batched_samples += 1
|
||||
|
||||
if self.args.include_num_input_tokens_seen:
|
||||
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
||||
if main_input_name not in inputs:
|
||||
logger.warning(
|
||||
"Tried to track the number of tokens seen, however the current model is "
|
||||
"not configured properly to know what item is the input. To fix this, add "
|
||||
"a `main_input_name` attribute to the model class you are using."
|
||||
)
|
||||
epoch_iterator = iter(epoch_dataloader)
|
||||
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
||||
remainder = num_examples % args.gradient_accumulation_steps
|
||||
num_items_in_batch = None
|
||||
if remainder == 0:
|
||||
remainder = args.gradient_accumulation_steps
|
||||
update_step = -1
|
||||
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
|
||||
for _ in range(total_updates):
|
||||
update_step += 1
|
||||
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
||||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
|
||||
for inputs in batch_samples:
|
||||
step += 1
|
||||
total_batched_samples += 1
|
||||
# Since we perform prefetching, we need to manually set sync_gradients
|
||||
if total_batched_samples % args.gradient_accumulation_steps != 0:
|
||||
self.accelerator.gradient_state._set_sync_gradients(False)
|
||||
else:
|
||||
self.state.num_input_tokens_seen += (
|
||||
torch.sum(
|
||||
self.accelerator.gather(
|
||||
torch.tensor(
|
||||
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
|
||||
)
|
||||
)
|
||||
)
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
if rng_to_sync:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
rng_to_sync = False
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
if steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.update(1)
|
||||
if steps_trained_in_current_epoch == 0:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
continue
|
||||
elif steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.close()
|
||||
steps_trained_progress_bar = None
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
||||
|
||||
with self.accelerator.accumulate(model):
|
||||
tr_loss_step = self.training_step(model, inputs)
|
||||
|
||||
if (
|
||||
args.logging_nan_inf_filter
|
||||
and not is_torch_xla_available()
|
||||
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
|
||||
):
|
||||
# if loss is nan or inf simply add the average of previous logged losses
|
||||
tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
|
||||
else:
|
||||
if tr_loss.device != tr_loss_step.device:
|
||||
raise ValueError(
|
||||
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
|
||||
)
|
||||
tr_loss = tr_loss + tr_loss_step
|
||||
|
||||
self.current_flos += float(self.floating_point_ops(inputs))
|
||||
|
||||
is_last_step_and_steps_less_than_grad_acc = (
|
||||
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
|
||||
)
|
||||
|
||||
if (
|
||||
total_batched_samples % args.gradient_accumulation_steps == 0
|
||||
or
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
is_last_step_and_steps_less_than_grad_acc
|
||||
):
|
||||
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
|
||||
# in accelerate. So, explicitly enable sync gradients to True in that case.
|
||||
if is_last_step_and_steps_less_than_grad_acc:
|
||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||
|
||||
# Gradient clipping
|
||||
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
||||
# deepspeed does its own clipping
|
||||
|
||||
if is_sagemaker_mp_enabled() and args.fp16:
|
||||
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||
elif self.use_apex:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
_grad_norm = nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
args.max_grad_norm,
|
||||
if self.args.include_num_input_tokens_seen:
|
||||
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
||||
if main_input_name not in inputs:
|
||||
logger.warning(
|
||||
"Tried to track the number of tokens seen, however the current model is "
|
||||
"not configured properly to know what item is the input. To fix this, add "
|
||||
"a `main_input_name` attribute to the model class you are using."
|
||||
)
|
||||
else:
|
||||
_grad_norm = self.accelerator.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
args.max_grad_norm,
|
||||
input_tokens = inputs[main_input_name].numel()
|
||||
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
|
||||
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item()
|
||||
if rng_to_sync:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
rng_to_sync = False
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
if steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.update(1)
|
||||
if steps_trained_in_current_epoch == 0:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
continue
|
||||
elif steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.close()
|
||||
steps_trained_progress_bar = None
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
||||
|
||||
with self.accelerator.accumulate(model):
|
||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
if (
|
||||
args.logging_nan_inf_filter
|
||||
and not is_torch_xla_available()
|
||||
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
|
||||
):
|
||||
# if loss is nan or inf simply add the average of previous logged losses
|
||||
tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
|
||||
else:
|
||||
if tr_loss.device != tr_loss_step.device:
|
||||
raise ValueError(
|
||||
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
|
||||
)
|
||||
tr_loss = tr_loss + tr_loss_step
|
||||
|
||||
if (
|
||||
is_accelerate_available()
|
||||
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
):
|
||||
grad_norm = model.get_global_grad_norm()
|
||||
# In some cases the grad norm may not return a float
|
||||
if hasattr(grad_norm, "item"):
|
||||
grad_norm = grad_norm.item()
|
||||
else:
|
||||
grad_norm = _grad_norm
|
||||
self.current_flos += float(self.floating_point_ops(inputs))
|
||||
|
||||
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
|
||||
is_last_step_and_steps_less_than_grad_acc = (
|
||||
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
if (
|
||||
(total_batched_samples) % args.gradient_accumulation_steps == 0
|
||||
or
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
is_last_step_and_steps_less_than_grad_acc
|
||||
):
|
||||
# Since we perform prefetching, we need to manually set sync_gradients to True
|
||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||
|
||||
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
||||
# Gradient clipping
|
||||
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
||||
# deepspeed does its own clipping
|
||||
|
||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||
if optimizer_was_run:
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.lr_scheduler.step()
|
||||
if is_sagemaker_mp_enabled() and args.fp16:
|
||||
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||
elif self.use_apex:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
_grad_norm = nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
else:
|
||||
_grad_norm = self.accelerator.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
|
||||
model.zero_grad()
|
||||
self.state.global_step += 1
|
||||
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
||||
if (
|
||||
is_accelerate_available()
|
||||
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
):
|
||||
grad_norm = model.get_global_grad_norm()
|
||||
# In some cases the grad norm may not return a float
|
||||
if hasattr(grad_norm, "item"):
|
||||
grad_norm = grad_norm.item()
|
||||
else:
|
||||
grad_norm = _grad_norm
|
||||
|
||||
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
|
||||
else:
|
||||
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
||||
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
|
||||
|
||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||
if optimizer_was_run:
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
model.zero_grad()
|
||||
self.state.global_step += 1
|
||||
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
||||
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
|
||||
else:
|
||||
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
# PyTorch/XLA relies on the data loader to insert the mark_step for
|
||||
# each step. Since we are breaking the loop early, we need to manually
|
||||
# insert the mark_step here.
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
break
|
||||
# We also need to break out of the nested loop
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
break
|
||||
@ -3514,7 +3532,9 @@ class Trainer:
|
||||
|
||||
return ctx_manager
|
||||
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
@ -3542,7 +3562,7 @@ class Trainer:
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
del inputs
|
||||
if (
|
||||
@ -3575,20 +3595,23 @@ class Trainer:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss *= self.args.gradient_accumulation_steps
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if self.label_smoother is not None and "labels" in inputs:
|
||||
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
# if num_items_in_batch is not None:
|
||||
# inputs["num_items_in_batch"] = num_items_in_batch
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
@ -3601,7 +3624,10 @@ class Trainer:
|
||||
model_name = unwrapped_model.base_model.model._get_name()
|
||||
else:
|
||||
model_name = unwrapped_model._get_name()
|
||||
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||
# User-defined compute_loss function
|
||||
if self.compute_loss_func is not None:
|
||||
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
|
||||
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||
else:
|
||||
loss = self.label_smoother(outputs, labels)
|
||||
@ -4993,3 +5019,21 @@ class Trainer:
|
||||
fsdp_plugin.set_mixed_precision(
|
||||
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
|
||||
)
|
||||
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
batch_samples = []
|
||||
num_items_in_batch = None
|
||||
for _ in range(num_batches):
|
||||
try:
|
||||
batch_samples += [next(epoch_iterator)]
|
||||
except StopIteration:
|
||||
break
|
||||
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
|
||||
# For now we don't support object detection
|
||||
try:
|
||||
num_items_in_batch = sum(
|
||||
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
|
||||
)
|
||||
except TypeError:
|
||||
pass
|
||||
return batch_samples, num_items_in_batch
|
||||
|
@ -42,6 +42,7 @@ from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
IntervalStrategy,
|
||||
PretrainedConfig,
|
||||
TrainerCallback,
|
||||
@ -49,6 +50,7 @@ from transformers import (
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
is_torch_available,
|
||||
logging,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
|
||||
from transformers.testing_utils import (
|
||||
@ -153,6 +155,19 @@ if is_accelerate_available():
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
|
||||
|
||||
class StoreLossCallback(TrainerCallback):
|
||||
"""
|
||||
Simple callback to store the loss.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.losses = []
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if "loss" in logs:
|
||||
self.losses.append(logs["loss"])
|
||||
|
||||
|
||||
class MockCudaOOMCallback(TrainerCallback):
|
||||
"""
|
||||
Simple callback to simulate CUDA OOM error if
|
||||
@ -168,6 +183,26 @@ class MockCudaOOMCallback(TrainerCallback):
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
|
||||
def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if num_items_in_batch is None or disable_num_items_in_batch:
|
||||
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="mean")
|
||||
else:
|
||||
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
|
||||
|
||||
class RegressionDataset:
|
||||
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
||||
np.random.seed(seed)
|
||||
@ -438,6 +473,31 @@ if is_torch_available():
|
||||
loss = nn.functional.mse_loss(y, labels)
|
||||
return (loss, y)
|
||||
|
||||
class BasicTextGenerationModel(nn.Module):
|
||||
def __init__(self, vocab_size, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
|
||||
self.fc = nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
embedded = self.embedding(input_ids)
|
||||
lstm_out, _ = self.lstm(embedded)
|
||||
logits = self.fc(lstm_out)
|
||||
return logits
|
||||
|
||||
def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples):
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
# Create random input sequences
|
||||
input_ids = np.random.randint(0, vocab_size, (num_samples, seq_length))
|
||||
|
||||
# Create a datasets.Dataset
|
||||
dataset = datasets.Dataset.from_dict({"input_ids": input_ids, "labels": input_ids})
|
||||
|
||||
return dataset
|
||||
|
||||
class TstLayer(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
@ -676,8 +736,105 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
@slow
|
||||
def test_gradient_accumulation_loss_alignment(self):
|
||||
set_seed(42)
|
||||
import datasets
|
||||
|
||||
model_name = "distilgpt2"
|
||||
dataset_name = "wikitext"
|
||||
dataset_config = "wikitext-2-raw-v1"
|
||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||
dataset = dataset.train_test_split(test_size=0.2)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"])
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
|
||||
return ForCausalLMLoss(
|
||||
logits["logits"], labels, vocab_size, num_items_in_batch, disable_num_items_in_batch
|
||||
)
|
||||
|
||||
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=False)
|
||||
|
||||
base_loss_callback = StoreLossCallback()
|
||||
|
||||
args_kwargs = {
|
||||
"report_to": "none",
|
||||
"logging_steps": 1,
|
||||
"max_steps": 20,
|
||||
"learning_rate": 3e-4,
|
||||
"disable_tqdm": True,
|
||||
}
|
||||
|
||||
args = TrainingArguments(
|
||||
"./generation",
|
||||
**args_kwargs,
|
||||
)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[base_loss_callback],
|
||||
compute_loss_func=loss_fn,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
grad_accum_loss_callback = StoreLossCallback()
|
||||
args = TrainingArguments(
|
||||
"./generation",
|
||||
**args_kwargs,
|
||||
gradient_accumulation_steps=2,
|
||||
per_device_train_batch_size=4,
|
||||
)
|
||||
set_seed(42)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[grad_accum_loss_callback],
|
||||
compute_loss_func=loss_fn,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
set_seed(42)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
broken_loss_callback = StoreLossCallback()
|
||||
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=True)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[broken_loss_callback],
|
||||
compute_loss_func=loss_fn,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Calculate the difference between the base loss and the grad_accum loss
|
||||
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
|
||||
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||
# These should be quite close
|
||||
for diff in diff_truth:
|
||||
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
|
||||
|
||||
# These should be very off
|
||||
for diff in diff_broken:
|
||||
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
|
||||
|
||||
def test_gradient_accumulation(self):
|
||||
# Training with half the batch size but accumulation steps as 2 should give the same results.
|
||||
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
|
||||
trainer = get_regression_trainer(
|
||||
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user