feat: add flexible Liger Kernel configuration to TrainingArguments (#38911)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* feat: add flexible Liger Kernel configuration to TrainingArguments

Add support for granular Liger Kernel configuration through a new
`liger_kernel_config` parameter in TrainingArguments. This allows users
to selectively enable/disable specific kernels (rope, swiglu, cross_entropy,
etc.) instead of the current approach that rely on default configuration.

Features:
- Add `liger_kernel_config` dict parameter to TrainingArguments
- Support selective kernel application for all supported models
- Maintain full backward compatibility with existing `use_liger_kernel` flag

Example usage:
```python
TrainingArguments(
    use_liger_kernel=True,
    liger_kernel_config={
        "rope": True,
        "swiglu": True,
        "cross_entropy": False,
        "fused_linear_cross_entropy": True
    }
)
Closes #38905

* Address comments and update Liger section in Trainer docs
This commit is contained in:
Hamza Benchekroun 2025-06-19 17:54:08 +02:00 committed by GitHub
parent 89b35be618
commit 797860c68c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 4 deletions

View File

@ -493,6 +493,33 @@ training_args = TrainingArguments(
) )
``` ```
You can also configure which specific kernels to apply using the `liger_kernel_config` parameter. This dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, allowing fine-grained control over kernel usage. Available options vary by model but typically include: `rope`, `swiglu`, `cross_entropy`, `fused_linear_cross_entropy`, `rms_norm`, etc.
```py
from transformers import TrainingArguments
# Apply only specific kernels
training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
use_liger_kernel=True,
liger_kernel_config={
"rope": True,
"cross_entropy": True,
"rms_norm": False, # Don't apply Liger's RMSNorm kernel
"swiglu": True,
}
)
```
### NEFTune ### NEFTune
[NEFTune](https://hf.co/papers/2310.05914) adds noise to the embedding vectors during training to improve model performance. Enable it in [`Trainer`] with the `neftune_noise_alpha` parameter in [`TrainingArguments`] to control how much noise is added. [NEFTune](https://hf.co/papers/2310.05914) adds noise to the embedding vectors during training to improve model performance. Enable it in [`Trainer`] with the `neftune_noise_alpha` parameter in [`TrainingArguments`] to control how much noise is added.

View File

@ -526,12 +526,15 @@ class Trainer:
if is_liger_kernel_available(): if is_liger_kernel_available():
from liger_kernel.transformers import _apply_liger_kernel_to_instance from liger_kernel.transformers import _apply_liger_kernel_to_instance
# Prepare kernel config - use provided config or default (empty dict for default behavior)
kernel_config = self.args.liger_kernel_config if self.args.liger_kernel_config is not None else {}
if isinstance(model, PreTrainedModel): if isinstance(model, PreTrainedModel):
# Patch the model with liger kernels. Use the default kernel configurations. # Patch the model with liger kernels. Use the the specified or default kernel configurations.
_apply_liger_kernel_to_instance(model=model) _apply_liger_kernel_to_instance(model=model, **kernel_config)
elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel): elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel):
# Patch the base model with liger kernels where model is a PeftModel. Use the default kernel configurations. # Patch the base model with liger kernels where model is a PeftModel. Use the specified or default kernel configurations.
_apply_liger_kernel_to_instance(model=model.get_base_model()) _apply_liger_kernel_to_instance(model=model.get_base_model(), **kernel_config)
else: else:
logger.warning( logger.warning(
"The model is not an instance of PreTrainedModel. No liger kernels will be applied." "The model is not an instance of PreTrainedModel. No liger kernels will be applied."

View File

@ -793,6 +793,11 @@ class TrainingArguments:
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models. flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
liger_kernel_config (`Optional[dict]`, *optional*):
Configuration to be used for Liger Kernel. When use_liger_kernel=True, this dict is passed as keyword arguments to the
`_apply_liger_kernel_to_instance` function, which specifies which kernels to apply. Available options vary by model but typically
include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', 'rms_norm', etc. If `None`, use the default kernel configurations.
average_tokens_across_devices (`bool`, *optional*, defaults to `False`): average_tokens_across_devices (`bool`, *optional*, defaults to `False`):
Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize
num_tokens_in_batch for precise loss calculation. Reference: num_tokens_in_batch for precise loss calculation. Reference:
@ -1525,6 +1530,19 @@ class TrainingArguments:
metadata={"help": "Whether or not to enable the Liger Kernel for model training."}, metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
) )
liger_kernel_config: Optional[dict[str, bool]] = field(
default=None,
metadata={
"help": (
"Configuration to be used for Liger Kernel. When use_liger_kernel=True, "
"this dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, "
"which specifies which kernels to apply. Available options vary by model "
"but typically include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', "
"'rms_norm', etc. If None, use the default kernel configurations."
)
},
)
eval_use_gather_object: Optional[bool] = field( eval_use_gather_object: Optional[bool] = field(
default=False, default=False,
metadata={ metadata={

View File

@ -1792,6 +1792,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb) self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm)) self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))
@require_liger_kernel
def test_use_liger_kernel_custom_config_patching(self):
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.llama.modeling_llama"):
from liger_kernel.transformers import LigerRMSNorm
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
use_liger_kernel=True,
liger_kernel_config={"rms_norm": False}, # Don't apply Liger's RMSNorm
)
Trainer(tiny_llama, args)
# Check that the RMSNorm kernel is not applied as specified in the config
self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm))
@require_liger_kernel @require_liger_kernel
@require_torch_accelerator @require_torch_accelerator
def test_use_liger_kernel_trainer(self): def test_use_liger_kernel_trainer(self):
@ -1810,6 +1829,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Check this works # Check this works
_ = trainer.train() _ = trainer.train()
@require_liger_kernel
@require_torch_accelerator
def test_use_liger_kernel_custom_config_trainer(self):
# Check that trainer still works with liger kernel applied when using a custom config
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-2,
logging_steps=5,
max_steps=20,
use_liger_kernel=True,
liger_kernel_config={"rms_norm": False, "cross_entropy": True, "fused_linear_cross_entropy": False},
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_lomo @require_lomo
@require_torch_accelerator @require_torch_accelerator
def test_lomo(self): def test_lomo(self):