Merge trainers (#10975)

* Replace is_sagemaker_distributed_available

* Merge SageMakerTrainer into Trainer

* Test with shorter condition

* Put back deleted line

* Deprecate SageMakerTrainer and SageMakerTrainingArguments

* Apply suggestions from code review

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2021-03-31 10:01:30 -04:00 committed by GitHub
parent b6dddda4d2
commit cd56f3fe7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 210 additions and 44 deletions

View File

@ -352,7 +352,7 @@ def is_pandas_available():
return importlib.util.find_spec("pandas") is not None return importlib.util.find_spec("pandas") is not None
def is_sagemaker_distributed_available(): def is_sagemaker_dp_enabled():
# Get the sagemaker specific env variable. # Get the sagemaker specific env variable.
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}") sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try: try:
@ -366,6 +366,30 @@ def is_sagemaker_distributed_available():
return importlib.util.find_spec("smdistributed") is not None return importlib.util.find_spec("smdistributed") is not None
def is_sagemaker_mp_enabled():
# Get the sagemaker specific mp parameters from smp_options variable.
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
try:
# Parse it and check the field "partitions" is included, it is required for model parallel.
smp_options = json.loads(smp_options)
if "partitions" not in smp_options:
return False
except json.JSONDecodeError:
return False
# Get the sagemaker specific framework parameters from mpi_options variable.
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
mpi_options = json.loads(mpi_options)
if not mpi_options.get("sagemaker_mpi_enabled", False):
return False
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
def is_training_run_on_sagemaker(): def is_training_run_on_sagemaker():
return "SAGEMAKER_JOB_NAME" in os.environ return "SAGEMAKER_JOB_NAME" in os.environ

View File

@ -17,4 +17,4 @@
# limitations under the License. # limitations under the License.
from .trainer_sm import SageMakerTrainer from .trainer_sm import SageMakerTrainer
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_distributed_available from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled

View File

@ -79,6 +79,11 @@ if is_sagemaker_model_parallel_available():
class SageMakerTrainer(Trainer): class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs): def __init__(self, args=None, **kwargs):
warnings.warn(
"`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` "
"instead.",
FutureWarning,
)
self.is_model_parallel_enabled = is_sagemaker_model_parallel_available() self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
super().__init__(args=args, **kwargs) super().__init__(args=args, **kwargs)

View File

@ -15,11 +15,12 @@
import importlib.util import importlib.util
import json import json
import os import os
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
from transformers.file_utils import cached_property, is_sagemaker_distributed_available from transformers.file_utils import cached_property, is_sagemaker_dp_enabled
from transformers.training_args import TrainingArguments from transformers.training_args import TrainingArguments
from transformers.utils import logging from transformers.utils import logging
@ -66,6 +67,14 @@ class SageMakerTrainingArguments(TrainingArguments):
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"}, metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
) )
def __post_init__(self):
super().__post_init__()
warnings.warn(
"`SageMakerTrainingArguments` is deprecated and will be removed in v5 of Transformers. You can use "
"`TrainingArguments` instead.",
FutureWarning,
)
@cached_property @cached_property
def _setup_devices(self) -> "torch.device": def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
@ -76,7 +85,7 @@ class SageMakerTrainingArguments(TrainingArguments):
local_rank = smp.local_rank() local_rank = smp.local_rank()
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
self._n_gpu = 1 self._n_gpu = 1
elif is_sagemaker_distributed_available(): elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist import smdistributed.dataparallel.torch.distributed as dist
dist.init_process_group() dist.init_process_group()

View File

@ -59,7 +59,8 @@ from .file_utils import (
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
is_in_notebook, is_in_notebook,
is_sagemaker_distributed_available, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available, is_torch_tpu_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
) )
@ -149,12 +150,17 @@ if is_fairscale_available():
else: else:
FullyShardedDDP = None FullyShardedDDP = None
if is_sagemaker_distributed_available(): if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else: else:
import torch.distributed as dist import torch.distributed as dist
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
if is_training_run_on_sagemaker(): if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout)) logging.add_handler(StreamHandler(sys.stdout))
@ -522,7 +528,10 @@ class Trainer:
else: else:
if self.args.world_size <= 1: if self.args.world_size <= 1:
return RandomSampler(self.train_dataset) return RandomSampler(self.train_dataset)
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last: elif (
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
and not self.args.dataloader_drop_last
):
# Use a loop for TPUs when drop_last is False to have all batches have the same size. # Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop( return DistributedSamplerWithLoop(
self.train_dataset, self.train_dataset,
@ -561,6 +570,13 @@ class Trainer:
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if is_torch_tpu_available(): if is_torch_tpu_available():
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif is_sagemaker_mp_enabled():
return SequentialDistributedSampler(
eval_dataset,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size,
)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset) return SequentialDistributedSampler(eval_dataset)
else: else:
@ -674,6 +690,9 @@ class Trainer:
else: else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
def create_scheduler(self, num_training_steps: int): def create_scheduler(self, num_training_steps: int):
""" """
Setup the scheduler. The optimizer of the trainer must have been set up before this method is called. Setup the scheduler. The optimizer of the trainer must have been set up before this method is called.
@ -775,6 +794,12 @@ class Trainer:
return model return model
def _wrap_model(self, model, training=True): def _wrap_model(self, model, training=True):
if is_sagemaker_mp_enabled():
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
return self.model_wrapped
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
# already initialized its own DDP and AMP # already initialized its own DDP and AMP
if self.deepspeed: if self.deepspeed:
return self.deepspeed return self.deepspeed
@ -815,7 +840,7 @@ class Trainer:
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
).to(self.args.device) ).to(self.args.device)
elif is_sagemaker_distributed_available(): elif is_sagemaker_dp_enabled():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
if self.args.ddp_find_unused_parameters is not None: if self.args.ddp_find_unused_parameters is not None:
@ -1280,6 +1305,15 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
# Consolidate the state dict on all processed of dp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.is_world_process_zero():
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero() and not self.deepspeed: elif self.is_world_process_zero() and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched # deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
@ -1337,8 +1371,9 @@ class Trainer:
self.optimizer.load_state_dict(optimizer_state) self.optimizer.load_state_dict(optimizer_state)
self.lr_scheduler.load_state_dict(lr_scheduler_state) self.lr_scheduler.load_state_dict(lr_scheduler_state)
else: else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict( self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=self.args.device) torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location)
) )
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt"))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
@ -1478,6 +1513,10 @@ class Trainer:
model.train() model.train()
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp: if self.use_amp:
with autocast(): with autocast():
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs)
@ -1535,6 +1574,8 @@ class Trainer:
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.is_master_ordinal(local=True) return xm.is_master_ordinal(local=True)
elif is_sagemaker_mp_enabled():
return smp.local_rank() == 0
else: else:
return self.args.local_rank in [-1, 0] return self.args.local_rank in [-1, 0]
@ -1545,8 +1586,10 @@ class Trainer:
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.is_master_ordinal(local=False) return xm.is_master_ordinal(local=False)
elif is_sagemaker_mp_enabled():
return smp.rank() == 0
else: else:
return self.args.local_rank == -1 or dist.get_rank() == 0 return self.args.process_index == 0
def save_model(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None):
""" """
@ -1556,6 +1599,11 @@ class Trainer:
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
self._save_tpu(output_dir) self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
state_dict = self.model_wrapped.state_dict()
if self.is_world_process_zero():
self._save(output_dir, state_dict=state_dict)
elif ( elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
): ):
@ -1905,6 +1953,8 @@ class Trainer:
return return
if is_torch_tpu_available(): if is_torch_tpu_available():
tensors = nested_xla_mesh_reduce(tensors, name) tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
tensors = distributed_concat(tensors) tensors = distributed_concat(tensors)
@ -1957,27 +2007,47 @@ class Trainer:
labels = None labels = None
with torch.no_grad(): with torch.no_grad():
if has_labels: if is_sagemaker_mp_enabled():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) raw_outputs = smp_forward_only(model, inputs)
loss = loss.mean().detach() if has_labels:
if isinstance(outputs, dict): if isinstance(raw_outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
loss = loss_mb.reduce_mean().detach().cpu()
logits = smp_nested_concat(logits_mb)
else: else:
logits = outputs[1:] loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
else:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else: else:
loss = None if has_labels:
if self.use_amp: loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
with autocast(): loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
logits = outputs[1:]
else:
loss = None
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs) outputs = model(**inputs)
else: if isinstance(outputs, dict):
outputs = model(**inputs) logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
if isinstance(outputs, dict): else:
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) logits = outputs
else: # TODO: this needs to be fixed and made cleaner later.
logits = outputs if self.args.past_index >= 0:
# TODO: this needs to be fixed and made cleaner later. self._past = outputs[self.args.past_index - 1]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only: if prediction_loss_only:
return (loss, None, None) return (loss, None, None)

View File

@ -32,11 +32,11 @@ from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .utils import logging from .utils import logging
if is_sagemaker_distributed_available(): if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist import smdistributed.dataparallel.torch.distributed as dist
else: else:
import torch.distributed as dist import torch.distributed as dist
@ -805,3 +805,40 @@ def get_parameter_names(model, forbidden_layer_types):
# Add model specific parameters (defined with nn.Parameter) since they are not in any child. # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys()) result += list(model._parameters.keys())
return result return result
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
@smp.step()
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
model.backward(loss)
return loss
@smp.step()
def smp_forward_only(model, inputs):
return model(**inputs)
def smp_gather(tensor):
if isinstance(tensor, (list, tuple)):
return type(tensor)(smp_gather(t) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
elif not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
return torch.cat([t.cpu() for t in all_tensors], dim=0)
def smp_nested_concat(tensor):
if isinstance(tensor, (list, tuple)):
return type(tensor)(smp_nested_concat(t) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
# which is also the name of the decorator so Python is confused.
return tensor.concat().detach().cpu()

View File

@ -31,7 +31,7 @@ import numpy as np
from .file_utils import ( from .file_utils import (
ExplicitEnum, ExplicitEnum,
is_psutil_available, is_psutil_available,
is_sagemaker_distributed_available, is_sagemaker_dp_enabled,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_torch_cuda_available, is_torch_cuda_available,
@ -214,7 +214,7 @@ def total_processes_number(local_rank):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.xrt_world_size() return xm.xrt_world_size()
elif is_sagemaker_distributed_available(): elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as dist import smdistributed.dataparallel.torch.distributed as dist
return dist.get_world_size() return dist.get_world_size()

View File

@ -21,7 +21,8 @@ from typing import Any, Dict, List, Optional
from .file_utils import ( from .file_utils import (
cached_property, cached_property,
is_sagemaker_distributed_available, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_available, is_torch_available,
is_torch_tpu_available, is_torch_tpu_available,
torch_required, torch_required,
@ -36,9 +37,14 @@ if is_torch_available():
if is_torch_tpu_available(): if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if is_sagemaker_distributed_available(): if is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.distributed as sm_dist import smdistributed.dataparallel.torch.distributed as sm_dist
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
smp.init()
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -519,6 +525,10 @@ class TrainingArguments:
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
) )
_n_gpu: int = field(init=False, repr=False, default=-1) _n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
)
def __post_init__(self): def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory # expand paths, if not os.makedirs("~/bar") will make directory
@ -646,7 +656,11 @@ class TrainingArguments:
elif is_torch_tpu_available(): elif is_torch_tpu_available():
device = xm.xla_device() device = xm.xla_device()
self._n_gpu = 0 self._n_gpu = 0
elif is_sagemaker_distributed_available(): elif is_sagemaker_mp_enabled():
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
sm_dist.init_process_group() sm_dist.init_process_group()
self.local_rank = sm_dist.get_local_rank() self.local_rank = sm_dist.get_local_rank()
device = torch.device("cuda", self.local_rank) device = torch.device("cuda", self.local_rank)
@ -730,8 +744,10 @@ class TrainingArguments:
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
return ParallelMode.TPU return ParallelMode.TPU
elif is_sagemaker_distributed_available(): elif is_sagemaker_mp_enabled():
return ParallelMode.SAGEMAKER_DISTRIBUTED return ParallelMode.SAGEMAKER_MODEL_PARALLEL
elif is_sagemaker_dp_enabled():
return ParallelMode.SAGEMAKER_DATA_PARALLEL
elif self.local_rank != -1: elif self.local_rank != -1:
return ParallelMode.DISTRIBUTED return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1: elif self.n_gpu > 1:
@ -747,7 +763,9 @@ class TrainingArguments:
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.xrt_world_size() return xm.xrt_world_size()
elif is_sagemaker_distributed_available(): elif is_sagemaker_mp_enabled():
return smp.dp_size()
elif is_sagemaker_dp_enabled():
return sm_dist.get_world_size() return sm_dist.get_world_size()
elif self.local_rank != -1: elif self.local_rank != -1:
return torch.distributed.get_world_size() return torch.distributed.get_world_size()
@ -761,7 +779,9 @@ class TrainingArguments:
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.get_ordinal() return xm.get_ordinal()
elif is_sagemaker_distributed_available(): elif is_sagemaker_mp_enabled():
return smp.dp_rank()
elif is_sagemaker_dp_enabled():
return sm_dist.get_rank() return sm_dist.get_rank()
elif self.local_rank != -1: elif self.local_rank != -1:
return torch.distributed.get_rank() return torch.distributed.get_rank()
@ -772,14 +792,14 @@ class TrainingArguments:
""" """
Can be subclassed and overridden for some specific integrations. Can be subclassed and overridden for some specific integrations.
""" """
return True return not is_sagemaker_mp_enabled()
@property @property
def _no_sync_in_gradient_accumulation(self): def _no_sync_in_gradient_accumulation(self):
""" """
Whether or not to use no_sync for the gradients when doing gradient accumulation. Whether or not to use no_sync for the gradients when doing gradient accumulation.
""" """
return not self.deepspeed return not (self.deepspeed or is_sagemaker_mp_enabled())
def to_dict(self): def to_dict(self):
""" """
@ -817,5 +837,6 @@ class ParallelMode(Enum):
NOT_PARALLEL = "not_parallel" NOT_PARALLEL = "not_parallel"
NOT_DISTRIBUTED = "not_distributed" NOT_DISTRIBUTED = "not_distributed"
DISTRIBUTED = "distributed" DISTRIBUTED = "distributed"
SAGEMAKER_DISTRIBUTED = "sm_distributed" SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel"
SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel"
TPU = "tpu" TPU = "tpu"

View File

@ -9,10 +9,10 @@ from datasets import load_dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
from transformers.file_utils import is_sagemaker_distributed_available from transformers.file_utils import is_sagemaker_dp_enabled
if os.environ.get("SDP_ENABLED") or is_sagemaker_distributed_available(): if os.environ.get("SDP_ENABLED") or is_sagemaker_dp_enabled():
SDP_ENABLED = True SDP_ENABLED = True
os.environ["SAGEMAKER_INSTANCE_TYPE"] = "p3dn.24xlarge" os.environ["SAGEMAKER_INSTANCE_TYPE"] = "p3dn.24xlarge"
import smdistributed.dataparallel.tensorflow as sdp import smdistributed.dataparallel.tensorflow as sdp