fix: Race Condition when using Sagemaker Checkpointing and Model Repository (#21614)

* Add _add_sm_patterns_to_gitignore

* Add _is_world_process_zero() call and move patterns arg to constant

* Update git status time.sleep

* Apply make style
This commit is contained in:
Douglas Trajano 2023-02-14 18:11:37 -03:00 committed by GitHub
parent 7bce804260
commit 26ef0f1991
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3395,6 +3395,10 @@ class Trainer:
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"])
# Add "*.sagemaker" to .gitignore if using SageMaker
if os.environ.get("SM_TRAINING_ENV"):
self._add_sm_patterns_to_gitignore()
self.push_in_progress = None
def create_model_card(
@ -3716,3 +3720,42 @@ class Trainer:
tensors = distributed_concat(tensors)
return nested_numpify(tensors)
def _add_sm_patterns_to_gitignore(self) -> None:
"""Add SageMaker Checkpointing patterns to .gitignore file."""
# Make sure we only do this on the main process
if not self.is_world_process_zero():
return
patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
# Get current .gitignore content
if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
current_content = f.read()
else:
current_content = ""
# Add the patterns to .gitignore
content = current_content
for pattern in patterns:
if pattern not in content:
if content.endswith("\n"):
content += pattern
else:
content += f"\n{pattern}"
# Write the .gitignore file if it has changed
if content != current_content:
with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
logger.debug(f"Writing .gitignore file. Content: {content}")
f.write(content)
self.repo.git_add(".gitignore")
# avoid race condition with git status
time.sleep(0.5)
if not self.repo.is_repo_clean():
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
self.repo.git_push()