feat: add flexible Liger Kernel configuration to TrainingArguments (#38911)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* feat: add flexible Liger Kernel configuration to TrainingArguments

Add support for granular Liger Kernel configuration through a new
`liger_kernel_config` parameter in TrainingArguments. This allows users
to selectively enable/disable specific kernels (rope, swiglu, cross_entropy,
etc.) instead of the current approach that rely on default configuration.

Features:
- Add `liger_kernel_config` dict parameter to TrainingArguments
- Support selective kernel application for all supported models
- Maintain full backward compatibility with existing `use_liger_kernel` flag

Example usage:
```python
TrainingArguments(
    use_liger_kernel=True,
    liger_kernel_config={
        "rope": True,
        "swiglu": True,
        "cross_entropy": False,
        "fused_linear_cross_entropy": True
    }
)
Closes #38905

* Address comments and update Liger section in Trainer docs
This commit is contained in:
Hamza Benchekroun 2025-06-19 17:54:08 +02:00 committed by GitHub
parent 89b35be618
commit 797860c68c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 4 deletions

View File

@ -493,6 +493,33 @@ training_args = TrainingArguments(
)
```
You can also configure which specific kernels to apply using the `liger_kernel_config` parameter. This dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, allowing fine-grained control over kernel usage. Available options vary by model but typically include: `rope`, `swiglu`, `cross_entropy`, `fused_linear_cross_entropy`, `rms_norm`, etc.
```py
from transformers import TrainingArguments
# Apply only specific kernels
training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
use_liger_kernel=True,
liger_kernel_config={
"rope": True,
"cross_entropy": True,
"rms_norm": False, # Don't apply Liger's RMSNorm kernel
"swiglu": True,
}
)
```
### NEFTune
[NEFTune](https://hf.co/papers/2310.05914) adds noise to the embedding vectors during training to improve model performance. Enable it in [`Trainer`] with the `neftune_noise_alpha` parameter in [`TrainingArguments`] to control how much noise is added.

View File

@ -526,12 +526,15 @@ class Trainer:
if is_liger_kernel_available():
from liger_kernel.transformers import _apply_liger_kernel_to_instance
# Prepare kernel config - use provided config or default (empty dict for default behavior)
kernel_config = self.args.liger_kernel_config if self.args.liger_kernel_config is not None else {}
if isinstance(model, PreTrainedModel):
# Patch the model with liger kernels. Use the default kernel configurations.
_apply_liger_kernel_to_instance(model=model)
# Patch the model with liger kernels. Use the the specified or default kernel configurations.
_apply_liger_kernel_to_instance(model=model, **kernel_config)
elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel):
# Patch the base model with liger kernels where model is a PeftModel. Use the default kernel configurations.
_apply_liger_kernel_to_instance(model=model.get_base_model())
# Patch the base model with liger kernels where model is a PeftModel. Use the specified or default kernel configurations.
_apply_liger_kernel_to_instance(model=model.get_base_model(), **kernel_config)
else:
logger.warning(
"The model is not an instance of PreTrainedModel. No liger kernels will be applied."

View File

@ -793,6 +793,11 @@ class TrainingArguments:
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
liger_kernel_config (`Optional[dict]`, *optional*):
Configuration to be used for Liger Kernel. When use_liger_kernel=True, this dict is passed as keyword arguments to the
`_apply_liger_kernel_to_instance` function, which specifies which kernels to apply. Available options vary by model but typically
include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', 'rms_norm', etc. If `None`, use the default kernel configurations.
average_tokens_across_devices (`bool`, *optional*, defaults to `False`):
Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize
num_tokens_in_batch for precise loss calculation. Reference:
@ -1525,6 +1530,19 @@ class TrainingArguments:
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
)
liger_kernel_config: Optional[dict[str, bool]] = field(
default=None,
metadata={
"help": (
"Configuration to be used for Liger Kernel. When use_liger_kernel=True, "
"this dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, "
"which specifies which kernels to apply. Available options vary by model "
"but typically include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', "
"'rms_norm', etc. If None, use the default kernel configurations."
)
},
)
eval_use_gather_object: Optional[bool] = field(
default=False,
metadata={

View File

@ -1792,6 +1792,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))
@require_liger_kernel
def test_use_liger_kernel_custom_config_patching(self):
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.llama.modeling_llama"):
from liger_kernel.transformers import LigerRMSNorm
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
use_liger_kernel=True,
liger_kernel_config={"rms_norm": False}, # Don't apply Liger's RMSNorm
)
Trainer(tiny_llama, args)
# Check that the RMSNorm kernel is not applied as specified in the config
self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm))
@require_liger_kernel
@require_torch_accelerator
def test_use_liger_kernel_trainer(self):
@ -1810,6 +1829,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Check this works
_ = trainer.train()
@require_liger_kernel
@require_torch_accelerator
def test_use_liger_kernel_custom_config_trainer(self):
# Check that trainer still works with liger kernel applied when using a custom config
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-2,
logging_steps=5,
max_steps=20,
use_liger_kernel=True,
liger_kernel_config={"rms_norm": False, "cross_entropy": True, "fused_linear_cross_entropy": False},
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_lomo
@require_torch_accelerator
def test_lomo(self):