mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Sm trainer smp init fix (#10870)
* rewrote is_sagemaker_model_parallel_available * added is_sagemaker_model_parallel_available to SageMakerTrainer * removed unnecessary mp_parameters as TrainingArguments * make style happy * added mp_parameters again to parse mp-specific args.
This commit is contained in:
parent
d4d4447d53
commit
8c297cdb30
@ -34,13 +34,13 @@ from ..trainer_pt_utils import (
|
||||
)
|
||||
from ..trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from ..utils import logging
|
||||
from .training_args_sm import is_smdistributed_available
|
||||
from .training_args_sm import is_sagemaker_model_parallel_available
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_smdistributed_available():
|
||||
if is_sagemaker_model_parallel_available():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
@smp.step()
|
||||
@ -79,7 +79,7 @@ if is_smdistributed_available():
|
||||
|
||||
class SageMakerTrainer(Trainer):
|
||||
def __init__(self, args=None, **kwargs):
|
||||
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
|
||||
self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
|
||||
super().__init__(args=args, **kwargs)
|
||||
|
||||
def is_world_process_zero(self) -> bool:
|
||||
|
@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
@ -24,33 +26,53 @@ from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# TODO: should be moved to `file_utils` after refactoring of SageMakerTrainer
|
||||
|
||||
def is_smdistributed_available():
|
||||
|
||||
def is_sagemaker_model_parallel_available():
|
||||
# Get the sagemaker specific mp parameters from smp_options variable.
|
||||
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
|
||||
try:
|
||||
# Parse it and check the field "partitions" is included, it is required for model parallel.
|
||||
smp_options = json.loads(smp_options)
|
||||
if "partitions" not in smp_options:
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
# Get the sagemaker specific framework parameters from mpi_options variable.
|
||||
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
||||
try:
|
||||
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
|
||||
mpi_options = json.loads(mpi_options)
|
||||
if not mpi_options.get("sagemaker_mpi_enabled", False):
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
# Lastly, check if the `smdistributed` module is present.
|
||||
return importlib.util.find_spec("smdistributed") is not None
|
||||
|
||||
|
||||
if is_smdistributed_available():
|
||||
if is_sagemaker_model_parallel_available():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
smp.init()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerTrainingArguments(TrainingArguments):
|
||||
mp_parameters: str = field(
|
||||
default="", metadata={"help": "Used by the SageMaker launcher to send mp-specific args."}
|
||||
default="",
|
||||
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if is_smdistributed_available() and self.mp_parameters != "":
|
||||
smp.init()
|
||||
|
||||
@cached_property
|
||||
def _setup_devices(self) -> "torch.device":
|
||||
logger.info("PyTorch: setting up devices")
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
self._n_gpu = 0
|
||||
elif is_smdistributed_available() and self.mp_parameters != "":
|
||||
elif is_sagemaker_model_parallel_available():
|
||||
local_rank = smp.local_rank()
|
||||
device = torch.device("cuda", local_rank)
|
||||
self._n_gpu = 1
|
||||
@ -86,14 +108,14 @@ class SageMakerTrainingArguments(TrainingArguments):
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
if is_smdistributed_available() and self.mp_parameters != "":
|
||||
if is_sagemaker_model_parallel_available():
|
||||
return smp.dp_size()
|
||||
|
||||
return super().world_size
|
||||
|
||||
@property
|
||||
def place_model_on_device(self):
|
||||
return not (is_smdistributed_available() and self.mp_parameters != "")
|
||||
return not is_sagemaker_model_parallel_available()
|
||||
|
||||
@property
|
||||
def _no_sync_in_gradient_accumulation(self):
|
||||
|
Loading…
Reference in New Issue
Block a user