mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
fix: Race Condition when using Sagemaker Checkpointing and Model Repository (#21614)
* Add _add_sm_patterns_to_gitignore * Add _is_world_process_zero() call and move patterns arg to constant * Update git status time.sleep * Apply make style
This commit is contained in:
parent
7bce804260
commit
26ef0f1991
@ -3395,6 +3395,10 @@ class Trainer:
|
||||
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
|
||||
writer.writelines(["checkpoint-*/"])
|
||||
|
||||
# Add "*.sagemaker" to .gitignore if using SageMaker
|
||||
if os.environ.get("SM_TRAINING_ENV"):
|
||||
self._add_sm_patterns_to_gitignore()
|
||||
|
||||
self.push_in_progress = None
|
||||
|
||||
def create_model_card(
|
||||
@ -3716,3 +3720,42 @@ class Trainer:
|
||||
tensors = distributed_concat(tensors)
|
||||
|
||||
return nested_numpify(tensors)
|
||||
|
||||
def _add_sm_patterns_to_gitignore(self) -> None:
|
||||
"""Add SageMaker Checkpointing patterns to .gitignore file."""
|
||||
# Make sure we only do this on the main process
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
|
||||
|
||||
# Get current .gitignore content
|
||||
if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
|
||||
with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
|
||||
current_content = f.read()
|
||||
else:
|
||||
current_content = ""
|
||||
|
||||
# Add the patterns to .gitignore
|
||||
content = current_content
|
||||
for pattern in patterns:
|
||||
if pattern not in content:
|
||||
if content.endswith("\n"):
|
||||
content += pattern
|
||||
else:
|
||||
content += f"\n{pattern}"
|
||||
|
||||
# Write the .gitignore file if it has changed
|
||||
if content != current_content:
|
||||
with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
|
||||
logger.debug(f"Writing .gitignore file. Content: {content}")
|
||||
f.write(content)
|
||||
|
||||
self.repo.git_add(".gitignore")
|
||||
|
||||
# avoid race condition with git status
|
||||
time.sleep(0.5)
|
||||
|
||||
if not self.repo.is_repo_clean():
|
||||
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
|
||||
self.repo.git_push()
|
||||
|
Loading…
Reference in New Issue
Block a user