mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add SageMakerTrainer for model paralellism (#10122)
* Refactor things out of main train * Store signature * Add SageMakerTrainer * Init + Copyright * Address review comments
This commit is contained in:
parent
b54cb0bd82
commit
31245775e5
20
src/transformers/sagemaker/__init__.py
Normal file
20
src/transformers/sagemaker/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .trainer_sm import SageMakerTrainer
|
||||
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_distributed_available
|
178
src/transformers/sagemaker/trainer_sm.py
Normal file
178
src/transformers/sagemaker/trainer_sm.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from ..trainer import Trainer
|
||||
from ..trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
)
|
||||
from ..utils import logging
|
||||
from .training_args_sm import is_smdistributed_available
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_smdistributed_available():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
@smp.step()
|
||||
def forward_backward(model, inputs):
|
||||
outputs = model(**inputs)
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
model.backward(loss)
|
||||
return loss
|
||||
|
||||
@smp.step()
|
||||
def forward_only(model, inputs):
|
||||
return model(**inputs)
|
||||
|
||||
def smp_gather(tensor):
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
return type(tensor)(smp_gather(t) for t in tensor)
|
||||
elif isinstance(tensor, dict):
|
||||
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
|
||||
elif not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
||||
)
|
||||
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
||||
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
||||
|
||||
def nested_smp_concat(tensor):
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
return type(tensor)(nested_smp_concat(t) for t in tensor)
|
||||
elif isinstance(tensor, dict):
|
||||
return type(tensor)({k: nested_smp_concat(v) for k, v in tensor.items()})
|
||||
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
|
||||
# which is also the name of the decorator so Python is confused.
|
||||
return tensor.concat().detach().cpu()
|
||||
|
||||
|
||||
class SageMakerTrainer(Trainer):
|
||||
def __init__(self, args=None, **kwargs):
|
||||
super().__init__(args=args, **kwargs)
|
||||
self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != ""
|
||||
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
|
||||
|
||||
def _get_train_sampler(self):
|
||||
if self.is_model_parallel_enabled:
|
||||
if self.args.group_by_length:
|
||||
return DistributedLengthGroupedSampler(
|
||||
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
|
||||
)
|
||||
else:
|
||||
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
|
||||
else:
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if self.is_model_parallel_enabled:
|
||||
return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
|
||||
else:
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def _wrap_model(self, model, training=True):
|
||||
if self.is_model_parallel_enabled:
|
||||
# Wrapping the base model twice in a DistributedModel will raise an error.
|
||||
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
||||
return self.model_wrapped
|
||||
return smp.DistributedModel(model)
|
||||
else:
|
||||
return super()._wrap_model(model)
|
||||
|
||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||
super().create_optimizer_and_scheduler(num_training_steps)
|
||||
if self.is_model_parallel_enabled:
|
||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
if self.is_model_parallel_enabled:
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
loss_mb = forward_backward(model, inputs)
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
else:
|
||||
return super().training_step(model, inputs)
|
||||
|
||||
def _gather_and_numpify(self, tensors, name):
|
||||
if self.is_model_parallel_enabled:
|
||||
tensors = smp_gather(tensors)
|
||||
return nested_numpify(tensors)
|
||||
else:
|
||||
return super()._gather_and_numpify(tensors, name)
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.is_model_parallel_enabled:
|
||||
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if ignore_keys is None:
|
||||
if hasattr(self.model, "config"):
|
||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
with torch.no_grad():
|
||||
raw_outputs = forward_only(model, inputs)
|
||||
if has_labels:
|
||||
if isinstance(raw_outputs, dict):
|
||||
loss_mb = raw_outputs["loss"]
|
||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
loss_mb = raw_outputs[0]
|
||||
logits_mb = raw_outputs[1:]
|
||||
|
||||
loss = loss_mb.reduce_mean().detach().cpu()
|
||||
logits = nested_smp_concat(logits_mb)
|
||||
else:
|
||||
loss = None
|
||||
if isinstance(raw_outputs, dict):
|
||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits_mb = raw_outputs
|
||||
logits = nested_smp_concat(logits_mb)
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
if len(logits) == 1:
|
||||
logits = logits[0]
|
||||
|
||||
if has_labels:
|
||||
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||
if len(labels) == 1:
|
||||
labels = labels[0]
|
||||
else:
|
||||
labels = None
|
||||
|
||||
return (loss, logits, labels)
|
||||
else:
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
89
src/transformers/sagemaker/training_args_sm.py
Normal file
89
src/transformers/sagemaker/training_args_sm.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.util
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.file_utils import cached_property, is_sagemaker_distributed_available
|
||||
from transformers.training_args import TrainingArguments
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def is_smdistributed_available():
|
||||
return importlib.util.find_spec("smdistributed") is not None
|
||||
|
||||
|
||||
if is_smdistributed_available():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerTrainingArguments(TrainingArguments):
|
||||
mp_parameters: str = field(
|
||||
default="", metadata={"help": "Used by the SageMaker launcher to send mp-specific args."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if is_smdistributed_available() and self.mp_parameters != "":
|
||||
smp.init()
|
||||
|
||||
@cached_property
|
||||
def _setup_devices(self) -> "torch.device":
|
||||
logger.info("PyTorch: setting up devices")
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
self._n_gpu = 0
|
||||
elif is_smdistributed_available() and self.mp_parameters != "":
|
||||
local_rank = smp.local_rank()
|
||||
device = torch.device("cuda", local_rank)
|
||||
self._n_gpu = 1
|
||||
elif is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
|
||||
dist.init_process_group()
|
||||
self.local_rank = dist.get_local_rank()
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
elif self.local_rank == -1:
|
||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
|
||||
# trigger an error that a device index is missing. Index 0 takes into account the
|
||||
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
|
||||
# will use the first GPU in that env, i.e. GPU#1
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
||||
# the default value.
|
||||
self._n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
# Here, we'll use torch.distributed.
|
||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
return device
|
||||
|
||||
@property
|
||||
def place_model_on_device(self):
|
||||
return not (is_smdistributed_available() and self.mp_parameters != "")
|
@ -272,7 +272,7 @@ class Trainer:
|
||||
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||
# and we only use deepspeed for training at the moment
|
||||
if not self.is_model_parallel and not (args.deepspeed and args.do_train):
|
||||
if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device:
|
||||
model = model.to(args.device)
|
||||
|
||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||
@ -319,6 +319,7 @@ class Trainer:
|
||||
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
|
||||
self._signature_columns = None
|
||||
if is_datasets_available():
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(self.train_dataset, description="training")
|
||||
@ -425,16 +426,18 @@ class Trainer:
|
||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||
if not self.args.remove_unused_columns:
|
||||
return
|
||||
# Inspect model forward signature to keep only the arguments it accepts.
|
||||
signature = inspect.signature(self.model.forward)
|
||||
signature_columns = list(signature.parameters.keys())
|
||||
# Labels may be named label or label_ids, the default data collator handles that.
|
||||
signature_columns += ["label", "label_ids"]
|
||||
columns = [k for k in signature_columns if k in dataset.column_names]
|
||||
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
|
||||
if self._signature_columns is None:
|
||||
# Inspect model forward signature to keep only the arguments it accepts.
|
||||
signature = inspect.signature(self.model.forward)
|
||||
self._signature_columns = list(signature.parameters.keys())
|
||||
# Labels may be named label or label_ids, the default data collator handles that.
|
||||
self._signature_columns += ["label", "label_ids"]
|
||||
columns = [k for k in self._signature_columns if k in dataset.column_names]
|
||||
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
|
||||
dset_description = "" if description is None else f"in the {description} set "
|
||||
logger.info(
|
||||
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
||||
f"The following columns {dset_description}don't have a corresponding argument in "
|
||||
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
||||
)
|
||||
dataset.set_format(type=dataset.format["type"], columns=columns)
|
||||
|
||||
@ -684,6 +687,45 @@ class Trainer:
|
||||
|
||||
return model
|
||||
|
||||
def _wrap_model(self, model, training=True):
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
if self.use_apex and training:
|
||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||
|
||||
# Multi-gpu training (should be after apex fp16 initialization)
|
||||
if self.args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||
if not training:
|
||||
return model
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if self.sharded_dpp:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
elif is_sagemaker_distributed_available():
|
||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||
elif self.deepspeed:
|
||||
pass # already initialized its own DDP earlier
|
||||
elif self.args.local_rank != -1:
|
||||
if self.args.ddp_find_unused_parameters is not None:
|
||||
find_unused_parameters = self.args.ddp_find_unused_parameters
|
||||
elif isinstance(model, PreTrainedModel):
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
|
||||
else:
|
||||
find_unused_parameters = True
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.args.local_rank],
|
||||
output_device=self.args.local_rank,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
@ -736,7 +778,7 @@ class Trainer:
|
||||
|
||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
if model_reloaded:
|
||||
if not self.is_model_parallel:
|
||||
if not self.is_model_parallel and self.args.place_model_on_device:
|
||||
self.model = self.model.to(self.args.device)
|
||||
self.model_wrapped = self.model
|
||||
|
||||
@ -783,38 +825,7 @@ class Trainer:
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||
|
||||
model = self.model_wrapped
|
||||
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
if self.use_apex:
|
||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||
|
||||
# Multi-gpu training (should be after apex fp16 initialization)
|
||||
if self.args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if self.sharded_dpp:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
elif is_sagemaker_distributed_available():
|
||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||
elif self.deepspeed:
|
||||
pass # already initialized its own DDP earlier
|
||||
elif self.args.local_rank != -1:
|
||||
if self.args.ddp_find_unused_parameters is not None:
|
||||
find_unused_parameters = self.args.ddp_find_unused_parameters
|
||||
elif isinstance(model, PreTrainedModel):
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
|
||||
else:
|
||||
find_unused_parameters = True
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.args.local_rank],
|
||||
output_device=self.args.local_rank,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
)
|
||||
model = self._wrap_model(self.model_wrapped)
|
||||
|
||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||
if model is not self.model:
|
||||
@ -1020,7 +1031,7 @@ class Trainer:
|
||||
)
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
||||
if not self.is_model_parallel:
|
||||
if not self.is_model_parallel and self.args.place_model_on_device:
|
||||
self.model = self.model.to(self.args.device)
|
||||
else:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
@ -1610,13 +1621,7 @@ class Trainer:
|
||||
# flagging only for when --do_train wasn't passed as only then it's redundant
|
||||
logger.info("Detected the deepspeed argument but it will not be used for evaluation")
|
||||
|
||||
model = self.model
|
||||
|
||||
# multi-gpu eval
|
||||
if self.args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||
model = self._wrap_model(self.model, training=False)
|
||||
|
||||
batch_size = dataloader.batch_size
|
||||
num_examples = self.num_examples(dataloader)
|
||||
|
@ -637,6 +637,13 @@ class TrainingArguments:
|
||||
else:
|
||||
return ParallelMode.NOT_PARALLEL
|
||||
|
||||
@property
|
||||
def place_model_on_device(self):
|
||||
"""
|
||||
Can be subclassed and overridden for some specific integrations.
|
||||
"""
|
||||
return True
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||
|
Loading…
Reference in New Issue
Block a user