mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Harmonize both Trainers API (#6157)
* Harmonize both Trainers API * Fix test * main_prcess -> process_zero
This commit is contained in:
parent
603cd81a01
commit
86caab1e0b
@ -11,6 +11,23 @@ customization during training.
|
||||
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
|
||||
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
|
||||
|
||||
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the
|
||||
previous features. To inject custom behavior you can subclass them and override the following methods:
|
||||
|
||||
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
|
||||
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
|
||||
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
|
||||
- **log** -- Logs information on the various objects watching training.
|
||||
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
|
||||
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
||||
init.
|
||||
- **training_step** -- Performs a training step.
|
||||
- **prediction_step** -- Performs an evaluation/test step.
|
||||
- **run_model** (TensorFlow only) -- Basic pass through the model.
|
||||
- **evaluate** -- Runs an evaluation loop and returns metrics.
|
||||
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
|
||||
|
||||
|
||||
``Trainer``
|
||||
~~~~~~~~~~~
|
||||
|
||||
|
@ -172,18 +172,6 @@ class Trainer:
|
||||
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
|
||||
"""
|
||||
|
||||
model: PreTrainedModel
|
||||
args: TrainingArguments
|
||||
data_collator: DataCollator
|
||||
train_dataset: Optional[Dataset]
|
||||
eval_dataset: Optional[Dataset]
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
||||
prediction_loss_only: bool
|
||||
tb_writer: Optional["SummaryWriter"] = None
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
|
||||
global_step: Optional[int] = None
|
||||
epoch: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
@ -194,7 +182,7 @@ class Trainer:
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
prediction_loss_only=False,
|
||||
tb_writer: Optional["SummaryWriter"] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
):
|
||||
self.model = model.to(args.device)
|
||||
self.args = args
|
||||
@ -203,10 +191,9 @@ class Trainer:
|
||||
self.eval_dataset = eval_dataset
|
||||
self.compute_metrics = compute_metrics
|
||||
self.prediction_loss_only = prediction_loss_only
|
||||
self.optimizers = optimizers
|
||||
if tb_writer is not None:
|
||||
self.tb_writer = tb_writer
|
||||
elif is_tensorboard_available() and self.is_world_master():
|
||||
self.optimizer, self.lr_scheduler = optimizers
|
||||
self.tb_writer = tb_writer
|
||||
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
|
||||
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
||||
if not is_tensorboard_available():
|
||||
logger.warning(
|
||||
@ -221,7 +208,7 @@ class Trainer:
|
||||
)
|
||||
set_seed(self.args.seed)
|
||||
# Create output directory if needed
|
||||
if self.is_world_master():
|
||||
if self.is_world_process_zero():
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
if is_torch_tpu_available():
|
||||
# Set an xla_device flag on the model's config.
|
||||
@ -236,6 +223,8 @@ class Trainer:
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
self.global_step = None
|
||||
self.epoch = None
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
@ -333,39 +322,35 @@ class Trainer:
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
|
||||
def get_optimizers(
|
||||
self, num_training_steps: int
|
||||
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
|
||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
|
||||
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
||||
"""
|
||||
if self.optimizers is not None:
|
||||
return self.optimizers
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=self.args.learning_rate,
|
||||
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
||||
eps=self.args.adam_epsilon,
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||
)
|
||||
return optimizer, scheduler
|
||||
if self.optimizer is None:
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=self.args.learning_rate,
|
||||
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
||||
eps=self.args.adam_epsilon,
|
||||
)
|
||||
if self.lr_scheduler is None:
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
def setup_wandb(self):
|
||||
"""
|
||||
@ -390,7 +375,7 @@ class Trainer:
|
||||
)
|
||||
return self._setup_wandb()
|
||||
|
||||
if self.is_world_master():
|
||||
if self.is_world_process_zero():
|
||||
logger.info(
|
||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||
)
|
||||
@ -426,7 +411,7 @@ class Trainer:
|
||||
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
|
||||
optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)
|
||||
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if (
|
||||
@ -435,16 +420,16 @@ class Trainer:
|
||||
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||
)
|
||||
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
|
||||
model = self.model
|
||||
if self.args.fp16 and _use_apex:
|
||||
if not is_apex_available():
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)
|
||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if self.args.n_gpu > 1:
|
||||
@ -506,7 +491,7 @@ class Trainer:
|
||||
logging_loss = 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
|
||||
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero()
|
||||
)
|
||||
for epoch in train_iterator:
|
||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
@ -516,9 +501,9 @@ class Trainer:
|
||||
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
||||
self.args.device
|
||||
)
|
||||
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
|
||||
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
|
||||
else:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero())
|
||||
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if self.args.past_index >= 0:
|
||||
@ -531,7 +516,7 @@ class Trainer:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
tr_loss += self.training_step(model, inputs, optimizer)
|
||||
tr_loss += self.training_step(model, inputs)
|
||||
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
@ -539,23 +524,22 @@ class Trainer:
|
||||
and (step + 1) == len(epoch_iterator)
|
||||
):
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler.unscale_(optimizer)
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||
elif self.args.fp16 and _use_apex:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.optimizer_step(optimizer)
|
||||
|
||||
xm.optimizer_step(self.optimizer)
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
self.optimizer.step()
|
||||
|
||||
scheduler.step()
|
||||
self.lr_scheduler.step()
|
||||
model.zero_grad()
|
||||
self.global_step += 1
|
||||
self.epoch = epoch + (step + 1) / len(epoch_iterator)
|
||||
@ -567,9 +551,9 @@ class Trainer:
|
||||
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
|
||||
# backward compatibility for pytorch schedulers
|
||||
logs["learning_rate"] = (
|
||||
scheduler.get_last_lr()[0]
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||
else scheduler.get_lr()[0]
|
||||
else self.lr_scheduler.get_lr()[0]
|
||||
)
|
||||
logging_loss = tr_loss
|
||||
|
||||
@ -590,16 +574,16 @@ class Trainer:
|
||||
|
||||
self.save_model(output_dir)
|
||||
|
||||
if self.is_world_master():
|
||||
if self.is_world_process_zero():
|
||||
self._rotate_checkpoints()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
elif self.is_world_master():
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
elif self.is_world_process_zero():
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
|
||||
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
|
||||
epoch_iterator.close()
|
||||
@ -660,7 +644,7 @@ class Trainer:
|
||||
)
|
||||
self.tb_writer.flush()
|
||||
if is_wandb_available():
|
||||
if self.is_world_master():
|
||||
if self.is_world_process_zero():
|
||||
wandb.log(logs, step=self.global_step)
|
||||
output = {**logs, **{"step": self.global_step}}
|
||||
if iterator is not None:
|
||||
@ -684,11 +668,9 @@ class Trainer:
|
||||
|
||||
return inputs
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
|
||||
) -> float:
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float:
|
||||
"""
|
||||
Perform a training step on :obj:`model` using obj:`inputs` and :obj:`optimizer`.
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
@ -700,19 +682,16 @@ class Trainer:
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
optimizer (:obj:`torch.optim.Optimizer`):
|
||||
The optimizer to use to make a step.
|
||||
|
||||
Return:
|
||||
`float`:
|
||||
The training loss on this batch.
|
||||
:obj:`float`: The training loss on this batch.
|
||||
"""
|
||||
if hasattr(self, "_training_step"):
|
||||
warnings.warn(
|
||||
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._training_step(model, inputs, optimizer)
|
||||
return self._training_step(model, inputs, self.optimizer)
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs, model)
|
||||
@ -738,7 +717,7 @@ class Trainer:
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
elif self.args.fp16 and _use_apex:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
@ -746,6 +725,22 @@ class Trainer:
|
||||
return loss.item()
|
||||
|
||||
def is_local_master(self) -> bool:
|
||||
"""
|
||||
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
||||
several machines) main process.
|
||||
|
||||
.. warning::
|
||||
|
||||
This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
|
||||
"""
|
||||
warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
|
||||
return self.is_local_process_zero()
|
||||
|
||||
def is_local_process_zero(self) -> bool:
|
||||
"""
|
||||
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
||||
several machines) main process.
|
||||
"""
|
||||
if is_torch_tpu_available():
|
||||
return xm.is_master_ordinal(local=True)
|
||||
else:
|
||||
@ -753,8 +748,20 @@ class Trainer:
|
||||
|
||||
def is_world_master(self) -> bool:
|
||||
"""
|
||||
This will be True only in one process, even in distributed mode,
|
||||
even when training on multiple machines.
|
||||
Whether or not this process is the global main process (when training in a distributed fashion on
|
||||
several machines, this is only going to be :obj:`True` for one process).
|
||||
|
||||
.. warning::
|
||||
|
||||
This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
|
||||
"""
|
||||
warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
|
||||
return self.is_world_process_zero()
|
||||
|
||||
def is_world_process_zero(self) -> bool:
|
||||
"""
|
||||
Whether or not this process is the global main process (when training in a distributed fashion on
|
||||
several machines, this is only going to be :obj:`True` for one process).
|
||||
"""
|
||||
if is_torch_tpu_available():
|
||||
return xm.is_master_ordinal(local=False)
|
||||
@ -770,7 +777,7 @@ class Trainer:
|
||||
|
||||
if is_torch_tpu_available():
|
||||
self._save_tpu(output_dir)
|
||||
elif self.is_world_master():
|
||||
elif self.is_world_process_zero():
|
||||
self._save(output_dir)
|
||||
|
||||
def _save_tpu(self, output_dir: Optional[str] = None):
|
||||
@ -846,6 +853,7 @@ class Trainer:
|
||||
Args:
|
||||
eval_dataset (:obj:`Dataset`, `optional`):
|
||||
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
||||
"""
|
||||
@ -871,6 +879,7 @@ class Trainer:
|
||||
Args:
|
||||
test_dataset (:obj:`Dataset`):
|
||||
Dataset to run the predictions on.
|
||||
|
||||
Returns:
|
||||
`NamedTuple`:
|
||||
predictions (:obj:`np.ndarray`):
|
||||
|
@ -63,17 +63,6 @@ class TFTrainer:
|
||||
an instance of :class:`~transformers.WarmUp`.
|
||||
"""
|
||||
|
||||
model: TFPreTrainedModel
|
||||
args: TFTrainingArguments
|
||||
train_dataset: Optional[tf.data.Dataset]
|
||||
eval_dataset: Optional[tf.data.Dataset]
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
||||
prediction_loss_only: bool
|
||||
tb_writer: Optional[tf.summary.SummaryWriter] = None
|
||||
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (None, None)
|
||||
global_step: Optional[int] = None
|
||||
epoch_logging: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: TFPreTrainedModel,
|
||||
@ -325,6 +314,15 @@ class TFTrainer:
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
"""
|
||||
Log :obj:`logs` on the various objects watching training.
|
||||
|
||||
Subclass and override this method to inject custom behavior.
|
||||
|
||||
Args:
|
||||
logs (:obj:`Dict[str, float]`):
|
||||
The values to log.
|
||||
"""
|
||||
if hasattr(self, "_log"):
|
||||
warnings.warn(
|
||||
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
||||
@ -356,6 +354,7 @@ class TFTrainer:
|
||||
Args:
|
||||
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
|
||||
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
||||
"""
|
||||
@ -577,9 +576,12 @@ class TFTrainer:
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
|
||||
Args:
|
||||
features: the batched features.
|
||||
labels: the batched labels.
|
||||
training: run the model in training mode or not
|
||||
features (:obj:`tf.Tensor`): A batch of input features.
|
||||
labels (:obj:`tf.Tensor`): A batch of labels.
|
||||
training (:obj:`bool`): Whether or not to run the model in training mode.
|
||||
|
||||
Returns:
|
||||
A tuple of two :obj:`tf.Tensor`: The loss and logits.
|
||||
"""
|
||||
if hasattr(self, "_run_model"):
|
||||
warnings.warn(
|
||||
|
Loading…
Reference in New Issue
Block a user