Add-support for commit description (#26704)

* fix

* update

* revert

* add dosctring

* good to go

* update

* add a test
This commit is contained in:
Arthur 2023-10-26 12:37:09 +02:00 committed by GitHub
parent 15cd096288
commit 4864d08d3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 0 deletions

View File

@ -724,6 +724,7 @@ class PushToHubMixin:
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
revision: str = None,
commit_description: str = None,
):
"""
Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
@ -778,6 +779,7 @@ class PushToHubMixin:
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
commit_description=commit_description,
token=token,
create_pr=create_pr,
revision=revision,
@ -794,6 +796,7 @@ class PushToHubMixin:
create_pr: bool = False,
safe_serialization: bool = False,
revision: str = None,
commit_description: str = None,
**deprecated_kwargs,
) -> str:
"""
@ -825,6 +828,8 @@ class PushToHubMixin:
Whether or not to convert the model weights in safetensors format for safer serialization.
revision (`str`, *optional*):
Branch to push the uploaded files to.
commit_description (`str`, *optional*):
The description of the commit that will be created
Examples:
@ -901,6 +906,7 @@ class PushToHubMixin:
token=token,
create_pr=create_pr,
revision=revision,
commit_description=commit_description,
)

View File

@ -1119,6 +1119,23 @@ class ModelPushToHubTester(unittest.TestCase):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_push_to_hub_with_description(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
COMMIT_DESCRIPTION = """
The commit description supports markdown synthax see:
```python
>>> form transformers import AutoConfig
>>> config = AutoConfig.from_pretrained("bert-base-uncased")
```
"""
commit_details = model.push_to_hub(
"test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION
)
self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION)
@unittest.skip("This test is flaky")
def test_push_to_hub_in_organization(self):
config = BertConfig(