Add kernelize to transformers (#38205)

* fix

* fix

* fix flow

* remove non compiling path

* change

* style

* fix

* update

* update pin

* revert
This commit is contained in:
Mohamed Mekkouri 2025-06-24 17:38:54 +02:00 committed by GitHub
parent be10d4df60
commit 08bf7f1afe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 13 additions and 43 deletions

View File

@ -128,7 +128,7 @@ _deps = [
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
"keras>2.9,<2.16",
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
"kernels>=0.4.4,<0.5",
"kernels>=0.6.1,<0.7",
"librosa",
"natten>=0.14.6,<0.15.0",
"nltk<=3.8.1",

View File

@ -34,7 +34,7 @@ deps = {
"kenlm": "kenlm",
"keras": "keras>2.9,<2.16",
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
"kernels": "kernels>=0.4.4,<0.5",
"kernels": "kernels>=0.6.1,<0.7",
"librosa": "librosa",
"natten": "natten>=0.14.6,<0.15.0",
"nltk": "nltk<=3.8.1",

View File

@ -13,8 +13,6 @@
# limitations under the License.
from typing import Union
from ..utils import is_torchdynamo_compiling
try:
from kernels import (
@ -22,9 +20,7 @@ try:
LayerRepository,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from kernels import (
use_kernel_forward_from_hub as original_use_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)
_hub_kernels_available = True
@ -45,9 +41,9 @@ try:
},
"RMSNorm": {
"cuda": LayerRepository(
repo_id="kernels-community/triton-layer-norm",
layer_name="LlamaRMSNorm",
revision="pure-layer-test",
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
# revision="pure-layer-test",
)
},
"MLP": {
@ -60,39 +56,6 @@ try:
register_kernel_mapping(_KERNEL_MAPPING)
def use_kernel_forward_from_hub(*args, **kwargs):
"""
Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed
when `kernels` supports `torch.compile`.
If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the
kernel.
"""
def decorator_with_compile_path(cls):
# Keeps a reference to the original forward method
original_forward = cls.forward
# Applies the original decorator
decorator = original_use_kernel_forward_from_hub(*args, **kwargs)
cls = decorator(cls)
# Replaces the kernel forward with a compile-friendly version
kernel_forward = cls.forward
def forward_with_compile_path(*forward_args, **forward_kwargs):
disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None)
if is_torchdynamo_compiling() or disable_custom_kernels:
return original_forward(*forward_args, **forward_kwargs)
else:
return kernel_forward(*forward_args, **forward_kwargs)
cls.forward = forward_with_compile_path
return cls
return decorator_with_compile_path
except ImportError:
# Stub to make decorators int transformers work when `kernels`

View File

@ -4281,6 +4281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
tp_size = kwargs.pop("tp_size", None)
device_mesh = kwargs.pop("device_mesh", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
use_kernels = kwargs.pop("use_kernels", False)
key_mapping = kwargs.pop("key_mapping", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
@ -4733,6 +4734,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# check if using kernels
if use_kernels:
from kernels import Device, kernelize
kernelize(model, device=Device(type=model.device.type))
# If it is a model with generation capabilities, attempt to load generation files (generation config,
# custom generate function)
if model.can_generate() and generation_config is not None: