Multiple fixes in SageMakerTrainer (#10687)

* Handle save differently

* Missing imports

* Fix typo

* Adapt to recent changes in save_pretrained

* Forgotten brackets

* Optimizer load

* Fix world size

* Deal wth None

* Remove needless self
This commit is contained in:
Sylvain Gugger 2021-03-15 09:28:15 -04:00 committed by GitHub
parent 3f1714f8a7
commit 6bef764506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 15 deletions

View File

@ -11,21 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from ..file_utils import WEIGHTS_NAME, is_torch_tpu_available
from ..modeling_utils import PreTrainedModel, unwrap_model
from ..trainer import Trainer
from ..trainer_pt_utils import (
DistributedLengthGroupedSampler,
SequentialDistributedSampler,
nested_detach,
nested_numpify,
reissue_pt_warnings,
)
from ..trainer_utils import PREFIX_CHECKPOINT_DIR
from ..utils import logging
from .training_args_sm import is_smdistributed_available
@ -83,7 +89,7 @@ class SageMakerTrainer(Trainer):
if self.is_model_parallel_enabled:
return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
else:
return super.is_world_process_zero()
return super().is_world_process_zero()
def _get_train_sampler(self):
if self.is_model_parallel_enabled:
@ -126,12 +132,123 @@ class SageMakerTrainer(Trainer):
return super().training_step(model, inputs)
def _gather_and_numpify(self, tensors, name):
if tensors is None:
return
if self.is_model_parallel_enabled:
tensors = smp_gather(tensors)
return nested_numpify(tensors)
else:
return super()._gather_and_numpify(tensors, name)
def save_model(self, output_dir: Optional[str] = None):
"""
Will save the model, so you can reload it using :obj:`from_pretrained()`.
Will only save from the world_master process (unless in TPUs).
"""
if self.is_model_parallel_enabled:
self._save_smp(output_dir)
elif is_torch_tpu_available():
self._save_tpu(output_dir)
elif self.is_world_process_zero():
self._save(output_dir)
# If on sagemaker and we are saving the main model (not a checkpoint so output_dir=None), save a copy to
# SM_MODEL_DIR for easy deployment.
if output_dir is None and os.getenv("SM_MODEL_DIR") is not None:
self.save_model(output_dir=os.getenv("SM_MODEL_DIR"))
def _save_smp(self, output_dir: Optional[str] = None):
if smp.dp_rank() != 0:
return
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
# Calling the state_dict needs to be done on the wrapped model
state_dict = self.model_wrapped.state_dict()
# Rest of the save is done for the main process only
if self.is_world_process_zero():
model = self.model
if not isinstance(model, PreTrainedModel):
model = unwrap_model(model)
if isinstance(model, PreTrainedModel):
model.save_pretrained(output_dir, state_dict=state_dict)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def _save_checkpoint(self, model, trial, metrics=None):
if self.is_model_parallel_enabled:
if smp.dp_rank() != 0:
return
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self.args.output_dir
self.store_flos()
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir)
# Consolidate the state dict on all processed of dp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.is_world_process_zero():
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
else:
super()._save_checkpoint(self, model, trial, metrics=metrics)
def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
if self.is_model_parallel_enabled:
if checkpoint is None:
return
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
os.path.join(checkpoint, "scheduler.pt")
):
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
else:
super()._load_optimizer_and_scheduler(checkpoint)
def prediction_step(
self,
model: nn.Module,

View File

@ -84,6 +84,13 @@ class SageMakerTrainingArguments(TrainingArguments):
return device
@property
def world_size(self):
if is_smdistributed_available() and self.mp_parameters != "":
return smp.dp_size()
return super().world_size
@property
def place_model_on_device(self):
return not (is_smdistributed_available() and self.mp_parameters != "")

View File

@ -915,9 +915,6 @@ class Trainer:
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
@ -927,6 +924,9 @@ class Trainer:
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
@ -1782,12 +1782,7 @@ class Trainer:
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = 1
if is_torch_tpu_available():
world_size = xm.xrt_world_size()
elif self.args.local_rank != -1:
world_size = dist.get_world_size()
world_size = max(1, world_size)
world_size = max(1, self.args.world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:

View File

@ -36,6 +36,9 @@ if is_torch_available():
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as sm_dist
logger = logging.get_logger(__name__)
@ -631,10 +634,8 @@ class TrainingArguments:
device = xm.xla_device()
self._n_gpu = 0
elif is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
dist.init_process_group()
self.local_rank = dist.get_local_rank()
sm_dist.init_process_group()
self.local_rank = sm_dist.get_local_rank()
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.deepspeed:
@ -725,6 +726,20 @@ class TrainingArguments:
else:
return ParallelMode.NOT_PARALLEL
@property
@torch_required
def world_size(self):
"""
The number of processes used in parallel.
"""
if is_torch_tpu_available():
return xm.xrt_world_size()
elif is_sagemaker_distributed_available():
return sm_dist.get_world_size()
elif self.local_rank != -1:
return torch.distributed.get_world_size()
return 1
@property
def place_model_on_device(self):
"""