TF Checkpoints (#4831)

* Align checkpoint dir with the PT trainer

* Use args for max to keep checkpoints
This commit is contained in:
Julien Plu 2020-06-08 15:45:23 +02:00 committed by GitHub
parent 439f1cab20
commit 36dfc317b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -230,8 +230,9 @@ class TFTrainer:
with self.args.strategy.scope():
optimizer, lr_scheduler = self.get_optimizers()
iterations = optimizer.iterations
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=5)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
if self.model.ckpt_manager.latest_checkpoint:
logger.info(
@ -401,17 +402,12 @@ class TFTrainer:
def save_model(self, output_dir: Optional[str] = None):
"""
Save the pretrained model and create a Tensorflow saved model.
Save the pretrained model.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info("Saving model in {}".format(output_dir))
path = os.path.join(self.args.output_dir, "saved_model")
logger.info("Saving model in {}".format(path))
os.makedirs(path, exist_ok=True)
if not isinstance(self.model, TFPreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")