Harmonize both Trainers API (#6157)

* Harmonize both Trainers API

* Fix test

* main_prcess -> process_zero
This commit is contained in:
Sylvain Gugger 2020-07-31 09:43:23 -04:00 committed by GitHub
parent 603cd81a01
commit 86caab1e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 125 additions and 97 deletions

View File

@ -11,6 +11,23 @@ customization during training.
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow. <https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the
previous features. To inject custom behavior you can subclass them and override the following methods:
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init.
- **training_step** -- Performs a training step.
- **prediction_step** -- Performs an evaluation/test step.
- **run_model** (TensorFlow only) -- Basic pass through the model.
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
``Trainer`` ``Trainer``
~~~~~~~~~~~ ~~~~~~~~~~~

View File

@ -172,18 +172,6 @@ class Trainer:
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`. :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
""" """
model: PreTrainedModel
args: TrainingArguments
data_collator: DataCollator
train_dataset: Optional[Dataset]
eval_dataset: Optional[Dataset]
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool
tb_writer: Optional["SummaryWriter"] = None
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
global_step: Optional[int] = None
epoch: Optional[float] = None
def __init__( def __init__(
self, self,
model: PreTrainedModel, model: PreTrainedModel,
@ -194,7 +182,7 @@ class Trainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False, prediction_loss_only=False,
tb_writer: Optional["SummaryWriter"] = None, tb_writer: Optional["SummaryWriter"] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
): ):
self.model = model.to(args.device) self.model = model.to(args.device)
self.args = args self.args = args
@ -203,10 +191,9 @@ class Trainer:
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only self.prediction_loss_only = prediction_loss_only
self.optimizers = optimizers self.optimizer, self.lr_scheduler = optimizers
if tb_writer is not None: self.tb_writer = tb_writer
self.tb_writer = tb_writer if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
elif is_tensorboard_available() and self.is_world_master():
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available(): if not is_tensorboard_available():
logger.warning( logger.warning(
@ -221,7 +208,7 @@ class Trainer:
) )
set_seed(self.args.seed) set_seed(self.args.seed)
# Create output directory if needed # Create output directory if needed
if self.is_world_master(): if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True) os.makedirs(self.args.output_dir, exist_ok=True)
if is_torch_tpu_available(): if is_torch_tpu_available():
# Set an xla_device flag on the model's config. # Set an xla_device flag on the model's config.
@ -236,6 +223,8 @@ class Trainer:
), ),
FutureWarning, FutureWarning,
) )
self.global_step = None
self.epoch = None
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
@ -333,39 +322,35 @@ class Trainer:
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
) )
def get_optimizers( def create_optimizer_and_scheduler(self, num_training_steps: int):
self, num_training_steps: int
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
""" """
Setup the optimizer and the learning rate scheduler. Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
""" """
if self.optimizers is not None: if self.optimizer is None:
return self.optimizers no_decay = ["bias", "LayerNorm.weight"]
# Prepare optimizer and schedule (linear warmup and decay) optimizer_grouped_parameters = [
no_decay = ["bias", "LayerNorm.weight"] {
optimizer_grouped_parameters = [ "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
{ "weight_decay": self.args.weight_decay,
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], },
"weight_decay": self.args.weight_decay, {
}, "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
{ "weight_decay": 0.0,
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], },
"weight_decay": 0.0, ]
}, self.optimizer = AdamW(
] optimizer_grouped_parameters,
optimizer = AdamW( lr=self.args.learning_rate,
optimizer_grouped_parameters, betas=(self.args.adam_beta1, self.args.adam_beta2),
lr=self.args.learning_rate, eps=self.args.adam_epsilon,
betas=(self.args.adam_beta1, self.args.adam_beta2), )
eps=self.args.adam_epsilon, if self.lr_scheduler is None:
) self.lr_scheduler = get_linear_schedule_with_warmup(
scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps )
)
return optimizer, scheduler
def setup_wandb(self): def setup_wandb(self):
""" """
@ -390,7 +375,7 @@ class Trainer:
) )
return self._setup_wandb() return self._setup_wandb()
if self.is_world_master(): if self.is_world_process_zero():
logger.info( logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
) )
@ -426,7 +411,7 @@ class Trainer:
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs num_train_epochs = self.args.num_train_epochs
optimizer, scheduler = self.get_optimizers(num_training_steps=t_total) self.create_optimizer_and_scheduler(num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
if ( if (
@ -435,16 +420,16 @@ class Trainer:
and os.path.isfile(os.path.join(model_path, "scheduler.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
): ):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
optimizer.load_state_dict( self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
) )
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
model = self.model model = self.model
if self.args.fp16 and _use_apex: if self.args.fp16 and _use_apex:
if not is_apex_available(): if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level) model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
@ -506,7 +491,7 @@ class Trainer:
logging_loss = 0.0 logging_loss = 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange( train_iterator = trange(
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master() epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero()
) )
for epoch in train_iterator: for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
@ -516,9 +501,9 @@ class Trainer:
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device self.args.device
) )
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
else: else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero())
# Reset the past mems state at the beginning of each epoch if necessary. # Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0: if self.args.past_index >= 0:
@ -531,7 +516,7 @@ class Trainer:
steps_trained_in_current_epoch -= 1 steps_trained_in_current_epoch -= 1
continue continue
tr_loss += self.training_step(model, inputs, optimizer) tr_loss += self.training_step(model, inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
@ -539,23 +524,22 @@ class Trainer:
and (step + 1) == len(epoch_iterator) and (step + 1) == len(epoch_iterator)
): ):
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(optimizer) self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex: elif self.args.fp16 and _use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.optimizer_step(optimizer) xm.optimizer_step(self.optimizer)
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler.step(optimizer) self.scaler.step(self.optimizer)
self.scaler.update() self.scaler.update()
else: else:
optimizer.step() self.optimizer.step()
scheduler.step() self.lr_scheduler.step()
model.zero_grad() model.zero_grad()
self.global_step += 1 self.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator) self.epoch = epoch + (step + 1) / len(epoch_iterator)
@ -567,9 +551,9 @@ class Trainer:
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
# backward compatibility for pytorch schedulers # backward compatibility for pytorch schedulers
logs["learning_rate"] = ( logs["learning_rate"] = (
scheduler.get_last_lr()[0] self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4") if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0] else self.lr_scheduler.get_lr()[0]
) )
logging_loss = tr_loss logging_loss = tr_loss
@ -590,16 +574,16 @@ class Trainer:
self.save_model(output_dir) self.save_model(output_dir)
if self.is_world_master(): if self.is_world_process_zero():
self._rotate_checkpoints() self._rotate_checkpoints()
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
elif self.is_world_master(): elif self.is_world_process_zero():
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
if self.args.max_steps > 0 and self.global_step > self.args.max_steps: if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
epoch_iterator.close() epoch_iterator.close()
@ -660,7 +644,7 @@ class Trainer:
) )
self.tb_writer.flush() self.tb_writer.flush()
if is_wandb_available(): if is_wandb_available():
if self.is_world_master(): if self.is_world_process_zero():
wandb.log(logs, step=self.global_step) wandb.log(logs, step=self.global_step)
output = {**logs, **{"step": self.global_step}} output = {**logs, **{"step": self.global_step}}
if iterator is not None: if iterator is not None:
@ -684,11 +668,9 @@ class Trainer:
return inputs return inputs
def training_step( def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float:
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
) -> float:
""" """
Perform a training step on :obj:`model` using obj:`inputs` and :obj:`optimizer`. Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
@ -700,19 +682,16 @@ class Trainer:
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments. argument :obj:`labels`. Check your model's documentation for all accepted arguments.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer to use to make a step.
Return: Return:
`float`: :obj:`float`: The training loss on this batch.
The training loss on this batch.
""" """
if hasattr(self, "_training_step"): if hasattr(self, "_training_step"):
warnings.warn( warnings.warn(
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.", "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
FutureWarning, FutureWarning,
) )
return self._training_step(model, inputs, optimizer) return self._training_step(model, inputs, self.optimizer)
model.train() model.train()
inputs = self._prepare_inputs(inputs, model) inputs = self._prepare_inputs(inputs, model)
@ -738,7 +717,7 @@ class Trainer:
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler.scale(loss).backward() self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex: elif self.args.fp16 and _use_apex:
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
loss.backward() loss.backward()
@ -746,6 +725,22 @@ class Trainer:
return loss.item() return loss.item()
def is_local_master(self) -> bool: def is_local_master(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
"""
warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
return self.is_local_process_zero()
def is_local_process_zero(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
"""
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.is_master_ordinal(local=True) return xm.is_master_ordinal(local=True)
else: else:
@ -753,8 +748,20 @@ class Trainer:
def is_world_master(self) -> bool: def is_world_master(self) -> bool:
""" """
This will be True only in one process, even in distributed mode, Whether or not this process is the global main process (when training in a distributed fashion on
even when training on multiple machines. several machines, this is only going to be :obj:`True` for one process).
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
"""
warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
return self.is_world_process_zero()
def is_world_process_zero(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.is_master_ordinal(local=False) return xm.is_master_ordinal(local=False)
@ -770,7 +777,7 @@ class Trainer:
if is_torch_tpu_available(): if is_torch_tpu_available():
self._save_tpu(output_dir) self._save_tpu(output_dir)
elif self.is_world_master(): elif self.is_world_process_zero():
self._save(output_dir) self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None): def _save_tpu(self, output_dir: Optional[str] = None):
@ -846,6 +853,7 @@ class Trainer:
Args: Args:
eval_dataset (:obj:`Dataset`, `optional`): eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Returns: Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
""" """
@ -871,6 +879,7 @@ class Trainer:
Args: Args:
test_dataset (:obj:`Dataset`): test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. Dataset to run the predictions on.
Returns: Returns:
`NamedTuple`: `NamedTuple`:
predictions (:obj:`np.ndarray`): predictions (:obj:`np.ndarray`):

View File

@ -63,17 +63,6 @@ class TFTrainer:
an instance of :class:`~transformers.WarmUp`. an instance of :class:`~transformers.WarmUp`.
""" """
model: TFPreTrainedModel
args: TFTrainingArguments
train_dataset: Optional[tf.data.Dataset]
eval_dataset: Optional[tf.data.Dataset]
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool
tb_writer: Optional[tf.summary.SummaryWriter] = None
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (None, None)
global_step: Optional[int] = None
epoch_logging: Optional[float] = None
def __init__( def __init__(
self, self,
model: TFPreTrainedModel, model: TFPreTrainedModel,
@ -325,6 +314,15 @@ class TFTrainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def log(self, logs: Dict[str, float]) -> None: def log(self, logs: Dict[str, float]) -> None:
"""
Log :obj:`logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
"""
if hasattr(self, "_log"): if hasattr(self, "_log"):
warnings.warn( warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.", "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
@ -356,6 +354,7 @@ class TFTrainer:
Args: Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`): eval_dataset (:class:`~tf.data.Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Returns: Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
""" """
@ -577,9 +576,12 @@ class TFTrainer:
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
Args: Args:
features: the batched features. features (:obj:`tf.Tensor`): A batch of input features.
labels: the batched labels. labels (:obj:`tf.Tensor`): A batch of labels.
training: run the model in training mode or not training (:obj:`bool`): Whether or not to run the model in training mode.
Returns:
A tuple of two :obj:`tf.Tensor`: The loss and logits.
""" """
if hasattr(self, "_run_model"): if hasattr(self, "_run_model"):
warnings.warn( warnings.warn(