mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
TF Checkpoints (#4831)
* Align checkpoint dir with the PT trainer * Use args for max to keep checkpoints
This commit is contained in:
parent
439f1cab20
commit
36dfc317b3
@ -230,8 +230,9 @@ class TFTrainer:
|
||||
with self.args.strategy.scope():
|
||||
optimizer, lr_scheduler = self.get_optimizers()
|
||||
iterations = optimizer.iterations
|
||||
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
|
||||
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
|
||||
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=5)
|
||||
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
|
||||
|
||||
if self.model.ckpt_manager.latest_checkpoint:
|
||||
logger.info(
|
||||
@ -401,17 +402,12 @@ class TFTrainer:
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None):
|
||||
"""
|
||||
Save the pretrained model and create a Tensorflow saved model.
|
||||
Save the pretrained model.
|
||||
"""
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
|
||||
logger.info("Saving model in {}".format(output_dir))
|
||||
|
||||
path = os.path.join(self.args.output_dir, "saved_model")
|
||||
|
||||
logger.info("Saving model in {}".format(path))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
if not isinstance(self.model, TFPreTrainedModel):
|
||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user