diff --git a/utils/update_metadata.py b/utils/update_metadata.py index c34d8a39237..e760ce2fd07 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -21,7 +21,7 @@ import tempfile import pandas as pd from datasets import Dataset -from huggingface_hub import Repository +from huggingface_hub import hf_hub_download, upload_folder from transformers.utils import direct_transformers_import @@ -209,43 +209,49 @@ def update_metadata(token, commit_sha): """ Update the metadata for the Transformers repo. """ - with tempfile.TemporaryDirectory() as tmp_dir: - repo = Repository(tmp_dir, clone_from="huggingface/transformers-metadata", repo_type="dataset", token=token) + frameworks_table = get_frameworks_table() + frameworks_dataset = Dataset.from_pandas(frameworks_table) - frameworks_table = get_frameworks_table() - frameworks_dataset = Dataset.from_pandas(frameworks_table) - frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json")) + resolved_tags_file = hf_hub_download( + "huggingface/transformers-metadata", "pipeline_tags.json", repo_type="dataset", token=token + ) + tags_dataset = Dataset.from_json(resolved_tags_file) + table = { + tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"]) + for i in range(len(tags_dataset)) + } + table = update_pipeline_and_auto_class_table(table) - tags_dataset = Dataset.from_json(os.path.join(tmp_dir, "pipeline_tags.json")) - table = { - tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"]) - for i in range(len(tags_dataset)) + # Sort the model classes to avoid some nondeterministic updates to create false update commits. + model_classes = sorted(table.keys()) + tags_table = pd.DataFrame( + { + "model_class": model_classes, + "pipeline_tag": [table[m][0] for m in model_classes], + "auto_class": [table[m][1] for m in model_classes], } - table = update_pipeline_and_auto_class_table(table) + ) + tags_dataset = Dataset.from_pandas(tags_table) - # Sort the model classes to avoid some nondeterministic updates to create false update commits. - model_classes = sorted(table.keys()) - tags_table = pd.DataFrame( - { - "model_class": model_classes, - "pipeline_tag": [table[m][0] for m in model_classes], - "auto_class": [table[m][1] for m in model_classes], - } - ) - tags_dataset = Dataset.from_pandas(tags_table) + with tempfile.TemporaryDirectory() as tmp_dir: + frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json")) tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json")) - if repo.is_repo_clean(): - print("Nothing to commit!") + if commit_sha is not None: + commit_message = ( + f"Update with commit {commit_sha}\n\nSee: " + f"https://github.com/huggingface/transformers/commit/{commit_sha}" + ) else: - if commit_sha is not None: - commit_message = ( - f"Update with commit {commit_sha}\n\nSee: " - f"https://github.com/huggingface/transformers/commit/{commit_sha}" - ) - else: - commit_message = "Update" - repo.push_to_hub(commit_message) + commit_message = "Update" + + upload_folder( + repo_id="huggingface/transformers-metadata", + folder_path=tmp_dir, + repo_type="dataset", + token=token, + commit_message=commit_message, + ) def check_pipeline_tags():