Update the update metadata job to use upload_folder (#23917)

This commit is contained in:
Sylvain Gugger 2023-05-31 14:10:14 -04:00 committed by GitHub
parent 3ff443a6d9
commit 4aa13224a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,7 +21,7 @@ import tempfile
import pandas as pd import pandas as pd
from datasets import Dataset from datasets import Dataset
from huggingface_hub import Repository from huggingface_hub import hf_hub_download, upload_folder
from transformers.utils import direct_transformers_import from transformers.utils import direct_transformers_import
@ -209,43 +209,49 @@ def update_metadata(token, commit_sha):
""" """
Update the metadata for the Transformers repo. Update the metadata for the Transformers repo.
""" """
with tempfile.TemporaryDirectory() as tmp_dir: frameworks_table = get_frameworks_table()
repo = Repository(tmp_dir, clone_from="huggingface/transformers-metadata", repo_type="dataset", token=token) frameworks_dataset = Dataset.from_pandas(frameworks_table)
frameworks_table = get_frameworks_table() resolved_tags_file = hf_hub_download(
frameworks_dataset = Dataset.from_pandas(frameworks_table) "huggingface/transformers-metadata", "pipeline_tags.json", repo_type="dataset", token=token
frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json")) )
tags_dataset = Dataset.from_json(resolved_tags_file)
table = {
tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"])
for i in range(len(tags_dataset))
}
table = update_pipeline_and_auto_class_table(table)
tags_dataset = Dataset.from_json(os.path.join(tmp_dir, "pipeline_tags.json")) # Sort the model classes to avoid some nondeterministic updates to create false update commits.
table = { model_classes = sorted(table.keys())
tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"]) tags_table = pd.DataFrame(
for i in range(len(tags_dataset)) {
"model_class": model_classes,
"pipeline_tag": [table[m][0] for m in model_classes],
"auto_class": [table[m][1] for m in model_classes],
} }
table = update_pipeline_and_auto_class_table(table) )
tags_dataset = Dataset.from_pandas(tags_table)
# Sort the model classes to avoid some nondeterministic updates to create false update commits. with tempfile.TemporaryDirectory() as tmp_dir:
model_classes = sorted(table.keys()) frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
tags_table = pd.DataFrame(
{
"model_class": model_classes,
"pipeline_tag": [table[m][0] for m in model_classes],
"auto_class": [table[m][1] for m in model_classes],
}
)
tags_dataset = Dataset.from_pandas(tags_table)
tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json")) tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
if repo.is_repo_clean(): if commit_sha is not None:
print("Nothing to commit!") commit_message = (
f"Update with commit {commit_sha}\n\nSee: "
f"https://github.com/huggingface/transformers/commit/{commit_sha}"
)
else: else:
if commit_sha is not None: commit_message = "Update"
commit_message = (
f"Update with commit {commit_sha}\n\nSee: " upload_folder(
f"https://github.com/huggingface/transformers/commit/{commit_sha}" repo_id="huggingface/transformers-metadata",
) folder_path=tmp_dir,
else: repo_type="dataset",
commit_message = "Update" token=token,
repo.push_to_hub(commit_message) commit_message=commit_message,
)
def check_pipeline_tags(): def check_pipeline_tags():