Harmonize both Trainers API (#6157)

* Harmonize both Trainers API

* Fix test

* main_prcess -> process_zero
This commit is contained in:
Sylvain Gugger 2020-07-31 09:43:23 -04:00 committed by GitHub
parent 603cd81a01
commit 86caab1e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 125 additions and 97 deletions

View File

@ -11,6 +11,23 @@ customization during training.
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the
previous features. To inject custom behavior you can subclass them and override the following methods:
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init.
- **training_step** -- Performs a training step.
- **prediction_step** -- Performs an evaluation/test step.
- **run_model** (TensorFlow only) -- Basic pass through the model.
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
``Trainer``
~~~~~~~~~~~

View File

@ -172,18 +172,6 @@ class Trainer:
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
"""
model: PreTrainedModel
args: TrainingArguments
data_collator: DataCollator
train_dataset: Optional[Dataset]
eval_dataset: Optional[Dataset]
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool
tb_writer: Optional["SummaryWriter"] = None
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
global_step: Optional[int] = None
epoch: Optional[float] = None
def __init__(
self,
model: PreTrainedModel,
@ -194,7 +182,7 @@ class Trainer:
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False,
tb_writer: Optional["SummaryWriter"] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
):
self.model = model.to(args.device)
self.args = args
@ -203,10 +191,9 @@ class Trainer:
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only
self.optimizers = optimizers
if tb_writer is not None:
self.tb_writer = tb_writer
elif is_tensorboard_available() and self.is_world_master():
self.optimizer, self.lr_scheduler = optimizers
self.tb_writer = tb_writer
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available():
logger.warning(
@ -221,7 +208,7 @@ class Trainer:
)
set_seed(self.args.seed)
# Create output directory if needed
if self.is_world_master():
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
if is_torch_tpu_available():
# Set an xla_device flag on the model's config.
@ -236,6 +223,8 @@ class Trainer:
),
FutureWarning,
)
self.global_step = None
self.epoch = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
@ -333,39 +322,35 @@ class Trainer:
drop_last=self.args.dataloader_drop_last,
)
def get_optimizers(
self, num_training_steps: int
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizers is not None:
return self.optimizers
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
return optimizer, scheduler
if self.optimizer is None:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
self.optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
def setup_wandb(self):
"""
@ -390,7 +375,7 @@ class Trainer:
)
return self._setup_wandb()
if self.is_world_master():
if self.is_world_process_zero():
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
@ -426,7 +411,7 @@ class Trainer:
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)
self.create_optimizer_and_scheduler(num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if (
@ -435,16 +420,16 @@ class Trainer:
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
optimizer.load_state_dict(
self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
model = self.model
if self.args.fp16 and _use_apex:
if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
@ -506,7 +491,7 @@ class Trainer:
logging_loss = 0.0
model.zero_grad()
train_iterator = trange(
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero()
)
for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
@ -516,9 +501,9 @@ class Trainer:
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero())
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
@ -531,7 +516,7 @@ class Trainer:
steps_trained_in_current_epoch -= 1
continue
tr_loss += self.training_step(model, inputs, optimizer)
tr_loss += self.training_step(model, inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
@ -539,23 +524,22 @@ class Trainer:
and (step + 1) == len(epoch_iterator)
):
if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(optimizer)
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if is_torch_tpu_available():
xm.optimizer_step(optimizer)
xm.optimizer_step(self.optimizer)
if self.args.fp16 and _use_native_amp:
self.scaler.step(optimizer)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
optimizer.step()
self.optimizer.step()
scheduler.step()
self.lr_scheduler.step()
model.zero_grad()
self.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator)
@ -567,9 +551,9 @@ class Trainer:
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
else self.lr_scheduler.get_lr()[0]
)
logging_loss = tr_loss
@ -590,16 +574,16 @@ class Trainer:
self.save_model(output_dir)
if self.is_world_master():
if self.is_world_process_zero():
self._rotate_checkpoints()
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
elif self.is_world_master():
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
elif self.is_world_process_zero():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
epoch_iterator.close()
@ -660,7 +644,7 @@ class Trainer:
)
self.tb_writer.flush()
if is_wandb_available():
if self.is_world_master():
if self.is_world_process_zero():
wandb.log(logs, step=self.global_step)
output = {**logs, **{"step": self.global_step}}
if iterator is not None:
@ -684,11 +668,9 @@ class Trainer:
return inputs
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
) -> float:
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float:
"""
Perform a training step on :obj:`model` using obj:`inputs` and :obj:`optimizer`.
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
@ -700,19 +682,16 @@ class Trainer:
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer to use to make a step.
Return:
`float`:
The training loss on this batch.
:obj:`float`: The training loss on this batch.
"""
if hasattr(self, "_training_step"):
warnings.warn(
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
FutureWarning,
)
return self._training_step(model, inputs, optimizer)
return self._training_step(model, inputs, self.optimizer)
model.train()
inputs = self._prepare_inputs(inputs, model)
@ -738,7 +717,7 @@ class Trainer:
if self.args.fp16 and _use_native_amp:
self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
@ -746,6 +725,22 @@ class Trainer:
return loss.item()
def is_local_master(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
"""
warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
return self.is_local_process_zero()
def is_local_process_zero(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=True)
else:
@ -753,8 +748,20 @@ class Trainer:
def is_world_master(self) -> bool:
"""
This will be True only in one process, even in distributed mode,
even when training on multiple machines.
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
"""
warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
return self.is_world_process_zero()
def is_world_process_zero(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=False)
@ -770,7 +777,7 @@ class Trainer:
if is_torch_tpu_available():
self._save_tpu(output_dir)
elif self.is_world_master():
elif self.is_world_process_zero():
self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None):
@ -846,6 +853,7 @@ class Trainer:
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
@ -871,6 +879,7 @@ class Trainer:
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on.
Returns:
`NamedTuple`:
predictions (:obj:`np.ndarray`):

View File

@ -63,17 +63,6 @@ class TFTrainer:
an instance of :class:`~transformers.WarmUp`.
"""
model: TFPreTrainedModel
args: TFTrainingArguments
train_dataset: Optional[tf.data.Dataset]
eval_dataset: Optional[tf.data.Dataset]
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool
tb_writer: Optional[tf.summary.SummaryWriter] = None
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (None, None)
global_step: Optional[int] = None
epoch_logging: Optional[float] = None
def __init__(
self,
model: TFPreTrainedModel,
@ -325,6 +314,15 @@ class TFTrainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def log(self, logs: Dict[str, float]) -> None:
"""
Log :obj:`logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
"""
if hasattr(self, "_log"):
warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
@ -356,6 +354,7 @@ class TFTrainer:
Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
@ -577,9 +576,12 @@ class TFTrainer:
Subclass and override this method if you want to inject some custom behavior.
Args:
features: the batched features.
labels: the batched labels.
training: run the model in training mode or not
features (:obj:`tf.Tensor`): A batch of input features.
labels (:obj:`tf.Tensor`): A batch of labels.
training (:obj:`bool`): Whether or not to run the model in training mode.
Returns:
A tuple of two :obj:`tf.Tensor`: The loss and logits.
"""
if hasattr(self, "_run_model"):
warnings.warn(