From 2156662deafc2ce3da8fd5e7bf54089f8fcc0b05 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 7 Mar 2023 17:32:08 +0100 Subject: [PATCH] [TF] Fix creating a PR while pushing in TF framework (#21968) * add create pr arg * style * add test * ficup * update test * last nit fix typo * add `is_pt_tf_cross_test` marker for the tsts --- src/transformers/modeling_tf_utils.py | 26 ++++++++++++++++---------- tests/test_modeling_tf_common.py | 9 +++++++++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 29c4d67510c..dcedd2946b8 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2905,9 +2905,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu use_temp_dir: Optional[bool] = None, commit_message: Optional[str] = None, private: Optional[bool] = None, - use_auth_token: Optional[Union[bool, str]] = None, max_shard_size: Optional[Union[int, str]] = "10GB", - **model_card_kwargs, + use_auth_token: Optional[Union[bool, str]] = None, + create_pr: bool = False, + **base_model_card_args, ) -> str: """ Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`. @@ -2931,8 +2932,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). - model_card_kwargs: - Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method. + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. Examples: @@ -2948,15 +2949,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu model.push_to_hub("huggingface/my-finetuned-bert") ``` """ - if "repo_path_or_name" in model_card_kwargs: + if "repo_path_or_name" in base_model_card_args: warnings.warn( "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use " "`repo_id` instead." ) - repo_id = model_card_kwargs.pop("repo_path_or_name") + repo_id = base_model_card_args.pop("repo_path_or_name") # Deprecation warning will be sent after for repo_url and organization - repo_url = model_card_kwargs.pop("repo_url", None) - organization = model_card_kwargs.pop("organization", None) + repo_url = base_model_card_args.pop("repo_url", None) + organization = base_model_card_args.pop("organization", None) if os.path.isdir(repo_id): working_dir = repo_id @@ -2982,11 +2983,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu "output_dir": work_dir, "model_name": Path(repo_id).name, } - base_model_card_args.update(model_card_kwargs) + base_model_card_args.update(base_model_card_args) self.create_model_card(**base_model_card_args) self._upload_modified_files( - work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token + work_dir, + repo_id, + files_timestamps, + commit_message=commit_message, + token=use_auth_token, + create_pr=create_pr, ) @classmethod diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 9cfd314dfcb..42db9aea269 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -85,6 +85,7 @@ if is_tf_available(): TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, BertConfig, + PreTrainedModel, PushToHubCallback, RagRetriever, TFAutoModel, @@ -92,6 +93,7 @@ if is_tf_available(): TFBertForMaskedLM, TFBertForSequenceClassification, TFBertModel, + TFPreTrainedModel, TFRagModel, TFSharedEmbeddings, ) @@ -2466,6 +2468,7 @@ class TFModelPushToHubTester(unittest.TestCase): break self.assertTrue(models_equal) + @is_pt_tf_cross_test def test_push_to_hub_callback(self): config = BertConfig( vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 @@ -2489,6 +2492,12 @@ class TFModelPushToHubTester(unittest.TestCase): break self.assertTrue(models_equal) + tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters) + tf_push_to_hub_params.pop("base_model_card_args") + pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters) + pt_push_to_hub_params.pop("deprecated_kwargs") + self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params) + def test_push_to_hub_in_organization(self): config = BertConfig( vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37