mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Merge trainers (#10975)
* Replace is_sagemaker_distributed_available * Merge SageMakerTrainer into Trainer * Test with shorter condition * Put back deleted line * Deprecate SageMakerTrainer and SageMakerTrainingArguments * Apply suggestions from code review Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
This commit is contained in:
parent
b6dddda4d2
commit
cd56f3fe7e
@ -352,7 +352,7 @@ def is_pandas_available():
|
|||||||
return importlib.util.find_spec("pandas") is not None
|
return importlib.util.find_spec("pandas") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_sagemaker_distributed_available():
|
def is_sagemaker_dp_enabled():
|
||||||
# Get the sagemaker specific env variable.
|
# Get the sagemaker specific env variable.
|
||||||
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
||||||
try:
|
try:
|
||||||
@ -366,6 +366,30 @@ def is_sagemaker_distributed_available():
|
|||||||
return importlib.util.find_spec("smdistributed") is not None
|
return importlib.util.find_spec("smdistributed") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_sagemaker_mp_enabled():
|
||||||
|
# Get the sagemaker specific mp parameters from smp_options variable.
|
||||||
|
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
|
||||||
|
try:
|
||||||
|
# Parse it and check the field "partitions" is included, it is required for model parallel.
|
||||||
|
smp_options = json.loads(smp_options)
|
||||||
|
if "partitions" not in smp_options:
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get the sagemaker specific framework parameters from mpi_options variable.
|
||||||
|
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
||||||
|
try:
|
||||||
|
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
|
||||||
|
mpi_options = json.loads(mpi_options)
|
||||||
|
if not mpi_options.get("sagemaker_mpi_enabled", False):
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return False
|
||||||
|
# Lastly, check if the `smdistributed` module is present.
|
||||||
|
return importlib.util.find_spec("smdistributed") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_training_run_on_sagemaker():
|
def is_training_run_on_sagemaker():
|
||||||
return "SAGEMAKER_JOB_NAME" in os.environ
|
return "SAGEMAKER_JOB_NAME" in os.environ
|
||||||
|
|
||||||
|
@ -17,4 +17,4 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .trainer_sm import SageMakerTrainer
|
from .trainer_sm import SageMakerTrainer
|
||||||
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_distributed_available
|
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled
|
||||||
|
@ -79,6 +79,11 @@ if is_sagemaker_model_parallel_available():
|
|||||||
|
|
||||||
class SageMakerTrainer(Trainer):
|
class SageMakerTrainer(Trainer):
|
||||||
def __init__(self, args=None, **kwargs):
|
def __init__(self, args=None, **kwargs):
|
||||||
|
warnings.warn(
|
||||||
|
"`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` "
|
||||||
|
"instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
|
self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
|
||||||
super().__init__(args=args, **kwargs)
|
super().__init__(args=args, **kwargs)
|
||||||
|
|
||||||
|
@ -15,11 +15,12 @@
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.file_utils import cached_property, is_sagemaker_distributed_available
|
from transformers.file_utils import cached_property, is_sagemaker_dp_enabled
|
||||||
from transformers.training_args import TrainingArguments
|
from transformers.training_args import TrainingArguments
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
@ -66,6 +67,14 @@ class SageMakerTrainingArguments(TrainingArguments):
|
|||||||
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
|
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
warnings.warn(
|
||||||
|
"`SageMakerTrainingArguments` is deprecated and will be removed in v5 of Transformers. You can use "
|
||||||
|
"`TrainingArguments` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _setup_devices(self) -> "torch.device":
|
def _setup_devices(self) -> "torch.device":
|
||||||
logger.info("PyTorch: setting up devices")
|
logger.info("PyTorch: setting up devices")
|
||||||
@ -76,7 +85,7 @@ class SageMakerTrainingArguments(TrainingArguments):
|
|||||||
local_rank = smp.local_rank()
|
local_rank = smp.local_rank()
|
||||||
device = torch.device("cuda", local_rank)
|
device = torch.device("cuda", local_rank)
|
||||||
self._n_gpu = 1
|
self._n_gpu = 1
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_dp_enabled():
|
||||||
import smdistributed.dataparallel.torch.distributed as dist
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
|
|
||||||
dist.init_process_group()
|
dist.init_process_group()
|
||||||
|
@ -59,7 +59,8 @@ from .file_utils import (
|
|||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_sagemaker_distributed_available,
|
is_sagemaker_dp_enabled,
|
||||||
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_training_run_on_sagemaker,
|
is_training_run_on_sagemaker,
|
||||||
)
|
)
|
||||||
@ -149,12 +150,17 @@ if is_fairscale_available():
|
|||||||
else:
|
else:
|
||||||
FullyShardedDDP = None
|
FullyShardedDDP = None
|
||||||
|
|
||||||
if is_sagemaker_distributed_available():
|
if is_sagemaker_dp_enabled():
|
||||||
import smdistributed.dataparallel.torch.distributed as dist
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
||||||
else:
|
else:
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
|
||||||
|
|
||||||
if is_training_run_on_sagemaker():
|
if is_training_run_on_sagemaker():
|
||||||
logging.add_handler(StreamHandler(sys.stdout))
|
logging.add_handler(StreamHandler(sys.stdout))
|
||||||
|
|
||||||
@ -522,7 +528,10 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return RandomSampler(self.train_dataset)
|
return RandomSampler(self.train_dataset)
|
||||||
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last:
|
elif (
|
||||||
|
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
||||||
|
and not self.args.dataloader_drop_last
|
||||||
|
):
|
||||||
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
||||||
return DistributedSamplerWithLoop(
|
return DistributedSamplerWithLoop(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
@ -561,6 +570,13 @@ class Trainer:
|
|||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
return SequentialDistributedSampler(
|
||||||
|
eval_dataset,
|
||||||
|
num_replicas=smp.dp_size(),
|
||||||
|
rank=smp.dp_rank(),
|
||||||
|
batch_size=self.args.per_device_eval_batch_size,
|
||||||
|
)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
return SequentialDistributedSampler(eval_dataset)
|
return SequentialDistributedSampler(eval_dataset)
|
||||||
else:
|
else:
|
||||||
@ -674,6 +690,9 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||||
|
|
||||||
def create_scheduler(self, num_training_steps: int):
|
def create_scheduler(self, num_training_steps: int):
|
||||||
"""
|
"""
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up before this method is called.
|
Setup the scheduler. The optimizer of the trainer must have been set up before this method is called.
|
||||||
@ -775,6 +794,12 @@ class Trainer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True):
|
def _wrap_model(self, model, training=True):
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
# Wrapping the base model twice in a DistributedModel will raise an error.
|
||||||
|
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
||||||
|
return self.model_wrapped
|
||||||
|
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
||||||
|
|
||||||
# already initialized its own DDP and AMP
|
# already initialized its own DDP and AMP
|
||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
return self.deepspeed
|
return self.deepspeed
|
||||||
@ -815,7 +840,7 @@ class Trainer:
|
|||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
).to(self.args.device)
|
).to(self.args.device)
|
||||||
|
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_dp_enabled():
|
||||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
if self.args.ddp_find_unused_parameters is not None:
|
if self.args.ddp_find_unused_parameters is not None:
|
||||||
@ -1280,6 +1305,15 @@ class Trainer:
|
|||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
# Consolidate the state dict on all processed of dp_rank 0
|
||||||
|
opt_state_dict = self.optimizer.state_dict()
|
||||||
|
# Save it and the scheduler on the main process
|
||||||
|
if self.is_world_process_zero():
|
||||||
|
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
elif self.is_world_process_zero() and not self.deepspeed:
|
elif self.is_world_process_zero() and not self.deepspeed:
|
||||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
@ -1337,8 +1371,9 @@ class Trainer:
|
|||||||
self.optimizer.load_state_dict(optimizer_state)
|
self.optimizer.load_state_dict(optimizer_state)
|
||||||
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
||||||
else:
|
else:
|
||||||
|
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
|
||||||
self.optimizer.load_state_dict(
|
self.optimizer.load_state_dict(
|
||||||
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=self.args.device)
|
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location)
|
||||||
)
|
)
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
||||||
@ -1478,6 +1513,10 @@ class Trainer:
|
|||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||||
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||||
|
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
with autocast():
|
with autocast():
|
||||||
loss = self.compute_loss(model, inputs)
|
loss = self.compute_loss(model, inputs)
|
||||||
@ -1535,6 +1574,8 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.is_master_ordinal(local=True)
|
return xm.is_master_ordinal(local=True)
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
return smp.local_rank() == 0
|
||||||
else:
|
else:
|
||||||
return self.args.local_rank in [-1, 0]
|
return self.args.local_rank in [-1, 0]
|
||||||
|
|
||||||
@ -1545,8 +1586,10 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.is_master_ordinal(local=False)
|
return xm.is_master_ordinal(local=False)
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
return smp.rank() == 0
|
||||||
else:
|
else:
|
||||||
return self.args.local_rank == -1 or dist.get_rank() == 0
|
return self.args.process_index == 0
|
||||||
|
|
||||||
def save_model(self, output_dir: Optional[str] = None):
|
def save_model(self, output_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
@ -1556,6 +1599,11 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
self._save_tpu(output_dir)
|
self._save_tpu(output_dir)
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
# Calling the state_dict needs to be done on the wrapped model and on all processes.
|
||||||
|
state_dict = self.model_wrapped.state_dict()
|
||||||
|
if self.is_world_process_zero():
|
||||||
|
self._save(output_dir, state_dict=state_dict)
|
||||||
elif (
|
elif (
|
||||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||||
):
|
):
|
||||||
@ -1905,6 +1953,8 @@ class Trainer:
|
|||||||
return
|
return
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
tensors = smp_gather(tensors)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
tensors = distributed_concat(tensors)
|
tensors = distributed_concat(tensors)
|
||||||
|
|
||||||
@ -1957,27 +2007,47 @@ class Trainer:
|
|||||||
labels = None
|
labels = None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if has_labels:
|
if is_sagemaker_mp_enabled():
|
||||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
raw_outputs = smp_forward_only(model, inputs)
|
||||||
loss = loss.mean().detach()
|
if has_labels:
|
||||||
if isinstance(outputs, dict):
|
if isinstance(raw_outputs, dict):
|
||||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
loss_mb = raw_outputs["loss"]
|
||||||
|
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
|
||||||
|
else:
|
||||||
|
loss_mb = raw_outputs[0]
|
||||||
|
logits_mb = raw_outputs[1:]
|
||||||
|
|
||||||
|
loss = loss_mb.reduce_mean().detach().cpu()
|
||||||
|
logits = smp_nested_concat(logits_mb)
|
||||||
else:
|
else:
|
||||||
logits = outputs[1:]
|
loss = None
|
||||||
|
if isinstance(raw_outputs, dict):
|
||||||
|
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
|
||||||
|
else:
|
||||||
|
logits_mb = raw_outputs
|
||||||
|
logits = smp_nested_concat(logits_mb)
|
||||||
else:
|
else:
|
||||||
loss = None
|
if has_labels:
|
||||||
if self.use_amp:
|
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||||
with autocast():
|
loss = loss.mean().detach()
|
||||||
|
if isinstance(outputs, dict):
|
||||||
|
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||||
|
else:
|
||||||
|
logits = outputs[1:]
|
||||||
|
else:
|
||||||
|
loss = None
|
||||||
|
if self.use_amp:
|
||||||
|
with autocast():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
else:
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
else:
|
if isinstance(outputs, dict):
|
||||||
outputs = model(**inputs)
|
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||||
if isinstance(outputs, dict):
|
else:
|
||||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
logits = outputs
|
||||||
else:
|
# TODO: this needs to be fixed and made cleaner later.
|
||||||
logits = outputs
|
if self.args.past_index >= 0:
|
||||||
# TODO: this needs to be fixed and made cleaner later.
|
self._past = outputs[self.args.past_index - 1]
|
||||||
if self.args.past_index >= 0:
|
|
||||||
self._past = outputs[self.args.past_index - 1]
|
|
||||||
|
|
||||||
if prediction_loss_only:
|
if prediction_loss_only:
|
||||||
return (loss, None, None)
|
return (loss, None, None)
|
||||||
|
@ -32,11 +32,11 @@ from torch.utils.data.dataset import Dataset
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||||
|
|
||||||
from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available
|
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
if is_sagemaker_distributed_available():
|
if is_sagemaker_dp_enabled():
|
||||||
import smdistributed.dataparallel.torch.distributed as dist
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
else:
|
else:
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -805,3 +805,40 @@ def get_parameter_names(model, forbidden_layer_types):
|
|||||||
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
|
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
|
||||||
result += list(model._parameters.keys())
|
result += list(model._parameters.keys())
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
@smp.step()
|
||||||
|
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
|
||||||
|
outputs = model(**inputs)
|
||||||
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||||
|
loss /= gradient_accumulation_steps
|
||||||
|
model.backward(loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@smp.step()
|
||||||
|
def smp_forward_only(model, inputs):
|
||||||
|
return model(**inputs)
|
||||||
|
|
||||||
|
def smp_gather(tensor):
|
||||||
|
if isinstance(tensor, (list, tuple)):
|
||||||
|
return type(tensor)(smp_gather(t) for t in tensor)
|
||||||
|
elif isinstance(tensor, dict):
|
||||||
|
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
|
||||||
|
elif not isinstance(tensor, torch.Tensor):
|
||||||
|
raise TypeError(
|
||||||
|
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
||||||
|
)
|
||||||
|
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
||||||
|
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
||||||
|
|
||||||
|
def smp_nested_concat(tensor):
|
||||||
|
if isinstance(tensor, (list, tuple)):
|
||||||
|
return type(tensor)(smp_nested_concat(t) for t in tensor)
|
||||||
|
elif isinstance(tensor, dict):
|
||||||
|
return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
|
||||||
|
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
|
||||||
|
# which is also the name of the decorator so Python is confused.
|
||||||
|
return tensor.concat().detach().cpu()
|
||||||
|
@ -31,7 +31,7 @@ import numpy as np
|
|||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
is_psutil_available,
|
is_psutil_available,
|
||||||
is_sagemaker_distributed_available,
|
is_sagemaker_dp_enabled,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
@ -214,7 +214,7 @@ def total_processes_number(local_rank):
|
|||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
return xm.xrt_world_size()
|
return xm.xrt_world_size()
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_dp_enabled():
|
||||||
import smdistributed.dataparallel.torch.distributed as dist
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
|
|
||||||
return dist.get_world_size()
|
return dist.get_world_size()
|
||||||
|
@ -21,7 +21,8 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
cached_property,
|
cached_property,
|
||||||
is_sagemaker_distributed_available,
|
is_sagemaker_dp_enabled,
|
||||||
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
torch_required,
|
torch_required,
|
||||||
@ -36,9 +37,14 @@ if is_torch_available():
|
|||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
if is_sagemaker_distributed_available():
|
if is_sagemaker_dp_enabled():
|
||||||
import smdistributed.dataparallel.torch.distributed as sm_dist
|
import smdistributed.dataparallel.torch.distributed as sm_dist
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
smp.init()
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -519,6 +525,10 @@ class TrainingArguments:
|
|||||||
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
|
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
|
||||||
)
|
)
|
||||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
|
mp_parameters: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||||
@ -646,7 +656,11 @@ class TrainingArguments:
|
|||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
device = xm.xla_device()
|
device = xm.xla_device()
|
||||||
self._n_gpu = 0
|
self._n_gpu = 0
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_mp_enabled():
|
||||||
|
local_rank = smp.local_rank()
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
self._n_gpu = 1
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
sm_dist.init_process_group()
|
sm_dist.init_process_group()
|
||||||
self.local_rank = sm_dist.get_local_rank()
|
self.local_rank = sm_dist.get_local_rank()
|
||||||
device = torch.device("cuda", self.local_rank)
|
device = torch.device("cuda", self.local_rank)
|
||||||
@ -730,8 +744,10 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return ParallelMode.TPU
|
return ParallelMode.TPU
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_mp_enabled():
|
||||||
return ParallelMode.SAGEMAKER_DISTRIBUTED
|
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
|
return ParallelMode.SAGEMAKER_DATA_PARALLEL
|
||||||
elif self.local_rank != -1:
|
elif self.local_rank != -1:
|
||||||
return ParallelMode.DISTRIBUTED
|
return ParallelMode.DISTRIBUTED
|
||||||
elif self.n_gpu > 1:
|
elif self.n_gpu > 1:
|
||||||
@ -747,7 +763,9 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.xrt_world_size()
|
return xm.xrt_world_size()
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_mp_enabled():
|
||||||
|
return smp.dp_size()
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
return sm_dist.get_world_size()
|
return sm_dist.get_world_size()
|
||||||
elif self.local_rank != -1:
|
elif self.local_rank != -1:
|
||||||
return torch.distributed.get_world_size()
|
return torch.distributed.get_world_size()
|
||||||
@ -761,7 +779,9 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.get_ordinal()
|
return xm.get_ordinal()
|
||||||
elif is_sagemaker_distributed_available():
|
elif is_sagemaker_mp_enabled():
|
||||||
|
return smp.dp_rank()
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
return sm_dist.get_rank()
|
return sm_dist.get_rank()
|
||||||
elif self.local_rank != -1:
|
elif self.local_rank != -1:
|
||||||
return torch.distributed.get_rank()
|
return torch.distributed.get_rank()
|
||||||
@ -772,14 +792,14 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
Can be subclassed and overridden for some specific integrations.
|
Can be subclassed and overridden for some specific integrations.
|
||||||
"""
|
"""
|
||||||
return True
|
return not is_sagemaker_mp_enabled()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _no_sync_in_gradient_accumulation(self):
|
def _no_sync_in_gradient_accumulation(self):
|
||||||
"""
|
"""
|
||||||
Whether or not to use no_sync for the gradients when doing gradient accumulation.
|
Whether or not to use no_sync for the gradients when doing gradient accumulation.
|
||||||
"""
|
"""
|
||||||
return not self.deepspeed
|
return not (self.deepspeed or is_sagemaker_mp_enabled())
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
@ -817,5 +837,6 @@ class ParallelMode(Enum):
|
|||||||
NOT_PARALLEL = "not_parallel"
|
NOT_PARALLEL = "not_parallel"
|
||||||
NOT_DISTRIBUTED = "not_distributed"
|
NOT_DISTRIBUTED = "not_distributed"
|
||||||
DISTRIBUTED = "distributed"
|
DISTRIBUTED = "distributed"
|
||||||
SAGEMAKER_DISTRIBUTED = "sm_distributed"
|
SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel"
|
||||||
|
SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel"
|
||||||
TPU = "tpu"
|
TPU = "tpu"
|
||||||
|
@ -9,10 +9,10 @@ from datasets import load_dataset
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
||||||
from transformers.file_utils import is_sagemaker_distributed_available
|
from transformers.file_utils import is_sagemaker_dp_enabled
|
||||||
|
|
||||||
|
|
||||||
if os.environ.get("SDP_ENABLED") or is_sagemaker_distributed_available():
|
if os.environ.get("SDP_ENABLED") or is_sagemaker_dp_enabled():
|
||||||
SDP_ENABLED = True
|
SDP_ENABLED = True
|
||||||
os.environ["SAGEMAKER_INSTANCE_TYPE"] = "p3dn.24xlarge"
|
os.environ["SAGEMAKER_INSTANCE_TYPE"] = "p3dn.24xlarge"
|
||||||
import smdistributed.dataparallel.tensorflow as sdp
|
import smdistributed.dataparallel.tensorflow as sdp
|
||||||
|
Loading…
Reference in New Issue
Block a user