Add SageMakerTrainer for model paralellism (#10122)

* Refactor things out of main train

* Store signature

* Add SageMakerTrainer

* Init + Copyright

* Address review comments
This commit is contained in:
Sylvain Gugger 2021-02-11 18:44:18 -05:00 committed by GitHub
parent b54cb0bd82
commit 31245775e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 349 additions and 50 deletions

View File

@ -0,0 +1,20 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .trainer_sm import SageMakerTrainer
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_distributed_available

View File

@ -0,0 +1,178 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from ..trainer import Trainer
from ..trainer_pt_utils import (
DistributedLengthGroupedSampler,
SequentialDistributedSampler,
nested_detach,
nested_numpify,
)
from ..utils import logging
from .training_args_sm import is_smdistributed_available
logger = logging.get_logger(__name__)
if is_smdistributed_available():
import smdistributed.modelparallel.torch as smp
@smp.step()
def forward_backward(model, inputs):
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
model.backward(loss)
return loss
@smp.step()
def forward_only(model, inputs):
return model(**inputs)
def smp_gather(tensor):
if isinstance(tensor, (list, tuple)):
return type(tensor)(smp_gather(t) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
elif not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
return torch.cat([t.cpu() for t in all_tensors], dim=0)
def nested_smp_concat(tensor):
if isinstance(tensor, (list, tuple)):
return type(tensor)(nested_smp_concat(t) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: nested_smp_concat(v) for k, v in tensor.items()})
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
# which is also the name of the decorator so Python is confused.
return tensor.concat().detach().cpu()
class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs):
super().__init__(args=args, **kwargs)
self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != ""
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
def _get_train_sampler(self):
if self.is_model_parallel_enabled:
if self.args.group_by_length:
return DistributedLengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
)
else:
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
else:
return super()._get_train_sampler()
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if self.is_model_parallel_enabled:
return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
else:
return super()._get_eval_sampler(eval_dataset)
def _wrap_model(self, model, training=True):
if self.is_model_parallel_enabled:
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
return self.model_wrapped
return smp.DistributedModel(model)
else:
return super()._wrap_model(model)
def create_optimizer_and_scheduler(self, num_training_steps: int):
super().create_optimizer_and_scheduler(num_training_steps)
if self.is_model_parallel_enabled:
self.optimizer = smp.DistributedOptimizer(self.optimizer)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
if self.is_model_parallel_enabled:
model.train()
inputs = self._prepare_inputs(inputs)
loss_mb = forward_backward(model, inputs)
return loss_mb.reduce_mean().detach().to(self.args.device)
else:
return super().training_step(model, inputs)
def _gather_and_numpify(self, tensors, name):
if self.is_model_parallel_enabled:
tensors = smp_gather(tensors)
return nested_numpify(tensors)
else:
return super()._gather_and_numpify(tensors, name)
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.is_model_parallel_enabled:
has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad():
raw_outputs = forward_only(model, inputs)
if has_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
loss = loss_mb.reduce_mean().detach().cpu()
logits = nested_smp_concat(logits_mb)
else:
loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
else:
logits_mb = raw_outputs
logits = nested_smp_concat(logits_mb)
if prediction_loss_only:
return (loss, None, None)
if len(logits) == 1:
logits = logits[0]
if has_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
return (loss, logits, labels)
else:
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)

View File

@ -0,0 +1,89 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
from dataclasses import dataclass, field
import torch
from transformers.file_utils import cached_property, is_sagemaker_distributed_available
from transformers.training_args import TrainingArguments
from transformers.utils import logging
logger = logging.get_logger(__name__)
def is_smdistributed_available():
return importlib.util.find_spec("smdistributed") is not None
if is_smdistributed_available():
import smdistributed.modelparallel.torch as smp
@dataclass
class SageMakerTrainingArguments(TrainingArguments):
mp_parameters: str = field(
default="", metadata={"help": "Used by the SageMaker launcher to send mp-specific args."}
)
def __post_init__(self):
super().__post_init__()
if is_smdistributed_available() and self.mp_parameters != "":
smp.init()
@cached_property
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
elif is_smdistributed_available() and self.mp_parameters != "":
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
dist.init_process_group()
self.local_rank = dist.get_local_rank()
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
if device.type == "cuda":
torch.cuda.set_device(device)
return device
@property
def place_model_on_device(self):
return not (is_smdistributed_available() and self.mp_parameters != "")

View File

@ -272,7 +272,7 @@ class Trainer:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model # 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment # and we only use deepspeed for training at the moment
if not self.is_model_parallel and not (args.deepspeed and args.do_train): if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device:
model = model.to(args.device) model = model.to(args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
@ -319,6 +319,7 @@ class Trainer:
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__") raise ValueError("eval_dataset must implement __len__")
self._signature_columns = None
if is_datasets_available(): if is_datasets_available():
if isinstance(train_dataset, datasets.Dataset): if isinstance(train_dataset, datasets.Dataset):
self._remove_unused_columns(self.train_dataset, description="training") self._remove_unused_columns(self.train_dataset, description="training")
@ -425,16 +426,18 @@ class Trainer:
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
return return
# Inspect model forward signature to keep only the arguments it accepts. if self._signature_columns is None:
signature = inspect.signature(self.model.forward) # Inspect model forward signature to keep only the arguments it accepts.
signature_columns = list(signature.parameters.keys()) signature = inspect.signature(self.model.forward)
# Labels may be named label or label_ids, the default data collator handles that. self._signature_columns = list(signature.parameters.keys())
signature_columns += ["label", "label_ids"] # Labels may be named label or label_ids, the default data collator handles that.
columns = [k for k in signature_columns if k in dataset.column_names] self._signature_columns += ["label", "label_ids"]
ignored_columns = list(set(dataset.column_names) - set(signature_columns)) columns = [k for k in self._signature_columns if k in dataset.column_names]
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
dset_description = "" if description is None else f"in the {description} set " dset_description = "" if description is None else f"in the {description} set "
logger.info( logger.info(
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." f"The following columns {dset_description}don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
) )
dataset.set_format(type=dataset.format["type"], columns=columns) dataset.set_format(type=dataset.format["type"], columns=columns)
@ -684,6 +687,45 @@ class Trainer:
return model return model
def _wrap_model(self, model, training=True):
# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
return model
# Distributed training (should be after apex fp16 initialization)
if self.sharded_dpp:
model = ShardedDDP(model, self.optimizer)
elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.deepspeed:
pass # already initialized its own DDP earlier
elif self.args.local_rank != -1:
if self.args.ddp_find_unused_parameters is not None:
find_unused_parameters = self.args.ddp_find_unused_parameters
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
else:
find_unused_parameters = True
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=find_unused_parameters,
)
return model
def train( def train(
self, self,
resume_from_checkpoint: Optional[str] = None, resume_from_checkpoint: Optional[str] = None,
@ -736,7 +778,7 @@ class Trainer:
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded: if model_reloaded:
if not self.is_model_parallel: if not self.is_model_parallel and self.args.place_model_on_device:
self.model = self.model.to(self.args.device) self.model = self.model.to(self.args.device)
self.model_wrapped = self.model self.model_wrapped = self.model
@ -783,38 +825,7 @@ class Trainer:
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint) self._load_optimizer_and_scheduler(resume_from_checkpoint)
model = self.model_wrapped model = self._wrap_model(self.model_wrapped)
# Mixed precision training with apex (torch < 1.6)
if self.use_apex:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if self.sharded_dpp:
model = ShardedDDP(model, self.optimizer)
elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.deepspeed:
pass # already initialized its own DDP earlier
elif self.args.local_rank != -1:
if self.args.ddp_find_unused_parameters is not None:
find_unused_parameters = self.args.ddp_find_unused_parameters
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
else:
find_unused_parameters = True
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=find_unused_parameters,
)
# for the rest of this function `model` is the outside model, whether it was wrapped or not # for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model: if model is not self.model:
@ -1020,7 +1031,7 @@ class Trainer:
) )
if isinstance(self.model, PreTrainedModel): if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint) self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if not self.is_model_parallel: if not self.is_model_parallel and self.args.place_model_on_device:
self.model = self.model.to(self.args.device) self.model = self.model.to(self.args.device)
else: else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
@ -1610,13 +1621,7 @@ class Trainer:
# flagging only for when --do_train wasn't passed as only then it's redundant # flagging only for when --do_train wasn't passed as only then it's redundant
logger.info("Detected the deepspeed argument but it will not be used for evaluation") logger.info("Detected the deepspeed argument but it will not be used for evaluation")
model = self.model model = self._wrap_model(self.model, training=False)
# multi-gpu eval
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
batch_size = dataloader.batch_size batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader) num_examples = self.num_examples(dataloader)

View File

@ -637,6 +637,13 @@ class TrainingArguments:
else: else:
return ParallelMode.NOT_PARALLEL return ParallelMode.NOT_PARALLEL
@property
def place_model_on_device(self):
"""
Can be subclassed and overridden for some specific integrations.
"""
return True
def to_dict(self): def to_dict(self):
""" """
Serializes this instance while replace `Enum` by their values (for JSON serialization support). Serializes this instance while replace `Enum` by their values (for JSON serialization support).