mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix LR decay in TF Trainer (#5269)
* Recover old PR * Apply style * Trigger CI
This commit is contained in:
parent
321c05abab
commit
7cb52f53ef
@ -3,6 +3,7 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -21,6 +22,12 @@ if is_wandb_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
|
||||
class TFTrainer:
|
||||
model: TFPreTrainedModel
|
||||
args: TFTrainingArguments
|
||||
@ -59,6 +66,7 @@ class TFTrainer:
|
||||
self.tb_writer = tb_writer
|
||||
else:
|
||||
self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
|
||||
|
||||
if is_wandb_available():
|
||||
self._setup_wandb()
|
||||
else:
|
||||
@ -67,6 +75,8 @@ class TFTrainer:
|
||||
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
|
||||
)
|
||||
|
||||
set_seed(self.args.seed)
|
||||
|
||||
def get_train_tfdataset(self) -> tf.data.Dataset:
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
@ -109,7 +119,7 @@ class TFTrainer:
|
||||
return self.args.strategy.experimental_distribute_dataset(ds)
|
||||
|
||||
def get_optimizers(
|
||||
self,
|
||||
self, num_training_steps: int,
|
||||
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
@ -123,7 +133,7 @@ class TFTrainer:
|
||||
|
||||
optimizer, scheduler = create_optimizer(
|
||||
self.args.learning_rate,
|
||||
self.train_steps,
|
||||
num_training_steps,
|
||||
self.args.warmup_steps,
|
||||
adam_epsilon=self.args.adam_epsilon,
|
||||
weight_decay_rate=self.args.weight_decay,
|
||||
@ -238,14 +248,19 @@ class TFTrainer:
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def _log(self, logs: Dict[str, float]) -> None:
|
||||
logs["epoch"] = self.epoch_logging
|
||||
|
||||
if self.tb_writer:
|
||||
with self.tb_writer.as_default():
|
||||
for k, v in logs.items():
|
||||
tf.summary.scalar(k, v, step=self.global_step)
|
||||
self.tb_writer.flush()
|
||||
|
||||
if is_wandb_available():
|
||||
wandb.log(logs, step=self.global_step)
|
||||
|
||||
output = {**logs, **{"step": self.global_step}}
|
||||
|
||||
logger.info(output)
|
||||
|
||||
def evaluate(
|
||||
@ -260,6 +275,7 @@ class TFTrainer:
|
||||
|
||||
logs = {**output.metrics}
|
||||
logs["epoch"] = self.epoch_logging
|
||||
|
||||
self._log(logs)
|
||||
|
||||
return output.metrics
|
||||
@ -275,25 +291,45 @@ class TFTrainer:
|
||||
|
||||
self.gradient_accumulator.reset()
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
t_total = self.args.max_steps
|
||||
steps_per_epoch = self.args.max_steps
|
||||
else:
|
||||
if self.args.dataloader_drop_last:
|
||||
approx = math.floor
|
||||
else:
|
||||
approx = math.ceil
|
||||
|
||||
steps_per_epoch = approx(
|
||||
self.num_train_examples / (self.args.train_batch_size * self.args.gradient_accumulation_steps)
|
||||
)
|
||||
t_total = steps_per_epoch * self.args.num_train_epochs
|
||||
|
||||
with self.args.strategy.scope():
|
||||
optimizer, lr_scheduler = self.get_optimizers()
|
||||
optimizer, lr_scheduler = self.get_optimizers(num_training_steps=t_total)
|
||||
iterations = optimizer.iterations
|
||||
self.global_step = iterations.numpy()
|
||||
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
|
||||
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
|
||||
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
|
||||
|
||||
if self.model.ckpt_manager.latest_checkpoint:
|
||||
epochs_trained = self.global_step // (self.num_train_examples // self.args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = self.global_step % (
|
||||
self.num_train_examples // self.args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", self.global_step)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
logger.info(
|
||||
"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
|
||||
)
|
||||
|
||||
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
|
||||
|
||||
if iterations.numpy() > 0:
|
||||
logger.info("Start the training from the last checkpoint")
|
||||
start_epoch = (iterations.numpy() // self.train_steps) + 1
|
||||
else:
|
||||
start_epoch = 1
|
||||
else:
|
||||
epochs_trained = 1
|
||||
|
||||
tf.summary.experimental.set_step(iterations)
|
||||
|
||||
@ -311,17 +347,23 @@ class TFTrainer:
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", self.num_train_examples)
|
||||
logger.info(" Num Epochs = %d", epochs)
|
||||
logger.info(" Total optimization steps = %d", self.train_steps)
|
||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
for epoch_iter in range(start_epoch, int(epochs + 1)):
|
||||
for epoch_iter in range(epochs_trained, int(epochs + 1)):
|
||||
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
|
||||
self.global_step = iterations.numpy()
|
||||
self.epoch_logging = epoch_iter - 1 + (step + 1) / self.train_steps
|
||||
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
|
||||
|
||||
if self.args.debug:
|
||||
logs = {}
|
||||
logs["loss"] = training_loss.numpy()
|
||||
logs["epoch"] = self.epoch_logging
|
||||
|
||||
self._log(logs)
|
||||
|
||||
if self.global_step == 1 and self.args.debug:
|
||||
@ -333,18 +375,23 @@ class TFTrainer:
|
||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
||||
self.evaluate()
|
||||
|
||||
if self.global_step % self.args.logging_steps == 0:
|
||||
if (
|
||||
self.global_step % self.args.logging_steps == 0
|
||||
or self.global_step == 1
|
||||
and self.args.logging_first_step
|
||||
):
|
||||
logs = {}
|
||||
logs["loss"] = training_loss.numpy()
|
||||
logs["learning_rate"] = lr_scheduler(self.global_step).numpy()
|
||||
logs["epoch"] = self.epoch_logging
|
||||
|
||||
self._log(logs)
|
||||
|
||||
if self.global_step % self.args.save_steps == 0:
|
||||
ckpt_save_path = self.model.ckpt_manager.save()
|
||||
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
|
||||
|
||||
if self.global_step % self.train_steps == 0:
|
||||
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
|
||||
break
|
||||
|
||||
def _training_steps(self, ds, optimizer):
|
||||
|
Loading…
Reference in New Issue
Block a user