mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Remove triton mlp kernel, not compiling for some models (#37449)
* remove mlp for now * disable on docker
This commit is contained in:
parent
f797e3d98a
commit
3c39c07939
@ -14,6 +14,8 @@ ARG PYTORCH='2.6.0'
|
||||
ARG INTEL_TORCH_EXT='2.3.0'
|
||||
# Example: `cu102`, `cu113`, etc.
|
||||
ARG CUDA='cu121'
|
||||
# Disable kernel mapping for now until all tests pass
|
||||
ENV DISABLE_KERNEL_MAPPING=1
|
||||
|
||||
RUN apt update
|
||||
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs
|
||||
|
@ -228,7 +228,6 @@ class AriaProjector(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class AriaSharedExpertsMLP(nn.Module):
|
||||
"""
|
||||
Shared Expert MLP for shared experts.
|
||||
|
@ -882,7 +882,6 @@ class BambaMixer(nn.Module):
|
||||
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class BambaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -36,7 +36,6 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
@ -118,7 +117,6 @@ class CohereRotaryEmbedding(nn.Module):
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class CohereMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -28,7 +28,6 @@ import torch.nn as nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
@ -268,7 +267,6 @@ class Cohere2Attention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class Cohere2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -74,7 +74,6 @@ _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
|
||||
_CONFIG_FOR_DOC = "DiffLlamaConfig"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class DiffLlamaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -84,7 +84,6 @@ class Emu3RMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class Emu3MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -27,7 +27,6 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
@ -85,7 +84,6 @@ class GemmaRMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -28,7 +28,6 @@ import torch.nn as nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -78,7 +77,6 @@ class Gemma2RMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class Gemma2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -31,7 +31,6 @@ import torch.nn as nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
@ -107,7 +106,6 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
||||
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class Gemma3MLP(nn.Module):
|
||||
def __init__(self, config: Gemma3TextConfig):
|
||||
super().__init__()
|
||||
|
@ -228,7 +228,6 @@ class GraniteRMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class GraniteMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -29,7 +29,6 @@ import torch.nn as nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
@ -118,7 +117,6 @@ class HeliumRotaryEmbedding(nn.Module):
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class HeliumMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -160,7 +160,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -45,7 +45,6 @@ _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
|
||||
_CONFIG_FOR_DOC = "MistralConfig"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -14,7 +14,6 @@ import torch.nn.functional as F
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
@ -58,7 +57,6 @@ class OlmoLayerNorm(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class OlmoMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -218,7 +218,6 @@ class Olmo2Attention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class Olmo2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -45,7 +45,6 @@ _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
|
||||
_CONFIG_FOR_DOC = "Qwen2Config"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class Qwen2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -81,7 +81,6 @@ class Qwen3RMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class Qwen3MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user