feat: add support for tensor parallel training workflow with accelerate (#34194)

* feat: add support for tensor parallel flow using accelerate

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: add tp degree to env variable

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: add version check for accelerate to allow TP

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* docs: tensor parallelism

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* nit: rename plugin name

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: guard accelerate version before allow tp

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* docs: add more docs and updates related to TP

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Mehant Kammakomati 2025-02-18 18:35:46 +05:30 committed by GitHub
parent e6cc410d5b
commit c3ba53303b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 133 additions and 6 deletions

View File

@ -673,6 +673,29 @@ tpu_use_sudo: false
use_cpu: false
```
</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">
```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
</hfoptions>
يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`.

View File

@ -55,7 +55,7 @@ To give some examples of how much VRAM it roughly takes to load a model in bfloa
As of writing this document, the largest GPU chip on the market is the A100 & H100 offering 80GB of VRAM. Most of the models listed before require more than 80GB just to be loaded and therefore necessarily require [tensor parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#tensor-parallelism) and/or [pipeline parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism).
🤗 Transformers does not support tensor parallelism out of the box as it requires the model architecture to be written in a specific way. If you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling).
🤗 Transformers now supports tensor parallelism for supported models having `base_tp_plan` in their respecitve config classes. Learn more about Tensor Parallelism [here](perf_train_gpu_many#tensor-parallelism). Furthermore, if you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling).
Naive pipeline parallelism is supported out of the box. For this, simply load the model with `device="auto"` which will automatically place the different layers on the available GPUs as explained [here](https://huggingface.co/docs/accelerate/v0.22.0/en/concept_guides/big_model_inference).
Note, however that while very effective, this naive pipeline parallelism does not tackle the issues of GPU idling. For this more advanced pipeline parallelism is required as explained [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism).

View File

@ -450,12 +450,13 @@ Implementations:
- [parallelformers](https://github.com/tunib-ai/parallelformers) (only inference at the moment)
- [SageMaker](https://arxiv.org/abs/2111.05972) - this is a proprietary solution that can only be used on AWS.
- [OSLO](https://github.com/tunib-ai/oslo) has the tensor parallelism implementation based on the Transformers.
- [`transformers` integration](main_classes/trainer) tensor parallelism is available through tp_size attribute for models having `base_tp_plan`. Further you can look at [example usage](perf_infer_gpu_multi)
SageMaker combines TP with DP for a more efficient processing.
🤗 Transformers status:
- core: not yet implemented in the core
- but if you want inference [parallelformers](https://github.com/tunib-ai/parallelformers) provides this support for most of our models. So until this is implemented in the core you can use theirs. And hopefully training mode will be supported too.
- core: uses PyTorch 2 APIs to support tensor parallelism to models having base_tp_plan in their respective config classes.
- Alternatively, you can as well try [parallelformers](https://github.com/tunib-ai/parallelformers) that provides this support for most of our models. Training mode with TP is as well supported natively in transformers.
- Deepspeed-Inference also supports our BERT, GPT-2, and GPT-Neo models in their super-fast CUDA-kernel-based inference mode, see more [here](https://www.deepspeed.ai/tutorials/inference-tutorial/)
🤗 Accelerate integrates with [TP from Megatron-LM](https://huggingface.co/docs/accelerate/v0.23.0/en/usage_guides/megatron_lm).
@ -535,7 +536,7 @@ Important papers:
- [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](
https://arxiv.org/abs/2201.11990)
🤗 Transformers status: not yet implemented, since we have no PP and TP.
🤗 Transformers status: not yet implemented, since we have no PP.
## FlexFlow

View File

@ -799,6 +799,29 @@ tpu_use_sudo: false
use_cpu: false
```
</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">
```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
</hfoptions>

View File

@ -361,6 +361,30 @@ use_cpu: false
```
</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">
```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
</hfoptions>

View File

@ -548,6 +548,29 @@ tpu_use_sudo: false
use_cpu: false
```
</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">
```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
</hfoptions>

View File

@ -241,6 +241,8 @@ if is_accelerate_available():
)
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("1.3.0"):
from accelerate.utils import TorchTensorParallelPlugin
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler
@ -5094,6 +5096,14 @@ class Trainer:
args["dataloader_config"] = dataloader_config
else:
args.update(accelerator_config)
# tp is initialized at Accelerator init phase so
# args should be prepared here
if self.args.tp_size > 1:
self.is_tp_enabled = True
if version.parse(accelerate_version) > version.parse("1.3.0"):
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
else:
raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")
# create accelerator object
self.accelerator = Accelerator(**args)
@ -5108,7 +5118,7 @@ class Trainer:
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin

View File

@ -569,7 +569,10 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
tp_size (`int`, *optional*):
Use tp_size to enable PyTorch tensor parallelism. Tensor parallelism support is only available to models having `base_tp_plan`
in their respective config classes.
Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. Requires accelerate>1.3.0.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
@ -1250,6 +1253,18 @@ class TrainingArguments:
)
},
)
tp_size: Optional[int] = field(
default=0,
metadata={
"help": (
"Use tp_size to enable pytorch tensor parallelism."
"Tensor parallelism support is only available to models having `base_tp_plan` in their respective config classes."
"Set a value greater than 1 to activate TP."
"The same is used to prepare device mesh internally."
"Requires accelerate>1.3.0."
)
},
)
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
default=None,
metadata={
@ -1975,6 +1990,14 @@ class TrainingArguments:
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
if self.tp_size > 1:
if not is_accelerate_available("1.3.1"):
raise NotImplementedError(
"TP using PyTorch requires Accelerate version `accelerate` >= 1.3.1. "
"This is not supported and we recommend you to update your version."
)
os.environ["ACCELERATE_USE_TP"] = "true"
os.environ["TP_SIZE"] = str(self.tp_size)
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"