mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 10:38:23 +06:00
feat: add flexible Liger Kernel configuration to TrainingArguments (#38911)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* feat: add flexible Liger Kernel configuration to TrainingArguments Add support for granular Liger Kernel configuration through a new `liger_kernel_config` parameter in TrainingArguments. This allows users to selectively enable/disable specific kernels (rope, swiglu, cross_entropy, etc.) instead of the current approach that rely on default configuration. Features: - Add `liger_kernel_config` dict parameter to TrainingArguments - Support selective kernel application for all supported models - Maintain full backward compatibility with existing `use_liger_kernel` flag Example usage: ```python TrainingArguments( use_liger_kernel=True, liger_kernel_config={ "rope": True, "swiglu": True, "cross_entropy": False, "fused_linear_cross_entropy": True } ) Closes #38905 * Address comments and update Liger section in Trainer docs
This commit is contained in:
parent
89b35be618
commit
797860c68c
@ -493,6 +493,33 @@ training_args = TrainingArguments(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also configure which specific kernels to apply using the `liger_kernel_config` parameter. This dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, allowing fine-grained control over kernel usage. Available options vary by model but typically include: `rope`, `swiglu`, `cross_entropy`, `fused_linear_cross_entropy`, `rms_norm`, etc.
|
||||||
|
|
||||||
|
```py
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
|
# Apply only specific kernels
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir="your-model",
|
||||||
|
learning_rate=2e-5,
|
||||||
|
per_device_train_batch_size=16,
|
||||||
|
per_device_eval_batch_size=16,
|
||||||
|
num_train_epochs=2,
|
||||||
|
weight_decay=0.01,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="epoch",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
push_to_hub=True,
|
||||||
|
use_liger_kernel=True,
|
||||||
|
liger_kernel_config={
|
||||||
|
"rope": True,
|
||||||
|
"cross_entropy": True,
|
||||||
|
"rms_norm": False, # Don't apply Liger's RMSNorm kernel
|
||||||
|
"swiglu": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### NEFTune
|
### NEFTune
|
||||||
|
|
||||||
[NEFTune](https://hf.co/papers/2310.05914) adds noise to the embedding vectors during training to improve model performance. Enable it in [`Trainer`] with the `neftune_noise_alpha` parameter in [`TrainingArguments`] to control how much noise is added.
|
[NEFTune](https://hf.co/papers/2310.05914) adds noise to the embedding vectors during training to improve model performance. Enable it in [`Trainer`] with the `neftune_noise_alpha` parameter in [`TrainingArguments`] to control how much noise is added.
|
||||||
|
@ -526,12 +526,15 @@ class Trainer:
|
|||||||
if is_liger_kernel_available():
|
if is_liger_kernel_available():
|
||||||
from liger_kernel.transformers import _apply_liger_kernel_to_instance
|
from liger_kernel.transformers import _apply_liger_kernel_to_instance
|
||||||
|
|
||||||
|
# Prepare kernel config - use provided config or default (empty dict for default behavior)
|
||||||
|
kernel_config = self.args.liger_kernel_config if self.args.liger_kernel_config is not None else {}
|
||||||
|
|
||||||
if isinstance(model, PreTrainedModel):
|
if isinstance(model, PreTrainedModel):
|
||||||
# Patch the model with liger kernels. Use the default kernel configurations.
|
# Patch the model with liger kernels. Use the the specified or default kernel configurations.
|
||||||
_apply_liger_kernel_to_instance(model=model)
|
_apply_liger_kernel_to_instance(model=model, **kernel_config)
|
||||||
elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel):
|
elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel):
|
||||||
# Patch the base model with liger kernels where model is a PeftModel. Use the default kernel configurations.
|
# Patch the base model with liger kernels where model is a PeftModel. Use the specified or default kernel configurations.
|
||||||
_apply_liger_kernel_to_instance(model=model.get_base_model())
|
_apply_liger_kernel_to_instance(model=model.get_base_model(), **kernel_config)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The model is not an instance of PreTrainedModel. No liger kernels will be applied."
|
"The model is not an instance of PreTrainedModel. No liger kernels will be applied."
|
||||||
|
@ -793,6 +793,11 @@ class TrainingArguments:
|
|||||||
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
|
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
|
||||||
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
|
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
|
||||||
|
|
||||||
|
liger_kernel_config (`Optional[dict]`, *optional*):
|
||||||
|
Configuration to be used for Liger Kernel. When use_liger_kernel=True, this dict is passed as keyword arguments to the
|
||||||
|
`_apply_liger_kernel_to_instance` function, which specifies which kernels to apply. Available options vary by model but typically
|
||||||
|
include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', 'rms_norm', etc. If `None`, use the default kernel configurations.
|
||||||
|
|
||||||
average_tokens_across_devices (`bool`, *optional*, defaults to `False`):
|
average_tokens_across_devices (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize
|
Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize
|
||||||
num_tokens_in_batch for precise loss calculation. Reference:
|
num_tokens_in_batch for precise loss calculation. Reference:
|
||||||
@ -1525,6 +1530,19 @@ class TrainingArguments:
|
|||||||
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
|
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
liger_kernel_config: Optional[dict[str, bool]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Configuration to be used for Liger Kernel. When use_liger_kernel=True, "
|
||||||
|
"this dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, "
|
||||||
|
"which specifies which kernels to apply. Available options vary by model "
|
||||||
|
"but typically include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', "
|
||||||
|
"'rms_norm', etc. If None, use the default kernel configurations."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
eval_use_gather_object: Optional[bool] = field(
|
eval_use_gather_object: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -1792,6 +1792,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
|
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
|
||||||
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
||||||
|
|
||||||
|
@require_liger_kernel
|
||||||
|
def test_use_liger_kernel_custom_config_patching(self):
|
||||||
|
# Ensure any monkey patching is cleaned up for subsequent tests
|
||||||
|
with patch("transformers.models.llama.modeling_llama"):
|
||||||
|
from liger_kernel.transformers import LigerRMSNorm
|
||||||
|
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
self.get_auto_remove_tmp_dir(),
|
||||||
|
use_liger_kernel=True,
|
||||||
|
liger_kernel_config={"rms_norm": False}, # Don't apply Liger's RMSNorm
|
||||||
|
)
|
||||||
|
Trainer(tiny_llama, args)
|
||||||
|
|
||||||
|
# Check that the RMSNorm kernel is not applied as specified in the config
|
||||||
|
self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
||||||
|
|
||||||
@require_liger_kernel
|
@require_liger_kernel
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
def test_use_liger_kernel_trainer(self):
|
def test_use_liger_kernel_trainer(self):
|
||||||
@ -1810,6 +1829,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# Check this works
|
# Check this works
|
||||||
_ = trainer.train()
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_liger_kernel
|
||||||
|
@require_torch_accelerator
|
||||||
|
def test_use_liger_kernel_custom_config_trainer(self):
|
||||||
|
# Check that trainer still works with liger kernel applied when using a custom config
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
self.get_auto_remove_tmp_dir(),
|
||||||
|
learning_rate=1e-2,
|
||||||
|
logging_steps=5,
|
||||||
|
max_steps=20,
|
||||||
|
use_liger_kernel=True,
|
||||||
|
liger_kernel_config={"rms_norm": False, "cross_entropy": True, "fused_linear_cross_entropy": False},
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
@require_lomo
|
@require_lomo
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
def test_lomo(self):
|
def test_lomo(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user