Remove triton mlp kernel, not compiling for some models (#37449)

* remove mlp for now

* disable on docker
This commit is contained in:
Mohamed Mekkouri 2025-04-11 12:47:13 +02:00 committed by GitHub
parent f797e3d98a
commit 3c39c07939
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2 additions and 24 deletions

View File

@ -14,6 +14,8 @@ ARG PYTORCH='2.6.0'
ARG INTEL_TORCH_EXT='2.3.0'
# Example: `cu102`, `cu113`, etc.
ARG CUDA='cu121'
# Disable kernel mapping for now until all tests pass
ENV DISABLE_KERNEL_MAPPING=1
RUN apt update
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs

View File

@ -228,7 +228,6 @@ class AriaProjector(nn.Module):
return out
@use_kernel_forward_from_hub("MLP")
class AriaSharedExpertsMLP(nn.Module):
"""
Shared Expert MLP for shared experts.

View File

@ -882,7 +882,6 @@ class BambaMixer(nn.Module):
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
@use_kernel_forward_from_hub("MLP")
class BambaMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -36,7 +36,6 @@ from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@ -118,7 +117,6 @@ class CohereRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@use_kernel_forward_from_hub("MLP")
class CohereMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -28,7 +28,6 @@ import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@ -268,7 +267,6 @@ class Cohere2Attention(nn.Module):
return attn_output, attn_weights
@use_kernel_forward_from_hub("MLP")
class Cohere2MLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -74,7 +74,6 @@ _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
_CONFIG_FOR_DOC = "DiffLlamaConfig"
@use_kernel_forward_from_hub("MLP")
class DiffLlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -84,7 +84,6 @@ class Emu3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@use_kernel_forward_from_hub("MLP")
class Emu3MLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -27,7 +27,6 @@ from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
@ -85,7 +84,6 @@ class GemmaRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
@use_kernel_forward_from_hub("MLP")
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -28,7 +28,6 @@ import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
@ -78,7 +77,6 @@ class Gemma2RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
@use_kernel_forward_from_hub("MLP")
class Gemma2MLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -31,7 +31,6 @@ import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@ -107,7 +106,6 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
@use_kernel_forward_from_hub("MLP")
class Gemma3MLP(nn.Module):
def __init__(self, config: Gemma3TextConfig):
super().__init__()

View File

@ -228,7 +228,6 @@ class GraniteRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@use_kernel_forward_from_hub("MLP")
class GraniteMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -29,7 +29,6 @@ import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
@ -118,7 +117,6 @@ class HeliumRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@use_kernel_forward_from_hub("MLP")
class HeliumMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -160,7 +160,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
@use_kernel_forward_from_hub("MLP")
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -45,7 +45,6 @@ _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
_CONFIG_FOR_DOC = "MistralConfig"
@use_kernel_forward_from_hub("MLP")
class MistralMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -14,7 +14,6 @@ import torch.nn.functional as F
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@ -58,7 +57,6 @@ class OlmoLayerNorm(nn.Module):
)
@use_kernel_forward_from_hub("MLP")
class OlmoMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -218,7 +218,6 @@ class Olmo2Attention(nn.Module):
return attn_output, attn_weights
@use_kernel_forward_from_hub("MLP")
class Olmo2MLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -45,7 +45,6 @@ _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
_CONFIG_FOR_DOC = "Qwen2Config"
@use_kernel_forward_from_hub("MLP")
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -81,7 +81,6 @@ class Qwen3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@use_kernel_forward_from_hub("MLP")
class Qwen3MLP(nn.Module):
def __init__(self, config):
super().__init__()