Clean push to hub API (#12187)

* Clean push to hub API

* Create working dir if it does not exist

* Different tweak

* New API + all models + test Flax

* Adds the Trainer clean up

* Update src/transformers/file_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address review comments

* (nit) output types

* No need to set clone_from when folder exists

* Update src/transformers/trainer.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Add generated_from_trainer tag

* Update to new version

* Fixes

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Chaumond <julien@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger 2021-06-23 10:11:19 -04:00 committed by GitHub
parent 625f512d5e
commit 53c60babe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 368 additions and 159 deletions

View File

@ -100,7 +100,7 @@ _deps = [
"flake8>=3.8.3",
"flax>=0.3.4",
"fugashi>=1.0",
"huggingface-hub==0.0.8",
"huggingface-hub==0.0.11",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",

View File

@ -337,12 +337,25 @@ class PretrainedConfig(PushToHubMixin):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
.. warning::
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
instead.
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
@ -351,7 +364,7 @@ class PretrainedConfig(PushToHubMixin):
logger.info(f"Configuration saved in {output_config_file}")
if push_to_hub:
url = self._push_to_hub(save_files=[output_config_file], **kwargs)
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Configuration pushed to the hub in this commit: {url}")
@classmethod

View File

@ -17,7 +17,7 @@ deps = {
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.3.4",
"fugashi": "fugashi>=1.0",
"huggingface-hub": "huggingface-hub==0.0.8",
"huggingface-hub": "huggingface-hub==0.0.11",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",

View File

@ -24,6 +24,7 @@ import json
import os
import re
import shutil
import subprocess
import sys
import tarfile
import tempfile
@ -31,7 +32,6 @@ import types
from collections import OrderedDict, UserDict
from contextlib import contextmanager
from dataclasses import fields
from distutils.dir_util import copy_tree
from enum import Enum
from functools import partial, wraps
from hashlib import sha256
@ -1907,6 +1907,30 @@ def copy_func(f):
return g
def is_local_clone(repo_path, repo_url):
"""
Checks if the folder in `repo_path` is a local clone of `repo_url`.
"""
# First double-check that `repo_path` is a git repo
if not os.path.exists(os.path.join(repo_path, ".git")):
return False
test_git = subprocess.run("git branch".split(), cwd=repo_path)
if test_git.returncode != 0:
return False
# Then look at its remotes
remotes = subprocess.run(
"git remote -v".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=repo_path,
).stdout
return repo_url in remotes.split()
class PushToHubMixin:
"""
A Mixin containing the functionality to push a model or tokenizer to the hub.
@ -1914,24 +1938,31 @@ class PushToHubMixin:
def push_to_hub(
self,
repo_name: Optional[str] = None,
repo_path_or_name: Optional[str] = None,
repo_url: Optional[str] = None,
use_temp_dir: bool = False,
commit_message: Optional[str] = None,
organization: Optional[str] = None,
private: bool = None,
private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> str:
"""
Upload model checkpoint or tokenizer files to the 🤗 model hub.
Upload model checkpoint or tokenizer files to the 🤗 Model Hub while synchronizing a local clone of the repo in
:obj:`repo_path_or_name`.
Parameters:
repo_name (:obj:`str`, `optional`):
Repository name for your model or tokenizer in the hub. If not specified, the repository name will be
the stem of :obj:`save_directory`.
repo_path_or_name (:obj:`str`, `optional`):
Can either be a repository name for your model or tokenizer in the Hub or a path to a local folder (in
which case the repository will have the name of that local folder). If not specified, will default to
the name given by :obj:`repo_url` and a local directory with that name will be created.
repo_url (:obj:`str`, `optional`):
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
repository will be created in your namespace (unless you specify an :obj:`organization`) with
:obj:`repo_name`.
use_temp_dir (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clone the distant repo in a temporary directory or in :obj:`repo_path_or_name` inside
the current working directory. This will slow things down if you are making changes in an existing repo
since you will need to clone the repo before every push.
commit_message (:obj:`str`, `optional`):
Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or
:obj:`"add model"` depending on the type of the class.
@ -1948,42 +1979,66 @@ class PushToHubMixin:
Returns:
The url of the commit of your model in the given repository.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
self.save_pretrained(tmp_dir)
self._push_to_hub(
save_directory=tmp_dir,
repo_name=repo_name,
repo_url=repo_url,
commit_message=commit_message,
organization=organization,
private=private,
use_auth_token=use_auth_token,
)
@classmethod
def _push_to_hub(
cls,
save_directory: Optional[str] = None,
save_files: Optional[List[str]] = None,
repo_name: Optional[str] = None,
repo_url: Optional[str] = None,
commit_message: Optional[str] = None,
Examples::
# Upload a model to the Hub:
from transformers import AutoModel
model = BertModel.from_pretrained("bert-base-cased")
# Fine-tuning code
# Push the model to your namespace with the name "my-finetuned-bert" and have a local clone in the
# `my-finetuned-bert` folder.
model.push_to_hub("my-finetuned-bert")
# Push the model to your namespace with the name "my-finetuned-bert" with no local clone.
model.push_to_hub("my-finetuned-bert", use_temp_dir=True)
# Push the model to an organization with the name "my-finetuned-bert" and have a local clone in the
# `my-finetuned-bert` folder.
model.push_to_hub("my-finetuned-bert", organization="huggingface")
# Make a change to an existing repo that has been cloned locally in `my-finetuned-bert`.
model.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
"""
if use_temp_dir:
# Make sure we use the right `repo_name` for the `repo_url` before replacing it.
if repo_url is None:
if use_auth_token is None:
use_auth_token = True
repo_name = Path(repo_path_or_name).name
repo_url = self._get_repo_url_from_name(
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
)
repo_path_or_name = tempfile.mkdtemp()
# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
repo = self._create_or_get_repo(
repo_path_or_name=repo_path_or_name,
repo_url=repo_url,
organization=organization,
private=private,
use_auth_token=use_auth_token,
)
# Save the files in the cloned repo
self.save_pretrained(repo_path_or_name)
# Commit and push!
url = self._push_to_hub(repo, commit_message=commit_message)
# Clean up! Clean up! Everybody everywhere!
if use_temp_dir:
shutil.rmtree(repo_path_or_name)
return url
@staticmethod
def _get_repo_url_from_name(
repo_name: str,
organization: Optional[str] = None,
private: bool = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> str:
# Private version of push_to_hub, that either accepts a folder to push or a list of files.
if save_directory is None and save_files is None:
raise ValueError("_push_to_hub requires either a `save_directory` or a list of `save_files`.")
if repo_name is None and repo_url is None and save_directory is None:
raise ValueError("Need either a `repo_name` or `repo_url` to know where to push!")
if repo_name is None and repo_url is None and save_files is None:
repo_name = Path(save_directory).name
if use_auth_token is None and repo_url is None:
use_auth_token = True
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token:
@ -1997,33 +2052,56 @@ class PushToHubMixin:
else:
token = None
if repo_url is None:
# Special provision for the test endpoint (CI)
repo_url = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo(
token,
repo_name,
organization=organization,
private=private,
repo_type=None,
exist_ok=True,
# Special provision for the test endpoint (CI)
return HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo(
token,
repo_name,
organization=organization,
private=private,
repo_type=None,
exist_ok=True,
)
@classmethod
def _create_or_get_repo(
cls,
repo_path_or_name: Optional[str] = None,
repo_url: Optional[str] = None,
organization: Optional[str] = None,
private: bool = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> Repository:
if repo_path_or_name is None and repo_url is None:
raise ValueError("You need to specify a `repo_path_or_name` or a `repo_url`.")
if use_auth_token is None and repo_url is None:
use_auth_token = True
if repo_path_or_name is None:
repo_path_or_name = repo_url.split("/")[-1]
if repo_url is None and not os.path.exists(repo_path_or_name):
repo_name = Path(repo_path_or_name).name
repo_url = cls._get_repo_url_from_name(
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
)
# Create a working directory if it does not exist.
if not os.path.exists(repo_path_or_name):
os.makedirs(repo_path_or_name)
repo = Repository(repo_path_or_name, clone_from=repo_url, use_auth_token=use_auth_token)
repo.git_pull()
return repo
@classmethod
def _push_to_hub(cls, repo: Repository, commit_message: Optional[str] = None) -> str:
if commit_message is None:
if "Tokenizer" in cls.__name__:
commit_message = "add tokenizer"
if "Config" in cls.__name__:
elif "Config" in cls.__name__:
commit_message = "add config"
else:
commit_message = "add model"
with tempfile.TemporaryDirectory() as tmp_dir:
# First create the repo (and clone its content if it's nonempty), then add the files (otherwise there is
# no diff so nothing is pushed).
repo = Repository(tmp_dir, clone_from=repo_url, use_auth_token=use_auth_token)
if save_directory is None:
for filename in save_files:
shutil.copy(filename, Path(tmp_dir) / Path(filename).name)
else:
copy_tree(save_directory, tmp_dir)
return repo.push_to_hub(commit_message=commit_message)
return repo.push_to_hub(commit_message=commit_message)

View File

@ -565,6 +565,14 @@ class TrainingSummary:
if model_name is None:
model_name = Path(trainer.args.output_dir).name
# Add `generated_from_trainer` to the tags
if tags is None:
tags = ["generated_from_trainer"]
elif isinstance(tags, str) and tags != "generated_from_trainer":
tags = [tags, "generated_from_trainer"]
elif "generated_from_trainer" not in tags:
tags.append("generated_from_trainer")
_, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
hyperparameters = extract_hyperparameters_from_trainer(trainer)

View File

@ -28,7 +28,6 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .file_utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
WEIGHTS_NAME,
PushToHubMixin,
@ -409,6 +408,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
Directory to which to save. Will be created if it doesn't exist.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
.. warning::
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
instead.
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
@ -416,6 +423,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
# get abs dir
@ -434,8 +446,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
logger.info(f"Model weights saved in {output_model_file}")
if push_to_hub:
saved_files = [os.path.join(save_directory, CONFIG_NAME), output_model_file]
url = self._push_to_hub(save_files=saved_files, **kwargs)
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Model pushed to the hub in this commit: {url}")

View File

@ -30,7 +30,6 @@ from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig
from .file_utils import (
CONFIG_NAME,
DUMMY_INPUTS,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
@ -1029,6 +1028,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
https://www.tensorflow.org/tfx/serving/serving_basic
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
.. warning::
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
instead.
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
@ -1036,6 +1043,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
if saved_model:
@ -1053,8 +1065,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
logger.info(f"Model weights saved in {output_model_file}")
if push_to_hub:
saved_files = [os.path.join(save_directory, CONFIG_NAME), output_model_file]
url = self._push_to_hub(save_files=saved_files, **kwargs)
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Model pushed to the hub in this commit: {url}")
@classmethod

View File

@ -30,7 +30,6 @@ from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .file_utils import (
CONFIG_NAME,
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
@ -852,6 +851,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
need to replace :obj:`torch.save` by another method.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
.. warning::
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
instead.
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
@ -859,6 +866,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
# Only save the model itself if we are using distributed training
@ -886,10 +898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.info(f"Model weights saved in {output_model_file}")
if push_to_hub:
saved_files = [output_model_file]
if save_config:
saved_files.append(os.path.join(save_directory, CONFIG_NAME))
url = self._push_to_hub(save_files=saved_files, **kwargs)
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Model pushed to the hub in this commit: {url}")
@classmethod

View File

@ -1884,6 +1884,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
value error is raised.
filename_prefix: (:obj:`str`, `optional`):
A prefix to add to the names of the files saved by the tokenizer.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
.. warning::
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
instead.
Returns:
A tuple of :obj:`str`: The files saved.
@ -1891,6 +1900,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
special_tokens_map_file = os.path.join(
@ -1949,9 +1963,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
)
if push_to_hub:
# Annoyingly, the return contains files that don't exist.
existing_files = [f for f in save_files if os.path.isfile(f)]
url = self._push_to_hub(save_files=existing_files, **kwargs)
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Tokenizer pushed to the hub in this commit: {url}")
return save_files

View File

@ -24,7 +24,6 @@ import random
import re
import shutil
import sys
import tempfile
import time
import warnings
from logging import StreamHandler
@ -391,9 +390,12 @@ class Trainer:
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
self._loggers_initialized = False
# Create output directory if needed
# Create clone of distant repo and output directory if needed
if self.args.push_to_hub:
self.init_git_repo()
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
@ -2430,6 +2432,27 @@ class Trainer:
else:
return 0
def init_git_repo(self):
"""
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
"""
if not self.is_world_process_zero():
return
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
repo_url = PushToHubMixin._get_repo_url_from_name(
self.args.push_to_hub_model_id,
organization=self.args.push_to_hub_organization,
use_auth_token=use_auth_token,
)
self.repo = PushToHubMixin._create_or_get_repo(
self.args.output_dir, repo_url=repo_url, use_auth_token=use_auth_token
)
# By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"])
def create_model_card(
self,
language: Optional[str] = None,
@ -2458,38 +2481,13 @@ class Trainer:
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
f.write(model_card)
def push_to_hub(
self,
repo_name: Optional[str] = None,
repo_url: Optional[str] = None,
commit_message: Optional[str] = "add model",
organization: Optional[str] = None,
private: bool = None,
use_auth_token: Optional[Union[bool, str]] = None,
**kwargs,
):
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
"""
Upload `self.model` to the 🤗 model hub.
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.push_to_hub_model_id`.
Parameters:
repo_name (:obj:`str`, `optional`):
Repository name for your model or tokenizer in the hub. If not specified and :obj:`repo_url` is not
specified either, will default to the stem of :obj:`self.args.output_dir`.
repo_url (:obj:`str`, `optional`):
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
repository will be created in your namespace (unless you specify an :obj:`organization`) with
:obj:`repo_name`.
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
Message to commit while pushing.
organization (:obj:`str`, `optional`):
Organization in which you want to push your model or tokenizer (you must be a member of this
organization).
private (:obj:`bool`, `optional`):
Whether or not the repository created should be private (requires a paying subscription).
use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
:obj:`True` if :obj:`repo_url` is not specified.
kwargs:
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
@ -2499,37 +2497,9 @@ class Trainer:
if not self.is_world_process_zero():
return
if not isinstance(unwrap_model(self.model), PushToHubMixin):
raise ValueError(
"The `upload_model_to_hub` method only works for models that inherit from `PushToHubMixin` models."
)
if repo_url is None and repo_name is None:
repo_name = Path(self.args.output_dir).name
if repo_name is not None:
model_name = repo_name
elif repo_url is not None:
model_name = repo_url.split("/")[-1]
else:
model_name = None
self.create_model_card(model_name=model_name, **kwargs)
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy(os.path.join(self.args.output_dir, "README.md"), os.path.join(tmp_dir, "README.md"))
unwrap_model(self.model).save_pretrained(tmp_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(tmp_dir)
return unwrap_model(self.model)._push_to_hub(
save_directory=tmp_dir,
repo_name=repo_name,
repo_url=repo_url,
commit_message=commit_message,
organization=organization,
private=private,
use_auth_token=use_auth_token,
)
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
self.save_model()
return self.repo.push_to_hub(commit_message=commit_message)
#
# Deprecated code

View File

@ -17,6 +17,7 @@ import os
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from .debug_utils import DebugOption
@ -157,7 +158,7 @@ class TrainingArguments:
node.
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.
`output_dir/runs/**CURRENT_DATETIME_HOSTNAME**`.
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
The logging strategy to adopt during training. Possible values are:
@ -318,15 +319,22 @@ class TrainingArguments:
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to upload the trained model to the hub after training. This argument is not directly used by
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
Whether or not to upload the trained model to the hub after training. If this is activated, and
:obj:`output_dir` exists, it needs to be a local clone of the repository to which the
:class:`~transformers.Trainer` will be pushed.
resume_from_checkpoint (:obj:`str`, `optional`):
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
push_to_hub_model_id (:obj:`str`, `optional`):
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`.
Will default to the name of :obj:`output_dir`.
push_to_hub_organization (:obj:`str`, `optional`):
The name of the organization in with to which push the :class:`~transformers.Trainer`.
push_to_hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
"""
output_dir: str = field(
@ -590,6 +598,13 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
push_to_hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
)
push_to_hub_organization: str = field(
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
)
push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
@ -612,6 +627,8 @@ class TrainingArguments:
#  see https://github.com/huggingface/transformers/issues/10628
if self.output_dir is not None:
self.output_dir = os.path.expanduser(self.output_dir)
if self.logging_dir is None and self.output_dir is not None:
self.logging_dir = os.path.join(self.output_dir, default_logdir())
if self.logging_dir is not None:
self.logging_dir = os.path.expanduser(self.logging_dir)
@ -705,6 +722,9 @@ class TrainingArguments:
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
self.hf_deepspeed_config.trainer_config_process(self)
if self.push_to_hub_model_id is None:
self.push_to_hub_model_id = Path(self.output_dir).name
def __str__(self):
self_as_dict = asdict(self)

View File

@ -113,7 +113,7 @@ class ConfigPushToHubTester(unittest.TestCase):
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-config", use_auth_token=self._token)
config.save_pretrained(os.path.join(tmp_dir, "test-config"), push_to_hub=True, use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items():
@ -127,9 +127,8 @@ class ConfigPushToHubTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(
tmp_dir,
os.path.join(tmp_dir, "test-config-org"),
push_to_hub=True,
repo_name="test-config-org",
use_auth_token=self._token,
organization="valid_org",
)

View File

@ -1613,7 +1613,7 @@ class ModelPushToHubTester(unittest.TestCase):
)
model = BertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token)
model.save_pretrained(os.path.join(tmp_dir, "test-model"), push_to_hub=True, use_auth_token=self._token)
new_model = BertModel.from_pretrained(f"{USER}/test-model")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
@ -1626,9 +1626,8 @@ class ModelPushToHubTester(unittest.TestCase):
model = BertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
tmp_dir,
os.path.join(tmp_dir, "test-model-org"),
push_to_hub=True,
repo_name="test-model-org",
use_auth_token=self._token,
organization="valid_org",
)

View File

@ -16,14 +16,25 @@ import copy
import inspect
import random
import tempfile
import unittest
from typing import List, Tuple
import numpy as np
import transformers
from transformers import is_flax_available, is_torch_available
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import BertConfig, FlaxBertModel, is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
slow,
)
if is_flax_available():
@ -504,3 +515,65 @@ class FlaxModelTesterMixin:
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
@require_flax
@is_staging_test
class FlaxModelPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-model-flax")
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
except HTTPError:
pass
def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
os.path.join(tmp_dir, "test-model-flax"), push_to_hub=True, use_auth_token=self._token
)
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
os.path.join(tmp_dir, "test-model-flax-org"),
push_to_hub=True,
use_auth_token=self._token,
organization="valid_org",
)
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

View File

@ -1487,7 +1487,7 @@ class TFModelPushToHubTester(unittest.TestCase):
# Make sure model is properly initialized
_ = model(model.dummy_inputs)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model-tf", use_auth_token=self._token)
model.save_pretrained(os.path.join(tmp_dir, "test-model-tf"), push_to_hub=True, use_auth_token=self._token)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
@ -1503,9 +1503,8 @@ class TFModelPushToHubTester(unittest.TestCase):
model = TFBertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
tmp_dir,
os.path.join(tmp_dir, "test-model-tf-org"),
push_to_hub=True,
repo_name="test-model-tf-org",
use_auth_token=self._token,
organization="valid_org",
)

View File

@ -3173,7 +3173,7 @@ class TokenizerPushToHubTester(unittest.TestCase):
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.save_pretrained(
tmp_dir, push_to_hub=True, repo_name="test-tokenizer", use_auth_token=self._token
os.path.join(tmp_dir, "test-tokenizer"), push_to_hub=True, use_auth_token=self._token
)
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
@ -3186,9 +3186,8 @@ class TokenizerPushToHubTester(unittest.TestCase):
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.save_pretrained(
tmp_dir,
os.path.join(tmp_dir, "test-tokenizer-org"),
push_to_hub=True,
repo_name="test-tokenizer-org",
use_auth_token=self._token,
organization="valid_org",
)

View File

@ -1274,8 +1274,12 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(output_dir=tmp_dir)
url = trainer.push_to_hub(repo_name="test-trainer", use_auth_token=self._token)
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer"),
push_to_hub=True,
push_to_hub_token=self._token,
)
url = trainer.push_to_hub()
# Extract repo_name from the url
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
@ -1292,9 +1296,13 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(output_dir=tmp_dir)
trainer.save_model()
url = trainer.push_to_hub(
repo_name="test-trainer-org", organization="valid_org", use_auth_token=self._token
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-org"),
push_to_hub=True,
push_to_hub_organization="valid_org",
push_to_hub_token=self._token,
)
url = trainer.push_to_hub()
# Extract repo_name from the url
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)