mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Keras callback to push to hub each epoch, or after N steps (#13773)
* Keras callback to push to hub each epoch, or after N steps * Reworked the callback to use Repository * Use an Enum for save_strategy * Style pass * Correct type for tokenizer * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding print message to the final upload * Adding print message to the final upload * Change how we wait for the last process to finish * is_done is a property, not a method, derp * Docstrings and documentation * Style pass * Style edit * Docstring reformat * Docstring rewrite * Replacing print with internal logger Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
aa018a795d
commit
3a8a8013ad
@ -533,6 +533,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
main_classes/callback
|
||||
main_classes/configuration
|
||||
main_classes/data_collator
|
||||
main_classes/keras_callbacks
|
||||
main_classes/logging
|
||||
main_classes/model
|
||||
main_classes/optimizer_schedules
|
||||
|
22
docs/source/main_classes/keras_callbacks.rst
Normal file
22
docs/source/main_classes/keras_callbacks.rst
Normal file
@ -0,0 +1,22 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
Keras callbacks
|
||||
=======================================================================================================================
|
||||
|
||||
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
|
||||
tasks:
|
||||
|
||||
PushToHubCallback
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. autoclass:: transformers.keras_callbacks.PushToHubCallback
|
97
src/transformers/keras_callbacks.py
Normal file
97
src/transformers/keras_callbacks.py
Normal file
@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import Optional, Union
|
||||
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
|
||||
from huggingface_hub import Repository
|
||||
|
||||
from . import IntervalStrategy, PreTrainedTokenizerBase
|
||||
from .file_utils import get_full_repo_name
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PushToHubCallback(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: Union[str, Path],
|
||||
save_strategy: Union[str, IntervalStrategy] = "epoch",
|
||||
save_steps: Optional[int] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
hub_model_id: Optional[str] = None,
|
||||
hub_token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
output_dir (:obj:`str`):
|
||||
The output directory where the model predictions and checkpoints will be written and synced with the
|
||||
repository on the Hub.
|
||||
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"epoch"`):
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No save is done during training.
|
||||
* :obj:`"epoch"`: Save is done at the end of each epoch.
|
||||
* :obj:`"steps"`: Save is done every :obj:`save_steps`
|
||||
save_steps (:obj:`int`, `optional`):
|
||||
The number of steps between saves when using the "steps" save_strategy.
|
||||
tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`):
|
||||
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
|
||||
hub_model_id (:obj:`str`, `optional`):
|
||||
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
|
||||
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
|
||||
of with :obj:`"organization_name/model"`. Will default to :obj:`user_name/output_dir_name` with
|
||||
`output_dir_name` being the name of :obj:`output_dir`.
|
||||
hub_token (:obj:`str`, `optional`):
|
||||
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
||||
:obj:`huggingface-cli login`.
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(save_strategy, str):
|
||||
save_strategy = IntervalStrategy(save_strategy.lower())
|
||||
self.save_strategy = save_strategy
|
||||
if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
|
||||
raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
|
||||
self.save_steps = save_steps
|
||||
output_dir = Path(output_dir)
|
||||
if hub_model_id is None:
|
||||
repo_name = get_full_repo_name(output_dir.absolute().name, token=hub_token)
|
||||
else:
|
||||
repo_name = hub_model_id
|
||||
self.output_dir = output_dir
|
||||
self.repo = Repository(str(output_dir), clone_from=repo_name)
|
||||
self.tokenizer = tokenizer
|
||||
self.last_job = None
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
|
||||
if self.last_job is not None and not self.last_job.is_done:
|
||||
return # The last upload is still running, don't start another
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(self.output_dir)
|
||||
_, self.last_job = self.repo.push_to_hub(
|
||||
commit_message=f"Training in progress steps {batch}", blocking=False
|
||||
)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
if self.save_strategy == IntervalStrategy.EPOCH:
|
||||
if self.last_job is not None and not self.last_job.is_done:
|
||||
return # The last upload is still running, don't start another
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(self.output_dir)
|
||||
_, self.last_job = self.repo.push_to_hub(
|
||||
commit_message=f"Training in progress epoch {epoch}", blocking=False
|
||||
)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
if self.last_job is not None and not self.last_job.is_done:
|
||||
logger.info("Waiting for existing upload to finish...")
|
||||
while not self.last_job.is_done:
|
||||
sleep(1)
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(self.output_dir)
|
||||
self.repo.push_to_hub(commit_message="End of training", blocking=True)
|
Loading…
Reference in New Issue
Block a user