Merge trainers (#10975)

* Replace is_sagemaker_distributed_available

* Merge SageMakerTrainer into Trainer

* Test with shorter condition

* Put back deleted line

* Deprecate SageMakerTrainer and SageMakerTrainingArguments

* Apply suggestions from code review

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2021-03-31 10:01:30 -04:00 committed by GitHub
parent b6dddda4d2
commit cd56f3fe7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 210 additions and 44 deletions

View File

@ -352,7 +352,7 @@ def is_pandas_available():
return importlib.util.find_spec("pandas") is not None
def is_sagemaker_distributed_available():
def is_sagemaker_dp_enabled():
# Get the sagemaker specific env variable.
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
@ -366,6 +366,30 @@ def is_sagemaker_distributed_available():
return importlib.util.find_spec("smdistributed") is not None
def is_sagemaker_mp_enabled():
# Get the sagemaker specific mp parameters from smp_options variable.
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
try:
# Parse it and check the field "partitions" is included, it is required for model parallel.
smp_options = json.loads(smp_options)
if "partitions" not in smp_options:
return False
except json.JSONDecodeError:
return False
# Get the sagemaker specific framework parameters from mpi_options variable.
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
mpi_options = json.loads(mpi_options)
if not mpi_options.get("sagemaker_mpi_enabled", False):
return False
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
def is_training_run_on_sagemaker():
return "SAGEMAKER_JOB_NAME" in os.environ

View File

@ -17,4 +17,4 @@
# limitations under the License.
from .trainer_sm import SageMakerTrainer
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_distributed_available
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled

View File

@ -79,6 +79,11 @@ if is_sagemaker_model_parallel_available():
class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs):
warnings.warn(
"`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` "
"instead.",
FutureWarning,
)
self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
super().__init__(args=args, **kwargs)

View File

@ -15,11 +15,12 @@
import importlib.util
import json
import os
import warnings
from dataclasses import dataclass, field
import torch
from transformers.file_utils import cached_property, is_sagemaker_distributed_available
from transformers.file_utils import cached_property, is_sagemaker_dp_enabled
from transformers.training_args import TrainingArguments
from transformers.utils import logging
@ -66,6 +67,14 @@ class SageMakerTrainingArguments(TrainingArguments):
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
)
def __post_init__(self):
super().__post_init__()
warnings.warn(
"`SageMakerTrainingArguments` is deprecated and will be removed in v5 of Transformers. You can use "
"`TrainingArguments` instead.",
FutureWarning,
)
@cached_property
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
@ -76,7 +85,7 @@ class SageMakerTrainingArguments(TrainingArguments):
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_distributed_available():
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
dist.init_process_group()

View File

@ -59,7 +59,8 @@ from .file_utils import (
is_apex_available,
is_datasets_available,
is_in_notebook,
is_sagemaker_distributed_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_training_run_on_sagemaker,
)
@ -149,12 +150,17 @@ if is_fairscale_available():
else:
FullyShardedDDP = None
if is_sagemaker_distributed_available():
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else:
import torch.distributed as dist
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
@ -522,7 +528,10 @@ class Trainer:
else:
if self.args.world_size <= 1:
return RandomSampler(self.train_dataset)
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last:
elif (
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
and not self.args.dataloader_drop_last
):
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop(
self.train_dataset,
@ -561,6 +570,13 @@ class Trainer:
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if is_torch_tpu_available():
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif is_sagemaker_mp_enabled():
return SequentialDistributedSampler(
eval_dataset,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size,
)
elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset)
else:
@ -674,6 +690,9 @@ class Trainer:
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
def create_scheduler(self, num_training_steps: int):
"""
Setup the scheduler. The optimizer of the trainer must have been set up before this method is called.
@ -775,6 +794,12 @@ class Trainer:
return model
def _wrap_model(self, model, training=True):
if is_sagemaker_mp_enabled():
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
return self.model_wrapped
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
# already initialized its own DDP and AMP
if self.deepspeed:
return self.deepspeed
@ -815,7 +840,7 @@ class Trainer:
cpu_offload=cpu_offload,
).to(self.args.device)
elif is_sagemaker_distributed_available():
elif is_sagemaker_dp_enabled():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1:
if self.args.ddp_find_unused_parameters is not None:
@ -1280,6 +1305,15 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
# Consolidate the state dict on all processed of dp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.is_world_process_zero():
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero() and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
@ -1337,8 +1371,9 @@ class Trainer:
self.optimizer.load_state_dict(optimizer_state)
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=self.args.device)
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
@ -1478,6 +1513,10 @@ class Trainer:
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
@ -1535,6 +1574,8 @@ class Trainer:
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=True)
elif is_sagemaker_mp_enabled():
return smp.local_rank() == 0
else:
return self.args.local_rank in [-1, 0]
@ -1545,8 +1586,10 @@ class Trainer:
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=False)
elif is_sagemaker_mp_enabled():
return smp.rank() == 0
else:
return self.args.local_rank == -1 or dist.get_rank() == 0
return self.args.process_index == 0
def save_model(self, output_dir: Optional[str] = None):
"""
@ -1556,6 +1599,11 @@ class Trainer:
"""
if is_torch_tpu_available():
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
state_dict = self.model_wrapped.state_dict()
if self.is_world_process_zero():
self._save(output_dir, state_dict=state_dict)
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
@ -1905,6 +1953,8 @@ class Trainer:
return
if is_torch_tpu_available():
tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)
@ -1957,27 +2007,47 @@ class Trainer:
labels = None
with torch.no_grad():
if has_labels:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs)
if has_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
loss = loss_mb.reduce_mean().detach().cpu()
logits = smp_nested_concat(logits_mb)
else:
logits = outputs[1:]
loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
else:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
loss = None
if self.use_amp:
with autocast():
if has_labels:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
logits = outputs[1:]
else:
loss = None
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
else:
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)

View File

@ -32,11 +32,11 @@ from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .utils import logging
if is_sagemaker_distributed_available():
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
else:
import torch.distributed as dist
@ -805,3 +805,40 @@ def get_parameter_names(model, forbidden_layer_types):
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
return result
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
@smp.step()
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
model.backward(loss)
return loss
@smp.step()
def smp_forward_only(model, inputs):
return model(**inputs)
def smp_gather(tensor):
if isinstance(tensor, (list, tuple)):
return type(tensor)(smp_gather(t) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
elif not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
return torch.cat([t.cpu() for t in all_tensors], dim=0)
def smp_nested_concat(tensor):
if isinstance(tensor, (list, tuple)):
return type(tensor)(smp_nested_concat(t) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
# which is also the name of the decorator so Python is confused.
return tensor.concat().detach().cpu()

View File

@ -31,7 +31,7 @@ import numpy as np
from .file_utils import (
ExplicitEnum,
is_psutil_available,
is_sagemaker_distributed_available,
is_sagemaker_dp_enabled,
is_tf_available,
is_torch_available,
is_torch_cuda_available,
@ -214,7 +214,7 @@ def total_processes_number(local_rank):
import torch_xla.core.xla_model as xm
return xm.xrt_world_size()
elif is_sagemaker_distributed_available():
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist
return dist.get_world_size()

View File

@ -21,7 +21,8 @@ from typing import Any, Dict, List, Optional
from .file_utils import (
cached_property,
is_sagemaker_distributed_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_available,
is_torch_tpu_available,
torch_required,
@ -36,9 +37,14 @@ if is_torch_available():
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if is_sagemaker_distributed_available():
if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as sm_dist
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
smp.init()
logger = logging.get_logger(__name__)
@ -519,6 +525,10 @@ class TrainingArguments:
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
)
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
)
def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory
@ -646,7 +656,11 @@ class TrainingArguments:
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
elif is_sagemaker_distributed_available():
elif is_sagemaker_mp_enabled():
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
sm_dist.init_process_group()
self.local_rank = sm_dist.get_local_rank()
device = torch.device("cuda", self.local_rank)
@ -730,8 +744,10 @@ class TrainingArguments:
"""
if is_torch_tpu_available():
return ParallelMode.TPU
elif is_sagemaker_distributed_available():
return ParallelMode.SAGEMAKER_DISTRIBUTED
elif is_sagemaker_mp_enabled():
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
elif is_sagemaker_dp_enabled():
return ParallelMode.SAGEMAKER_DATA_PARALLEL
elif self.local_rank != -1:
return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1:
@ -747,7 +763,9 @@ class TrainingArguments:
"""
if is_torch_tpu_available():
return xm.xrt_world_size()
elif is_sagemaker_distributed_available():
elif is_sagemaker_mp_enabled():
return smp.dp_size()
elif is_sagemaker_dp_enabled():
return sm_dist.get_world_size()
elif self.local_rank != -1:
return torch.distributed.get_world_size()
@ -761,7 +779,9 @@ class TrainingArguments:
"""
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_distributed_available():
elif is_sagemaker_mp_enabled():
return smp.dp_rank()
elif is_sagemaker_dp_enabled():
return sm_dist.get_rank()
elif self.local_rank != -1:
return torch.distributed.get_rank()
@ -772,14 +792,14 @@ class TrainingArguments:
"""
Can be subclassed and overridden for some specific integrations.
"""
return True
return not is_sagemaker_mp_enabled()
@property
def _no_sync_in_gradient_accumulation(self):
"""
Whether or not to use no_sync for the gradients when doing gradient accumulation.
"""
return not self.deepspeed
return not (self.deepspeed or is_sagemaker_mp_enabled())
def to_dict(self):
"""
@ -817,5 +837,6 @@ class ParallelMode(Enum):
NOT_PARALLEL = "not_parallel"
NOT_DISTRIBUTED = "not_distributed"
DISTRIBUTED = "distributed"
SAGEMAKER_DISTRIBUTED = "sm_distributed"
SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel"
SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel"
TPU = "tpu"

View File

@ -9,10 +9,10 @@ from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
from transformers.file_utils import is_sagemaker_distributed_available
from transformers.file_utils import is_sagemaker_dp_enabled
if os.environ.get("SDP_ENABLED") or is_sagemaker_distributed_available():
if os.environ.get("SDP_ENABLED") or is_sagemaker_dp_enabled():
SDP_ENABLED = True
os.environ["SAGEMAKER_INSTANCE_TYPE"] = "p3dn.24xlarge"
import smdistributed.dataparallel.tensorflow as sdp