[TF] Fix creating a PR while pushing in TF framework (#21968)

* add create pr arg

* style

* add test

* ficup

* update test

* last nit fix typo

* add `is_pt_tf_cross_test` marker for the tsts
This commit is contained in:
Arthur 2023-03-07 17:32:08 +01:00 committed by GitHub
parent d128f2ffab
commit 2156662dea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 10 deletions

View File

@ -2905,9 +2905,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_temp_dir: Optional[bool] = None,
commit_message: Optional[str] = None,
private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
max_shard_size: Optional[Union[int, str]] = "10GB",
**model_card_kwargs,
use_auth_token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
**base_model_card_args,
) -> str:
"""
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
@ -2931,8 +2932,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
will then be each of size lower than this size. If expressed as a string, needs to be digits followed
by a unit (like `"5MB"`).
model_card_kwargs:
Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method.
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
Examples:
@ -2948,15 +2949,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
model.push_to_hub("huggingface/my-finetuned-bert")
```
"""
if "repo_path_or_name" in model_card_kwargs:
if "repo_path_or_name" in base_model_card_args:
warnings.warn(
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
"`repo_id` instead."
)
repo_id = model_card_kwargs.pop("repo_path_or_name")
repo_id = base_model_card_args.pop("repo_path_or_name")
# Deprecation warning will be sent after for repo_url and organization
repo_url = model_card_kwargs.pop("repo_url", None)
organization = model_card_kwargs.pop("organization", None)
repo_url = base_model_card_args.pop("repo_url", None)
organization = base_model_card_args.pop("organization", None)
if os.path.isdir(repo_id):
working_dir = repo_id
@ -2982,11 +2983,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"output_dir": work_dir,
"model_name": Path(repo_id).name,
}
base_model_card_args.update(model_card_kwargs)
base_model_card_args.update(base_model_card_args)
self.create_model_card(**base_model_card_args)
self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token
work_dir,
repo_id,
files_timestamps,
commit_message=commit_message,
token=use_auth_token,
create_pr=create_pr,
)
@classmethod

View File

@ -85,6 +85,7 @@ if is_tf_available():
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
PreTrainedModel,
PushToHubCallback,
RagRetriever,
TFAutoModel,
@ -92,6 +93,7 @@ if is_tf_available():
TFBertForMaskedLM,
TFBertForSequenceClassification,
TFBertModel,
TFPreTrainedModel,
TFRagModel,
TFSharedEmbeddings,
)
@ -2466,6 +2468,7 @@ class TFModelPushToHubTester(unittest.TestCase):
break
self.assertTrue(models_equal)
@is_pt_tf_cross_test
def test_push_to_hub_callback(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
@ -2489,6 +2492,12 @@ class TFModelPushToHubTester(unittest.TestCase):
break
self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
tf_push_to_hub_params.pop("base_model_card_args")
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
pt_push_to_hub_params.pop("deprecated_kwargs")
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37