mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add kernelize to transformers (#38205)
* fix * fix * fix flow * remove non compiling path * change * style * fix * update * update pin * revert
This commit is contained in:
parent
be10d4df60
commit
08bf7f1afe
2
setup.py
2
setup.py
@ -128,7 +128,7 @@ _deps = [
|
||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||
"keras>2.9,<2.16",
|
||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||
"kernels>=0.4.4,<0.5",
|
||||
"kernels>=0.6.1,<0.7",
|
||||
"librosa",
|
||||
"natten>=0.14.6,<0.15.0",
|
||||
"nltk<=3.8.1",
|
||||
|
@ -34,7 +34,7 @@ deps = {
|
||||
"kenlm": "kenlm",
|
||||
"keras": "keras>2.9,<2.16",
|
||||
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
||||
"kernels": "kernels>=0.4.4,<0.5",
|
||||
"kernels": "kernels>=0.6.1,<0.7",
|
||||
"librosa": "librosa",
|
||||
"natten": "natten>=0.14.6,<0.15.0",
|
||||
"nltk": "nltk<=3.8.1",
|
||||
|
@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
from typing import Union
|
||||
|
||||
from ..utils import is_torchdynamo_compiling
|
||||
|
||||
|
||||
try:
|
||||
from kernels import (
|
||||
@ -22,9 +20,7 @@ try:
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
)
|
||||
from kernels import (
|
||||
use_kernel_forward_from_hub as original_use_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
|
||||
_hub_kernels_available = True
|
||||
@ -45,9 +41,9 @@ try:
|
||||
},
|
||||
"RMSNorm": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/triton-layer-norm",
|
||||
layer_name="LlamaRMSNorm",
|
||||
revision="pure-layer-test",
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
layer_name="LigerRMSNorm",
|
||||
# revision="pure-layer-test",
|
||||
)
|
||||
},
|
||||
"MLP": {
|
||||
@ -60,39 +56,6 @@ try:
|
||||
|
||||
register_kernel_mapping(_KERNEL_MAPPING)
|
||||
|
||||
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||
"""
|
||||
Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed
|
||||
when `kernels` supports `torch.compile`.
|
||||
|
||||
If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the
|
||||
kernel.
|
||||
"""
|
||||
|
||||
def decorator_with_compile_path(cls):
|
||||
# Keeps a reference to the original forward method
|
||||
original_forward = cls.forward
|
||||
|
||||
# Applies the original decorator
|
||||
decorator = original_use_kernel_forward_from_hub(*args, **kwargs)
|
||||
cls = decorator(cls)
|
||||
|
||||
# Replaces the kernel forward with a compile-friendly version
|
||||
kernel_forward = cls.forward
|
||||
|
||||
def forward_with_compile_path(*forward_args, **forward_kwargs):
|
||||
disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None)
|
||||
if is_torchdynamo_compiling() or disable_custom_kernels:
|
||||
return original_forward(*forward_args, **forward_kwargs)
|
||||
else:
|
||||
return kernel_forward(*forward_args, **forward_kwargs)
|
||||
|
||||
cls.forward = forward_with_compile_path
|
||||
|
||||
return cls
|
||||
|
||||
return decorator_with_compile_path
|
||||
|
||||
|
||||
except ImportError:
|
||||
# Stub to make decorators int transformers work when `kernels`
|
||||
|
@ -4281,6 +4281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
device_mesh = kwargs.pop("device_mesh", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
use_kernels = kwargs.pop("use_kernels", False)
|
||||
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
@ -4733,6 +4734,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
# check if using kernels
|
||||
if use_kernels:
|
||||
from kernels import Device, kernelize
|
||||
|
||||
kernelize(model, device=Device(type=model.device.type))
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
||||
# custom generate function)
|
||||
if model.can_generate() and generation_config is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user