diff --git a/docs/source/ar/trainer.md b/docs/source/ar/trainer.md index 7da7cbf4e17..e70dbb255ea 100644 --- a/docs/source/ar/trainer.md +++ b/docs/source/ar/trainer.md @@ -673,6 +673,29 @@ tpu_use_sudo: false use_cpu: false ``` + + + +```yml +compute_environment: LOCAL_MACHINE +tp_config: + tp_size: 4 +distributed_type: TP +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false + +``` + يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`. diff --git a/docs/source/en/llm_tutorial_optimization.md b/docs/source/en/llm_tutorial_optimization.md index 3414725fc37..1fd458d430f 100644 --- a/docs/source/en/llm_tutorial_optimization.md +++ b/docs/source/en/llm_tutorial_optimization.md @@ -55,7 +55,7 @@ To give some examples of how much VRAM it roughly takes to load a model in bfloa As of writing this document, the largest GPU chip on the market is the A100 & H100 offering 80GB of VRAM. Most of the models listed before require more than 80GB just to be loaded and therefore necessarily require [tensor parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#tensor-parallelism) and/or [pipeline parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism). -🤗 Transformers does not support tensor parallelism out of the box as it requires the model architecture to be written in a specific way. If you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling). +🤗 Transformers now supports tensor parallelism for supported models having `base_tp_plan` in their respecitve config classes. Learn more about Tensor Parallelism [here](perf_train_gpu_many#tensor-parallelism). Furthermore, if you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling). Naive pipeline parallelism is supported out of the box. For this, simply load the model with `device="auto"` which will automatically place the different layers on the available GPUs as explained [here](https://huggingface.co/docs/accelerate/v0.22.0/en/concept_guides/big_model_inference). Note, however that while very effective, this naive pipeline parallelism does not tackle the issues of GPU idling. For this more advanced pipeline parallelism is required as explained [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism). diff --git a/docs/source/en/perf_train_gpu_many.md b/docs/source/en/perf_train_gpu_many.md index d60c61020c7..bf9467d19d2 100644 --- a/docs/source/en/perf_train_gpu_many.md +++ b/docs/source/en/perf_train_gpu_many.md @@ -450,12 +450,13 @@ Implementations: - [parallelformers](https://github.com/tunib-ai/parallelformers) (only inference at the moment) - [SageMaker](https://arxiv.org/abs/2111.05972) - this is a proprietary solution that can only be used on AWS. - [OSLO](https://github.com/tunib-ai/oslo) has the tensor parallelism implementation based on the Transformers. +- [`transformers` integration](main_classes/trainer) tensor parallelism is available through tp_size attribute for models having `base_tp_plan`. Further you can look at [example usage](perf_infer_gpu_multi) SageMaker combines TP with DP for a more efficient processing. 🤗 Transformers status: -- core: not yet implemented in the core -- but if you want inference [parallelformers](https://github.com/tunib-ai/parallelformers) provides this support for most of our models. So until this is implemented in the core you can use theirs. And hopefully training mode will be supported too. +- core: uses PyTorch 2 APIs to support tensor parallelism to models having base_tp_plan in their respective config classes. +- Alternatively, you can as well try [parallelformers](https://github.com/tunib-ai/parallelformers) that provides this support for most of our models. Training mode with TP is as well supported natively in transformers. - Deepspeed-Inference also supports our BERT, GPT-2, and GPT-Neo models in their super-fast CUDA-kernel-based inference mode, see more [here](https://www.deepspeed.ai/tutorials/inference-tutorial/) 🤗 Accelerate integrates with [TP from Megatron-LM](https://huggingface.co/docs/accelerate/v0.23.0/en/usage_guides/megatron_lm). @@ -535,7 +536,7 @@ Important papers: - [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model]( https://arxiv.org/abs/2201.11990) -🤗 Transformers status: not yet implemented, since we have no PP and TP. +🤗 Transformers status: not yet implemented, since we have no PP. ## FlexFlow diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 92bb4367139..8cfe5dfc6af 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -799,6 +799,29 @@ tpu_use_sudo: false use_cpu: false ``` + + + +```yml +compute_environment: LOCAL_MACHINE +tp_config: + tp_size: 4 +distributed_type: TP +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false + +``` + diff --git a/docs/source/es/trainer.md b/docs/source/es/trainer.md index dab83e9a9d9..0362fe1d7d2 100644 --- a/docs/source/es/trainer.md +++ b/docs/source/es/trainer.md @@ -361,6 +361,30 @@ use_cpu: false ``` + + + + +```yml +compute_environment: LOCAL_MACHINE +tp_config: + tp_size: 4 +distributed_type: TP +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false + +``` + diff --git a/docs/source/ko/trainer.md b/docs/source/ko/trainer.md index 42789fc0c2f..976072730c8 100644 --- a/docs/source/ko/trainer.md +++ b/docs/source/ko/trainer.md @@ -548,6 +548,29 @@ tpu_use_sudo: false use_cpu: false ``` + + + +```yml +compute_environment: LOCAL_MACHINE +tp_config: + tp_size: 4 +distributed_type: TP +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false + +``` + diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f970885314f..94659d98841 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -241,6 +241,8 @@ if is_accelerate_available(): ) DATA_SAMPLERS = [RandomSampler] + if version.parse(accelerate_version) > version.parse("1.3.0"): + from accelerate.utils import TorchTensorParallelPlugin if version.parse(accelerate_version) > version.parse("0.23.0"): from accelerate.data_loader import SeedableRandomSampler @@ -5094,6 +5096,14 @@ class Trainer: args["dataloader_config"] = dataloader_config else: args.update(accelerator_config) + # tp is initialized at Accelerator init phase so + # args should be prepared here + if self.args.tp_size > 1: + self.is_tp_enabled = True + if version.parse(accelerate_version) > version.parse("1.3.0"): + args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size) + else: + raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.") # create accelerator object self.accelerator = Accelerator(**args) @@ -5108,7 +5118,7 @@ class Trainer: # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - + self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None # post accelerator creation setup if self.is_fsdp_enabled: fsdp_plugin = self.accelerator.state.fsdp_plugin diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 6afdfb33249..005c035ca62 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -569,7 +569,10 @@ class TrainingArguments: Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap. - + tp_size (`int`, *optional*): + Use tp_size to enable PyTorch tensor parallelism. Tensor parallelism support is only available to models having `base_tp_plan` + in their respective config classes. + Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. Requires accelerate>1.3.0. deepspeed (`str` or `dict`, *optional*): Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may evolve in the future. The value is either the location of DeepSpeed json config file (e.g., @@ -1250,6 +1253,18 @@ class TrainingArguments: ) }, ) + tp_size: Optional[int] = field( + default=0, + metadata={ + "help": ( + "Use tp_size to enable pytorch tensor parallelism." + "Tensor parallelism support is only available to models having `base_tp_plan` in their respective config classes." + "Set a value greater than 1 to activate TP." + "The same is used to prepare device mesh internally." + "Requires accelerate>1.3.0." + ) + }, + ) fsdp_transformer_layer_cls_to_wrap: Optional[str] = field( default=None, metadata={ @@ -1975,6 +1990,14 @@ class TrainingArguments: if self.fsdp_config["xla_fsdp_grad_ckpt"]: warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") + if self.tp_size > 1: + if not is_accelerate_available("1.3.1"): + raise NotImplementedError( + "TP using PyTorch requires Accelerate version `accelerate` >= 1.3.1. " + "This is not supported and we recommend you to update your version." + ) + os.environ["ACCELERATE_USE_TP"] = "true" + os.environ["TP_SIZE"] = str(self.tp_size) # accelerate integration for FSDP if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: os.environ["ACCELERATE_USE_FSDP"] = "true"