Merge pull request #36 from huggingface/sparse-llama4-moe

Add support for sparse `Llama4TextMoe` layer from the kernel hub
This commit is contained in:
Arthur 2025-04-04 17:39:19 +02:00 committed by GitHub
commit ccda19f050
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 0 deletions

View File

@ -26,6 +26,13 @@ try:
_hub_kernels_available = True
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
"Llama4TextMoe": {
"cuda": LayerRepository(
# Move to kernels-community/moe once we release.
repo_id="kernels-community/moe-new-models",
layer_name="Llama4TextMoe",
)
},
"MultiScaleDeformableAttention": {
"cuda": LayerRepository(
repo_id="kernels-community/deformable-detr",

View File

@ -33,6 +33,7 @@ from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations.hub_kernels import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
@ -150,6 +151,7 @@ class Llama4TextRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
@use_kernel_forward_from_hub("Llama4TextMoe")
class Llama4TextMoe(nn.Module):
def __init__(self, config):
super().__init__()