mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Add option to save on each training node (#12421)
* Add option to save on each training node * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
parent
990540b72d
commit
31a8110918
@ -393,7 +393,7 @@ class Trainer:
|
||||
# Create clone of distant repo and output directory if needed
|
||||
if self.args.push_to_hub:
|
||||
self.init_git_repo()
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
|
||||
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
||||
@ -899,7 +899,7 @@ class Trainer:
|
||||
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
||||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
self.save_model(output_dir)
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
@ -1357,10 +1357,18 @@ class Trainer:
|
||||
logger.info(
|
||||
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
|
||||
)
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||
# If the model is on the GPU, it still works!
|
||||
self._load_state_dict_in_model(state_dict)
|
||||
|
||||
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
|
||||
if os.path.exists(best_model_path):
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(best_model_path, map_location="cpu")
|
||||
# If the model is on the GPU, it still works!
|
||||
self._load_state_dict_in_model(state_dict)
|
||||
else:
|
||||
logger.warn(
|
||||
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
|
||||
"on multiple nodes, you should activate `--save_on_each_node`."
|
||||
)
|
||||
|
||||
if self.deepspeed:
|
||||
self.deepspeed.load_checkpoint(
|
||||
@ -1500,14 +1508,14 @@ class Trainer:
|
||||
# Consolidate the state dict on all processed of dp_rank 0
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
# Save it and the scheduler on the main process
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.use_amp:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
|
||||
elif self.is_world_process_zero() and not self.deepspeed:
|
||||
elif self.args.should_save and not self.deepspeed:
|
||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
@ -1533,7 +1541,7 @@ class Trainer:
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
|
||||
# Save the Trainer state
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||
|
||||
# Save RNG state in non-distributed training
|
||||
@ -1562,7 +1570,7 @@ class Trainer:
|
||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||
|
||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||
@ -1831,19 +1839,19 @@ class Trainer:
|
||||
elif is_sagemaker_mp_enabled():
|
||||
# Calling the state_dict needs to be done on the wrapped model and on all processes.
|
||||
state_dict = self.model_wrapped.state_dict()
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
elif (
|
||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||
):
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
elif self.deepspeed:
|
||||
|
||||
# this takes care of everything as long as we aren't under zero3
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
@ -1851,7 +1859,7 @@ class Trainer:
|
||||
# saved, so since under zero3 the file is bogus, simply delete it. The user should
|
||||
# either user deepspeed checkpoint to resume or to recover full weights use
|
||||
# zero_to_fp32.py stored in the checkpoint.
|
||||
if self.is_world_process_zero():
|
||||
if self.args.should_save:
|
||||
file = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
if os.path.isfile(file):
|
||||
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
|
||||
@ -1862,7 +1870,7 @@ class Trainer:
|
||||
# This must be called on all ranks
|
||||
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
|
||||
|
||||
elif self.is_world_process_zero():
|
||||
elif self.args.should_save:
|
||||
self._save(output_dir)
|
||||
|
||||
def _save_tpu(self, output_dir: Optional[str] = None):
|
||||
@ -1880,7 +1888,7 @@ class Trainer:
|
||||
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
||||
unwrap_model(self.model).save_pretrained(
|
||||
output_dir,
|
||||
save_config=self.is_world_process_zero(),
|
||||
save_config=self.args.should_save,
|
||||
state_dict=self.model.state_dict(),
|
||||
save_function=xm.save,
|
||||
)
|
||||
@ -1889,8 +1897,8 @@ class Trainer:
|
||||
state_dict = self.model.state_dict()
|
||||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
|
||||
if self.tokenizer is not None and self.is_world_process_zero():
|
||||
self.model.save_pretrained(output_dir, save_config=self.args.should_save, save_function=xm.save)
|
||||
if self.tokenizer is not None and self.args.should_save:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
@ -1960,7 +1968,7 @@ class Trainer:
|
||||
if len(checkpoints_sorted) <= self.args.save_total_limit:
|
||||
return
|
||||
|
||||
# If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
|
||||
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
|
||||
# we don't do to allow resuming.
|
||||
save_total_limit = self.args.save_total_limit
|
||||
if (
|
||||
@ -2436,7 +2444,7 @@ class Trainer:
|
||||
"""
|
||||
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
if not self.args.should_save:
|
||||
return
|
||||
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
|
||||
repo_url = PushToHubMixin._get_repo_url_from_name(
|
||||
@ -2494,11 +2502,16 @@ class Trainer:
|
||||
Returns:
|
||||
The url of the commit of your model in the given repository.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
if not self.args.should_save:
|
||||
return
|
||||
|
||||
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
|
||||
self.save_model()
|
||||
|
||||
# Only push from one node.
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
return self.repo.push_to_hub(commit_message=commit_message)
|
||||
|
||||
#
|
||||
|
@ -183,6 +183,12 @@ class TrainingArguments:
|
||||
save_total_limit (:obj:`int`, `optional`):
|
||||
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
|
||||
:obj:`output_dir`.
|
||||
save_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
|
||||
the main one.
|
||||
|
||||
This should not be activated when the different nodes use the same storage as the files will be saved with
|
||||
the same names for each node.
|
||||
no_cuda (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to not use CUDA even when it is available or not.
|
||||
seed (:obj:`int`, `optional`, defaults to 42):
|
||||
@ -456,6 +462,12 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
save_on_each_node: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one"
|
||||
},
|
||||
)
|
||||
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
|
||||
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||
|
||||
@ -937,6 +949,19 @@ class TrainingArguments:
|
||||
else:
|
||||
return self.process_index == 0
|
||||
|
||||
@property
|
||||
def should_save(self):
|
||||
"""
|
||||
Whether or not the current process should write to disk, e.g., to save models and checkpoints.
|
||||
"""
|
||||
if self.save_on_each_node:
|
||||
return self.local_process_index == 0
|
||||
else:
|
||||
if is_sagemaker_mp_enabled():
|
||||
return smp.rank() == 0
|
||||
else:
|
||||
return self.process_index == 0
|
||||
|
||||
def get_process_log_level(self):
|
||||
"""
|
||||
Returns the log level to be used depending on whether this process is the main process of node 0, main process
|
||||
|
Loading…
Reference in New Issue
Block a user