Keras callback to push to hub each epoch, or after N steps (#13773)

* Keras callback to push to hub each epoch, or after N steps

* Reworked the callback to use Repository

* Use an Enum for save_strategy

* Style pass

* Correct type for tokenizer

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding print message to the final upload

* Adding print message to the final upload

* Change how we wait for the last process to finish

* is_done is a property, not a method, derp

* Docstrings and documentation

* Style pass

* Style edit

* Docstring reformat

* Docstring rewrite

* Replacing print with internal logger

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matt 2021-09-29 12:47:35 +01:00 committed by GitHub
parent aa018a795d
commit 3a8a8013ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 120 additions and 0 deletions

View File

@ -533,6 +533,7 @@ Flax), PyTorch, and/or TensorFlow.
main_classes/callback
main_classes/configuration
main_classes/data_collator
main_classes/keras_callbacks
main_classes/logging
main_classes/model
main_classes/optimizer_schedules

View File

@ -0,0 +1,22 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
Keras callbacks
=======================================================================================================================
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
tasks:
PushToHubCallback
-----------------------------------------------------------------------------------------------------------------------
.. autoclass:: transformers.keras_callbacks.PushToHubCallback

View File

@ -0,0 +1,97 @@
import logging
from pathlib import Path
from time import sleep
from typing import Optional, Union
from tensorflow.keras.callbacks import Callback
from huggingface_hub import Repository
from . import IntervalStrategy, PreTrainedTokenizerBase
from .file_utils import get_full_repo_name
logger = logging.getLogger(__name__)
class PushToHubCallback(Callback):
def __init__(
self,
output_dir: Union[str, Path],
save_strategy: Union[str, IntervalStrategy] = "epoch",
save_steps: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
):
"""
output_dir (:obj:`str`):
The output directory where the model predictions and checkpoints will be written and synced with the
repository on the Hub.
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"epoch"`):
The checkpoint save strategy to adopt during training. Possible values are:
* :obj:`"no"`: No save is done during training.
* :obj:`"epoch"`: Save is done at the end of each epoch.
* :obj:`"steps"`: Save is done every :obj:`save_steps`
save_steps (:obj:`int`, `optional`):
The number of steps between saves when using the "steps" save_strategy.
tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`):
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
hub_model_id (:obj:`str`, `optional`):
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
of with :obj:`"organization_name/model"`. Will default to :obj:`user_name/output_dir_name` with
`output_dir_name` being the name of :obj:`output_dir`.
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
"""
super().__init__()
if isinstance(save_strategy, str):
save_strategy = IntervalStrategy(save_strategy.lower())
self.save_strategy = save_strategy
if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
self.save_steps = save_steps
output_dir = Path(output_dir)
if hub_model_id is None:
repo_name = get_full_repo_name(output_dir.absolute().name, token=hub_token)
else:
repo_name = hub_model_id
self.output_dir = output_dir
self.repo = Repository(str(output_dir), clone_from=repo_name)
self.tokenizer = tokenizer
self.last_job = None
def on_train_batch_end(self, batch, logs=None):
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
if self.last_job is not None and not self.last_job.is_done:
return # The last upload is still running, don't start another
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress steps {batch}", blocking=False
)
def on_epoch_end(self, epoch, logs=None):
if self.save_strategy == IntervalStrategy.EPOCH:
if self.last_job is not None and not self.last_job.is_done:
return # The last upload is still running, don't start another
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False
)
def on_train_end(self, logs=None):
if self.last_job is not None and not self.last_job.is_done:
logger.info("Waiting for existing upload to finish...")
while not self.last_job.is_done:
sleep(1)
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
self.repo.push_to_hub(commit_message="End of training", blocking=True)