mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add kernelize to transformers (#38205)
* fix * fix * fix flow * remove non compiling path * change * style * fix * update * update pin * revert
This commit is contained in:
parent
be10d4df60
commit
08bf7f1afe
2
setup.py
2
setup.py
@ -128,7 +128,7 @@ _deps = [
|
|||||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||||
"keras>2.9,<2.16",
|
"keras>2.9,<2.16",
|
||||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||||
"kernels>=0.4.4,<0.5",
|
"kernels>=0.6.1,<0.7",
|
||||||
"librosa",
|
"librosa",
|
||||||
"natten>=0.14.6,<0.15.0",
|
"natten>=0.14.6,<0.15.0",
|
||||||
"nltk<=3.8.1",
|
"nltk<=3.8.1",
|
||||||
|
@ -34,7 +34,7 @@ deps = {
|
|||||||
"kenlm": "kenlm",
|
"kenlm": "kenlm",
|
||||||
"keras": "keras>2.9,<2.16",
|
"keras": "keras>2.9,<2.16",
|
||||||
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
||||||
"kernels": "kernels>=0.4.4,<0.5",
|
"kernels": "kernels>=0.6.1,<0.7",
|
||||||
"librosa": "librosa",
|
"librosa": "librosa",
|
||||||
"natten": "natten>=0.14.6,<0.15.0",
|
"natten": "natten>=0.14.6,<0.15.0",
|
||||||
"nltk": "nltk<=3.8.1",
|
"nltk": "nltk<=3.8.1",
|
||||||
|
@ -13,8 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from ..utils import is_torchdynamo_compiling
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from kernels import (
|
from kernels import (
|
||||||
@ -22,9 +20,7 @@ try:
|
|||||||
LayerRepository,
|
LayerRepository,
|
||||||
register_kernel_mapping,
|
register_kernel_mapping,
|
||||||
replace_kernel_forward_from_hub,
|
replace_kernel_forward_from_hub,
|
||||||
)
|
use_kernel_forward_from_hub,
|
||||||
from kernels import (
|
|
||||||
use_kernel_forward_from_hub as original_use_kernel_forward_from_hub,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_hub_kernels_available = True
|
_hub_kernels_available = True
|
||||||
@ -45,9 +41,9 @@ try:
|
|||||||
},
|
},
|
||||||
"RMSNorm": {
|
"RMSNorm": {
|
||||||
"cuda": LayerRepository(
|
"cuda": LayerRepository(
|
||||||
repo_id="kernels-community/triton-layer-norm",
|
repo_id="kernels-community/liger_kernels",
|
||||||
layer_name="LlamaRMSNorm",
|
layer_name="LigerRMSNorm",
|
||||||
revision="pure-layer-test",
|
# revision="pure-layer-test",
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
"MLP": {
|
"MLP": {
|
||||||
@ -60,39 +56,6 @@ try:
|
|||||||
|
|
||||||
register_kernel_mapping(_KERNEL_MAPPING)
|
register_kernel_mapping(_KERNEL_MAPPING)
|
||||||
|
|
||||||
def use_kernel_forward_from_hub(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed
|
|
||||||
when `kernels` supports `torch.compile`.
|
|
||||||
|
|
||||||
If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the
|
|
||||||
kernel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator_with_compile_path(cls):
|
|
||||||
# Keeps a reference to the original forward method
|
|
||||||
original_forward = cls.forward
|
|
||||||
|
|
||||||
# Applies the original decorator
|
|
||||||
decorator = original_use_kernel_forward_from_hub(*args, **kwargs)
|
|
||||||
cls = decorator(cls)
|
|
||||||
|
|
||||||
# Replaces the kernel forward with a compile-friendly version
|
|
||||||
kernel_forward = cls.forward
|
|
||||||
|
|
||||||
def forward_with_compile_path(*forward_args, **forward_kwargs):
|
|
||||||
disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None)
|
|
||||||
if is_torchdynamo_compiling() or disable_custom_kernels:
|
|
||||||
return original_forward(*forward_args, **forward_kwargs)
|
|
||||||
else:
|
|
||||||
return kernel_forward(*forward_args, **forward_kwargs)
|
|
||||||
|
|
||||||
cls.forward = forward_with_compile_path
|
|
||||||
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return decorator_with_compile_path
|
|
||||||
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Stub to make decorators int transformers work when `kernels`
|
# Stub to make decorators int transformers work when `kernels`
|
||||||
|
@ -4281,6 +4281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
tp_size = kwargs.pop("tp_size", None)
|
tp_size = kwargs.pop("tp_size", None)
|
||||||
device_mesh = kwargs.pop("device_mesh", None)
|
device_mesh = kwargs.pop("device_mesh", None)
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
|
use_kernels = kwargs.pop("use_kernels", False)
|
||||||
|
|
||||||
key_mapping = kwargs.pop("key_mapping", None)
|
key_mapping = kwargs.pop("key_mapping", None)
|
||||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||||
@ -4733,6 +4734,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# check if using kernels
|
||||||
|
if use_kernels:
|
||||||
|
from kernels import Device, kernelize
|
||||||
|
|
||||||
|
kernelize(model, device=Device(type=model.device.type))
|
||||||
|
|
||||||
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
||||||
# custom generate function)
|
# custom generate function)
|
||||||
if model.can_generate() and generation_config is not None:
|
if model.can_generate() and generation_config is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user