Sm trainer smp init fix (#10870)

* rewrote is_sagemaker_model_parallel_available

* added is_sagemaker_model_parallel_available to SageMakerTrainer

* removed unnecessary mp_parameters as TrainingArguments

* make style happy

* added mp_parameters again to parse mp-specific args.
This commit is contained in:
Philipp Schmid 2021-03-23 20:07:55 +01:00 committed by GitHub
parent d4d4447d53
commit 8c297cdb30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 14 deletions

View File

@ -34,13 +34,13 @@ from ..trainer_pt_utils import (
)
from ..trainer_utils import PREFIX_CHECKPOINT_DIR
from ..utils import logging
from .training_args_sm import is_smdistributed_available
from .training_args_sm import is_sagemaker_model_parallel_available
logger = logging.get_logger(__name__)
if is_smdistributed_available():
if is_sagemaker_model_parallel_available():
import smdistributed.modelparallel.torch as smp
@smp.step()
@ -79,7 +79,7 @@ if is_smdistributed_available():
class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs):
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
super().__init__(args=args, **kwargs)
def is_world_process_zero(self) -> bool:

View File

@ -13,6 +13,8 @@
# limitations under the License.
import importlib.util
import json
import os
from dataclasses import dataclass, field
import torch
@ -24,33 +26,53 @@ from transformers.utils import logging
logger = logging.get_logger(__name__)
# TODO: should be moved to `file_utils` after refactoring of SageMakerTrainer
def is_smdistributed_available():
def is_sagemaker_model_parallel_available():
# Get the sagemaker specific mp parameters from smp_options variable.
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
try:
# Parse it and check the field "partitions" is included, it is required for model parallel.
smp_options = json.loads(smp_options)
if "partitions" not in smp_options:
return False
except json.JSONDecodeError:
return False
# Get the sagemaker specific framework parameters from mpi_options variable.
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
mpi_options = json.loads(mpi_options)
if not mpi_options.get("sagemaker_mpi_enabled", False):
return False
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
if is_smdistributed_available():
if is_sagemaker_model_parallel_available():
import smdistributed.modelparallel.torch as smp
smp.init()
@dataclass
class SageMakerTrainingArguments(TrainingArguments):
mp_parameters: str = field(
default="", metadata={"help": "Used by the SageMaker launcher to send mp-specific args."}
default="",
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
)
def __post_init__(self):
super().__post_init__()
if is_smdistributed_available() and self.mp_parameters != "":
smp.init()
@cached_property
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
elif is_smdistributed_available() and self.mp_parameters != "":
elif is_sagemaker_model_parallel_available():
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
@ -86,14 +108,14 @@ class SageMakerTrainingArguments(TrainingArguments):
@property
def world_size(self):
if is_smdistributed_available() and self.mp_parameters != "":
if is_sagemaker_model_parallel_available():
return smp.dp_size()
return super().world_size
@property
def place_model_on_device(self):
return not (is_smdistributed_available() and self.mp_parameters != "")
return not is_sagemaker_model_parallel_available()
@property
def _no_sync_in_gradient_accumulation(self):