mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #36 from huggingface/sparse-llama4-moe
Add support for sparse `Llama4TextMoe` layer from the kernel hub
This commit is contained in:
commit
ccda19f050
@ -26,6 +26,13 @@ try:
|
||||
_hub_kernels_available = True
|
||||
|
||||
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
|
||||
"Llama4TextMoe": {
|
||||
"cuda": LayerRepository(
|
||||
# Move to kernels-community/moe once we release.
|
||||
repo_id="kernels-community/moe-new-models",
|
||||
layer_name="Llama4TextMoe",
|
||||
)
|
||||
},
|
||||
"MultiScaleDeformableAttention": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/deformable-detr",
|
||||
|
@ -33,6 +33,7 @@ from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.hub_kernels import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
@ -150,6 +151,7 @@ class Llama4TextRMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("Llama4TextMoe")
|
||||
class Llama4TextMoe(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user