[core/ FEAT] Add the possibility to push custom tags using PreTrainedModel itself (#28405)

* v1 tags

* remove unneeded conversion

* v2

* rm unneeded warning

* add more utility methods

* Update src/transformers/utils/hub.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/utils/hub.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update src/transformers/utils/hub.py

Co-authored-by: Lucain <lucainp@gmail.com>

* more enhancements

* oops

* merge tags

* clean up

* revert unneeded change

* add extensive docs

* more docs

* more kwargs

* add test

* oops

* fix test

* Update src/transformers/modeling_utils.py

Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>

* Update src/transformers/utils/hub.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update src/transformers/modeling_utils.py

* Update src/transformers/trainer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add more conditions

* more logic

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
This commit is contained in:
Younes Belkada 2024-01-15 14:48:07 +01:00 committed by GitHub
parent 64bdbd888c
commit 1b9a2e4c80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 159 additions and 1 deletions

View File

@ -89,7 +89,7 @@ from .utils import (
replace_return_docstrings,
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_sagemaker_mp_enabled,
@ -1172,6 +1172,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
model_tags = None
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
@ -1252,6 +1254,38 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")
def add_model_tags(self, tags: Union[List[str], str]) -> None:
r"""
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
not overwrite existing tags in the model.
Args:
tags (`Union[List[str], str]`):
The desired tags to inject in the model
Examples:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-cased")
model.add_model_tags(["custom", "custom-bert"])
# Push the model to your namespace with the name "my-custom-bert".
model.push_to_hub("my-custom-bert")
```
"""
if isinstance(tags, str):
tags = [tags]
if self.model_tags is None:
self.model_tags = []
for tag in tags:
if tag not in self.model_tags:
self.model_tags.append(tag)
@classmethod
def _from_config(cls, config, **kwargs):
"""
@ -2212,6 +2246,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
use_auth_token = kwargs.pop("use_auth_token", None)
ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)
if use_auth_token is not None:
warnings.warn(
@ -2438,6 +2473,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if push_to_hub:
# Eventually create an empty model card
model_card = create_and_tag_model_card(
repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
)
# Update model card if needed:
model_card.save(os.path.join(save_directory, "README.md"))
self._upload_modified_files(
save_directory,
repo_id,
@ -2446,6 +2489,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
token=token,
)
@wraps(PushToHubMixin.push_to_hub)
def push_to_hub(self, *args, **kwargs):
tags = self.model_tags if self.model_tags is not None else []
tags_kwargs = kwargs.get("tags", [])
if isinstance(tags_kwargs, str):
tags_kwargs = [tags_kwargs]
for tag in tags_kwargs:
if tag not in tags:
tags.append(tag)
if tags:
kwargs["tags"] = tags
return super().push_to_hub(*args, **kwargs)
def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.

View File

@ -3581,6 +3581,15 @@ class Trainer:
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
is_peft_library = library_name == "peft"
# Append existing tags in `tags`
existing_tags = ModelCard.load(model_card_filepath).data.tags
if tags is not None and existing_tags is not None:
if isinstance(tags, str):
tags = [tags]
for tag in existing_tags:
if tag not in tags:
tags.append(tag)
training_summary = TrainingSummary.from_trainer(
self,
language=language,
@ -3699,6 +3708,18 @@ class Trainer:
if not self.is_world_process_zero():
return
# Add additional tags in the case the model has already some tags and users pass
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
# from all models since Trainer does not call `model.push_to_hub`.
if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None:
# If it is a string, convert it to a list
if isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"]]
for model_tag in self.model.model_tags:
if model_tag not in kwargs["tags"]:
kwargs["tags"].append(model_tag)
self.create_model_card(model_name=model_name, **kwargs)
# Wait for the current upload to be finished.

View File

@ -33,6 +33,8 @@ import requests
from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd,
ModelCard,
ModelCardData,
constants,
create_branch,
create_commit,
@ -762,6 +764,7 @@ class PushToHubMixin:
safe_serialization: bool = True,
revision: str = None,
commit_description: str = None,
tags: Optional[List[str]] = None,
**deprecated_kwargs,
) -> str:
"""
@ -795,6 +798,8 @@ class PushToHubMixin:
Branch to push the uploaded files to.
commit_description (`str`, *optional*):
The description of the commit that will be created
tags (`List[str]`, *optional*):
List of tags to push on the Hub.
Examples:
@ -811,6 +816,7 @@ class PushToHubMixin:
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
@ -855,6 +861,11 @@ class PushToHubMixin:
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
)
# Create a new empty model card and eventually tag it
model_card = create_and_tag_model_card(
repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors
)
if use_temp_dir is None:
use_temp_dir = not os.path.isdir(working_dir)
@ -864,6 +875,9 @@ class PushToHubMixin:
# Save all files.
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
# Update model card if needed:
model_card.save(os.path.join(work_dir, "README.md"))
return self._upload_modified_files(
work_dir,
repo_id,
@ -1081,6 +1095,43 @@ def extract_info_from_url(url):
return {"repo": cache_repo, "revision": revision, "filename": filename}
def create_and_tag_model_card(
repo_id: str,
tags: Optional[List[str]] = None,
token: Optional[str] = None,
ignore_metadata_errors: bool = False,
):
"""
Creates or loads an existing model card and tags it.
Args:
repo_id (`str`):
The repo_id where to look for the model card.
tags (`List[str]`, *optional*):
The list of tags to add in the model card
token (`str`, *optional*):
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
ignore_metadata_errors (`str`):
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
the process. Use it at your own risk.
"""
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
except EntryNotFoundError:
# Otherwise create a simple model card from template
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
model_card = ModelCard.from_template(card_data, model_description=model_description)
if tags is not None:
for model_tag in tags:
if model_tag not in model_card.data.tags:
model_card.data.tags.append(model_tag)
return model_card
def clean_files_for(file):
"""
Remove, if they exist, file, file.json and file.lock

View File

@ -1435,6 +1435,11 @@ class ModelPushToHubTester(unittest.TestCase):
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
except HTTPError:
pass
@unittest.skip("This test is flaky")
def test_push_to_hub(self):
config = BertConfig(
@ -1522,6 +1527,28 @@ The commit description supports markdown synthax see:
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "CustomModel")
def test_push_to_hub_with_tags(self):
from huggingface_hub import ModelCard
new_tags = ["tag-1", "tag-2"]
CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class()
config = CustomConfig(hidden_size=32)
model = CustomModel(config)
self.assertTrue(model.model_tags is None)
model.add_model_tags(new_tags)
self.assertTrue(model.model_tags == new_tags)
model.push_to_hub("test-dynamic-model-with-tags", token=self._token)
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags")
self.assertEqual(loaded_model_card.data.tags, new_tags)
@require_torch
class AttentionMaskTester(unittest.TestCase):