mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Multiple fixes in SageMakerTrainer (#10687)
* Handle save differently * Missing imports * Fix typo * Adapt to recent changes in save_pretrained * Forgotten brackets * Optimizer load * Fix world size * Deal wth None * Remove needless self
This commit is contained in:
parent
3f1714f8a7
commit
6bef764506
@ -11,21 +11,27 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from ..file_utils import WEIGHTS_NAME, is_torch_tpu_available
|
||||
from ..modeling_utils import PreTrainedModel, unwrap_model
|
||||
from ..trainer import Trainer
|
||||
from ..trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
reissue_pt_warnings,
|
||||
)
|
||||
from ..trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from ..utils import logging
|
||||
from .training_args_sm import is_smdistributed_available
|
||||
|
||||
@ -83,7 +89,7 @@ class SageMakerTrainer(Trainer):
|
||||
if self.is_model_parallel_enabled:
|
||||
return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
|
||||
else:
|
||||
return super.is_world_process_zero()
|
||||
return super().is_world_process_zero()
|
||||
|
||||
def _get_train_sampler(self):
|
||||
if self.is_model_parallel_enabled:
|
||||
@ -126,12 +132,123 @@ class SageMakerTrainer(Trainer):
|
||||
return super().training_step(model, inputs)
|
||||
|
||||
def _gather_and_numpify(self, tensors, name):
|
||||
if tensors is None:
|
||||
return
|
||||
if self.is_model_parallel_enabled:
|
||||
tensors = smp_gather(tensors)
|
||||
return nested_numpify(tensors)
|
||||
else:
|
||||
return super()._gather_and_numpify(tensors, name)
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None):
|
||||
"""
|
||||
Will save the model, so you can reload it using :obj:`from_pretrained()`.
|
||||
|
||||
Will only save from the world_master process (unless in TPUs).
|
||||
"""
|
||||
if self.is_model_parallel_enabled:
|
||||
self._save_smp(output_dir)
|
||||
elif is_torch_tpu_available():
|
||||
self._save_tpu(output_dir)
|
||||
elif self.is_world_process_zero():
|
||||
self._save(output_dir)
|
||||
|
||||
# If on sagemaker and we are saving the main model (not a checkpoint so output_dir=None), save a copy to
|
||||
# SM_MODEL_DIR for easy deployment.
|
||||
if output_dir is None and os.getenv("SM_MODEL_DIR") is not None:
|
||||
self.save_model(output_dir=os.getenv("SM_MODEL_DIR"))
|
||||
|
||||
def _save_smp(self, output_dir: Optional[str] = None):
|
||||
if smp.dp_rank() != 0:
|
||||
return
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
# Calling the state_dict needs to be done on the wrapped model
|
||||
state_dict = self.model_wrapped.state_dict()
|
||||
|
||||
# Rest of the save is done for the main process only
|
||||
if self.is_world_process_zero():
|
||||
model = self.model
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
model = unwrap_model(model)
|
||||
if isinstance(model, PreTrainedModel):
|
||||
model.save_pretrained(output_dir, state_dict=state_dict)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
if self.is_model_parallel_enabled:
|
||||
if smp.dp_rank() != 0:
|
||||
return
|
||||
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
run_dir = self.args.output_dir
|
||||
self.store_flos()
|
||||
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
self.save_model(output_dir)
|
||||
# Consolidate the state dict on all processed of dp_rank 0
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
# Save it and the scheduler on the main process
|
||||
if self.is_world_process_zero():
|
||||
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
metric_value = metrics[metric_to_check]
|
||||
|
||||
operator = np.greater if self.args.greater_is_better else np.less
|
||||
if (
|
||||
self.state.best_metric is None
|
||||
or self.state.best_model_checkpoint is None
|
||||
or operator(metric_value, self.state.best_metric)
|
||||
):
|
||||
self.state.best_metric = metric_value
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
|
||||
# Save the Trainer state
|
||||
if self.is_world_process_zero():
|
||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.is_world_process_zero():
|
||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||
else:
|
||||
super()._save_checkpoint(self, model, trial, metrics=metrics)
|
||||
|
||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||
"""If optimizer and scheduler states exist, load them."""
|
||||
if self.is_model_parallel_enabled:
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(checkpoint, "scheduler.pt")
|
||||
):
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
else:
|
||||
super()._load_optimizer_and_scheduler(checkpoint)
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
@ -84,6 +84,13 @@ class SageMakerTrainingArguments(TrainingArguments):
|
||||
|
||||
return device
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
if is_smdistributed_available() and self.mp_parameters != "":
|
||||
return smp.dp_size()
|
||||
|
||||
return super().world_size
|
||||
|
||||
@property
|
||||
def place_model_on_device(self):
|
||||
return not (is_smdistributed_available() and self.mp_parameters != "")
|
||||
|
@ -915,9 +915,6 @@ class Trainer:
|
||||
self.state = TrainerState()
|
||||
self.state.is_hyper_param_search = trial is not None
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||
|
||||
model = self._wrap_model(self.model_wrapped)
|
||||
|
||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||
@ -927,6 +924,9 @@ class Trainer:
|
||||
if delay_optimizer_creation:
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||
|
||||
# important: at this point:
|
||||
# self.model is the Transformers Model
|
||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
||||
@ -1782,12 +1782,7 @@ class Trainer:
|
||||
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
|
||||
world_size = 1
|
||||
if is_torch_tpu_available():
|
||||
world_size = xm.xrt_world_size()
|
||||
elif self.args.local_rank != -1:
|
||||
world_size = dist.get_world_size()
|
||||
world_size = max(1, world_size)
|
||||
world_size = max(1, self.args.world_size)
|
||||
|
||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
||||
if not prediction_loss_only:
|
||||
|
@ -36,6 +36,9 @@ if is_torch_available():
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as sm_dist
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -631,10 +634,8 @@ class TrainingArguments:
|
||||
device = xm.xla_device()
|
||||
self._n_gpu = 0
|
||||
elif is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
|
||||
dist.init_process_group()
|
||||
self.local_rank = dist.get_local_rank()
|
||||
sm_dist.init_process_group()
|
||||
self.local_rank = sm_dist.get_local_rank()
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
elif self.deepspeed:
|
||||
@ -725,6 +726,20 @@ class TrainingArguments:
|
||||
else:
|
||||
return ParallelMode.NOT_PARALLEL
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def world_size(self):
|
||||
"""
|
||||
The number of processes used in parallel.
|
||||
"""
|
||||
if is_torch_tpu_available():
|
||||
return xm.xrt_world_size()
|
||||
elif is_sagemaker_distributed_available():
|
||||
return sm_dist.get_world_size()
|
||||
elif self.local_rank != -1:
|
||||
return torch.distributed.get_world_size()
|
||||
return 1
|
||||
|
||||
@property
|
||||
def place_model_on_device(self):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user