Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer (#32860)

* add liger integration

* fix syntax

* fix import issue

* add trainer.md

* Use _apply_liger_kernel()

* Fixed log message

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/en/trainer.md

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Fixed checkstyle and updated readme

* Added test

* Fixed checkstyle

* fix docstring

* rename use_liger to use_liger_kernel

* Trigger Build

* Added test

* add fix-copies

* Fixed copy inconsistencies

---------

Co-authored-by: shimizust <sshimizu@linkedin.com>
Co-authored-by: Steven Shimizu <shimizust@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
Jason (Siyu) Zhu 2024-08-23 04:20:49 -07:00 committed by GitHub
parent 970a16ec7f
commit adb91179b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 118 additions and 0 deletions

View File

@ -382,6 +382,41 @@ trainer.train()
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
## Liger Kernel
[Liger-Kernel](https://github.com/linkedin/Liger-Kernel) Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.
<Tip>
Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. Its also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)
</Tip>
First make sure to install Liger official repository:
```bash
pip install liger-kernel
```
You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:
```py
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
use_liger_kernel=True
)
```
The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.
## LOMO optimizer
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).

View File

@ -84,6 +84,7 @@ from .utils import (
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_liger_kernel_available,
is_lomo_available,
is_natten_available,
is_nltk_available,
@ -1162,6 +1163,13 @@ def require_librosa(test_case):
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
def require_liger_kernel(test_case):
"""
Decorator marking a test that requires liger_kernel
"""
return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
def require_essentia(test_case):
"""
Decorator marking a test that requires essentia

View File

@ -155,6 +155,7 @@ from .utils import (
is_grokadamw_available,
is_in_notebook,
is_ipex_available,
is_liger_kernel_available,
is_lomo_available,
is_peft_available,
is_safetensors_available,
@ -464,6 +465,24 @@ class Trainer:
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)
if self.args.use_liger_kernel:
if is_liger_kernel_available():
from liger_kernel.transformers.trainer_integration import _apply_liger_kernel
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
if model_type:
# Monkey patch the model with liger kernels. Use the default kernel configurations.
_apply_liger_kernel(model_type=model_type)
else:
logger.warning(
"The model does not have a valid `model_type` specified. No liger kernels will be applied."
)
else:
raise ImportError(
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. "
"Please install it with `pip install liger-kernel`"
)
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)

View File

@ -793,6 +793,11 @@ class TrainingArguments:
eval_use_gather_object (`bool`, *optional*, defaults to `False`):
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch.
use_liger_kernel (`bool`, *optional*, defaults to `False`):
Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training.
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
"""
framework = "pt"
@ -1493,6 +1498,11 @@ class TrainingArguments:
},
)
use_liger_kernel: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
)
eval_use_gather_object: Optional[bool] = field(
default=False,
metadata={

View File

@ -148,6 +148,7 @@ from .import_utils import (
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_liger_kernel_available,
is_lomo_available,
is_mlx_available,
is_natten_available,

View File

@ -177,6 +177,7 @@ _torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
_hqq_available = _is_package_available("hqq")
_liger_kernel_available = _is_package_available("liger_kernel")
_torch_version = "N/A"
@ -1164,6 +1165,13 @@ def is_mlx_available():
return _mlx_available
def is_liger_kernel_available():
if not _liger_kernel_available:
return False
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.1.0")
# docstyle-ignore
AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:

View File

@ -64,6 +64,7 @@ from transformers.testing_utils import (
require_galore_torch,
require_grokadamw,
require_intel_extension_for_pytorch,
require_liger_kernel,
require_lomo,
require_optuna,
require_peft,
@ -1325,6 +1326,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(first_dataloader, first_dataloader_repeated)
self.assertEqual(second_dataloader, second_dataloader_repeated)
@require_liger_kernel
def test_use_liger_kernel_patching(self):
# Test that the model code actually gets patched with Liger kernel
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from transformers.models.llama import modeling_llama
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
args = TrainingArguments(
"./test",
use_liger_kernel=True,
)
Trainer(tiny_llama, args)
# Check that one of the Llama model layers has been correctly patched with Liger kernel
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
@require_liger_kernel
@require_torch_gpu
def test_use_liger_kernel_trainer(self):
# Check that trainer still works with liger kernel applied
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, max_steps=20, use_liger_kernel=True)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_lomo
@require_torch_gpu
def test_lomo(self):