mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-13 17:48:22 +06:00
Introduce PartialState
as the device handler in the Trainer
(#22752)
* Use accelerate for device management * Add accelerate to setup Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
50caa20628
commit
03462875cc
2
setup.py
2
setup.py
@ -260,7 +260,7 @@ extras["sklearn"] = deps_list("scikit-learn")
|
|||||||
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
|
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
|
||||||
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
|
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
|
||||||
|
|
||||||
extras["torch"] = deps_list("torch")
|
extras["torch"] = deps_list("torch", "accelerate")
|
||||||
extras["accelerate"] = deps_list("accelerate")
|
extras["accelerate"] = deps_list("accelerate")
|
||||||
|
|
||||||
if os.name == "nt": # windows
|
if os.name == "nt": # windows
|
||||||
|
@ -416,8 +416,7 @@ class Trainer:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
|
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
|
||||||
)
|
)
|
||||||
|
if args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
if args.local_rank == -1:
|
|
||||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||||
elif not is_fairscale_available():
|
elif not is_fairscale_available():
|
||||||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||||
@ -439,7 +438,7 @@ class Trainer:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||||
)
|
)
|
||||||
if not args.fsdp_config["xla"] and args.local_rank == -1:
|
if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
raise ValueError("Using fsdp only works in distributed training.")
|
raise ValueError("Using fsdp only works in distributed training.")
|
||||||
|
|
||||||
# dep_version_check("torch>=1.12.0")
|
# dep_version_check("torch>=1.12.0")
|
||||||
@ -551,7 +550,7 @@ class Trainer:
|
|||||||
# In case of pull, we need to make sure every process has the latest.
|
# In case of pull, we need to make sure every process has the latest.
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous("init git repo")
|
xm.rendezvous("init git repo")
|
||||||
elif args.local_rank != -1:
|
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
@ -929,7 +928,7 @@ class Trainer:
|
|||||||
rank=smp.dp_rank(),
|
rank=smp.dp_rank(),
|
||||||
batch_size=self.args.per_device_eval_batch_size,
|
batch_size=self.args.per_device_eval_batch_size,
|
||||||
)
|
)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
return SequentialDistributedSampler(eval_dataset)
|
return SequentialDistributedSampler(eval_dataset)
|
||||||
else:
|
else:
|
||||||
return SequentialSampler(eval_dataset)
|
return SequentialSampler(eval_dataset)
|
||||||
@ -1551,7 +1550,7 @@ class Trainer:
|
|||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||||
)
|
)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if self.args.ddp_find_unused_parameters is not None:
|
if self.args.ddp_find_unused_parameters is not None:
|
||||||
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
|
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
|
||||||
@ -1919,7 +1918,7 @@ class Trainer:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
(total_batched_samples % args.gradient_accumulation_steps != 0)
|
(total_batched_samples % args.gradient_accumulation_steps != 0)
|
||||||
and args.local_rank != -1
|
and args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||||
and args._no_sync_in_gradient_accumulation
|
and args._no_sync_in_gradient_accumulation
|
||||||
):
|
):
|
||||||
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
|
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
|
||||||
@ -2041,7 +2040,7 @@ class Trainer:
|
|||||||
# Wait for everyone to get here so we are sur the model has been saved by process 0.
|
# Wait for everyone to get here so we are sur the model has been saved by process 0.
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous("load_best_model_at_end")
|
xm.rendezvous("load_best_model_at_end")
|
||||||
elif args.local_rank != -1:
|
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
smp.barrier()
|
smp.barrier()
|
||||||
@ -2319,7 +2318,7 @@ class Trainer:
|
|||||||
np.random.set_state(checkpoint_rng_state["numpy"])
|
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||||
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
if self.args.local_rank != -1:
|
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
|
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@ -2413,7 +2412,7 @@ class Trainer:
|
|||||||
"cpu": torch.random.get_rng_state(),
|
"cpu": torch.random.get_rng_state(),
|
||||||
}
|
}
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
if self.args.local_rank == -1:
|
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
|
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
|
||||||
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
|
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
|
||||||
else:
|
else:
|
||||||
@ -2895,7 +2894,7 @@ class Trainer:
|
|||||||
|
|
||||||
def store_flos(self):
|
def store_flos(self):
|
||||||
# Storing the number of floating-point operations that went into the model
|
# Storing the number of floating-point operations that went into the model
|
||||||
if self.args.local_rank != -1:
|
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
self.state.total_flos += (
|
self.state.total_flos += (
|
||||||
distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
|
distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
|
||||||
)
|
)
|
||||||
@ -3310,7 +3309,7 @@ class Trainer:
|
|||||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
tensors = smp_gather(tensors)
|
tensors = smp_gather(tensors)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
tensors = distributed_concat(tensors)
|
tensors = distributed_concat(tensors)
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
@ -3834,7 +3833,7 @@ class Trainer:
|
|||||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
tensors = smp_gather(tensors)
|
tensors = smp_gather(tensors)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
tensors = distributed_concat(tensors)
|
tensors = distributed_concat(tensors)
|
||||||
|
|
||||||
return nested_numpify(tensors)
|
return nested_numpify(tensors)
|
||||||
|
@ -38,10 +38,8 @@ from .trainer_utils import (
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
cached_property,
|
cached_property,
|
||||||
ccl_version,
|
|
||||||
get_full_repo_name,
|
get_full_repo_name,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_psutil_available,
|
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
@ -65,6 +63,10 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import DistributedType
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
@ -1122,12 +1124,6 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
|
||||||
# This needs to happen before any call to self.device or self.n_gpu.
|
|
||||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
||||||
if env_local_rank != -1 and env_local_rank != self.local_rank:
|
|
||||||
self.local_rank = env_local_rank
|
|
||||||
|
|
||||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||||
# in the current directory instead of the actual home
|
# in the current directory instead of the actual home
|
||||||
# see https://github.com/huggingface/transformers/issues/10628
|
# see https://github.com/huggingface/transformers/issues/10628
|
||||||
@ -1535,104 +1531,40 @@ class TrainingArguments:
|
|||||||
def _setup_devices(self) -> "torch.device":
|
def _setup_devices(self) -> "torch.device":
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
logger.info("PyTorch: setting up devices")
|
logger.info("PyTorch: setting up devices")
|
||||||
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
|
|
||||||
logger.warning(
|
|
||||||
"torch.distributed process group is initialized, but local_rank == -1. "
|
|
||||||
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
|
|
||||||
)
|
|
||||||
if self.no_cuda:
|
if self.no_cuda:
|
||||||
device = torch.device("cpu")
|
self.distributed_state = PartialState(cpu=True)
|
||||||
self._n_gpu = 0
|
device = self.distributed_state.device
|
||||||
self.local_rank = get_int_from_env(
|
|
||||||
["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"],
|
|
||||||
self.local_rank,
|
|
||||||
)
|
|
||||||
if self.local_rank != -1 and not torch.distributed.is_initialized():
|
|
||||||
# Initializes distributed backend for cpu
|
|
||||||
if self.xpu_backend not in ("mpi", "ccl", "gloo"):
|
|
||||||
raise ValueError(
|
|
||||||
"CPU distributed training backend is not properly set. "
|
|
||||||
"Please set '--xpu_backend' to either 'mpi' or 'ccl' or 'gloo'."
|
|
||||||
)
|
|
||||||
if self.xpu_backend == "ccl":
|
|
||||||
requires_backends(self, "oneccl_bind_pt")
|
|
||||||
if ccl_version >= "1.12":
|
|
||||||
import oneccl_bindings_for_pytorch # noqa: F401
|
|
||||||
else:
|
|
||||||
import torch_ccl # noqa: F401
|
|
||||||
if int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1:
|
|
||||||
raise ValueError(
|
|
||||||
"CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. "
|
|
||||||
"Please use like 'export CCL_WORKER_COUNT = 1' to set."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH
|
|
||||||
rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)
|
|
||||||
size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1)
|
|
||||||
local_size = get_int_from_env(
|
|
||||||
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1
|
|
||||||
)
|
|
||||||
os.environ["RANK"] = str(rank)
|
|
||||||
os.environ["WORLD_SIZE"] = str(size)
|
|
||||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
|
||||||
if not os.environ.get("MASTER_PORT", None):
|
|
||||||
os.environ["MASTER_PORT"] = "29500"
|
|
||||||
if not os.environ.get("MASTER_ADDR", None):
|
|
||||||
if local_size != size or self.xpu_backend != "mpi":
|
|
||||||
raise ValueError(
|
|
||||||
"Looks like distributed multinode run but MASTER_ADDR env not set, "
|
|
||||||
"please try exporting rank 0's hostname as MASTER_ADDR"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
torch.get_num_threads() == 1
|
|
||||||
and get_int_from_env(["OMP_NUM_THREADS", "MKL_NUM_THREADS"], 0) == 0
|
|
||||||
and is_psutil_available()
|
|
||||||
):
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
num_cpu_threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
|
|
||||||
if num_cpu_threads_per_process == 0:
|
|
||||||
num_cpu_threads_per_process = 1
|
|
||||||
torch.set_num_threads(num_cpu_threads_per_process)
|
|
||||||
logger.info(
|
|
||||||
f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob"
|
|
||||||
" performance."
|
|
||||||
)
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta
|
|
||||||
)
|
|
||||||
elif is_torch_tpu_available():
|
|
||||||
device = xm.xla_device()
|
|
||||||
self._n_gpu = 0
|
self._n_gpu = 0
|
||||||
|
self.local_rank = self.distributed_state.local_process_index
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
local_rank = smp.local_rank()
|
local_rank = smp.local_rank()
|
||||||
device = torch.device("cuda", local_rank)
|
device = torch.device("cuda", local_rank)
|
||||||
self._n_gpu = 1
|
self._n_gpu = 1
|
||||||
elif is_sagemaker_dp_enabled():
|
torch.cuda.set_device(device)
|
||||||
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
|
|
||||||
|
|
||||||
dist.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
|
|
||||||
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
|
|
||||||
device = torch.device("cuda", self.local_rank)
|
|
||||||
self._n_gpu = 1
|
|
||||||
elif self.deepspeed:
|
elif self.deepspeed:
|
||||||
# deepspeed inits torch.distributed internally
|
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
|
||||||
from .deepspeed import is_deepspeed_available
|
|
||||||
|
|
||||||
if not is_deepspeed_available():
|
|
||||||
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
deepspeed.init_distributed(timeout=timedelta(seconds=self.ddp_timeout))
|
|
||||||
|
|
||||||
# workaround for setups like notebooks where the launcher can't be used,
|
|
||||||
# but deepspeed requires a dist env.
|
|
||||||
# env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed
|
|
||||||
self.local_rank = int(os.environ.get("LOCAL_RANK", "-1"))
|
|
||||||
|
|
||||||
device = torch.device("cuda", self.local_rank)
|
|
||||||
self._n_gpu = 1
|
self._n_gpu = 1
|
||||||
elif self.local_rank == -1:
|
device = self.distributed_state.device
|
||||||
|
else:
|
||||||
|
self.distributed_state = PartialState(backend=self.xpu_backend)
|
||||||
|
device = self.distributed_state.device
|
||||||
|
self._n_gpu = 1
|
||||||
|
if (
|
||||||
|
torch.distributed.is_available()
|
||||||
|
and torch.distributed.is_initialized()
|
||||||
|
and self.distributed_state.distributed_type != DistributedType.NO
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"torch.distributed process group is initialized, but parallel_mode == ParallelMode.DISTRIBUTED. "
|
||||||
|
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
device = self.distributed_state.device
|
||||||
|
self._n_gpu = 0
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
|
self._n_gpu = 1
|
||||||
|
elif self.distributed_state.distributed_type == DistributedType.NO:
|
||||||
if self.use_mps_device:
|
if self.use_mps_device:
|
||||||
if not torch.backends.mps.is_available():
|
if not torch.backends.mps.is_available():
|
||||||
if not torch.backends.mps.is_built():
|
if not torch.backends.mps.is_built():
|
||||||
@ -1665,24 +1597,13 @@ class TrainingArguments:
|
|||||||
# trigger an error that a device index is missing. Index 0 takes into account the
|
# trigger an error that a device index is missing. Index 0 takes into account the
|
||||||
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
|
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
|
||||||
# will use the first GPU in that env, i.e. GPU#1
|
# will use the first GPU in that env, i.e. GPU#1
|
||||||
|
# device = self.distributed_state.device
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
||||||
# the default value.
|
# the default value.
|
||||||
self._n_gpu = torch.cuda.device_count()
|
self._n_gpu = torch.cuda.device_count()
|
||||||
else:
|
|
||||||
# Here, we'll use torch.distributed.
|
|
||||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
|
||||||
if not torch.distributed.is_initialized():
|
|
||||||
if self.xpu_backend and self.xpu_backend in ("mpi", "gloo"):
|
|
||||||
torch.distributed.init_process_group(backend=self.xpu_backend, timeout=self.ddp_timeout_delta)
|
|
||||||
else:
|
|
||||||
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
|
|
||||||
device = torch.device("cuda", self.local_rank)
|
|
||||||
self._n_gpu = 1
|
|
||||||
|
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1725,7 +1646,7 @@ class TrainingArguments:
|
|||||||
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
|
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled():
|
||||||
return ParallelMode.SAGEMAKER_DATA_PARALLEL
|
return ParallelMode.SAGEMAKER_DATA_PARALLEL
|
||||||
elif self.local_rank != -1:
|
elif hasattr(self, "distributed_state") and (self.distributed_state.distributed_type != DistributedType.NO):
|
||||||
return ParallelMode.DISTRIBUTED
|
return ParallelMode.DISTRIBUTED
|
||||||
elif self.n_gpu > 1:
|
elif self.n_gpu > 1:
|
||||||
return ParallelMode.NOT_DISTRIBUTED
|
return ParallelMode.NOT_DISTRIBUTED
|
||||||
@ -1745,7 +1666,7 @@ class TrainingArguments:
|
|||||||
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
|
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled():
|
||||||
return dist.get_world_size()
|
return dist.get_world_size()
|
||||||
elif self.local_rank != -1:
|
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
return torch.distributed.get_world_size()
|
return torch.distributed.get_world_size()
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@ -1761,7 +1682,7 @@ class TrainingArguments:
|
|||||||
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
|
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled():
|
||||||
return dist.get_rank()
|
return dist.get_rank()
|
||||||
elif self.local_rank != -1:
|
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
return torch.distributed.get_rank()
|
return torch.distributed.get_rank()
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -1777,7 +1698,7 @@ class TrainingArguments:
|
|||||||
return smp.local_rank()
|
return smp.local_rank()
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled():
|
||||||
return dist.get_rank()
|
return dist.get_rank()
|
||||||
elif self.local_rank != -1:
|
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
return self.local_rank
|
return self.local_rank
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import sys
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||||
@ -23,6 +22,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torch_neuroncore,
|
require_torch_neuroncore,
|
||||||
)
|
)
|
||||||
|
from transformers.training_args import ParallelMode
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
@ -66,15 +66,13 @@ if is_torch_available():
|
|||||||
class TestTrainerDistributedNeuronCore(TestCasePlus):
|
class TestTrainerDistributedNeuronCore(TestCasePlus):
|
||||||
@require_torch_neuroncore
|
@require_torch_neuroncore
|
||||||
def test_trainer(self):
|
def test_trainer(self):
|
||||||
distributed_args = f"""
|
distributed_args = f"""--nproc_per_node=2
|
||||||
-m torch.distributed.run
|
|
||||||
--nproc_per_node=2
|
|
||||||
--master_port={get_torch_dist_unique_port()}
|
--master_port={get_torch_dist_unique_port()}
|
||||||
{self.test_file_dir}/test_trainer_distributed.py
|
{self.test_file_dir}/test_trainer_distributed.py
|
||||||
""".split()
|
""".split()
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
args = f"--output_dir {output_dir}".split()
|
args = f"--output_dir {output_dir}".split()
|
||||||
cmd = [sys.executable] + distributed_args + args
|
cmd = ["torchrun"] + distributed_args + args
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
# successful return here == success - any errors would have caused an error in the sub-call
|
||||||
|
|
||||||
@ -82,15 +80,13 @@ class TestTrainerDistributedNeuronCore(TestCasePlus):
|
|||||||
class TestTrainerDistributed(TestCasePlus):
|
class TestTrainerDistributed(TestCasePlus):
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
def test_trainer(self):
|
def test_trainer(self):
|
||||||
distributed_args = f"""
|
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
||||||
-m torch.distributed.run
|
|
||||||
--nproc_per_node={torch.cuda.device_count()}
|
|
||||||
--master_port={get_torch_dist_unique_port()}
|
--master_port={get_torch_dist_unique_port()}
|
||||||
{self.test_file_dir}/test_trainer_distributed.py
|
{self.test_file_dir}/test_trainer_distributed.py
|
||||||
""".split()
|
""".split()
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
args = f"--output_dir {output_dir}".split()
|
args = f"--output_dir {output_dir}".split()
|
||||||
cmd = [sys.executable] + distributed_args + args
|
cmd = ["torchrun"] + distributed_args + args
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
# successful return here == success - any errors would have caused an error in the sub-call
|
||||||
|
|
||||||
@ -105,7 +101,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
||||||
f"distributed training: {training_args.local_rank != -1}"
|
f"distributed training: {training_args.parallel_mode != ParallelMode.NOT_DISTRIBUTED}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Essentially, what we want to verify in the distributed case is that we get all samples back,
|
# Essentially, what we want to verify in the distributed case is that we get all samples back,
|
||||||
|
Loading…
Reference in New Issue
Block a user