diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 56f929884a5..3572fb4385d 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -493,6 +493,33 @@ training_args = TrainingArguments( ) ``` +You can also configure which specific kernels to apply using the `liger_kernel_config` parameter. This dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, allowing fine-grained control over kernel usage. Available options vary by model but typically include: `rope`, `swiglu`, `cross_entropy`, `fused_linear_cross_entropy`, `rms_norm`, etc. + +```py +from transformers import TrainingArguments + +# Apply only specific kernels +training_args = TrainingArguments( + output_dir="your-model", + learning_rate=2e-5, + per_device_train_batch_size=16, + per_device_eval_batch_size=16, + num_train_epochs=2, + weight_decay=0.01, + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + push_to_hub=True, + use_liger_kernel=True, + liger_kernel_config={ + "rope": True, + "cross_entropy": True, + "rms_norm": False, # Don't apply Liger's RMSNorm kernel + "swiglu": True, + } +) +``` + ### NEFTune [NEFTune](https://hf.co/papers/2310.05914) adds noise to the embedding vectors during training to improve model performance. Enable it in [`Trainer`] with the `neftune_noise_alpha` parameter in [`TrainingArguments`] to control how much noise is added. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a1f7902c91c..39448cd5ab0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -526,12 +526,15 @@ class Trainer: if is_liger_kernel_available(): from liger_kernel.transformers import _apply_liger_kernel_to_instance + # Prepare kernel config - use provided config or default (empty dict for default behavior) + kernel_config = self.args.liger_kernel_config if self.args.liger_kernel_config is not None else {} + if isinstance(model, PreTrainedModel): - # Patch the model with liger kernels. Use the default kernel configurations. - _apply_liger_kernel_to_instance(model=model) + # Patch the model with liger kernels. Use the the specified or default kernel configurations. + _apply_liger_kernel_to_instance(model=model, **kernel_config) elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel): - # Patch the base model with liger kernels where model is a PeftModel. Use the default kernel configurations. - _apply_liger_kernel_to_instance(model=model.get_base_model()) + # Patch the base model with liger kernels where model is a PeftModel. Use the specified or default kernel configurations. + _apply_liger_kernel_to_instance(model=model.get_base_model(), **kernel_config) else: logger.warning( "The model is not an instance of PreTrainedModel. No liger kernels will be applied." diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 679975309c4..ceb5a6b132a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -793,6 +793,11 @@ class TrainingArguments: It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models. + liger_kernel_config (`Optional[dict]`, *optional*): + Configuration to be used for Liger Kernel. When use_liger_kernel=True, this dict is passed as keyword arguments to the + `_apply_liger_kernel_to_instance` function, which specifies which kernels to apply. Available options vary by model but typically + include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', 'rms_norm', etc. If `None`, use the default kernel configurations. + average_tokens_across_devices (`bool`, *optional*, defaults to `False`): Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize num_tokens_in_batch for precise loss calculation. Reference: @@ -1525,6 +1530,19 @@ class TrainingArguments: metadata={"help": "Whether or not to enable the Liger Kernel for model training."}, ) + liger_kernel_config: Optional[dict[str, bool]] = field( + default=None, + metadata={ + "help": ( + "Configuration to be used for Liger Kernel. When use_liger_kernel=True, " + "this dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, " + "which specifies which kernels to apply. Available options vary by model " + "but typically include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', " + "'rms_norm', etc. If None, use the default kernel configurations." + ) + }, + ) + eval_use_gather_object: Optional[bool] = field( default=False, metadata={ diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a11ff9bcbc2..2594edcdef8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1792,6 +1792,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb) self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm)) + @require_liger_kernel + def test_use_liger_kernel_custom_config_patching(self): + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.llama.modeling_llama"): + from liger_kernel.transformers import LigerRMSNorm + + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + + args = TrainingArguments( + self.get_auto_remove_tmp_dir(), + use_liger_kernel=True, + liger_kernel_config={"rms_norm": False}, # Don't apply Liger's RMSNorm + ) + Trainer(tiny_llama, args) + + # Check that the RMSNorm kernel is not applied as specified in the config + self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm)) + @require_liger_kernel @require_torch_accelerator def test_use_liger_kernel_trainer(self): @@ -1810,6 +1829,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # Check this works _ = trainer.train() + @require_liger_kernel + @require_torch_accelerator + def test_use_liger_kernel_custom_config_trainer(self): + # Check that trainer still works with liger kernel applied when using a custom config + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + args = TrainingArguments( + self.get_auto_remove_tmp_dir(), + learning_rate=1e-2, + logging_steps=5, + max_steps=20, + use_liger_kernel=True, + liger_kernel_config={"rms_norm": False, "cross_entropy": True, "fused_linear_cross_entropy": False}, + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + @require_lomo @require_torch_accelerator def test_lomo(self):