mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
CLI: Add flag to push TF weights directly into main (#17720)
* Add flag to push weights directly into main
This commit is contained in:
parent
6ebeeeef81
commit
c3c62b5d2c
@ -45,7 +45,7 @@ def convert_command_factory(args: Namespace):
|
||||
|
||||
Returns: ServeCommand
|
||||
"""
|
||||
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights)
|
||||
return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push)
|
||||
|
||||
|
||||
class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
@ -76,14 +76,19 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
default="",
|
||||
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--new-weights",
|
||||
action="store_true",
|
||||
help="Optional flag to create new TensorFlow weights, even if they already exist.",
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--push",
|
||||
action="store_true",
|
||||
help="Optional flag to push the weights directly to `main` (requires permissions)",
|
||||
)
|
||||
train_parser.set_defaults(func=convert_command_factory)
|
||||
|
||||
@staticmethod
|
||||
@ -129,12 +134,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
|
||||
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
|
||||
|
||||
def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args):
|
||||
def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args):
|
||||
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
||||
self._model_name = model_name
|
||||
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
||||
self._no_pr = no_pr
|
||||
self._new_weights = new_weights
|
||||
self._no_pr = no_pr
|
||||
self._push = push
|
||||
|
||||
def get_text_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
|
||||
@ -234,7 +240,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
)
|
||||
|
||||
if not self._no_pr:
|
||||
if self._push:
|
||||
repo.git_add(auto_lfs_track=True)
|
||||
repo.git_commit("Add TF weights")
|
||||
repo.git_push(blocking=True) # this prints a progress bar with the upload
|
||||
self._logger.warn(f"TF weights pushed into {self._model_name}")
|
||||
elif not self._no_pr:
|
||||
# TODO: remove try/except when the upload to PR feature is released
|
||||
# (https://github.com/huggingface/huggingface_hub/pull/884)
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user