Modular: support for importing functions from any file (#35692)

* fix function imports

* improve comment

* Update modeling_switch_function.py

* make checks more robust

* improvement

* rename

* final test update
This commit is contained in:
Cyril Vallez 2025-01-16 16:37:53 +00:00 committed by GitHub
parent 8ebe9d7166
commit 91be6a5eb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 305 additions and 43 deletions

View File

@ -0,0 +1,66 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from examples/modular-transformers/modular_add_function.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_add_function.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Note that zamba does not have the `apply_rotary_pos_emb` function!
from typing import Optional, Tuple
import torch
from torch import nn
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class TestAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
(see fig. 2 in https://arxiv.org/pdf/2405.16712).
Additionally, replaced
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
"""
def __init__(self):
pass
def forward(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
_ = apply_rotary_pos_emb(1, 1, 1, 1)

View File

@ -45,13 +45,8 @@ class DummyRMSNorm(nn.Module):
class DummyRotaryEmbedding(nn.Module): class DummyRotaryEmbedding(nn.Module):
def __init__( def __init__(self, config: DummyConfig, device=None):
self,
config: DummyConfig,
device=None,
):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@ -63,7 +58,7 @@ class DummyRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@ -75,13 +70,14 @@ class DummyRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len self.max_seq_len_cached = self.original_max_seq_len
@ -356,6 +352,7 @@ class DummyPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True _supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True

View File

@ -45,13 +45,8 @@ class Multimodal1TextRMSNorm(nn.Module):
class Multimodal1TextRotaryEmbedding(nn.Module): class Multimodal1TextRotaryEmbedding(nn.Module):
def __init__( def __init__(self, config: Multimodal1TextConfig, device=None):
self,
config: Multimodal1TextConfig,
device=None,
):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@ -63,7 +58,7 @@ class Multimodal1TextRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@ -75,13 +70,14 @@ class Multimodal1TextRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len self.max_seq_len_cached = self.original_max_seq_len
@ -356,6 +352,7 @@ class Multimodal1TextPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True _supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True

View File

@ -61,13 +61,8 @@ class MyNewModel2MLP(nn.Module):
class MyNewModel2RotaryEmbedding(nn.Module): class MyNewModel2RotaryEmbedding(nn.Module):
def __init__( def __init__(self, config: MyNewModel2Config, device=None):
self,
config: MyNewModel2Config,
device=None,
):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@ -79,7 +74,7 @@ class MyNewModel2RotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@ -91,13 +86,14 @@ class MyNewModel2RotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len self.max_seq_len_cached = self.original_max_seq_len
@ -356,6 +352,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True _supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True

View File

@ -107,7 +107,6 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True _supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True
_supports_cache_class = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True

View File

@ -45,13 +45,8 @@ class SuperRMSNorm(nn.Module):
class SuperRotaryEmbedding(nn.Module): class SuperRotaryEmbedding(nn.Module):
def __init__( def __init__(self, config: SuperConfig, device=None):
self,
config: SuperConfig,
device=None,
):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@ -63,7 +58,7 @@ class SuperRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@ -75,13 +70,14 @@ class SuperRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len self.max_seq_len_cached = self.original_max_seq_len
@ -356,6 +352,7 @@ class SuperPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True _supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True

View File

@ -0,0 +1,170 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from examples/modular-transformers/modular_switch_function.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_switch_function.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Note that llama and cohere have different definitions for rotate_half
from typing import Callable, Optional, Tuple
import torch
from torch import nn
from ...cache_utils import Cache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import logging
from .configuration_switch_function import SwitchFunctionConfig
logger = logging.get_logger(__name__)
def rotate_half(x):
# Split and rotate. Note that this function is different from e.g. Llama.
x1 = x[..., ::2]
x2 = x[..., 1::2]
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
return rot_x
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SwitchFunctionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: SwitchFunctionConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

View File

@ -0,0 +1,15 @@
# Note that zamba does not have the `apply_rotary_pos_emb` function!
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.models.zamba.modeling_zamba import ZambaAttention
# When following ZambaAttention dependencies, the function `apply_rotary_pos_emb` is not present
# by default as it is absent from the class definition (and the file altogether).
# Note that this syntax should be able to add both `apply_rotary_pos_emb` as imported directly, but
# `rotate_half` as well as a dependency from the imported function!!
class TestAttention(ZambaAttention):
def __init__(self):
pass
def forward(self):
_ = apply_rotary_pos_emb(1, 1, 1, 1)

View File

@ -0,0 +1,10 @@
# Note that llama and cohere have different definitions for rotate_half
from transformers.models.cohere.modeling_cohere import rotate_half # noqa
from transformers.models.llama.modeling_llama import LlamaAttention
# When following LlamaAttention dependencies, we will grab the function `rotate_half` defined
# in `modeling_llama.py`. But here we imported it explicitly from Cohere, so it should use Cohere's
# definition instead
class SwitchFunctionAttention(LlamaAttention):
pass

View File

@ -776,7 +776,7 @@ class ModelFileMapper(ModuleMapper):
else: else:
merged_dependencies.append(class_dep) merged_dependencies.append(class_dep)
# Sort both list according to the order in their respective file # Sort both list according to the order in their respective file
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10))
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
# Add all original node first, then merged ones # Add all original node first, then merged ones
@ -801,7 +801,7 @@ class ModelFileMapper(ModuleMapper):
else: else:
original_dependencies.append(dep) original_dependencies.append(dep)
# Sort both list according to the order in their respective file # Sort both list according to the order in their respective file
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10))
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
# Add all original node first, then merged ones # Add all original node first, then merged ones
@ -1321,6 +1321,20 @@ class ModularFileMapper(ModuleMapper):
self.added_objects_file_mapping[dep] = file self.added_objects_file_mapping[dep] = file
self.functions[dep] = visited_module.global_nodes[dep] self.functions[dep] = visited_module.global_nodes[dep]
# Add/overwrite the imported functions to other visited modules as well, in case it is absent/different
# in he modeling source file of the inherited class. See `examples/modular-tranformers/modular_switch_function.py`
# and `examples/modular-tranformers/modular_add_function.py` for examples
recursive_dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, set())
node_recursive_dependencies_mapping = {
dep: visited_module.global_nodes[dep] for dep in recursive_dependencies
}
for filename, module_mapper in self.visited_modules.items():
if filename != file:
module_mapper.global_nodes[object_name] = visited_module.functions[object_name]
if len(recursive_dependencies) > 0:
module_mapper.object_recursive_dependency_mapping[object_name] = recursive_dependencies
module_mapper.global_nodes.update(node_recursive_dependencies_mapping)
# Add assignments and their dependencies # Add assignments and their dependencies
elif object_name in visited_module.assignments and object_name not in self.assignments: elif object_name in visited_module.assignments and object_name not in self.assignments:
self.assignments[object_name] = visited_module.assignments[object_name] self.assignments[object_name] = visited_module.assignments[object_name]