Trainer.push_to_hub always tries to push to the Hub (#15463)

This commit is contained in:
Sylvain Gugger 2022-02-01 15:49:04 -05:00 committed by GitHub
parent 37800f1365
commit 8e5d4e4906
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -403,7 +403,7 @@ class Trainer:
# Create clone of distant repo and output directory if needed
if self.args.push_to_hub:
self.init_git_repo()
self.init_git_repo(at_init=True)
# In case of pull, we need to make sure every process has the latest.
if is_torch_tpu_available():
xm.rendezvous("init git repo")
@ -2657,9 +2657,15 @@ class Trainer:
else:
return 0
def init_git_repo(self):
def init_git_repo(self, at_init: bool = False):
"""
Initializes a git repo in `self.args.hub_model_id`.
Args:
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
out.
"""
if not self.is_world_process_zero():
return
@ -2678,7 +2684,7 @@ class Trainer:
use_auth_token=use_auth_token,
)
except EnvironmentError:
if self.args.overwrite_output_dir:
if self.args.overwrite_output_dir and at_init:
# Try again after wiping output_dir
shutil.rmtree(self.args.output_dir)
self.repo = Repository(
@ -2790,6 +2796,10 @@ class Trainer:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
the commit and an object to track the progress of the commit if `blocking=True`
"""
# If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
# it might fail.
if not hasattr(self, "repo"):
self.init_git_repo()
if self.args.should_save:
if self.args.hub_model_id is None: