mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[core
/ FEAT] Add the possibility to push custom tags using PreTrainedModel
itself (#28405)
* v1 tags * remove unneeded conversion * v2 * rm unneeded warning * add more utility methods * Update src/transformers/utils/hub.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * more enhancements * oops * merge tags * clean up * revert unneeded change * add extensive docs * more docs * more kwargs * add test * oops * fix test * Update src/transformers/modeling_utils.py Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> * Update src/transformers/utils/hub.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/modeling_utils.py * Update src/transformers/trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add more conditions * more logic --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
This commit is contained in:
parent
64bdbd888c
commit
1b9a2e4c80
@ -89,7 +89,7 @@ from .utils import (
|
||||
replace_return_docstrings,
|
||||
strtobool,
|
||||
)
|
||||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files
|
||||
from .utils.import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
is_sagemaker_mp_enabled,
|
||||
@ -1172,6 +1172,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
main_input_name = "input_ids"
|
||||
model_tags = None
|
||||
|
||||
_auto_class = None
|
||||
_no_split_modules = None
|
||||
_skip_keys_device_placement = None
|
||||
@ -1252,6 +1254,38 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
||||
delattr(self.config, "gradient_checkpointing")
|
||||
|
||||
def add_model_tags(self, tags: Union[List[str], str]) -> None:
|
||||
r"""
|
||||
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
|
||||
not overwrite existing tags in the model.
|
||||
|
||||
Args:
|
||||
tags (`Union[List[str], str]`):
|
||||
The desired tags to inject in the model
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained("bert-base-cased")
|
||||
|
||||
model.add_model_tags(["custom", "custom-bert"])
|
||||
|
||||
# Push the model to your namespace with the name "my-custom-bert".
|
||||
model.push_to_hub("my-custom-bert")
|
||||
```
|
||||
"""
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if self.model_tags is None:
|
||||
self.model_tags = []
|
||||
|
||||
for tag in tags:
|
||||
if tag not in self.model_tags:
|
||||
self.model_tags.append(tag)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, **kwargs):
|
||||
"""
|
||||
@ -2212,6 +2246,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@ -2438,6 +2473,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
# Eventually create an empty model card
|
||||
model_card = create_and_tag_model_card(
|
||||
repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
|
||||
)
|
||||
|
||||
# Update model card if needed:
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
self._upload_modified_files(
|
||||
save_directory,
|
||||
repo_id,
|
||||
@ -2446,6 +2489,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
token=token,
|
||||
)
|
||||
|
||||
@wraps(PushToHubMixin.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
tags = self.model_tags if self.model_tags is not None else []
|
||||
|
||||
tags_kwargs = kwargs.get("tags", [])
|
||||
if isinstance(tags_kwargs, str):
|
||||
tags_kwargs = [tags_kwargs]
|
||||
|
||||
for tag in tags_kwargs:
|
||||
if tag not in tags:
|
||||
tags.append(tag)
|
||||
|
||||
if tags:
|
||||
kwargs["tags"] = tags
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
def get_memory_footprint(self, return_buffers=True):
|
||||
r"""
|
||||
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
|
||||
|
@ -3581,6 +3581,15 @@ class Trainer:
|
||||
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
|
||||
is_peft_library = library_name == "peft"
|
||||
|
||||
# Append existing tags in `tags`
|
||||
existing_tags = ModelCard.load(model_card_filepath).data.tags
|
||||
if tags is not None and existing_tags is not None:
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
for tag in existing_tags:
|
||||
if tag not in tags:
|
||||
tags.append(tag)
|
||||
|
||||
training_summary = TrainingSummary.from_trainer(
|
||||
self,
|
||||
language=language,
|
||||
@ -3699,6 +3708,18 @@ class Trainer:
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
# Add additional tags in the case the model has already some tags and users pass
|
||||
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
|
||||
# from all models since Trainer does not call `model.push_to_hub`.
|
||||
if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None:
|
||||
# If it is a string, convert it to a list
|
||||
if isinstance(kwargs["tags"], str):
|
||||
kwargs["tags"] = [kwargs["tags"]]
|
||||
|
||||
for model_tag in self.model.model_tags:
|
||||
if model_tag not in kwargs["tags"]:
|
||||
kwargs["tags"].append(model_tag)
|
||||
|
||||
self.create_model_card(model_name=model_name, **kwargs)
|
||||
|
||||
# Wait for the current upload to be finished.
|
||||
|
@ -33,6 +33,8 @@ import requests
|
||||
from huggingface_hub import (
|
||||
_CACHED_NO_EXIST,
|
||||
CommitOperationAdd,
|
||||
ModelCard,
|
||||
ModelCardData,
|
||||
constants,
|
||||
create_branch,
|
||||
create_commit,
|
||||
@ -762,6 +764,7 @@ class PushToHubMixin:
|
||||
safe_serialization: bool = True,
|
||||
revision: str = None,
|
||||
commit_description: str = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**deprecated_kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
@ -795,6 +798,8 @@ class PushToHubMixin:
|
||||
Branch to push the uploaded files to.
|
||||
commit_description (`str`, *optional*):
|
||||
The description of the commit that will be created
|
||||
tags (`List[str]`, *optional*):
|
||||
List of tags to push on the Hub.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -811,6 +816,7 @@ class PushToHubMixin:
|
||||
```
|
||||
"""
|
||||
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
||||
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
@ -855,6 +861,11 @@ class PushToHubMixin:
|
||||
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
|
||||
)
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = create_and_tag_model_card(
|
||||
repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors
|
||||
)
|
||||
|
||||
if use_temp_dir is None:
|
||||
use_temp_dir = not os.path.isdir(working_dir)
|
||||
|
||||
@ -864,6 +875,9 @@ class PushToHubMixin:
|
||||
# Save all files.
|
||||
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
|
||||
|
||||
# Update model card if needed:
|
||||
model_card.save(os.path.join(work_dir, "README.md"))
|
||||
|
||||
return self._upload_modified_files(
|
||||
work_dir,
|
||||
repo_id,
|
||||
@ -1081,6 +1095,43 @@ def extract_info_from_url(url):
|
||||
return {"repo": cache_repo, "revision": revision, "filename": filename}
|
||||
|
||||
|
||||
def create_and_tag_model_card(
|
||||
repo_id: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
token: Optional[str] = None,
|
||||
ignore_metadata_errors: bool = False,
|
||||
):
|
||||
"""
|
||||
Creates or loads an existing model card and tags it.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The repo_id where to look for the model card.
|
||||
tags (`List[str]`, *optional*):
|
||||
The list of tags to add in the model card
|
||||
token (`str`, *optional*):
|
||||
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
|
||||
ignore_metadata_errors (`str`):
|
||||
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
|
||||
the process. Use it at your own risk.
|
||||
"""
|
||||
try:
|
||||
# Check if the model card is present on the remote repo
|
||||
model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
|
||||
except EntryNotFoundError:
|
||||
# Otherwise create a simple model card from template
|
||||
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
|
||||
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
|
||||
model_card = ModelCard.from_template(card_data, model_description=model_description)
|
||||
|
||||
if tags is not None:
|
||||
for model_tag in tags:
|
||||
if model_tag not in model_card.data.tags:
|
||||
model_card.data.tags.append(model_tag)
|
||||
|
||||
return model_card
|
||||
|
||||
|
||||
def clean_files_for(file):
|
||||
"""
|
||||
Remove, if they exist, file, file.json and file.lock
|
||||
|
@ -1435,6 +1435,11 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
@unittest.skip("This test is flaky")
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
@ -1522,6 +1527,28 @@ The commit description supports markdown synthax see:
|
||||
new_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
self.assertEqual(new_model.__class__.__name__, "CustomModel")
|
||||
|
||||
def test_push_to_hub_with_tags(self):
|
||||
from huggingface_hub import ModelCard
|
||||
|
||||
new_tags = ["tag-1", "tag-2"]
|
||||
|
||||
CustomConfig.register_for_auto_class()
|
||||
CustomModel.register_for_auto_class()
|
||||
|
||||
config = CustomConfig(hidden_size=32)
|
||||
model = CustomModel(config)
|
||||
|
||||
self.assertTrue(model.model_tags is None)
|
||||
|
||||
model.add_model_tags(new_tags)
|
||||
|
||||
self.assertTrue(model.model_tags == new_tags)
|
||||
|
||||
model.push_to_hub("test-dynamic-model-with-tags", token=self._token)
|
||||
|
||||
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags")
|
||||
self.assertEqual(loaded_model_card.data.tags, new_tags)
|
||||
|
||||
|
||||
@require_torch
|
||||
class AttentionMaskTester(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user