mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
feat: add support for tensor parallel training workflow with accelerate (#34194)
* feat: add support for tensor parallel flow using accelerate Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: add tp degree to env variable Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: add version check for accelerate to allow TP Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * docs: tensor parallelism Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * nit: rename plugin name Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: guard accelerate version before allow tp Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * docs: add more docs and updates related to TP Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
e6cc410d5b
commit
c3ba53303b
@ -673,6 +673,29 @@ tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`.
|
||||
|
@ -55,7 +55,7 @@ To give some examples of how much VRAM it roughly takes to load a model in bfloa
|
||||
|
||||
As of writing this document, the largest GPU chip on the market is the A100 & H100 offering 80GB of VRAM. Most of the models listed before require more than 80GB just to be loaded and therefore necessarily require [tensor parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#tensor-parallelism) and/or [pipeline parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism).
|
||||
|
||||
🤗 Transformers does not support tensor parallelism out of the box as it requires the model architecture to be written in a specific way. If you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling).
|
||||
🤗 Transformers now supports tensor parallelism for supported models having `base_tp_plan` in their respecitve config classes. Learn more about Tensor Parallelism [here](perf_train_gpu_many#tensor-parallelism). Furthermore, if you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling).
|
||||
|
||||
Naive pipeline parallelism is supported out of the box. For this, simply load the model with `device="auto"` which will automatically place the different layers on the available GPUs as explained [here](https://huggingface.co/docs/accelerate/v0.22.0/en/concept_guides/big_model_inference).
|
||||
Note, however that while very effective, this naive pipeline parallelism does not tackle the issues of GPU idling. For this more advanced pipeline parallelism is required as explained [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism).
|
||||
|
@ -450,12 +450,13 @@ Implementations:
|
||||
- [parallelformers](https://github.com/tunib-ai/parallelformers) (only inference at the moment)
|
||||
- [SageMaker](https://arxiv.org/abs/2111.05972) - this is a proprietary solution that can only be used on AWS.
|
||||
- [OSLO](https://github.com/tunib-ai/oslo) has the tensor parallelism implementation based on the Transformers.
|
||||
- [`transformers` integration](main_classes/trainer) tensor parallelism is available through tp_size attribute for models having `base_tp_plan`. Further you can look at [example usage](perf_infer_gpu_multi)
|
||||
|
||||
SageMaker combines TP with DP for a more efficient processing.
|
||||
|
||||
🤗 Transformers status:
|
||||
- core: not yet implemented in the core
|
||||
- but if you want inference [parallelformers](https://github.com/tunib-ai/parallelformers) provides this support for most of our models. So until this is implemented in the core you can use theirs. And hopefully training mode will be supported too.
|
||||
- core: uses PyTorch 2 APIs to support tensor parallelism to models having base_tp_plan in their respective config classes.
|
||||
- Alternatively, you can as well try [parallelformers](https://github.com/tunib-ai/parallelformers) that provides this support for most of our models. Training mode with TP is as well supported natively in transformers.
|
||||
- Deepspeed-Inference also supports our BERT, GPT-2, and GPT-Neo models in their super-fast CUDA-kernel-based inference mode, see more [here](https://www.deepspeed.ai/tutorials/inference-tutorial/)
|
||||
|
||||
🤗 Accelerate integrates with [TP from Megatron-LM](https://huggingface.co/docs/accelerate/v0.23.0/en/usage_guides/megatron_lm).
|
||||
@ -535,7 +536,7 @@ Important papers:
|
||||
- [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](
|
||||
https://arxiv.org/abs/2201.11990)
|
||||
|
||||
🤗 Transformers status: not yet implemented, since we have no PP and TP.
|
||||
🤗 Transformers status: not yet implemented, since we have no PP.
|
||||
|
||||
## FlexFlow
|
||||
|
||||
|
@ -799,6 +799,29 @@ tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
@ -361,6 +361,30 @@ use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
@ -548,6 +548,29 @@ tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
@ -241,6 +241,8 @@ if is_accelerate_available():
|
||||
)
|
||||
|
||||
DATA_SAMPLERS = [RandomSampler]
|
||||
if version.parse(accelerate_version) > version.parse("1.3.0"):
|
||||
from accelerate.utils import TorchTensorParallelPlugin
|
||||
if version.parse(accelerate_version) > version.parse("0.23.0"):
|
||||
from accelerate.data_loader import SeedableRandomSampler
|
||||
|
||||
@ -5094,6 +5096,14 @@ class Trainer:
|
||||
args["dataloader_config"] = dataloader_config
|
||||
else:
|
||||
args.update(accelerator_config)
|
||||
# tp is initialized at Accelerator init phase so
|
||||
# args should be prepared here
|
||||
if self.args.tp_size > 1:
|
||||
self.is_tp_enabled = True
|
||||
if version.parse(accelerate_version) > version.parse("1.3.0"):
|
||||
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
|
||||
else:
|
||||
raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")
|
||||
|
||||
# create accelerator object
|
||||
self.accelerator = Accelerator(**args)
|
||||
@ -5108,7 +5118,7 @@ class Trainer:
|
||||
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
|
||||
self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
|
||||
# post accelerator creation setup
|
||||
if self.is_fsdp_enabled:
|
||||
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
||||
|
@ -569,7 +569,10 @@ class TrainingArguments:
|
||||
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
|
||||
used when the xla flag is set to true, and an auto wrapping policy is specified through
|
||||
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
|
||||
|
||||
tp_size (`int`, *optional*):
|
||||
Use tp_size to enable PyTorch tensor parallelism. Tensor parallelism support is only available to models having `base_tp_plan`
|
||||
in their respective config classes.
|
||||
Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. Requires accelerate>1.3.0.
|
||||
deepspeed (`str` or `dict`, *optional*):
|
||||
Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
|
||||
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
|
||||
@ -1250,6 +1253,18 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
tp_size: Optional[int] = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": (
|
||||
"Use tp_size to enable pytorch tensor parallelism."
|
||||
"Tensor parallelism support is only available to models having `base_tp_plan` in their respective config classes."
|
||||
"Set a value greater than 1 to activate TP."
|
||||
"The same is used to prepare device mesh internally."
|
||||
"Requires accelerate>1.3.0."
|
||||
)
|
||||
},
|
||||
)
|
||||
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -1975,6 +1990,14 @@ class TrainingArguments:
|
||||
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
|
||||
|
||||
if self.tp_size > 1:
|
||||
if not is_accelerate_available("1.3.1"):
|
||||
raise NotImplementedError(
|
||||
"TP using PyTorch requires Accelerate version `accelerate` >= 1.3.1. "
|
||||
"This is not supported and we recommend you to update your version."
|
||||
)
|
||||
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||
os.environ["TP_SIZE"] = str(self.tp_size)
|
||||
# accelerate integration for FSDP
|
||||
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
|
Loading…
Reference in New Issue
Block a user