mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Clean push to hub API (#12187)
* Clean push to hub API * Create working dir if it does not exist * Different tweak * New API + all models + test Flax * Adds the Trainer clean up * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * (nit) output types * No need to set clone_from when folder exists * Update src/transformers/trainer.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Add generated_from_trainer tag * Update to new version * Fixes Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Julien Chaumond <julien@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
625f512d5e
commit
53c60babe4
2
setup.py
2
setup.py
@ -100,7 +100,7 @@ _deps = [
|
|||||||
"flake8>=3.8.3",
|
"flake8>=3.8.3",
|
||||||
"flax>=0.3.4",
|
"flax>=0.3.4",
|
||||||
"fugashi>=1.0",
|
"fugashi>=1.0",
|
||||||
"huggingface-hub==0.0.8",
|
"huggingface-hub==0.0.11",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"ipadic>=1.0.0,<2.0",
|
"ipadic>=1.0.0,<2.0",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
|
@ -337,12 +337,25 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||||
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
|
||||||
|
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
|
||||||
|
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
|
||||||
|
instead.
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the
|
Additional key word arguments passed along to the
|
||||||
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
|
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
commit_message = kwargs.pop("commit_message", None)
|
||||||
|
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||||
@ -351,7 +364,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
logger.info(f"Configuration saved in {output_config_file}")
|
logger.info(f"Configuration saved in {output_config_file}")
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
url = self._push_to_hub(save_files=[output_config_file], **kwargs)
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
logger.info(f"Configuration pushed to the hub in this commit: {url}")
|
logger.info(f"Configuration pushed to the hub in this commit: {url}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -17,7 +17,7 @@ deps = {
|
|||||||
"flake8": "flake8>=3.8.3",
|
"flake8": "flake8>=3.8.3",
|
||||||
"flax": "flax>=0.3.4",
|
"flax": "flax>=0.3.4",
|
||||||
"fugashi": "fugashi>=1.0",
|
"fugashi": "fugashi>=1.0",
|
||||||
"huggingface-hub": "huggingface-hub==0.0.8",
|
"huggingface-hub": "huggingface-hub==0.0.11",
|
||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
|
@ -24,6 +24,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -31,7 +32,6 @@ import types
|
|||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from distutils.dir_util import copy_tree
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
@ -1907,6 +1907,30 @@ def copy_func(f):
|
|||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_clone(repo_path, repo_url):
|
||||||
|
"""
|
||||||
|
Checks if the folder in `repo_path` is a local clone of `repo_url`.
|
||||||
|
"""
|
||||||
|
# First double-check that `repo_path` is a git repo
|
||||||
|
if not os.path.exists(os.path.join(repo_path, ".git")):
|
||||||
|
return False
|
||||||
|
test_git = subprocess.run("git branch".split(), cwd=repo_path)
|
||||||
|
if test_git.returncode != 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Then look at its remotes
|
||||||
|
remotes = subprocess.run(
|
||||||
|
"git remote -v".split(),
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
check=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
cwd=repo_path,
|
||||||
|
).stdout
|
||||||
|
|
||||||
|
return repo_url in remotes.split()
|
||||||
|
|
||||||
|
|
||||||
class PushToHubMixin:
|
class PushToHubMixin:
|
||||||
"""
|
"""
|
||||||
A Mixin containing the functionality to push a model or tokenizer to the hub.
|
A Mixin containing the functionality to push a model or tokenizer to the hub.
|
||||||
@ -1914,24 +1938,31 @@ class PushToHubMixin:
|
|||||||
|
|
||||||
def push_to_hub(
|
def push_to_hub(
|
||||||
self,
|
self,
|
||||||
repo_name: Optional[str] = None,
|
repo_path_or_name: Optional[str] = None,
|
||||||
repo_url: Optional[str] = None,
|
repo_url: Optional[str] = None,
|
||||||
|
use_temp_dir: bool = False,
|
||||||
commit_message: Optional[str] = None,
|
commit_message: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
private: bool = None,
|
private: Optional[bool] = None,
|
||||||
use_auth_token: Optional[Union[bool, str]] = None,
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Upload model checkpoint or tokenizer files to the 🤗 model hub.
|
Upload model checkpoint or tokenizer files to the 🤗 Model Hub while synchronizing a local clone of the repo in
|
||||||
|
:obj:`repo_path_or_name`.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
repo_name (:obj:`str`, `optional`):
|
repo_path_or_name (:obj:`str`, `optional`):
|
||||||
Repository name for your model or tokenizer in the hub. If not specified, the repository name will be
|
Can either be a repository name for your model or tokenizer in the Hub or a path to a local folder (in
|
||||||
the stem of :obj:`save_directory`.
|
which case the repository will have the name of that local folder). If not specified, will default to
|
||||||
|
the name given by :obj:`repo_url` and a local directory with that name will be created.
|
||||||
repo_url (:obj:`str`, `optional`):
|
repo_url (:obj:`str`, `optional`):
|
||||||
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
|
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
|
||||||
repository will be created in your namespace (unless you specify an :obj:`organization`) with
|
repository will be created in your namespace (unless you specify an :obj:`organization`) with
|
||||||
:obj:`repo_name`.
|
:obj:`repo_name`.
|
||||||
|
use_temp_dir (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to clone the distant repo in a temporary directory or in :obj:`repo_path_or_name` inside
|
||||||
|
the current working directory. This will slow things down if you are making changes in an existing repo
|
||||||
|
since you will need to clone the repo before every push.
|
||||||
commit_message (:obj:`str`, `optional`):
|
commit_message (:obj:`str`, `optional`):
|
||||||
Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or
|
Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or
|
||||||
:obj:`"add model"` depending on the type of the class.
|
:obj:`"add model"` depending on the type of the class.
|
||||||
@ -1948,42 +1979,66 @@ class PushToHubMixin:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The url of the commit of your model in the given repository.
|
The url of the commit of your model in the given repository.
|
||||||
"""
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
self.save_pretrained(tmp_dir)
|
|
||||||
self._push_to_hub(
|
|
||||||
save_directory=tmp_dir,
|
|
||||||
repo_name=repo_name,
|
|
||||||
repo_url=repo_url,
|
|
||||||
commit_message=commit_message,
|
|
||||||
organization=organization,
|
|
||||||
private=private,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
Examples::
|
||||||
def _push_to_hub(
|
|
||||||
cls,
|
# Upload a model to the Hub:
|
||||||
save_directory: Optional[str] = None,
|
from transformers import AutoModel
|
||||||
save_files: Optional[List[str]] = None,
|
|
||||||
repo_name: Optional[str] = None,
|
model = BertModel.from_pretrained("bert-base-cased")
|
||||||
repo_url: Optional[str] = None,
|
# Fine-tuning code
|
||||||
commit_message: Optional[str] = None,
|
|
||||||
|
# Push the model to your namespace with the name "my-finetuned-bert" and have a local clone in the
|
||||||
|
# `my-finetuned-bert` folder.
|
||||||
|
model.push_to_hub("my-finetuned-bert")
|
||||||
|
|
||||||
|
# Push the model to your namespace with the name "my-finetuned-bert" with no local clone.
|
||||||
|
model.push_to_hub("my-finetuned-bert", use_temp_dir=True)
|
||||||
|
|
||||||
|
# Push the model to an organization with the name "my-finetuned-bert" and have a local clone in the
|
||||||
|
# `my-finetuned-bert` folder.
|
||||||
|
model.push_to_hub("my-finetuned-bert", organization="huggingface")
|
||||||
|
|
||||||
|
# Make a change to an existing repo that has been cloned locally in `my-finetuned-bert`.
|
||||||
|
model.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
|
||||||
|
"""
|
||||||
|
if use_temp_dir:
|
||||||
|
# Make sure we use the right `repo_name` for the `repo_url` before replacing it.
|
||||||
|
if repo_url is None:
|
||||||
|
if use_auth_token is None:
|
||||||
|
use_auth_token = True
|
||||||
|
repo_name = Path(repo_path_or_name).name
|
||||||
|
repo_url = self._get_repo_url_from_name(
|
||||||
|
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
|
||||||
|
)
|
||||||
|
repo_path_or_name = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
|
||||||
|
repo = self._create_or_get_repo(
|
||||||
|
repo_path_or_name=repo_path_or_name,
|
||||||
|
repo_url=repo_url,
|
||||||
|
organization=organization,
|
||||||
|
private=private,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
# Save the files in the cloned repo
|
||||||
|
self.save_pretrained(repo_path_or_name)
|
||||||
|
# Commit and push!
|
||||||
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
|
|
||||||
|
# Clean up! Clean up! Everybody everywhere!
|
||||||
|
if use_temp_dir:
|
||||||
|
shutil.rmtree(repo_path_or_name)
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_repo_url_from_name(
|
||||||
|
repo_name: str,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
private: bool = None,
|
private: bool = None,
|
||||||
use_auth_token: Optional[Union[bool, str]] = None,
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Private version of push_to_hub, that either accepts a folder to push or a list of files.
|
|
||||||
if save_directory is None and save_files is None:
|
|
||||||
raise ValueError("_push_to_hub requires either a `save_directory` or a list of `save_files`.")
|
|
||||||
if repo_name is None and repo_url is None and save_directory is None:
|
|
||||||
raise ValueError("Need either a `repo_name` or `repo_url` to know where to push!")
|
|
||||||
|
|
||||||
if repo_name is None and repo_url is None and save_files is None:
|
|
||||||
repo_name = Path(save_directory).name
|
|
||||||
if use_auth_token is None and repo_url is None:
|
|
||||||
use_auth_token = True
|
|
||||||
|
|
||||||
if isinstance(use_auth_token, str):
|
if isinstance(use_auth_token, str):
|
||||||
token = use_auth_token
|
token = use_auth_token
|
||||||
elif use_auth_token:
|
elif use_auth_token:
|
||||||
@ -1997,33 +2052,56 @@ class PushToHubMixin:
|
|||||||
else:
|
else:
|
||||||
token = None
|
token = None
|
||||||
|
|
||||||
if repo_url is None:
|
# Special provision for the test endpoint (CI)
|
||||||
# Special provision for the test endpoint (CI)
|
return HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo(
|
||||||
repo_url = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo(
|
token,
|
||||||
token,
|
repo_name,
|
||||||
repo_name,
|
organization=organization,
|
||||||
organization=organization,
|
private=private,
|
||||||
private=private,
|
repo_type=None,
|
||||||
repo_type=None,
|
exist_ok=True,
|
||||||
exist_ok=True,
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_or_get_repo(
|
||||||
|
cls,
|
||||||
|
repo_path_or_name: Optional[str] = None,
|
||||||
|
repo_url: Optional[str] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
private: bool = None,
|
||||||
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
|
) -> Repository:
|
||||||
|
if repo_path_or_name is None and repo_url is None:
|
||||||
|
raise ValueError("You need to specify a `repo_path_or_name` or a `repo_url`.")
|
||||||
|
|
||||||
|
if use_auth_token is None and repo_url is None:
|
||||||
|
use_auth_token = True
|
||||||
|
|
||||||
|
if repo_path_or_name is None:
|
||||||
|
repo_path_or_name = repo_url.split("/")[-1]
|
||||||
|
|
||||||
|
if repo_url is None and not os.path.exists(repo_path_or_name):
|
||||||
|
repo_name = Path(repo_path_or_name).name
|
||||||
|
repo_url = cls._get_repo_url_from_name(
|
||||||
|
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create a working directory if it does not exist.
|
||||||
|
if not os.path.exists(repo_path_or_name):
|
||||||
|
os.makedirs(repo_path_or_name)
|
||||||
|
|
||||||
|
repo = Repository(repo_path_or_name, clone_from=repo_url, use_auth_token=use_auth_token)
|
||||||
|
repo.git_pull()
|
||||||
|
return repo
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _push_to_hub(cls, repo: Repository, commit_message: Optional[str] = None) -> str:
|
||||||
if commit_message is None:
|
if commit_message is None:
|
||||||
if "Tokenizer" in cls.__name__:
|
if "Tokenizer" in cls.__name__:
|
||||||
commit_message = "add tokenizer"
|
commit_message = "add tokenizer"
|
||||||
if "Config" in cls.__name__:
|
elif "Config" in cls.__name__:
|
||||||
commit_message = "add config"
|
commit_message = "add config"
|
||||||
else:
|
else:
|
||||||
commit_message = "add model"
|
commit_message = "add model"
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
return repo.push_to_hub(commit_message=commit_message)
|
||||||
# First create the repo (and clone its content if it's nonempty), then add the files (otherwise there is
|
|
||||||
# no diff so nothing is pushed).
|
|
||||||
repo = Repository(tmp_dir, clone_from=repo_url, use_auth_token=use_auth_token)
|
|
||||||
if save_directory is None:
|
|
||||||
for filename in save_files:
|
|
||||||
shutil.copy(filename, Path(tmp_dir) / Path(filename).name)
|
|
||||||
else:
|
|
||||||
copy_tree(save_directory, tmp_dir)
|
|
||||||
|
|
||||||
return repo.push_to_hub(commit_message=commit_message)
|
|
||||||
|
@ -565,6 +565,14 @@ class TrainingSummary:
|
|||||||
if model_name is None:
|
if model_name is None:
|
||||||
model_name = Path(trainer.args.output_dir).name
|
model_name = Path(trainer.args.output_dir).name
|
||||||
|
|
||||||
|
# Add `generated_from_trainer` to the tags
|
||||||
|
if tags is None:
|
||||||
|
tags = ["generated_from_trainer"]
|
||||||
|
elif isinstance(tags, str) and tags != "generated_from_trainer":
|
||||||
|
tags = [tags, "generated_from_trainer"]
|
||||||
|
elif "generated_from_trainer" not in tags:
|
||||||
|
tags.append("generated_from_trainer")
|
||||||
|
|
||||||
_, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
|
_, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
|
||||||
hyperparameters = extract_hyperparameters_from_trainer(trainer)
|
hyperparameters = extract_hyperparameters_from_trainer(trainer)
|
||||||
|
|
||||||
|
@ -28,7 +28,6 @@ from jax.random import PRNGKey
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
CONFIG_NAME,
|
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
@ -409,6 +408,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
Directory to which to save. Will be created if it doesn't exist.
|
Directory to which to save. Will be created if it doesn't exist.
|
||||||
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
|
||||||
|
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
|
||||||
|
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
|
||||||
|
instead.
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the
|
Additional key word arguments passed along to the
|
||||||
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
|
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
|
||||||
@ -416,6 +423,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
commit_message = kwargs.pop("commit_message", None)
|
||||||
|
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
# get abs dir
|
# get abs dir
|
||||||
@ -434,8 +446,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
logger.info(f"Model weights saved in {output_model_file}")
|
logger.info(f"Model weights saved in {output_model_file}")
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
saved_files = [os.path.join(save_directory, CONFIG_NAME), output_model_file]
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
url = self._push_to_hub(save_files=saved_files, **kwargs)
|
|
||||||
logger.info(f"Model pushed to the hub in this commit: {url}")
|
logger.info(f"Model pushed to the hub in this commit: {url}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,6 @@ from tensorflow.python.keras.saving import hdf5_format
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
CONFIG_NAME,
|
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
@ -1029,6 +1028,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
https://www.tensorflow.org/tfx/serving/serving_basic
|
https://www.tensorflow.org/tfx/serving/serving_basic
|
||||||
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
|
||||||
|
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
|
||||||
|
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
|
||||||
|
instead.
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the
|
Additional key word arguments passed along to the
|
||||||
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
|
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
|
||||||
@ -1036,6 +1043,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
commit_message = kwargs.pop("commit_message", None)
|
||||||
|
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
if saved_model:
|
if saved_model:
|
||||||
@ -1053,8 +1065,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
logger.info(f"Model weights saved in {output_model_file}")
|
logger.info(f"Model weights saved in {output_model_file}")
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
saved_files = [os.path.join(save_directory, CONFIG_NAME), output_model_file]
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
url = self._push_to_hub(save_files=saved_files, **kwargs)
|
|
||||||
logger.info(f"Model pushed to the hub in this commit: {url}")
|
logger.info(f"Model pushed to the hub in this commit: {url}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -30,7 +30,6 @@ from .activations import get_activation
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
CONFIG_NAME,
|
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
@ -852,6 +851,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
need to replace :obj:`torch.save` by another method.
|
need to replace :obj:`torch.save` by another method.
|
||||||
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
|
||||||
|
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
|
||||||
|
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
|
||||||
|
instead.
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the
|
Additional key word arguments passed along to the
|
||||||
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
|
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
|
||||||
@ -859,6 +866,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
commit_message = kwargs.pop("commit_message", None)
|
||||||
|
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
# Only save the model itself if we are using distributed training
|
# Only save the model itself if we are using distributed training
|
||||||
@ -886,10 +898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
logger.info(f"Model weights saved in {output_model_file}")
|
logger.info(f"Model weights saved in {output_model_file}")
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
saved_files = [output_model_file]
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
if save_config:
|
|
||||||
saved_files.append(os.path.join(save_directory, CONFIG_NAME))
|
|
||||||
url = self._push_to_hub(save_files=saved_files, **kwargs)
|
|
||||||
logger.info(f"Model pushed to the hub in this commit: {url}")
|
logger.info(f"Model pushed to the hub in this commit: {url}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1884,6 +1884,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
value error is raised.
|
value error is raised.
|
||||||
filename_prefix: (:obj:`str`, `optional`):
|
filename_prefix: (:obj:`str`, `optional`):
|
||||||
A prefix to add to the names of the files saved by the tokenizer.
|
A prefix to add to the names of the files saved by the tokenizer.
|
||||||
|
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
|
||||||
|
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
|
||||||
|
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
|
||||||
|
instead.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of :obj:`str`: The files saved.
|
A tuple of :obj:`str`: The files saved.
|
||||||
@ -1891,6 +1900,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
commit_message = kwargs.pop("commit_message", None)
|
||||||
|
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
special_tokens_map_file = os.path.join(
|
special_tokens_map_file = os.path.join(
|
||||||
@ -1949,9 +1963,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
# Annoyingly, the return contains files that don't exist.
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
existing_files = [f for f in save_files if os.path.isfile(f)]
|
|
||||||
url = self._push_to_hub(save_files=existing_files, **kwargs)
|
|
||||||
logger.info(f"Tokenizer pushed to the hub in this commit: {url}")
|
logger.info(f"Tokenizer pushed to the hub in this commit: {url}")
|
||||||
|
|
||||||
return save_files
|
return save_files
|
||||||
|
@ -24,7 +24,6 @@ import random
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from logging import StreamHandler
|
from logging import StreamHandler
|
||||||
@ -391,9 +390,12 @@ class Trainer:
|
|||||||
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
||||||
self._loggers_initialized = False
|
self._loggers_initialized = False
|
||||||
|
|
||||||
# Create output directory if needed
|
# Create clone of distant repo and output directory if needed
|
||||||
|
if self.args.push_to_hub:
|
||||||
|
self.init_git_repo()
|
||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
||||||
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
|
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
|
||||||
|
|
||||||
@ -2430,6 +2432,27 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def init_git_repo(self):
|
||||||
|
"""
|
||||||
|
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
|
||||||
|
"""
|
||||||
|
if not self.is_world_process_zero():
|
||||||
|
return
|
||||||
|
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
|
||||||
|
repo_url = PushToHubMixin._get_repo_url_from_name(
|
||||||
|
self.args.push_to_hub_model_id,
|
||||||
|
organization=self.args.push_to_hub_organization,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
self.repo = PushToHubMixin._create_or_get_repo(
|
||||||
|
self.args.output_dir, repo_url=repo_url, use_auth_token=use_auth_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# By default, ignore the checkpoint folders
|
||||||
|
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
|
||||||
|
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
|
||||||
|
writer.writelines(["checkpoint-*/"])
|
||||||
|
|
||||||
def create_model_card(
|
def create_model_card(
|
||||||
self,
|
self,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
@ -2458,38 +2481,13 @@ class Trainer:
|
|||||||
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
|
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
|
||||||
f.write(model_card)
|
f.write(model_card)
|
||||||
|
|
||||||
def push_to_hub(
|
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
|
||||||
self,
|
|
||||||
repo_name: Optional[str] = None,
|
|
||||||
repo_url: Optional[str] = None,
|
|
||||||
commit_message: Optional[str] = "add model",
|
|
||||||
organization: Optional[str] = None,
|
|
||||||
private: bool = None,
|
|
||||||
use_auth_token: Optional[Union[bool, str]] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Upload `self.model` to the 🤗 model hub.
|
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.push_to_hub_model_id`.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
repo_name (:obj:`str`, `optional`):
|
|
||||||
Repository name for your model or tokenizer in the hub. If not specified and :obj:`repo_url` is not
|
|
||||||
specified either, will default to the stem of :obj:`self.args.output_dir`.
|
|
||||||
repo_url (:obj:`str`, `optional`):
|
|
||||||
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
|
|
||||||
repository will be created in your namespace (unless you specify an :obj:`organization`) with
|
|
||||||
:obj:`repo_name`.
|
|
||||||
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
|
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
|
||||||
Message to commit while pushing.
|
Message to commit while pushing.
|
||||||
organization (:obj:`str`, `optional`):
|
|
||||||
Organization in which you want to push your model or tokenizer (you must be a member of this
|
|
||||||
organization).
|
|
||||||
private (:obj:`bool`, `optional`):
|
|
||||||
Whether or not the repository created should be private (requires a paying subscription).
|
|
||||||
use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
|
||||||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
|
|
||||||
:obj:`True` if :obj:`repo_url` is not specified.
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
|
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
|
||||||
|
|
||||||
@ -2499,37 +2497,9 @@ class Trainer:
|
|||||||
if not self.is_world_process_zero():
|
if not self.is_world_process_zero():
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(unwrap_model(self.model), PushToHubMixin):
|
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
|
||||||
raise ValueError(
|
self.save_model()
|
||||||
"The `upload_model_to_hub` method only works for models that inherit from `PushToHubMixin` models."
|
return self.repo.push_to_hub(commit_message=commit_message)
|
||||||
)
|
|
||||||
|
|
||||||
if repo_url is None and repo_name is None:
|
|
||||||
repo_name = Path(self.args.output_dir).name
|
|
||||||
|
|
||||||
if repo_name is not None:
|
|
||||||
model_name = repo_name
|
|
||||||
elif repo_url is not None:
|
|
||||||
model_name = repo_url.split("/")[-1]
|
|
||||||
else:
|
|
||||||
model_name = None
|
|
||||||
self.create_model_card(model_name=model_name, **kwargs)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
shutil.copy(os.path.join(self.args.output_dir, "README.md"), os.path.join(tmp_dir, "README.md"))
|
|
||||||
unwrap_model(self.model).save_pretrained(tmp_dir)
|
|
||||||
if self.tokenizer is not None:
|
|
||||||
self.tokenizer.save_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
return unwrap_model(self.model)._push_to_hub(
|
|
||||||
save_directory=tmp_dir,
|
|
||||||
repo_name=repo_name,
|
|
||||||
repo_url=repo_url,
|
|
||||||
commit_message=commit_message,
|
|
||||||
organization=organization,
|
|
||||||
private=private,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Deprecated code
|
# Deprecated code
|
||||||
|
@ -17,6 +17,7 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from .debug_utils import DebugOption
|
from .debug_utils import DebugOption
|
||||||
@ -157,7 +158,7 @@ class TrainingArguments:
|
|||||||
node.
|
node.
|
||||||
logging_dir (:obj:`str`, `optional`):
|
logging_dir (:obj:`str`, `optional`):
|
||||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
`output_dir/runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||||
The logging strategy to adopt during training. Possible values are:
|
The logging strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
@ -318,15 +319,22 @@ class TrainingArguments:
|
|||||||
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
|
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
|
||||||
down the training and evaluation speed.
|
down the training and evaluation speed.
|
||||||
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to upload the trained model to the hub after training. This argument is not directly used by
|
Whether or not to upload the trained model to the hub after training. If this is activated, and
|
||||||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
:obj:`output_dir` exists, it needs to be a local clone of the repository to which the
|
||||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
:class:`~transformers.Trainer` will be pushed.
|
||||||
details.
|
|
||||||
resume_from_checkpoint (:obj:`str`, `optional`):
|
resume_from_checkpoint (:obj:`str`, `optional`):
|
||||||
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
|
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
|
||||||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||||
details.
|
details.
|
||||||
|
push_to_hub_model_id (:obj:`str`, `optional`):
|
||||||
|
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`.
|
||||||
|
Will default to the name of :obj:`output_dir`.
|
||||||
|
push_to_hub_organization (:obj:`str`, `optional`):
|
||||||
|
The name of the organization in with to which push the :class:`~transformers.Trainer`.
|
||||||
|
push_to_hub_token (:obj:`str`, `optional`):
|
||||||
|
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
||||||
|
:obj:`huggingface-cli login`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@ -590,6 +598,13 @@ class TrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
||||||
)
|
)
|
||||||
|
push_to_hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
|
||||||
|
)
|
||||||
|
push_to_hub_organization: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
|
||||||
|
)
|
||||||
|
push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
mp_parameters: str = field(
|
mp_parameters: str = field(
|
||||||
default="",
|
default="",
|
||||||
@ -612,6 +627,8 @@ class TrainingArguments:
|
|||||||
# see https://github.com/huggingface/transformers/issues/10628
|
# see https://github.com/huggingface/transformers/issues/10628
|
||||||
if self.output_dir is not None:
|
if self.output_dir is not None:
|
||||||
self.output_dir = os.path.expanduser(self.output_dir)
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
if self.logging_dir is None and self.output_dir is not None:
|
||||||
|
self.logging_dir = os.path.join(self.output_dir, default_logdir())
|
||||||
if self.logging_dir is not None:
|
if self.logging_dir is not None:
|
||||||
self.logging_dir = os.path.expanduser(self.logging_dir)
|
self.logging_dir = os.path.expanduser(self.logging_dir)
|
||||||
|
|
||||||
@ -705,6 +722,9 @@ class TrainingArguments:
|
|||||||
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
|
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
|
||||||
self.hf_deepspeed_config.trainer_config_process(self)
|
self.hf_deepspeed_config.trainer_config_process(self)
|
||||||
|
|
||||||
|
if self.push_to_hub_model_id is None:
|
||||||
|
self.push_to_hub_model_id = Path(self.output_dir).name
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
self_as_dict = asdict(self)
|
self_as_dict = asdict(self)
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
|||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
)
|
)
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
config.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-config", use_auth_token=self._token)
|
config.save_pretrained(os.path.join(tmp_dir, "test-config"), push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||||
for k, v in config.__dict__.items():
|
for k, v in config.__dict__.items():
|
||||||
@ -127,9 +127,8 @@ class ConfigPushToHubTester(unittest.TestCase):
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
config.save_pretrained(
|
config.save_pretrained(
|
||||||
tmp_dir,
|
os.path.join(tmp_dir, "test-config-org"),
|
||||||
push_to_hub=True,
|
push_to_hub=True,
|
||||||
repo_name="test-config-org",
|
|
||||||
use_auth_token=self._token,
|
use_auth_token=self._token,
|
||||||
organization="valid_org",
|
organization="valid_org",
|
||||||
)
|
)
|
||||||
|
@ -1613,7 +1613,7 @@ class ModelPushToHubTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
model = BertModel(config)
|
model = BertModel(config)
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token)
|
model.save_pretrained(os.path.join(tmp_dir, "test-model"), push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
new_model = BertModel.from_pretrained(f"{USER}/test-model")
|
new_model = BertModel.from_pretrained(f"{USER}/test-model")
|
||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
@ -1626,9 +1626,8 @@ class ModelPushToHubTester(unittest.TestCase):
|
|||||||
model = BertModel(config)
|
model = BertModel(config)
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
tmp_dir,
|
os.path.join(tmp_dir, "test-model-org"),
|
||||||
push_to_hub=True,
|
push_to_hub=True,
|
||||||
repo_name="test-model-org",
|
|
||||||
use_auth_token=self._token,
|
use_auth_token=self._token,
|
||||||
organization="valid_org",
|
organization="valid_org",
|
||||||
)
|
)
|
||||||
|
@ -16,14 +16,25 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import is_flax_available, is_torch_available
|
from huggingface_hub import HfApi
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
from transformers import BertConfig, FlaxBertModel, is_flax_available, is_torch_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
from transformers.testing_utils import (
|
||||||
|
ENDPOINT_STAGING,
|
||||||
|
PASS,
|
||||||
|
USER,
|
||||||
|
is_pt_flax_cross_test,
|
||||||
|
is_staging_test,
|
||||||
|
require_flax,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
@ -504,3 +515,65 @@ class FlaxModelTesterMixin:
|
|||||||
list(self_attentions[0].shape[-3:]),
|
list(self_attentions[0].shape[-3:]),
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
@is_staging_test
|
||||||
|
class FlaxModelPushToHubTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||||
|
cls._token = cls._api.login(username=USER, password=PASS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
cls._api.delete_repo(token=cls._token, name="test-model-flax")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = FlaxBertModel(config)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(
|
||||||
|
os.path.join(tmp_dir, "test-model-flax"), push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
new_params = flatten_dict(unfreeze(new_model.params))
|
||||||
|
|
||||||
|
for key in base_params.keys():
|
||||||
|
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = FlaxBertModel(config)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(
|
||||||
|
os.path.join(tmp_dir, "test-model-flax-org"),
|
||||||
|
push_to_hub=True,
|
||||||
|
use_auth_token=self._token,
|
||||||
|
organization="valid_org",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
new_params = flatten_dict(unfreeze(new_model.params))
|
||||||
|
|
||||||
|
for key in base_params.keys():
|
||||||
|
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
@ -1487,7 +1487,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
|||||||
# Make sure model is properly initialized
|
# Make sure model is properly initialized
|
||||||
_ = model(model.dummy_inputs)
|
_ = model(model.dummy_inputs)
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model-tf", use_auth_token=self._token)
|
model.save_pretrained(os.path.join(tmp_dir, "test-model-tf"), push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||||
models_equal = True
|
models_equal = True
|
||||||
@ -1503,9 +1503,8 @@ class TFModelPushToHubTester(unittest.TestCase):
|
|||||||
model = TFBertModel(config)
|
model = TFBertModel(config)
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
tmp_dir,
|
os.path.join(tmp_dir, "test-model-tf-org"),
|
||||||
push_to_hub=True,
|
push_to_hub=True,
|
||||||
repo_name="test-model-tf-org",
|
|
||||||
use_auth_token=self._token,
|
use_auth_token=self._token,
|
||||||
organization="valid_org",
|
organization="valid_org",
|
||||||
)
|
)
|
||||||
|
@ -3173,7 +3173,7 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
|||||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||||
tokenizer = BertTokenizer(vocab_file)
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
tokenizer.save_pretrained(
|
tokenizer.save_pretrained(
|
||||||
tmp_dir, push_to_hub=True, repo_name="test-tokenizer", use_auth_token=self._token
|
os.path.join(tmp_dir, "test-tokenizer"), push_to_hub=True, use_auth_token=self._token
|
||||||
)
|
)
|
||||||
|
|
||||||
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
||||||
@ -3186,9 +3186,8 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
|||||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||||
tokenizer = BertTokenizer(vocab_file)
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
tokenizer.save_pretrained(
|
tokenizer.save_pretrained(
|
||||||
tmp_dir,
|
os.path.join(tmp_dir, "test-tokenizer-org"),
|
||||||
push_to_hub=True,
|
push_to_hub=True,
|
||||||
repo_name="test-tokenizer-org",
|
|
||||||
use_auth_token=self._token,
|
use_auth_token=self._token,
|
||||||
organization="valid_org",
|
organization="valid_org",
|
||||||
)
|
)
|
||||||
|
@ -1274,8 +1274,12 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
|
|
||||||
def test_push_to_hub(self):
|
def test_push_to_hub(self):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
trainer = get_regression_trainer(output_dir=tmp_dir)
|
trainer = get_regression_trainer(
|
||||||
url = trainer.push_to_hub(repo_name="test-trainer", use_auth_token=self._token)
|
output_dir=os.path.join(tmp_dir, "test-trainer"),
|
||||||
|
push_to_hub=True,
|
||||||
|
push_to_hub_token=self._token,
|
||||||
|
)
|
||||||
|
url = trainer.push_to_hub()
|
||||||
|
|
||||||
# Extract repo_name from the url
|
# Extract repo_name from the url
|
||||||
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
||||||
@ -1292,9 +1296,13 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
trainer = get_regression_trainer(output_dir=tmp_dir)
|
trainer = get_regression_trainer(output_dir=tmp_dir)
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
url = trainer.push_to_hub(
|
trainer = get_regression_trainer(
|
||||||
repo_name="test-trainer-org", organization="valid_org", use_auth_token=self._token
|
output_dir=os.path.join(tmp_dir, "test-trainer-org"),
|
||||||
|
push_to_hub=True,
|
||||||
|
push_to_hub_organization="valid_org",
|
||||||
|
push_to_hub_token=self._token,
|
||||||
)
|
)
|
||||||
|
url = trainer.push_to_hub()
|
||||||
|
|
||||||
# Extract repo_name from the url
|
# Extract repo_name from the url
|
||||||
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
||||||
|
Loading…
Reference in New Issue
Block a user