mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[TF] Fix creating a PR while pushing in TF framework (#21968)
* add create pr arg * style * add test * ficup * update test * last nit fix typo * add `is_pt_tf_cross_test` marker for the tsts
This commit is contained in:
parent
d128f2ffab
commit
2156662dea
@ -2905,9 +2905,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
use_temp_dir: Optional[bool] = None,
|
||||
commit_message: Optional[str] = None,
|
||||
private: Optional[bool] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = "10GB",
|
||||
**model_card_kwargs,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
create_pr: bool = False,
|
||||
**base_model_card_args,
|
||||
) -> str:
|
||||
"""
|
||||
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
|
||||
@ -2931,8 +2932,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
|
||||
will then be each of size lower than this size. If expressed as a string, needs to be digits followed
|
||||
by a unit (like `"5MB"`).
|
||||
model_card_kwargs:
|
||||
Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method.
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -2948,15 +2949,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
model.push_to_hub("huggingface/my-finetuned-bert")
|
||||
```
|
||||
"""
|
||||
if "repo_path_or_name" in model_card_kwargs:
|
||||
if "repo_path_or_name" in base_model_card_args:
|
||||
warnings.warn(
|
||||
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
|
||||
"`repo_id` instead."
|
||||
)
|
||||
repo_id = model_card_kwargs.pop("repo_path_or_name")
|
||||
repo_id = base_model_card_args.pop("repo_path_or_name")
|
||||
# Deprecation warning will be sent after for repo_url and organization
|
||||
repo_url = model_card_kwargs.pop("repo_url", None)
|
||||
organization = model_card_kwargs.pop("organization", None)
|
||||
repo_url = base_model_card_args.pop("repo_url", None)
|
||||
organization = base_model_card_args.pop("organization", None)
|
||||
|
||||
if os.path.isdir(repo_id):
|
||||
working_dir = repo_id
|
||||
@ -2982,11 +2983,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
"output_dir": work_dir,
|
||||
"model_name": Path(repo_id).name,
|
||||
}
|
||||
base_model_card_args.update(model_card_kwargs)
|
||||
base_model_card_args.update(base_model_card_args)
|
||||
self.create_model_card(**base_model_card_args)
|
||||
|
||||
self._upload_modified_files(
|
||||
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token
|
||||
work_dir,
|
||||
repo_id,
|
||||
files_timestamps,
|
||||
commit_message=commit_message,
|
||||
token=use_auth_token,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -85,6 +85,7 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
BertConfig,
|
||||
PreTrainedModel,
|
||||
PushToHubCallback,
|
||||
RagRetriever,
|
||||
TFAutoModel,
|
||||
@ -92,6 +93,7 @@ if is_tf_available():
|
||||
TFBertForMaskedLM,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFPreTrainedModel,
|
||||
TFRagModel,
|
||||
TFSharedEmbeddings,
|
||||
)
|
||||
@ -2466,6 +2468,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_push_to_hub_callback(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
@ -2489,6 +2492,12 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
|
||||
tf_push_to_hub_params.pop("base_model_card_args")
|
||||
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
|
||||
pt_push_to_hub_params.pop("deprecated_kwargs")
|
||||
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
|
Loading…
Reference in New Issue
Block a user