mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Rewrite push_to_hub to use upload_files (#18366)
* Rewrite push_to_hub to use upload_files * Adapt the doc a bit * Address review comments and clean doc
This commit is contained in:
parent
3909d7f139
commit
01db72abd4
@ -813,13 +813,9 @@ checkpoint and to get the required access rights to be able to upload the model
|
||||
*brand_new_bert*. The `push_to_hub` method, present in all models in `transformers`, is a quick and efficient way to push your checkpoint to the hub. A little snippet is pasted below:
|
||||
|
||||
```python
|
||||
brand_new_bert.push_to_hub(
|
||||
repo_path_or_name="brand_new_bert",
|
||||
# Uncomment the following line to push to an organization
|
||||
# organization="<ORGANIZATION>",
|
||||
commit_message="Add model",
|
||||
use_temp_dir=True,
|
||||
)
|
||||
brand_new_bert.push_to_hub("brand_new_bert")
|
||||
# Uncomment the following line to push to an organization.
|
||||
# brand_new_bert.push_to_hub("<organization>/brand_new_bert")
|
||||
```
|
||||
|
||||
It is worth spending some time to create fitting model cards for each checkpoint. The model cards should highlight the
|
||||
|
@ -179,10 +179,10 @@ This creates a repository under your username with the model name `my-awesome-mo
|
||||
>>> model = AutoModel.from_pretrained("your_username/my-awesome-model")
|
||||
```
|
||||
|
||||
If you belong to an organization and want to push your model under the organization name instead, add the `organization` parameter:
|
||||
If you belong to an organization and want to push your model under the organization name instead, just add it to the `repo_id`:
|
||||
|
||||
```py
|
||||
>>> pt_model.push_to_hub("my-awesome-model", organization="my-awesome-org")
|
||||
>>> pt_model.push_to_hub("my-awesome-org/my-awesome-model")
|
||||
```
|
||||
|
||||
The `push_to_hub` function can also be used to add other files to a model repository. For example, add a tokenizer to a model repository:
|
||||
|
@ -417,27 +417,22 @@ class PretrainedConfig(PushToHubMixin):
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id, token = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
@ -451,8 +446,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Configuration pushed to the hub in this commit: {url}")
|
||||
self._upload_modified_files(
|
||||
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
|
@ -318,32 +318,28 @@ class FeatureExtractionMixin(PushToHubMixin):
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your feature extractor to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id, token = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
custom_object_save(self, save_directory, config=self)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
|
||||
|
||||
@ -351,8 +347,9 @@ class FeatureExtractionMixin(PushToHubMixin):
|
||||
logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Feature extractor pushed to the hub in this commit: {url}")
|
||||
self._upload_modified_files(
|
||||
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
|
||||
)
|
||||
|
||||
return [output_feature_extractor_file]
|
||||
|
||||
|
@ -941,16 +941,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||
@ -969,11 +962,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id, token = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
# get abs dir
|
||||
save_directory = os.path.abspath(save_directory)
|
||||
@ -1028,8 +1023,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Model pushed to the hub in this commit: {url}")
|
||||
self._upload_modified_files(
|
||||
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
|
||||
|
@ -24,6 +24,7 @@ import pickle
|
||||
import re
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import h5py
|
||||
@ -58,7 +59,6 @@ from .utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
find_labels,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
@ -66,6 +66,7 @@ from .utils import (
|
||||
is_remote_url,
|
||||
logging,
|
||||
requires_backends,
|
||||
working_or_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
@ -1919,6 +1920,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
version=1,
|
||||
push_to_hub=False,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
create_pr: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@ -1935,16 +1937,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
TensorFlow Serving as detailed in the official documentation
|
||||
https://www.tensorflow.org/tfx/serving/serving_basic
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||
@ -1956,6 +1951,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
</Tip>
|
||||
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@ -1963,11 +1961,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id, token = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
if saved_model:
|
||||
saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
|
||||
@ -2030,8 +2030,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
param_dset[:] = layer.numpy()
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Model pushed to the hub in this commit: {url}")
|
||||
self._upload_modified_files(
|
||||
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -2475,12 +2476,95 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
return model
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
use_temp_dir: Optional[bool] = None,
|
||||
commit_message: Optional[str] = None,
|
||||
private: Optional[bool] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = "10GB",
|
||||
**model_card_kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
|
||||
|
||||
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
|
||||
TFPreTrainedModel.push_to_hub = copy_func(TFPreTrainedModel.push_to_hub)
|
||||
TFPreTrainedModel.push_to_hub.__doc__ = TFPreTrainedModel.push_to_hub.__doc__.format(
|
||||
object="model", object_class="TFAutoModel", object_files="model checkpoint"
|
||||
)
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your model to. It should contain your organization name
|
||||
when pushing to a given organization.
|
||||
use_temp_dir (`bool`, *optional*):
|
||||
Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
|
||||
Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
|
||||
commit_message (`str`, *optional*):
|
||||
Message to commit while pushing. Will default to `"Upload model"`.
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private (requires a paying subscription).
|
||||
use_auth_token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
|
||||
`repo_url` is not specified.
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
|
||||
will then be each of size lower than this size. If expressed as a string, needs to be digits followed
|
||||
by a unit (like `"5MB"`).
|
||||
model_card_kwargs:
|
||||
Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
from transformers import TFAutoModel
|
||||
|
||||
model = TFAutoModel.from_pretrained("bert-base-cased")
|
||||
|
||||
# Push the model to your namespace with the name "my-finetuned-bert".
|
||||
model.push_to_hub("my-finetuned-bert")
|
||||
|
||||
# Push the model to an organization with the name "my-finetuned-bert".
|
||||
model.push_to_hub("huggingface/my-finetuned-bert")
|
||||
```
|
||||
"""
|
||||
if "repo_path_or_name" in model_card_kwargs:
|
||||
warnings.warn(
|
||||
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
|
||||
"`repo_id` instead."
|
||||
)
|
||||
repo_id = model_card_kwargs.pop("repo_path_or_name")
|
||||
# Deprecation warning will be sent after for repo_url and organization
|
||||
repo_url = model_card_kwargs.pop("repo_url", None)
|
||||
organization = model_card_kwargs.pop("organization", None)
|
||||
|
||||
if os.path.isdir(repo_id):
|
||||
working_dir = repo_id
|
||||
repo_id = repo_id.split(os.path.sep)[-1]
|
||||
else:
|
||||
working_dir = repo_id.split("/")[-1]
|
||||
|
||||
repo_id, token = self._create_repo(
|
||||
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
|
||||
)
|
||||
|
||||
if use_temp_dir is None:
|
||||
use_temp_dir = not os.path.isdir(working_dir)
|
||||
|
||||
with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
|
||||
files_timestamps = self._get_files_timestamps(work_dir)
|
||||
|
||||
# Save all files.
|
||||
self.save_pretrained(work_dir, max_shard_size=max_shard_size)
|
||||
if hasattr(self, "history") and hasattr(self, "create_model_card"):
|
||||
# This is a Keras model and we might be able to fish out its History and make a model card out of it
|
||||
base_model_card_args = {
|
||||
"output_dir": work_dir,
|
||||
"model_name": Path(repo_id).name,
|
||||
}
|
||||
base_model_card_args.update(model_card_kwargs)
|
||||
self.create_model_card(**base_model_card_args)
|
||||
|
||||
self._upload_modified_files(
|
||||
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token
|
||||
)
|
||||
|
||||
|
||||
class TFConv1D(tf.keras.layers.Layer):
|
||||
|
@ -24,7 +24,6 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -64,6 +63,7 @@ from .utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_accelerate_available,
|
||||
@ -1473,16 +1473,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||
@ -1507,11 +1500,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id, token = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = unwrap_model(self)
|
||||
@ -1583,8 +1578,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Model pushed to the hub in this commit: {url}")
|
||||
self._upload_modified_files(
|
||||
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
@ -2548,109 +2544,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_path_or_name: Optional[str] = None,
|
||||
repo_url: Optional[str] = None,
|
||||
use_temp_dir: bool = False,
|
||||
commit_message: str = "add model",
|
||||
organization: Optional[str] = None,
|
||||
private: Optional[bool] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
**model_card_kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
|
||||
|
||||
Parameters:
|
||||
repo_path_or_name (`str`, *optional*):
|
||||
Can either be a repository name for your model in the Hub or a path to a local folder (in which case
|
||||
the repository will have the name of that local folder). If not specified, will default to the name
|
||||
given by `repo_url` and a local directory with that name will be created.
|
||||
repo_url (`str`, *optional*):
|
||||
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
|
||||
repository will be created in your namespace (unless you specify an `organization`) with `repo_name`.
|
||||
use_temp_dir (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clone the distant repo in a temporary directory or in `repo_path_or_name` inside the
|
||||
current working directory. This will slow things down if you are making changes in an existing repo
|
||||
since you will need to clone the repo before every push.
|
||||
commit_message (`str`, *optional*, defaults to `"add model"`):
|
||||
Message to commit while pushing.
|
||||
organization (`str`, *optional*):
|
||||
Organization in which you want to push your {object} (you must be a member of this organization).
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private (requires a paying subscription).
|
||||
use_auth_token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
|
||||
`repo_url` is not specified.
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
||||
which will be bigger than `max_shard_size`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`str`: The url of the commit of your {object} in the given repository.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained("bert-base-cased")
|
||||
|
||||
# Push the model to your namespace with the name "my-finetuned-bert" and have a local clone in the
|
||||
# *my-finetuned-bert* folder.
|
||||
model.push_to_hub("my-finetuned-bert")
|
||||
|
||||
# Push the model to your namespace with the name "my-finetuned-bert" with no local clone.
|
||||
model.push_to_hub("my-finetuned-bert", use_temp_dir=True)
|
||||
|
||||
# Push the model to an organization with the name "my-finetuned-bert" and have a local clone in the
|
||||
# *my-finetuned-bert* folder.
|
||||
model.push_to_hub("my-finetuned-bert", organization="huggingface")
|
||||
|
||||
# Make a change to an existing repo that has been cloned locally in *my-finetuned-bert*.
|
||||
model.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
|
||||
```
|
||||
"""
|
||||
if use_temp_dir:
|
||||
# Make sure we use the right `repo_name` for the `repo_url` before replacing it.
|
||||
if repo_url is None:
|
||||
if use_auth_token is None:
|
||||
use_auth_token = True
|
||||
repo_name = Path(repo_path_or_name).name
|
||||
repo_url = self._get_repo_url_from_name(
|
||||
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
|
||||
)
|
||||
repo_path_or_name = tempfile.mkdtemp()
|
||||
|
||||
# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
|
||||
repo = self._create_or_get_repo(
|
||||
repo_path_or_name=repo_path_or_name,
|
||||
repo_url=repo_url,
|
||||
organization=organization,
|
||||
private=private,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
# Save the files in the cloned repo
|
||||
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||
|
||||
# Commit and push!
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
|
||||
# Clean up! Clean up! Everybody everywhere!
|
||||
if use_temp_dir:
|
||||
shutil.rmtree(repo_path_or_name)
|
||||
|
||||
return url
|
||||
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
||||
PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
|
||||
object="model", object_class="AutoModel", object_files="model file"
|
||||
)
|
||||
|
||||
|
||||
class PoolerStartLogits(nn.Module):
|
||||
|
@ -109,24 +109,19 @@ class ProcessorMixin(PushToHubMixin):
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your processor to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id, token = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
@ -150,8 +145,9 @@ class ProcessorMixin(PushToHubMixin):
|
||||
del attribute.init_kwargs["auto_map"]
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Processor pushed to the hub in this commit: {url}")
|
||||
self._upload_modified_files(
|
||||
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
|
@ -2077,15 +2077,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
filename_prefix: (`str`, *optional*):
|
||||
A prefix to add to the names of the files saved by the tokenizer.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
|
||||
Returns:
|
||||
A tuple of `str`: The files saved.
|
||||
@ -2094,11 +2090,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id, token = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
special_tokens_map_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE
|
||||
@ -2167,8 +2165,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Tokenizer pushed to the hub in this commit: {url}")
|
||||
self._upload_modified_files(
|
||||
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
|
||||
)
|
||||
|
||||
return save_files
|
||||
|
||||
|
@ -42,6 +42,7 @@ from .generic import (
|
||||
is_tensor,
|
||||
to_numpy,
|
||||
to_py_obj,
|
||||
working_or_temp_dir,
|
||||
)
|
||||
from .hub import (
|
||||
CLOUDFRONT_DISTRIB_PREFIX,
|
||||
|
@ -16,9 +16,10 @@ Generic utilities
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
from collections import OrderedDict, UserDict
|
||||
from collections.abc import MutableMapping
|
||||
from contextlib import ExitStack
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields
|
||||
from enum import Enum
|
||||
from typing import Any, ContextManager, List, Tuple
|
||||
@ -325,3 +326,12 @@ def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
|
||||
yield key, v
|
||||
|
||||
return dict(_flatten_dict(d, parent_key, delimiter))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
|
||||
if use_temp_dir:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
yield tmp_dir
|
||||
else:
|
||||
yield working_dir
|
||||
|
@ -36,12 +36,13 @@ from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
|
||||
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, list_repo_files, whoami
|
||||
from requests.exceptions import HTTPError
|
||||
from requests.models import Response
|
||||
from transformers.utils.logging import tqdm
|
||||
|
||||
from . import __version__, logging
|
||||
from .generic import working_or_temp_dir
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
_tf_version,
|
||||
@ -869,48 +870,122 @@ class PushToHubMixin:
|
||||
A Mixin containing the functionality to push a model or tokenizer to the hub.
|
||||
"""
|
||||
|
||||
def _create_repo(
|
||||
self,
|
||||
repo_id: str,
|
||||
private: Optional[bool] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
repo_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create the repo if needed, cleans up repo_id with deprecated kwards `repo_url` and `organization`, retrives the
|
||||
token.
|
||||
"""
|
||||
if repo_url is not None:
|
||||
warnings.warn(
|
||||
"The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
|
||||
"instead."
|
||||
)
|
||||
repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
|
||||
if organization is not None:
|
||||
warnings.warn(
|
||||
"The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
|
||||
"organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
|
||||
)
|
||||
if not repo_id.startswith(organization):
|
||||
if "/" in repo_id:
|
||||
repo_id = repo_id.split("/")[-1]
|
||||
repo_id = f"{organization}/{repo_id}"
|
||||
|
||||
token = HfFolder.get_token() if use_auth_token is True else use_auth_token
|
||||
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
|
||||
|
||||
# If the namespace is not there, add it or `upload_file` will complain
|
||||
if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
|
||||
repo_id = get_full_repo_name(repo_id, token=token)
|
||||
return repo_id, token
|
||||
|
||||
def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
|
||||
"""
|
||||
Returns the list of files with their last modification timestamp.
|
||||
"""
|
||||
return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}
|
||||
|
||||
def _upload_modified_files(
|
||||
self,
|
||||
working_dir: Union[str, os.PathLike],
|
||||
repo_id: str,
|
||||
files_timestamps: Dict[str, float],
|
||||
commit_message: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
create_pr: bool = False,
|
||||
):
|
||||
"""
|
||||
Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
|
||||
"""
|
||||
if commit_message is None:
|
||||
if "Model" in self.__class__.__name__:
|
||||
commit_message = "Upload model"
|
||||
elif "Config" in self.__class__.__name__:
|
||||
commit_message = "Upload config"
|
||||
elif "Tokenizer" in self.__class__.__name__:
|
||||
commit_message = "Upload tokenizer"
|
||||
elif "FeatureExtractor" in self.__class__.__name__:
|
||||
commit_message = "Upload feature extractor"
|
||||
elif "Processor" in self.__class__.__name__:
|
||||
commit_message = "Upload processor"
|
||||
else:
|
||||
commit_message = f"Upload {self.__class__.__name__}"
|
||||
modified_files = [
|
||||
f
|
||||
for f in os.listdir(working_dir)
|
||||
if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
|
||||
]
|
||||
operations = []
|
||||
for file in modified_files:
|
||||
operations.append(CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file))
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
|
||||
return create_commit(
|
||||
repo_id=repo_id, operations=operations, commit_message=commit_message, token=token, create_pr=create_pr
|
||||
)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_path_or_name: Optional[str] = None,
|
||||
repo_url: Optional[str] = None,
|
||||
use_temp_dir: bool = False,
|
||||
repo_id: str,
|
||||
use_temp_dir: Optional[bool] = None,
|
||||
commit_message: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
private: Optional[bool] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = "10GB",
|
||||
**model_card_kwargs
|
||||
create_pr: bool = False,
|
||||
**deprecated_kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in
|
||||
`repo_path_or_name`.
|
||||
|
||||
Parameters:
|
||||
repo_path_or_name (`str`, *optional*):
|
||||
Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
|
||||
the repository will have the name of that local folder). If not specified, will default to the name
|
||||
given by `repo_url` and a local directory with that name will be created.
|
||||
repo_url (`str`, *optional*):
|
||||
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
|
||||
repository will be created in your namespace (unless you specify an `organization`) with `repo_name`.
|
||||
use_temp_dir (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clone the distant repo in a temporary directory or in `repo_path_or_name` inside the
|
||||
current working directory. This will slow things down if you are making changes in an existing repo
|
||||
since you will need to clone the repo before every push.
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your {object} to. It should contain your organization name
|
||||
when pushing to a given organization.
|
||||
use_temp_dir (`bool`, *optional*):
|
||||
Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
|
||||
Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
|
||||
commit_message (`str`, *optional*):
|
||||
Message to commit while pushing. Will default to `"add {object}"`.
|
||||
organization (`str`, *optional*):
|
||||
Organization in which you want to push your {object} (you must be a member of this organization).
|
||||
Message to commit while pushing. Will default to `"Upload {object}"`.
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private (requires a paying subscription).
|
||||
use_auth_token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
|
||||
`repo_url` is not specified.
|
||||
|
||||
|
||||
Returns:
|
||||
`str`: The url of the commit of your {object} in the given repository.
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
|
||||
will then be each of size lower than this size. If expressed as a string, needs to be digits followed
|
||||
by a unit (like `"5MB"`).
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -919,134 +994,45 @@ class PushToHubMixin:
|
||||
|
||||
{object} = {object_class}.from_pretrained("bert-base-cased")
|
||||
|
||||
# Push the {object} to your namespace with the name "my-finetuned-bert" and have a local clone in the
|
||||
# *my-finetuned-bert* folder.
|
||||
# Push the {object} to your namespace with the name "my-finetuned-bert".
|
||||
{object}.push_to_hub("my-finetuned-bert")
|
||||
|
||||
# Push the {object} to your namespace with the name "my-finetuned-bert" with no local clone.
|
||||
{object}.push_to_hub("my-finetuned-bert", use_temp_dir=True)
|
||||
|
||||
# Push the {object} to an organization with the name "my-finetuned-bert" and have a local clone in the
|
||||
# *my-finetuned-bert* folder.
|
||||
{object}.push_to_hub("my-finetuned-bert", organization="huggingface")
|
||||
|
||||
# Make a change to an existing repo that has been cloned locally in *my-finetuned-bert*.
|
||||
{object}.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
|
||||
# Push the {object} to an organization with the name "my-finetuned-bert".
|
||||
{object}.push_to_hub("huggingface/my-finetuned-bert")
|
||||
```
|
||||
"""
|
||||
if use_temp_dir:
|
||||
# Make sure we use the right `repo_name` for the `repo_url` before replacing it.
|
||||
if repo_url is None:
|
||||
if use_auth_token is None:
|
||||
use_auth_token = True
|
||||
repo_name = Path(repo_path_or_name).name
|
||||
repo_url = self._get_repo_url_from_name(
|
||||
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
|
||||
)
|
||||
repo_path_or_name = tempfile.mkdtemp()
|
||||
|
||||
# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
|
||||
repo = self._create_or_get_repo(
|
||||
repo_path_or_name=repo_path_or_name,
|
||||
repo_url=repo_url,
|
||||
organization=organization,
|
||||
private=private,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
# Save the files in the cloned repo
|
||||
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||
if hasattr(self, "history") and hasattr(self, "create_model_card"):
|
||||
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||
# This is a Keras model and we might be able to fish out its History and make a model card out of it
|
||||
base_model_card_args = {
|
||||
"output_dir": repo_path_or_name,
|
||||
"model_name": Path(repo_path_or_name).name,
|
||||
}
|
||||
base_model_card_args.update(model_card_kwargs)
|
||||
self.create_model_card(**base_model_card_args)
|
||||
|
||||
# Commit and push!
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
|
||||
# Clean up! Clean up! Everybody everywhere!
|
||||
if use_temp_dir:
|
||||
shutil.rmtree(repo_path_or_name)
|
||||
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def _get_repo_url_from_name(
|
||||
repo_name: str,
|
||||
organization: Optional[str] = None,
|
||||
private: bool = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
) -> str:
|
||||
if isinstance(use_auth_token, str):
|
||||
token = use_auth_token
|
||||
elif use_auth_token:
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
|
||||
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
|
||||
"token as the `use_auth_token` argument."
|
||||
)
|
||||
else:
|
||||
token = None
|
||||
|
||||
# Special provision for the test endpoint (CI)
|
||||
return create_repo(
|
||||
token,
|
||||
repo_name,
|
||||
organization=organization,
|
||||
private=private,
|
||||
repo_type=None,
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _create_or_get_repo(
|
||||
cls,
|
||||
repo_path_or_name: Optional[str] = None,
|
||||
repo_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
private: bool = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
) -> Repository:
|
||||
if repo_path_or_name is None and repo_url is None:
|
||||
raise ValueError("You need to specify a `repo_path_or_name` or a `repo_url`.")
|
||||
|
||||
if use_auth_token is None and repo_url is None:
|
||||
use_auth_token = True
|
||||
|
||||
if repo_path_or_name is None:
|
||||
repo_path_or_name = repo_url.split("/")[-1]
|
||||
|
||||
if repo_url is None and not os.path.exists(repo_path_or_name):
|
||||
repo_name = Path(repo_path_or_name).name
|
||||
repo_url = cls._get_repo_url_from_name(
|
||||
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
|
||||
if "repo_path_or_name" in deprecated_kwargs:
|
||||
warnings.warn(
|
||||
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
|
||||
"`repo_id` instead."
|
||||
)
|
||||
repo_id = deprecated_kwargs.pop("repo_path_or_name")
|
||||
# Deprecation warning will be sent after for repo_url and organization
|
||||
repo_url = deprecated_kwargs.pop("repo_url", None)
|
||||
organization = deprecated_kwargs.pop("organization", None)
|
||||
|
||||
# Create a working directory if it does not exist.
|
||||
if not os.path.exists(repo_path_or_name):
|
||||
os.makedirs(repo_path_or_name)
|
||||
if os.path.isdir(repo_id):
|
||||
working_dir = repo_id
|
||||
repo_id = repo_id.split(os.path.sep)[-1]
|
||||
else:
|
||||
working_dir = repo_id.split("/")[-1]
|
||||
|
||||
repo = Repository(repo_path_or_name, clone_from=repo_url, use_auth_token=use_auth_token)
|
||||
repo.git_pull()
|
||||
return repo
|
||||
repo_id, token = self._create_repo(
|
||||
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _push_to_hub(cls, repo: Repository, commit_message: Optional[str] = None) -> str:
|
||||
if commit_message is None:
|
||||
if "Tokenizer" in cls.__name__:
|
||||
commit_message = "add tokenizer"
|
||||
elif "Config" in cls.__name__:
|
||||
commit_message = "add config"
|
||||
else:
|
||||
commit_message = "add model"
|
||||
if use_temp_dir is None:
|
||||
use_temp_dir = not os.path.isdir(working_dir)
|
||||
|
||||
return repo.push_to_hub(commit_message=commit_message)
|
||||
with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
|
||||
files_timestamps = self._get_files_timestamps(work_dir)
|
||||
|
||||
# Save all files.
|
||||
self.save_pretrained(work_dir, max_shard_size=max_shard_size)
|
||||
|
||||
return self._upload_modified_files(
|
||||
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token, create_pr=create_pr
|
||||
)
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
|
@ -23,7 +23,7 @@ import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@ -243,46 +243,58 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(os.path.join(tmp_dir, "test-config"), push_to_hub=True, use_auth_token=self._token)
|
||||
config.push_to_hub("test-config", use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-config")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-config-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-config-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.__dict__.items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_dynamic_config(self):
|
||||
CustomConfig.register_for_auto_class()
|
||||
config = CustomConfig(attribute=42)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
|
||||
config.save_pretrained(tmp_dir)
|
||||
config.push_to_hub("test-dynamic-config", use_auth_token=self._token)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
|
||||
# The code has been copied from fixtures
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_configuration.py")))
|
||||
|
||||
repo.push_to_hub()
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
|
||||
|
||||
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
|
||||
|
@ -22,7 +22,7 @@ import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
|
||||
@ -167,47 +167,57 @@ class FeatureExtractorPushToHubTester(unittest.TestCase):
|
||||
|
||||
def test_push_to_hub(self):
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
feature_extractor.push_to_hub("test-feature-extractor", use_auth_token=self._token)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-feature-extractor")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token
|
||||
tmp_dir, repo_id="test-feature-extractor", push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
feature_extractor.push_to_hub("valid_org/test-feature-extractor", use_auth_token=self._token)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-feature-extractor")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-feature-extractor-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
tmp_dir, repo_id="valid_org/test-feature-extractor-org", push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_dynamic_feature_extractor(self):
|
||||
CustomFeatureExtractor.register_for_auto_class()
|
||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-feature-extractor", use_auth_token=self._token)
|
||||
feature_extractor.save_pretrained(tmp_dir)
|
||||
feature_extractor.push_to_hub("test-dynamic-feature-extractor", use_auth_token=self._token)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(
|
||||
feature_extractor.auto_map,
|
||||
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
|
||||
)
|
||||
# The code has been copied from fixtures
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))
|
||||
|
||||
repo.push_to_hub()
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(
|
||||
feature_extractor.auto_map,
|
||||
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
|
||||
)
|
||||
|
||||
new_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
|
||||
|
@ -32,7 +32,7 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@ -2962,39 +2962,51 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-dynamic-model-config")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = BertModel(config)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, "test-model"), push_to_hub=True, use_auth_token=self._token)
|
||||
model.push_to_hub("test-model", use_auth_token=self._token)
|
||||
|
||||
new_model = BertModel.from_pretrained(f"{USER}/test-model")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
new_model = BertModel.from_pretrained(f"{USER}/test-model")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-model")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, repo_id="test-model", push_to_hub=True, use_auth_token=self._token)
|
||||
|
||||
new_model = BertModel.from_pretrained(f"{USER}/test-model")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = BertModel(config)
|
||||
model.push_to_hub("valid_org/test-model-org", use_auth_token=self._token)
|
||||
|
||||
new_model = BertModel.from_pretrained("valid_org/test-model-org")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-model-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-model-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
tmp_dir, push_to_hub=True, use_auth_token=self._token, repo_id="valid_org/test-model-org"
|
||||
)
|
||||
|
||||
new_model = BertModel.from_pretrained("valid_org/test-model-org")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
new_model = BertModel.from_pretrained("valid_org/test-model-org")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_push_to_hub_dynamic_model(self):
|
||||
CustomConfig.register_for_auto_class()
|
||||
@ -3003,16 +3015,12 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
config = CustomConfig(hidden_size=32)
|
||||
model = CustomModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model", use_auth_token=self._token)
|
||||
model.save_pretrained(tmp_dir)
|
||||
# checks
|
||||
self.assertDictEqual(
|
||||
config.auto_map,
|
||||
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
|
||||
)
|
||||
|
||||
repo.push_to_hub()
|
||||
model.push_to_hub("test-dynamic-model", use_auth_token=self._token)
|
||||
# checks
|
||||
self.assertDictEqual(
|
||||
config.auto_map,
|
||||
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
|
||||
)
|
||||
|
||||
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
|
||||
|
@ -1153,38 +1153,63 @@ class FlaxModelPushToHubTester(unittest.TestCase):
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = FlaxBertModel(config)
|
||||
model.push_to_hub("test-model-flax", use_auth_token=self._token)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-model-flax")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-model-flax"), push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
model.save_pretrained(tmp_dir, repo_id="test-model-flax", push_to_hub=True, use_auth_token=self._token)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = FlaxBertModel(config)
|
||||
model.push_to_hub("valid_org/test-model-flax-org", use_auth_token=self._token)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-model-flax-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-model-flax-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
tmp_dir, repo_id="valid_org/test-model-flax-org", push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
@ -33,17 +33,18 @@ from requests.exceptions import HTTPError
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import tooslow # noqa: F401
|
||||
from transformers.testing_utils import (
|
||||
from transformers.testing_utils import ( # noqa: F401
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
CaptureStdout,
|
||||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_tf,
|
||||
require_tf2onnx,
|
||||
slow,
|
||||
tooslow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
@ -2189,41 +2190,65 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
model = TFBertModel(config)
|
||||
# Make sure model is properly initialized
|
||||
_ = model(model.dummy_inputs)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, "test-model-tf"), push_to_hub=True, use_auth_token=self._token)
|
||||
|
||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers.utils.hub")
|
||||
with CaptureLogger(logger) as cl:
|
||||
model.push_to_hub("test-model-tf", use_auth_token=self._token)
|
||||
logging.set_verbosity_warning()
|
||||
# Check the model card was created and uploaded.
|
||||
self.assertIn("Uploading README.md to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
|
||||
|
||||
def test_push_to_hub_with_model_card(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = TFBertModel(config)
|
||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-model-tf")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.push_to_hub(os.path.join(tmp_dir, "test-model-tf"))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "test-model-tf", "README.md")))
|
||||
model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, use_auth_token=self._token)
|
||||
|
||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = TFBertModel(config)
|
||||
# Make sure model is properly initialized
|
||||
_ = model(model.dummy_inputs)
|
||||
|
||||
model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)
|
||||
|
||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-model-tf-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
tmp_dir, push_to_hub=True, use_auth_token=self._token, repo_id="valid_org/test-model-tf-org"
|
||||
)
|
||||
|
||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
@ -30,7 +30,7 @@ from itertools import takewhile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
@ -3875,12 +3875,20 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
tokenizer.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-tokenizer"), push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
tokenizer.push_to_hub("test-tokenizer", use_auth_token=self._token)
|
||||
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-tokenizer")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, use_auth_token=self._token)
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -3888,15 +3896,22 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
tokenizer.push_to_hub("valid_org/test-tokenizer-org", use_auth_token=self._token)
|
||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-tokenizer-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
@require_tokenizers
|
||||
def test_push_to_hub_dynamic_tokenizer(self):
|
||||
@ -3908,17 +3923,7 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
tokenizer = CustomTokenizer(vocab_file)
|
||||
|
||||
# No fast custom tokenizer
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
|
||||
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
|
||||
tokenizer_config = json.load(f)
|
||||
self.assertDictEqual(
|
||||
tokenizer_config["auto_map"], {"AutoTokenizer": ["custom_tokenization.CustomTokenizer", None]}
|
||||
)
|
||||
|
||||
repo.push_to_hub()
|
||||
tokenizer.push_to_hub("test-dynamic-tokenizer", use_auth_token=self._token)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
|
||||
@ -3935,23 +3940,7 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
bert_tokenizer.save_pretrained(tmp_dir)
|
||||
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
|
||||
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
|
||||
tokenizer_config = json.load(f)
|
||||
self.assertDictEqual(
|
||||
tokenizer_config["auto_map"],
|
||||
{
|
||||
"AutoTokenizer": [
|
||||
"custom_tokenization.CustomTokenizer",
|
||||
"custom_tokenization_fast.CustomTokenizerFast",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
repo.push_to_hub()
|
||||
tokenizer.push_to_hub("test-dynamic-tokenizer", use_auth_token=self._token)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
|
Loading…
Reference in New Issue
Block a user