mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
[TF] Fix creating a PR while pushing in TF framework (#21968)
* add create pr arg * style * add test * ficup * update test * last nit fix typo * add `is_pt_tf_cross_test` marker for the tsts
This commit is contained in:
parent
d128f2ffab
commit
2156662dea
@ -2905,9 +2905,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
use_temp_dir: Optional[bool] = None,
|
use_temp_dir: Optional[bool] = None,
|
||||||
commit_message: Optional[str] = None,
|
commit_message: Optional[str] = None,
|
||||||
private: Optional[bool] = None,
|
private: Optional[bool] = None,
|
||||||
use_auth_token: Optional[Union[bool, str]] = None,
|
|
||||||
max_shard_size: Optional[Union[int, str]] = "10GB",
|
max_shard_size: Optional[Union[int, str]] = "10GB",
|
||||||
**model_card_kwargs,
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
|
create_pr: bool = False,
|
||||||
|
**base_model_card_args,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
|
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
|
||||||
@ -2931,8 +2932,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
|
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
|
||||||
will then be each of size lower than this size. If expressed as a string, needs to be digits followed
|
will then be each of size lower than this size. If expressed as a string, needs to be digits followed
|
||||||
by a unit (like `"5MB"`).
|
by a unit (like `"5MB"`).
|
||||||
model_card_kwargs:
|
create_pr (`bool`, *optional*, defaults to `False`):
|
||||||
Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method.
|
Whether or not to create a PR with the uploaded files or directly commit.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@ -2948,15 +2949,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
model.push_to_hub("huggingface/my-finetuned-bert")
|
model.push_to_hub("huggingface/my-finetuned-bert")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if "repo_path_or_name" in model_card_kwargs:
|
if "repo_path_or_name" in base_model_card_args:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
|
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
|
||||||
"`repo_id` instead."
|
"`repo_id` instead."
|
||||||
)
|
)
|
||||||
repo_id = model_card_kwargs.pop("repo_path_or_name")
|
repo_id = base_model_card_args.pop("repo_path_or_name")
|
||||||
# Deprecation warning will be sent after for repo_url and organization
|
# Deprecation warning will be sent after for repo_url and organization
|
||||||
repo_url = model_card_kwargs.pop("repo_url", None)
|
repo_url = base_model_card_args.pop("repo_url", None)
|
||||||
organization = model_card_kwargs.pop("organization", None)
|
organization = base_model_card_args.pop("organization", None)
|
||||||
|
|
||||||
if os.path.isdir(repo_id):
|
if os.path.isdir(repo_id):
|
||||||
working_dir = repo_id
|
working_dir = repo_id
|
||||||
@ -2982,11 +2983,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
"output_dir": work_dir,
|
"output_dir": work_dir,
|
||||||
"model_name": Path(repo_id).name,
|
"model_name": Path(repo_id).name,
|
||||||
}
|
}
|
||||||
base_model_card_args.update(model_card_kwargs)
|
base_model_card_args.update(base_model_card_args)
|
||||||
self.create_model_card(**base_model_card_args)
|
self.create_model_card(**base_model_card_args)
|
||||||
|
|
||||||
self._upload_modified_files(
|
self._upload_modified_files(
|
||||||
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token
|
work_dir,
|
||||||
|
repo_id,
|
||||||
|
files_timestamps,
|
||||||
|
commit_message=commit_message,
|
||||||
|
token=use_auth_token,
|
||||||
|
create_pr=create_pr,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -85,6 +85,7 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
PreTrainedModel,
|
||||||
PushToHubCallback,
|
PushToHubCallback,
|
||||||
RagRetriever,
|
RagRetriever,
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
@ -92,6 +93,7 @@ if is_tf_available():
|
|||||||
TFBertForMaskedLM,
|
TFBertForMaskedLM,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
|
TFPreTrainedModel,
|
||||||
TFRagModel,
|
TFRagModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
)
|
)
|
||||||
@ -2466,6 +2468,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
|||||||
break
|
break
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
def test_push_to_hub_callback(self):
|
def test_push_to_hub_callback(self):
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
@ -2489,6 +2492,12 @@ class TFModelPushToHubTester(unittest.TestCase):
|
|||||||
break
|
break
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
|
||||||
|
tf_push_to_hub_params.pop("base_model_card_args")
|
||||||
|
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
|
||||||
|
pt_push_to_hub_params.pop("deprecated_kwargs")
|
||||||
|
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
|
||||||
|
|
||||||
def test_push_to_hub_in_organization(self):
|
def test_push_to_hub_in_organization(self):
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
Loading…
Reference in New Issue
Block a user