From c3c62b5d2c777ce50039323412599ff5f570ce3c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 15 Jun 2022 19:25:50 +0100 Subject: [PATCH] CLI: Add flag to push TF weights directly into main (#17720) * Add flag to push weights directly into main --- src/transformers/commands/pt_to_tf.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/transformers/commands/pt_to_tf.py b/src/transformers/commands/pt_to_tf.py index 77a822544f4..3a2465093c4 100644 --- a/src/transformers/commands/pt_to_tf.py +++ b/src/transformers/commands/pt_to_tf.py @@ -45,7 +45,7 @@ def convert_command_factory(args: Namespace): Returns: ServeCommand """ - return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights) + return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push) class PTtoTFCommand(BaseTransformersCLICommand): @@ -76,14 +76,19 @@ class PTtoTFCommand(BaseTransformersCLICommand): default="", help="Optional local directory of the model repository. Defaults to /tmp/{model_name}", ) - train_parser.add_argument( - "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." - ) train_parser.add_argument( "--new-weights", action="store_true", help="Optional flag to create new TensorFlow weights, even if they already exist.", ) + train_parser.add_argument( + "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." + ) + train_parser.add_argument( + "--push", + action="store_true", + help="Optional flag to push the weights directly to `main` (requires permissions)", + ) train_parser.set_defaults(func=convert_command_factory) @staticmethod @@ -129,12 +134,13 @@ class PTtoTFCommand(BaseTransformersCLICommand): return _find_pt_tf_differences(pt_outputs, tf_outputs, {}) - def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args): + def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args): self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._model_name = model_name self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) - self._no_pr = no_pr self._new_weights = new_weights + self._no_pr = no_pr + self._push = push def get_text_inputs(self): tokenizer = AutoTokenizer.from_pretrained(self._local_dir) @@ -234,7 +240,12 @@ class PTtoTFCommand(BaseTransformersCLICommand): ) ) - if not self._no_pr: + if self._push: + repo.git_add(auto_lfs_track=True) + repo.git_commit("Add TF weights") + repo.git_push(blocking=True) # this prints a progress bar with the upload + self._logger.warn(f"TF weights pushed into {self._model_name}") + elif not self._no_pr: # TODO: remove try/except when the upload to PR feature is released # (https://github.com/huggingface/huggingface_hub/pull/884) try: