mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer (#32860)
* add liger integration * fix syntax * fix import issue * add trainer.md * Use _apply_liger_kernel() * Fixed log message * Update docs/source/en/trainer.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update docs/source/en/trainer.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> * Update docs/source/en/trainer.md Co-authored-by: Byron Hsu <byronhsu1230@gmail.com> * Fixed checkstyle and updated readme * Added test * Fixed checkstyle * fix docstring * rename use_liger to use_liger_kernel * Trigger Build * Added test * add fix-copies * Fixed copy inconsistencies --------- Co-authored-by: shimizust <sshimizu@linkedin.com> Co-authored-by: Steven Shimizu <shimizust@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
parent
970a16ec7f
commit
adb91179b9
@ -382,6 +382,41 @@ trainer.train()
|
||||
|
||||
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
|
||||
|
||||
## Liger Kernel
|
||||
|
||||
[Liger-Kernel](https://github.com/linkedin/Liger-Kernel) Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.
|
||||
|
||||
<Tip>
|
||||
Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)
|
||||
</Tip>
|
||||
|
||||
First make sure to install Liger official repository:
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:
|
||||
|
||||
```py
|
||||
from transformers import TrainingArguments
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="your-model",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=16,
|
||||
num_train_epochs=2,
|
||||
weight_decay=0.01,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
push_to_hub=True,
|
||||
use_liger_kernel=True
|
||||
)
|
||||
```
|
||||
|
||||
The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.
|
||||
|
||||
## LOMO optimizer
|
||||
|
||||
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
|
||||
|
@ -84,6 +84,7 @@ from .utils import (
|
||||
is_keras_nlp_available,
|
||||
is_levenshtein_available,
|
||||
is_librosa_available,
|
||||
is_liger_kernel_available,
|
||||
is_lomo_available,
|
||||
is_natten_available,
|
||||
is_nltk_available,
|
||||
@ -1162,6 +1163,13 @@ def require_librosa(test_case):
|
||||
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
|
||||
|
||||
|
||||
def require_liger_kernel(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires liger_kernel
|
||||
"""
|
||||
return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
|
||||
|
||||
|
||||
def require_essentia(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires essentia
|
||||
|
@ -155,6 +155,7 @@ from .utils import (
|
||||
is_grokadamw_available,
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_liger_kernel_available,
|
||||
is_lomo_available,
|
||||
is_peft_available,
|
||||
is_safetensors_available,
|
||||
@ -464,6 +465,24 @@ class Trainer:
|
||||
" to `True` to avoid any unexpected behavior such as device placement mismatching."
|
||||
)
|
||||
|
||||
if self.args.use_liger_kernel:
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.transformers.trainer_integration import _apply_liger_kernel
|
||||
|
||||
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
||||
if model_type:
|
||||
# Monkey patch the model with liger kernels. Use the default kernel configurations.
|
||||
_apply_liger_kernel(model_type=model_type)
|
||||
else:
|
||||
logger.warning(
|
||||
"The model does not have a valid `model_type` specified. No liger kernels will be applied."
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. "
|
||||
"Please install it with `pip install liger-kernel`"
|
||||
)
|
||||
|
||||
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
|
||||
model, "_hf_peft_config_loaded", False
|
||||
)
|
||||
|
@ -793,6 +793,11 @@ class TrainingArguments:
|
||||
|
||||
eval_use_gather_object (`bool`, *optional*, defaults to `False`):
|
||||
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch.
|
||||
|
||||
use_liger_kernel (`bool`, *optional*, defaults to `False`):
|
||||
Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training.
|
||||
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
|
||||
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
|
||||
"""
|
||||
|
||||
framework = "pt"
|
||||
@ -1493,6 +1498,11 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
use_liger_kernel: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
|
||||
)
|
||||
|
||||
eval_use_gather_object: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
|
@ -148,6 +148,7 @@ from .import_utils import (
|
||||
is_keras_nlp_available,
|
||||
is_levenshtein_available,
|
||||
is_librosa_available,
|
||||
is_liger_kernel_available,
|
||||
is_lomo_available,
|
||||
is_mlx_available,
|
||||
is_natten_available,
|
||||
|
@ -177,6 +177,7 @@ _torchdistx_available = _is_package_available("torchdistx")
|
||||
_torchvision_available = _is_package_available("torchvision")
|
||||
_mlx_available = _is_package_available("mlx")
|
||||
_hqq_available = _is_package_available("hqq")
|
||||
_liger_kernel_available = _is_package_available("liger_kernel")
|
||||
|
||||
|
||||
_torch_version = "N/A"
|
||||
@ -1164,6 +1165,13 @@ def is_mlx_available():
|
||||
return _mlx_available
|
||||
|
||||
|
||||
def is_liger_kernel_available():
|
||||
if not _liger_kernel_available:
|
||||
return False
|
||||
|
||||
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.1.0")
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
AV_IMPORT_ERROR = """
|
||||
{0} requires the PyAv library but it was not found in your environment. You can install it with:
|
||||
|
@ -64,6 +64,7 @@ from transformers.testing_utils import (
|
||||
require_galore_torch,
|
||||
require_grokadamw,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_liger_kernel,
|
||||
require_lomo,
|
||||
require_optuna,
|
||||
require_peft,
|
||||
@ -1325,6 +1326,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(first_dataloader, first_dataloader_repeated)
|
||||
self.assertEqual(second_dataloader, second_dataloader_repeated)
|
||||
|
||||
@require_liger_kernel
|
||||
def test_use_liger_kernel_patching(self):
|
||||
# Test that the model code actually gets patched with Liger kernel
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
args = TrainingArguments(
|
||||
"./test",
|
||||
use_liger_kernel=True,
|
||||
)
|
||||
Trainer(tiny_llama, args)
|
||||
|
||||
# Check that one of the Llama model layers has been correctly patched with Liger kernel
|
||||
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
|
||||
|
||||
@require_liger_kernel
|
||||
@require_torch_gpu
|
||||
def test_use_liger_kernel_trainer(self):
|
||||
# Check that trainer still works with liger kernel applied
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, max_steps=20, use_liger_kernel=True)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_lomo
|
||||
@require_torch_gpu
|
||||
def test_lomo(self):
|
||||
|
Loading…
Reference in New Issue
Block a user