CLI: use hub's create_commit (#17755)

* use create_commit

* better commit message and description

* touch setup.py to trigger cache update

* add hub version gating
This commit is contained in:
Joao Gante 2022-06-22 16:50:21 +01:00 committed by GitHub
parent c366ce1011
commit 0d0c392c45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 32 deletions

View File

@ -27,7 +27,7 @@ jobs:
id: cache
with:
path: ~/venv/
key: v3-tests_model_like-${{ hashFiles('setup.py') }}
key: v4-tests_model_like-${{ hashFiles('setup.py') }}
- name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true'

View File

@ -21,7 +21,7 @@ jobs:
id: cache
with:
path: ~/venv/
key: v3-tests_templates-${{ hashFiles('setup.py') }}
key: v4-tests_templates-${{ hashFiles('setup.py') }}
- name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true'

View File

@ -21,7 +21,7 @@ jobs:
id: cache
with:
path: ~/venv/
key: v2-metadata-${{ hashFiles('setup.py') }}
key: v3-metadata-${{ hashFiles('setup.py') }}
- name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true'

View File

@ -18,8 +18,9 @@ from importlib import import_module
import numpy as np
from datasets import load_dataset
from packaging import version
from huggingface_hub import Repository, upload_file
import huggingface_hub
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
from ..utils import logging
@ -45,7 +46,9 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand
"""
return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push)
return PTtoTFCommand(
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description
)
class PTtoTFCommand(BaseTransformersCLICommand):
@ -89,6 +92,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
action="store_true",
help="Optional flag to push the weights directly to `main` (requires permissions)",
)
train_parser.add_argument(
"--extra-commit-description",
type=str,
default="",
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
)
train_parser.set_defaults(func=convert_command_factory)
@staticmethod
@ -134,13 +143,23 @@ class PTtoTFCommand(BaseTransformersCLICommand):
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args):
def __init__(
self,
model_name: str,
local_dir: str,
new_weights: bool,
no_pr: bool,
push: bool,
extra_commit_description: str,
*args
):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._new_weights = new_weights
self._no_pr = no_pr
self._push = push
self._extra_commit_description = extra_commit_description
def get_text_inputs(self):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
@ -170,10 +189,17 @@ class PTtoTFCommand(BaseTransformersCLICommand):
return pt_input, tf_input
def run(self):
if version.parse(huggingface_hub.__version__) < version.parse("0.8.1"):
raise ImportError(
"The huggingface_hub version must be >= 0.8.1 to use this command. Please update your huggingface_hub"
" installation."
)
else:
from huggingface_hub import Repository, create_commit
from huggingface_hub._commit_api import CommitOperationAdd
# Fetch remote data
# TODO: implement a solution to pull a specific PR/commit, so we can use this CLI to validate pushes.
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
repo.git_pull() # in case the repo already exists locally, but with an older commit
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
config = AutoConfig.from_pretrained(self._local_dir)
@ -240,32 +266,29 @@ class PTtoTFCommand(BaseTransformersCLICommand):
)
)
commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
if self._push:
repo.git_add(auto_lfs_track=True)
repo.git_commit("Add TF weights")
repo.git_commit(commit_message)
repo.git_push(blocking=True) # this prints a progress bar with the upload
self._logger.warn(f"TF weights pushed into {self._model_name}")
elif not self._no_pr:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try:
self._logger.warn("Uploading the weights into a new PR...")
hub_pr_url = upload_file(
path_or_fileobj=tf_weights_path,
path_in_repo=TF_WEIGHTS_NAME,
repo_id=self._model_name,
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f" difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
except TypeError:
self._logger.warn(
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
f" uploading the file in {tf_weights_path}"
)
self._logger.warn("Uploading the weights into a new PR...")
commit_descrition = (
"Model converted by the [`transformers`' `pt_to_tf`"
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)."
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart."
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
)
if self._extra_commit_description:
commit_descrition += "\n\n" + self._extra_commit_description
hub_pr_url = create_commit(
repo_id=self._model_name,
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)],
commit_message=commit_message,
commit_description=commit_descrition,
repo_type="model",
create_pr=True,
)
self._logger.warn(f"PR open in {hub_pr_url}")