Cleanup Trainer and expose customization points (#5982)

* Clean up Trainer and expose customization points

* Formatting

* eval_step -> prediction_step
This commit is contained in:
Sylvain Gugger 2020-07-23 12:05:41 -04:00 committed by GitHub
parent 76f52324b1
commit e168488a74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -202,7 +202,7 @@ class Trainer:
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
)
if is_wandb_available():
self._setup_wandb()
self.setup_wandb()
else:
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
@ -226,23 +226,32 @@ class Trainer:
FutureWarning,
)
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.
"""
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
train_sampler = None
elif self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
return None
elif is_torch_tpu_available():
train_sampler = get_tpu_sampler(self.train_dataset)
return get_tpu_sampler(self.train_dataset)
else:
train_sampler = (
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
data_loader = DataLoader(
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
(adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = self._get_train_sampler()
return DataLoader(
self.train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
@ -250,69 +259,66 @@ class Trainer:
drop_last=self.args.dataloader_drop_last,
)
return data_loader
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset)
else:
return SequentialSampler(eval_dataset)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
"""
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
sampler (adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
If provided, will override `self.eval_dataset`.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
eval_sampler = self._get_eval_sampler(eval_dataset)
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
sampler = None
elif is_torch_tpu_available():
sampler = SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
elif self.args.local_rank != -1:
sampler = SequentialDistributedSampler(eval_dataset)
else:
sampler = SequentialSampler(eval_dataset)
data_loader = DataLoader(
return DataLoader(
eval_dataset,
sampler=sampler,
sampler=eval_sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
)
return data_loader
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Returns the test :class:`~torch.utils.data.DataLoader`.
Args:
test_dataset (obj:`Dataset`): The test dataset to use.
"""
# We use the same batch_size as for eval.
if isinstance(self.test_dataset, torch.utils.data.IterableDataset):
sampler = None
elif is_torch_tpu_available():
sampler = SequentialDistributedSampler(
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
elif self.args.local_rank != -1:
sampler = SequentialDistributedSampler(test_dataset)
else:
sampler = SequentialSampler(test_dataset)
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
sampler (adapted to distributed training if necessary) otherwise.
data_loader = DataLoader(
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The test dataset to use.
"""
test_sampler = self._get_eval_sampler(test_dataset)
# We use the same batch_size as for eval.
return DataLoader(
test_dataset,
sampler=sampler,
sampler=test_sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
)
return data_loader
def get_optimizers(
self, num_training_steps: int
@ -321,7 +327,7 @@ class Trainer:
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or override this method in a subclass.
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizers is not None:
return self.optimizers
@ -343,12 +349,12 @@ class Trainer:
)
return optimizer, scheduler
def _setup_wandb(self):
def setup_wandb(self):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface
You can also override the following environment variables:
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_WATCH:
@ -359,6 +365,13 @@ class Trainer:
WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
"""
if hasattr(self, "_setup_wandb"):
warnings.warn(
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
FutureWarning,
)
return self._setup_wandb()
if self.is_world_master():
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
@ -372,7 +385,7 @@ class Trainer:
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its Dataset.
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
"""
return len(dataloader.dataset)
@ -500,7 +513,7 @@ class Trainer:
steps_trained_in_current_epoch -= 1
continue
tr_loss += self._training_step(model, inputs, optimizer)
tr_loss += self.training_step(model, inputs, optimizer)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
@ -535,7 +548,7 @@ class Trainer:
)
logging_loss = tr_loss
self._log(logs)
self.log(logs)
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
self.evaluate()
@ -582,7 +595,25 @@ class Trainer:
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step)
def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
"""
Log :obj:`logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
iterator (:obj:`tqdm`, `optional`):
A potential tqdm progress bar to write the logs on.
"""
if hasattr(self, "_log"):
warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
FutureWarning,
)
return self._log(logs, iterator=iterator)
if self.epoch is not None:
logs["epoch"] = self.epoch
if self.global_step is None:
@ -612,10 +643,13 @@ class Trainer:
else:
logger.info(output)
def _training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
) -> float:
model.train()
def _prepare_inputs(
self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module
) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
@ -625,9 +659,44 @@ class Trainer:
# Our model outputs do not work with DataParallel, so forcing return tuple.
if isinstance(model, nn.DataParallel):
inputs["return_tuple"] = True
return inputs
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
) -> float:
"""
Perform a training step on :obj:`model` using obj:`inputs` and :obj:`optimizer`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer to use to make a step.
Return:
`float`:
The training loss on this batch.
"""
if hasattr(self, "_training_step"):
warnings.warn(
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
FutureWarning,
)
return self._training_step(model, inputs, optimizer)
model.train()
inputs = self._prepare_inputs(inputs, model)
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs[0]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
@ -741,6 +810,8 @@ class Trainer:
The calling script will be responsible for providing a method to compute metrics, as they are
task-dependent (pass it to the init :obj:`compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
@ -749,9 +820,9 @@ class Trainer:
"""
eval_dataloader = self.get_eval_dataloader(eval_dataset)
output = self._prediction_loop(eval_dataloader, description="Evaluation")
output = self.prediction_loop(eval_dataloader, description="Evaluation")
self._log(output.metrics)
self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
@ -780,16 +851,22 @@ class Trainer:
"""
test_dataloader = self.get_test_dataloader(test_dataset)
return self._prediction_loop(test_dataloader, description="Prediction")
return self.prediction_loop(test_dataloader, description="Prediction")
def _prediction_loop(
def prediction_loop(
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Works both with or without labels.
"""
if hasattr(self, "_prediction_loop"):
warnings.warn(
"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
FutureWarning,
)
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
@ -815,40 +892,20 @@ class Trainer:
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
if self.args.past_index >= 0:
past = None
self._past = None
for inputs in tqdm(dataloader, desc=description):
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
if loss is not None:
eval_losses.append(loss)
if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
if labels is not None:
label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0:
inputs["mems"] = past
# Our model outputs do not work with DataParallel, so forcing return tuple.
if isinstance(model, nn.DataParallel):
inputs["return_tuple"] = True
with torch.no_grad():
outputs = model(**inputs)
if has_labels:
step_eval_loss, logits = outputs[:2]
eval_losses += [step_eval_loss.mean().item()]
else:
logits = outputs[0]
if self.args.past_index >= 0:
past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
if not prediction_loss_only:
if preds is None:
preds = logits.detach()
else:
preds = torch.cat((preds, logits.detach()), dim=0)
if inputs.get("labels") is not None:
if label_ids is None:
label_ids = inputs["labels"].detach()
else:
label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
@ -894,3 +951,49 @@ class Trainer:
# truncate the dummy elements added by SequentialDistributedSampler
output = concat[:num_total_examples]
return output
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional).
"""
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
inputs = self._prepare_inputs(inputs, model)
with torch.no_grad():
outputs = model(**inputs)
if has_labels:
loss, logits = outputs[:2]
loss = loss.mean().item()
else:
loss = None
logits = outputs[0]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
labels = inputs.get("labels")
if labels is not None:
labels = labels.detach()
return (loss, logits.detach(), labels)