mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
CLI: use hub's create_commit
(#17755)
* use create_commit * better commit message and description * touch setup.py to trigger cache update * add hub version gating
This commit is contained in:
parent
c366ce1011
commit
0d0c392c45
2
.github/workflows/add-model-like.yml
vendored
2
.github/workflows/add-model-like.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
id: cache
|
||||
with:
|
||||
path: ~/venv/
|
||||
key: v3-tests_model_like-${{ hashFiles('setup.py') }}
|
||||
key: v4-tests_model_like-${{ hashFiles('setup.py') }}
|
||||
|
||||
- name: Create virtual environment on cache miss
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
|
2
.github/workflows/model-templates.yml
vendored
2
.github/workflows/model-templates.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
id: cache
|
||||
with:
|
||||
path: ~/venv/
|
||||
key: v3-tests_templates-${{ hashFiles('setup.py') }}
|
||||
key: v4-tests_templates-${{ hashFiles('setup.py') }}
|
||||
|
||||
- name: Create virtual environment on cache miss
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
|
2
.github/workflows/update_metdata.yml
vendored
2
.github/workflows/update_metdata.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
id: cache
|
||||
with:
|
||||
path: ~/venv/
|
||||
key: v2-metadata-${{ hashFiles('setup.py') }}
|
||||
key: v3-metadata-${{ hashFiles('setup.py') }}
|
||||
|
||||
- name: Create virtual environment on cache miss
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
|
@ -18,8 +18,9 @@ from importlib import import_module
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
|
||||
from huggingface_hub import Repository, upload_file
|
||||
import huggingface_hub
|
||||
|
||||
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
|
||||
from ..utils import logging
|
||||
@ -45,7 +46,9 @@ def convert_command_factory(args: Namespace):
|
||||
|
||||
Returns: ServeCommand
|
||||
"""
|
||||
return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push)
|
||||
return PTtoTFCommand(
|
||||
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description
|
||||
)
|
||||
|
||||
|
||||
class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
@ -89,6 +92,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
action="store_true",
|
||||
help="Optional flag to push the weights directly to `main` (requires permissions)",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--extra-commit-description",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
|
||||
)
|
||||
train_parser.set_defaults(func=convert_command_factory)
|
||||
|
||||
@staticmethod
|
||||
@ -134,13 +143,23 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
|
||||
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
|
||||
|
||||
def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
local_dir: str,
|
||||
new_weights: bool,
|
||||
no_pr: bool,
|
||||
push: bool,
|
||||
extra_commit_description: str,
|
||||
*args
|
||||
):
|
||||
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
||||
self._model_name = model_name
|
||||
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
||||
self._new_weights = new_weights
|
||||
self._no_pr = no_pr
|
||||
self._push = push
|
||||
self._extra_commit_description = extra_commit_description
|
||||
|
||||
def get_text_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
|
||||
@ -170,10 +189,17 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
return pt_input, tf_input
|
||||
|
||||
def run(self):
|
||||
if version.parse(huggingface_hub.__version__) < version.parse("0.8.1"):
|
||||
raise ImportError(
|
||||
"The huggingface_hub version must be >= 0.8.1 to use this command. Please update your huggingface_hub"
|
||||
" installation."
|
||||
)
|
||||
else:
|
||||
from huggingface_hub import Repository, create_commit
|
||||
from huggingface_hub._commit_api import CommitOperationAdd
|
||||
|
||||
# Fetch remote data
|
||||
# TODO: implement a solution to pull a specific PR/commit, so we can use this CLI to validate pushes.
|
||||
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
|
||||
repo.git_pull() # in case the repo already exists locally, but with an older commit
|
||||
|
||||
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
|
||||
config = AutoConfig.from_pretrained(self._local_dir)
|
||||
@ -240,32 +266,29 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
)
|
||||
|
||||
commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
|
||||
if self._push:
|
||||
repo.git_add(auto_lfs_track=True)
|
||||
repo.git_commit("Add TF weights")
|
||||
repo.git_commit(commit_message)
|
||||
repo.git_push(blocking=True) # this prints a progress bar with the upload
|
||||
self._logger.warn(f"TF weights pushed into {self._model_name}")
|
||||
elif not self._no_pr:
|
||||
# TODO: remove try/except when the upload to PR feature is released
|
||||
# (https://github.com/huggingface/huggingface_hub/pull/884)
|
||||
try:
|
||||
self._logger.warn("Uploading the weights into a new PR...")
|
||||
hub_pr_url = upload_file(
|
||||
path_or_fileobj=tf_weights_path,
|
||||
path_in_repo=TF_WEIGHTS_NAME,
|
||||
repo_id=self._model_name,
|
||||
create_pr=True,
|
||||
pr_commit_summary="Add TF weights",
|
||||
pr_commit_description=(
|
||||
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
|
||||
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
|
||||
f" difference={max_crossload_diff:.3e}; Maximum converted output"
|
||||
f" difference={max_conversion_diff:.3e}."
|
||||
),
|
||||
)
|
||||
self._logger.warn(f"PR open in {hub_pr_url}")
|
||||
except TypeError:
|
||||
self._logger.warn(
|
||||
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
|
||||
f" uploading the file in {tf_weights_path}"
|
||||
)
|
||||
self._logger.warn("Uploading the weights into a new PR...")
|
||||
commit_descrition = (
|
||||
"Model converted by the [`transformers`' `pt_to_tf`"
|
||||
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)."
|
||||
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart."
|
||||
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output"
|
||||
f" difference={max_conversion_diff:.3e}."
|
||||
)
|
||||
if self._extra_commit_description:
|
||||
commit_descrition += "\n\n" + self._extra_commit_description
|
||||
hub_pr_url = create_commit(
|
||||
repo_id=self._model_name,
|
||||
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)],
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_descrition,
|
||||
repo_type="model",
|
||||
create_pr=True,
|
||||
)
|
||||
self._logger.warn(f"PR open in {hub_pr_url}")
|
||||
|
Loading…
Reference in New Issue
Block a user