[TF] Fix creating a PR while pushing in TF framework (#21968)

* add create pr arg

* style

* add test

* ficup

* update test

* last nit fix typo

* add `is_pt_tf_cross_test` marker for the tsts
This commit is contained in:
Arthur 2023-03-07 17:32:08 +01:00 committed by GitHub
parent d128f2ffab
commit 2156662dea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 10 deletions

View File

@ -2905,9 +2905,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_temp_dir: Optional[bool] = None, use_temp_dir: Optional[bool] = None,
commit_message: Optional[str] = None, commit_message: Optional[str] = None,
private: Optional[bool] = None, private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
max_shard_size: Optional[Union[int, str]] = "10GB", max_shard_size: Optional[Union[int, str]] = "10GB",
**model_card_kwargs, use_auth_token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
**base_model_card_args,
) -> str: ) -> str:
""" """
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`. Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
@ -2931,8 +2932,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
will then be each of size lower than this size. If expressed as a string, needs to be digits followed will then be each of size lower than this size. If expressed as a string, needs to be digits followed
by a unit (like `"5MB"`). by a unit (like `"5MB"`).
model_card_kwargs: create_pr (`bool`, *optional*, defaults to `False`):
Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method. Whether or not to create a PR with the uploaded files or directly commit.
Examples: Examples:
@ -2948,15 +2949,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
model.push_to_hub("huggingface/my-finetuned-bert") model.push_to_hub("huggingface/my-finetuned-bert")
``` ```
""" """
if "repo_path_or_name" in model_card_kwargs: if "repo_path_or_name" in base_model_card_args:
warnings.warn( warnings.warn(
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use " "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
"`repo_id` instead." "`repo_id` instead."
) )
repo_id = model_card_kwargs.pop("repo_path_or_name") repo_id = base_model_card_args.pop("repo_path_or_name")
# Deprecation warning will be sent after for repo_url and organization # Deprecation warning will be sent after for repo_url and organization
repo_url = model_card_kwargs.pop("repo_url", None) repo_url = base_model_card_args.pop("repo_url", None)
organization = model_card_kwargs.pop("organization", None) organization = base_model_card_args.pop("organization", None)
if os.path.isdir(repo_id): if os.path.isdir(repo_id):
working_dir = repo_id working_dir = repo_id
@ -2982,11 +2983,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"output_dir": work_dir, "output_dir": work_dir,
"model_name": Path(repo_id).name, "model_name": Path(repo_id).name,
} }
base_model_card_args.update(model_card_kwargs) base_model_card_args.update(base_model_card_args)
self.create_model_card(**base_model_card_args) self.create_model_card(**base_model_card_args)
self._upload_modified_files( self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token work_dir,
repo_id,
files_timestamps,
commit_message=commit_message,
token=use_auth_token,
create_pr=create_pr,
) )
@classmethod @classmethod

View File

@ -85,6 +85,7 @@ if is_tf_available():
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig, BertConfig,
PreTrainedModel,
PushToHubCallback, PushToHubCallback,
RagRetriever, RagRetriever,
TFAutoModel, TFAutoModel,
@ -92,6 +93,7 @@ if is_tf_available():
TFBertForMaskedLM, TFBertForMaskedLM,
TFBertForSequenceClassification, TFBertForSequenceClassification,
TFBertModel, TFBertModel,
TFPreTrainedModel,
TFRagModel, TFRagModel,
TFSharedEmbeddings, TFSharedEmbeddings,
) )
@ -2466,6 +2468,7 @@ class TFModelPushToHubTester(unittest.TestCase):
break break
self.assertTrue(models_equal) self.assertTrue(models_equal)
@is_pt_tf_cross_test
def test_push_to_hub_callback(self): def test_push_to_hub_callback(self):
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
@ -2489,6 +2492,12 @@ class TFModelPushToHubTester(unittest.TestCase):
break break
self.assertTrue(models_equal) self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
tf_push_to_hub_params.pop("base_model_card_args")
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
pt_push_to_hub_params.pop("deprecated_kwargs")
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37