Update the update metadata job to use upload_folder (#23917)

This commit is contained in:
Sylvain Gugger 2023-05-31 14:10:14 -04:00 committed by GitHub
parent 3ff443a6d9
commit 4aa13224a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,7 +21,7 @@ import tempfile
import pandas as pd import pandas as pd
from datasets import Dataset from datasets import Dataset
from huggingface_hub import Repository from huggingface_hub import hf_hub_download, upload_folder
from transformers.utils import direct_transformers_import from transformers.utils import direct_transformers_import
@ -209,14 +209,13 @@ def update_metadata(token, commit_sha):
""" """
Update the metadata for the Transformers repo. Update the metadata for the Transformers repo.
""" """
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from="huggingface/transformers-metadata", repo_type="dataset", token=token)
frameworks_table = get_frameworks_table() frameworks_table = get_frameworks_table()
frameworks_dataset = Dataset.from_pandas(frameworks_table) frameworks_dataset = Dataset.from_pandas(frameworks_table)
frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
tags_dataset = Dataset.from_json(os.path.join(tmp_dir, "pipeline_tags.json")) resolved_tags_file = hf_hub_download(
"huggingface/transformers-metadata", "pipeline_tags.json", repo_type="dataset", token=token
)
tags_dataset = Dataset.from_json(resolved_tags_file)
table = { table = {
tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"]) tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"])
for i in range(len(tags_dataset)) for i in range(len(tags_dataset))
@ -233,11 +232,11 @@ def update_metadata(token, commit_sha):
} }
) )
tags_dataset = Dataset.from_pandas(tags_table) tags_dataset = Dataset.from_pandas(tags_table)
with tempfile.TemporaryDirectory() as tmp_dir:
frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json")) tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
if repo.is_repo_clean():
print("Nothing to commit!")
else:
if commit_sha is not None: if commit_sha is not None:
commit_message = ( commit_message = (
f"Update with commit {commit_sha}\n\nSee: " f"Update with commit {commit_sha}\n\nSee: "
@ -245,7 +244,14 @@ def update_metadata(token, commit_sha):
) )
else: else:
commit_message = "Update" commit_message = "Update"
repo.push_to_hub(commit_message)
upload_folder(
repo_id="huggingface/transformers-metadata",
folder_path=tmp_dir,
repo_type="dataset",
token=token,
commit_message=commit_message,
)
def check_pipeline_tags(): def check_pipeline_tags():