Add option to save on each training node (#12421)

* Add option to save on each training node

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Address review comments

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2021-06-30 02:41:47 -04:00 committed by GitHub
parent 990540b72d
commit 31a8110918
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 21 deletions

View File

@ -393,7 +393,7 @@ class Trainer:
# Create clone of distant repo and output directory if needed
if self.args.push_to_hub:
self.init_git_repo()
if self.is_world_process_zero():
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
@ -899,7 +899,7 @@ class Trainer:
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
if self.is_world_process_zero():
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
@ -1357,10 +1357,18 @@ class Trainer:
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
else:
logger.warn(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`."
)
if self.deepspeed:
self.deepspeed.load_checkpoint(
@ -1500,14 +1508,14 @@ class Trainer:
# Consolidate the state dict on all processed of dp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.is_world_process_zero():
if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
elif self.is_world_process_zero() and not self.deepspeed:
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
@ -1533,7 +1541,7 @@ class Trainer:
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.is_world_process_zero():
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
# Save RNG state in non-distributed training
@ -1562,7 +1570,7 @@ class Trainer:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def _load_optimizer_and_scheduler(self, checkpoint):
@ -1831,19 +1839,19 @@ class Trainer:
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
state_dict = self.model_wrapped.state_dict()
if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
state_dict = self.model.state_dict()
if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.deepspeed:
# this takes care of everything as long as we aren't under zero3
if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir)
if is_deepspeed_zero3_enabled():
@ -1851,7 +1859,7 @@ class Trainer:
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either user deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.is_world_process_zero():
if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME)
if os.path.isfile(file):
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
@ -1862,7 +1870,7 @@ class Trainer:
# This must be called on all ranks
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
elif self.is_world_process_zero():
elif self.args.should_save:
self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None):
@ -1880,7 +1888,7 @@ class Trainer:
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
save_config=self.is_world_process_zero(),
save_config=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
@ -1889,8 +1897,8 @@ class Trainer:
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
if self.tokenizer is not None and self.is_world_process_zero():
self.model.save_pretrained(output_dir, save_config=self.args.should_save, save_function=xm.save)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
@ -1960,7 +1968,7 @@ class Trainer:
if len(checkpoints_sorted) <= self.args.save_total_limit:
return
# If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
# we don't do to allow resuming.
save_total_limit = self.args.save_total_limit
if (
@ -2436,7 +2444,7 @@ class Trainer:
"""
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
"""
if not self.is_world_process_zero():
if not self.args.should_save:
return
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
repo_url = PushToHubMixin._get_repo_url_from_name(
@ -2494,11 +2502,16 @@ class Trainer:
Returns:
The url of the commit of your model in the given repository.
"""
if not self.is_world_process_zero():
if not self.args.should_save:
return
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
self.save_model()
# Only push from one node.
if not self.is_world_process_zero():
return
return self.repo.push_to_hub(commit_message=commit_message)
#

View File

@ -183,6 +183,12 @@ class TrainingArguments:
save_total_limit (:obj:`int`, `optional`):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
:obj:`output_dir`.
save_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`False`):
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
the main one.
This should not be activated when the different nodes use the same storage as the files will be saved with
the same names for each node.
no_cuda (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to not use CUDA even when it is available or not.
seed (:obj:`int`, `optional`, defaults to 42):
@ -456,6 +462,12 @@ class TrainingArguments:
)
},
)
save_on_each_node: bool = field(
default=False,
metadata={
"help": "When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one"
},
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
@ -937,6 +949,19 @@ class TrainingArguments:
else:
return self.process_index == 0
@property
def should_save(self):
"""
Whether or not the current process should write to disk, e.g., to save models and checkpoints.
"""
if self.save_on_each_node:
return self.local_process_index == 0
else:
if is_sagemaker_mp_enabled():
return smp.rank() == 0
else:
return self.process_index == 0
def get_process_log_level(self):
"""
Returns the log level to be used depending on whether this process is the main process of node 0, main process