mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Cleanup Trainer and expose customization points (#5982)
* Clean up Trainer and expose customization points * Formatting * eval_step -> prediction_step
This commit is contained in:
parent
76f52324b1
commit
e168488a74
@ -202,7 +202,7 @@ class Trainer:
|
||||
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
||||
)
|
||||
if is_wandb_available():
|
||||
self._setup_wandb()
|
||||
self.setup_wandb()
|
||||
else:
|
||||
logger.info(
|
||||
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
||||
@ -226,23 +226,32 @@ class Trainer:
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""
|
||||
Returns the training :class:`~torch.utils.data.DataLoader`.
|
||||
"""
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
train_sampler = None
|
||||
elif self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
return None
|
||||
elif is_torch_tpu_available():
|
||||
train_sampler = get_tpu_sampler(self.train_dataset)
|
||||
return get_tpu_sampler(self.train_dataset)
|
||||
else:
|
||||
train_sampler = (
|
||||
return (
|
||||
RandomSampler(self.train_dataset)
|
||||
if self.args.local_rank == -1
|
||||
else DistributedSampler(self.train_dataset)
|
||||
)
|
||||
data_loader = DataLoader(
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""
|
||||
Returns the training :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
|
||||
(adapted to distributed training if necessary) otherwise.
|
||||
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
"""
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
train_sampler = self._get_train_sampler()
|
||||
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
sampler=train_sampler,
|
||||
@ -250,69 +259,66 @@ class Trainer:
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
return data_loader
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||
return None
|
||||
elif is_torch_tpu_available():
|
||||
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
elif self.args.local_rank != -1:
|
||||
return SequentialDistributedSampler(eval_dataset)
|
||||
else:
|
||||
return SequentialSampler(eval_dataset)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
"""
|
||||
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
||||
sampler (adapted to distributed training if necessary) otherwise.
|
||||
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
|
||||
Args:
|
||||
eval_dataset (:obj:`Dataset`, `optional`):
|
||||
If provided, will override `self.eval_dataset`.
|
||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
If provided, will override :obj:`self.eval_dataset`.
|
||||
"""
|
||||
if eval_dataset is None and self.eval_dataset is None:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
|
||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = None
|
||||
elif is_torch_tpu_available():
|
||||
sampler = SequentialDistributedSampler(
|
||||
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||
)
|
||||
elif self.args.local_rank != -1:
|
||||
sampler = SequentialDistributedSampler(eval_dataset)
|
||||
else:
|
||||
sampler = SequentialSampler(eval_dataset)
|
||||
|
||||
data_loader = DataLoader(
|
||||
return DataLoader(
|
||||
eval_dataset,
|
||||
sampler=sampler,
|
||||
sampler=eval_sampler,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
return data_loader
|
||||
|
||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||
"""
|
||||
Returns the test :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
Args:
|
||||
test_dataset (obj:`Dataset`): The test dataset to use.
|
||||
"""
|
||||
# We use the same batch_size as for eval.
|
||||
if isinstance(self.test_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = None
|
||||
elif is_torch_tpu_available():
|
||||
sampler = SequentialDistributedSampler(
|
||||
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||
)
|
||||
elif self.args.local_rank != -1:
|
||||
sampler = SequentialDistributedSampler(test_dataset)
|
||||
else:
|
||||
sampler = SequentialSampler(test_dataset)
|
||||
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
||||
sampler (adapted to distributed training if necessary) otherwise.
|
||||
|
||||
data_loader = DataLoader(
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
|
||||
Args:
|
||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
The test dataset to use.
|
||||
"""
|
||||
test_sampler = self._get_eval_sampler(test_dataset)
|
||||
|
||||
# We use the same batch_size as for eval.
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
sampler=sampler,
|
||||
sampler=test_sampler,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
return data_loader
|
||||
|
||||
def get_optimizers(
|
||||
self, num_training_steps: int
|
||||
@ -321,7 +327,7 @@ class Trainer:
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
|
||||
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||
Trainer's init through :obj:`optimizers`, or override this method in a subclass.
|
||||
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
||||
"""
|
||||
if self.optimizers is not None:
|
||||
return self.optimizers
|
||||
@ -343,12 +349,12 @@ class Trainer:
|
||||
)
|
||||
return optimizer, scheduler
|
||||
|
||||
def _setup_wandb(self):
|
||||
def setup_wandb(self):
|
||||
"""
|
||||
Setup the optional Weights & Biases (`wandb`) integration.
|
||||
|
||||
One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface
|
||||
You can also override the following environment variables:
|
||||
One can subclass and override this method to customize the setup if needed. Find more information
|
||||
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
||||
|
||||
Environment:
|
||||
WANDB_WATCH:
|
||||
@ -359,6 +365,13 @@ class Trainer:
|
||||
WANDB_DISABLED:
|
||||
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
|
||||
"""
|
||||
if hasattr(self, "_setup_wandb"):
|
||||
warnings.warn(
|
||||
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._setup_wandb()
|
||||
|
||||
if self.is_world_master():
|
||||
logger.info(
|
||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||
@ -372,7 +385,7 @@ class Trainer:
|
||||
|
||||
def num_examples(self, dataloader: DataLoader) -> int:
|
||||
"""
|
||||
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its Dataset.
|
||||
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
||||
"""
|
||||
return len(dataloader.dataset)
|
||||
|
||||
@ -500,7 +513,7 @@ class Trainer:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
tr_loss += self._training_step(model, inputs, optimizer)
|
||||
tr_loss += self.training_step(model, inputs, optimizer)
|
||||
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
@ -535,7 +548,7 @@ class Trainer:
|
||||
)
|
||||
logging_loss = tr_loss
|
||||
|
||||
self._log(logs)
|
||||
self.log(logs)
|
||||
|
||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
||||
self.evaluate()
|
||||
@ -582,7 +595,25 @@ class Trainer:
|
||||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
||||
|
||||
def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
||||
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
||||
"""
|
||||
Log :obj:`logs` on the various objects watching training.
|
||||
|
||||
Subclass and override this method to inject custom behavior.
|
||||
|
||||
Args:
|
||||
logs (:obj:`Dict[str, float]`):
|
||||
The values to log.
|
||||
iterator (:obj:`tqdm`, `optional`):
|
||||
A potential tqdm progress bar to write the logs on.
|
||||
"""
|
||||
if hasattr(self, "_log"):
|
||||
warnings.warn(
|
||||
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._log(logs, iterator=iterator)
|
||||
|
||||
if self.epoch is not None:
|
||||
logs["epoch"] = self.epoch
|
||||
if self.global_step is None:
|
||||
@ -612,10 +643,13 @@ class Trainer:
|
||||
else:
|
||||
logger.info(output)
|
||||
|
||||
def _training_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
|
||||
) -> float:
|
||||
model.train()
|
||||
def _prepare_inputs(
|
||||
self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module
|
||||
) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||
"""
|
||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||
handling potential state.
|
||||
"""
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
inputs[k] = v.to(self.args.device)
|
||||
@ -625,9 +659,44 @@ class Trainer:
|
||||
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
||||
if isinstance(model, nn.DataParallel):
|
||||
inputs["return_tuple"] = True
|
||||
return inputs
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
|
||||
) -> float:
|
||||
"""
|
||||
Perform a training step on :obj:`model` using obj:`inputs` and :obj:`optimizer`.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to train.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
optimizer (:obj:`torch.optim.Optimizer`):
|
||||
The optimizer to use to make a step.
|
||||
|
||||
Return:
|
||||
`float`:
|
||||
The training loss on this batch.
|
||||
"""
|
||||
if hasattr(self, "_training_step"):
|
||||
warnings.warn(
|
||||
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._training_step(model, inputs, optimizer)
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs, model)
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs[0]
|
||||
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
@ -741,6 +810,8 @@ class Trainer:
|
||||
The calling script will be responsible for providing a method to compute metrics, as they are
|
||||
task-dependent (pass it to the init :obj:`compute_metrics` argument).
|
||||
|
||||
You can also subclass and override this method to inject custom behavior.
|
||||
|
||||
Args:
|
||||
eval_dataset (:obj:`Dataset`, `optional`):
|
||||
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
|
||||
@ -749,9 +820,9 @@ class Trainer:
|
||||
"""
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
|
||||
output = self._prediction_loop(eval_dataloader, description="Evaluation")
|
||||
output = self.prediction_loop(eval_dataloader, description="Evaluation")
|
||||
|
||||
self._log(output.metrics)
|
||||
self.log(output.metrics)
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
@ -780,16 +851,22 @@ class Trainer:
|
||||
"""
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
|
||||
return self._prediction_loop(test_dataloader, description="Prediction")
|
||||
return self.prediction_loop(test_dataloader, description="Prediction")
|
||||
|
||||
def _prediction_loop(
|
||||
def prediction_loop(
|
||||
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
|
||||
) -> PredictionOutput:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
||||
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
||||
|
||||
Works both with or without labels.
|
||||
"""
|
||||
if hasattr(self, "_prediction_loop"):
|
||||
warnings.warn(
|
||||
"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
|
||||
|
||||
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
||||
|
||||
@ -815,40 +892,20 @@ class Trainer:
|
||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||
|
||||
if self.args.past_index >= 0:
|
||||
past = None
|
||||
self._past = None
|
||||
|
||||
for inputs in tqdm(dataloader, desc=description):
|
||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
if loss is not None:
|
||||
eval_losses.append(loss)
|
||||
if logits is not None:
|
||||
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
|
||||
if labels is not None:
|
||||
label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
|
||||
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
inputs[k] = v.to(self.args.device)
|
||||
if self.args.past_index >= 0:
|
||||
inputs["mems"] = past
|
||||
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
||||
if isinstance(model, nn.DataParallel):
|
||||
inputs["return_tuple"] = True
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
if has_labels:
|
||||
step_eval_loss, logits = outputs[:2]
|
||||
eval_losses += [step_eval_loss.mean().item()]
|
||||
else:
|
||||
logits = outputs[0]
|
||||
if self.args.past_index >= 0:
|
||||
past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
||||
|
||||
if not prediction_loss_only:
|
||||
if preds is None:
|
||||
preds = logits.detach()
|
||||
else:
|
||||
preds = torch.cat((preds, logits.detach()), dim=0)
|
||||
if inputs.get("labels") is not None:
|
||||
if label_ids is None:
|
||||
label_ids = inputs["labels"].detach()
|
||||
else:
|
||||
label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
delattr(self, "_past")
|
||||
|
||||
if self.args.local_rank != -1:
|
||||
# In distributed mode, concatenate all results from all nodes:
|
||||
@ -894,3 +951,49 @@ class Trainer:
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
output = concat[:num_total_examples]
|
||||
return output
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (:obj:`bool`):
|
||||
Whether or not to return the loss only.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
A tuple with the loss, logits and labels (each being optional).
|
||||
"""
|
||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||
|
||||
inputs = self._prepare_inputs(inputs, model)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
if has_labels:
|
||||
loss, logits = outputs[:2]
|
||||
loss = loss.mean().item()
|
||||
else:
|
||||
loss = None
|
||||
logits = outputs[0]
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
labels = inputs.get("labels")
|
||||
if labels is not None:
|
||||
labels = labels.detach()
|
||||
return (loss, logits.detach(), labels)
|
||||
|
Loading…
Reference in New Issue
Block a user