mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Add GraniteMoeHybrid support for 4.0 (#37658)
* initial config and MLA layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at decoder Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * completion of layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * modeling class Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * adding hybrid class to imports Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix imports granitemoehybrid Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix granitehybrid imports Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix granitehybrid import Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix generated modeling file Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * add some comments Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * minor fixes in layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * add sharedMLP layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * correct layer names Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fixes in mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * change name of MLP layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix seq mizer layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * correct mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fixes in param names Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * enable hybrid model Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix config granite hybrid Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix attention layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * cleanup to re-use mamba code Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * keep layer types Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * attention bias cleanup Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update mamba layer name Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * use granite attention Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix: self attn weights Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * pass at making pos_emb optional Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * initialize self_attn only as needed Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * overwrite forward to create HybridMambaCache Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * Log invalid layer types * Add attention outputs test * Only emit attentions/logits if not None * Fix config test hidden size divisibility * mark granitmoehybrid as stateful * Initialize mamba convolutional layers * Formatting fixes * config docstring, removed some unused attrs * Fix missing arg in models test * Fix create and check decoder model test * support logits to keep in granitemoe * regen to pass logits_to_keep * Allow None or rope * Fix gradient checkpointing * Add granitemoehybrid as special cache for generate check * Remove unused MLA refs * Fix mamba layer mask * Remove logits to keep from config * Minor docstring nits * Update licenses * Enable cache by default * map layer types to layer block type * First pass at granite moe hybrid docs * Ignore granite moe hybrid in valid checkpoint check * Align attention interfaces * regenerate modular granitemoeshared attention interface * Align granite moe hybrid attn interface * run formatting * Handle mamba initialization * avoid conditional attr defs * Move hybrid layer validation to config * Add placeholder integration tests * Docs nits / Update model names * Clean up forward conditions * Use gradient checkpointing layer * Remove some copied bamba tests + inherit align test init delete more tests Use common layer init with bamba tests finish test consolidation * avoid redundant intermediate std var * use @can_return_tuple * Remove unused moe state * make skipped test names consistent * Fix docstring order * Add missing toc * Always create the shared mlp * Fix name in docstring * link preview model in docs --------- Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> Co-authored-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
fe29b8c487
commit
471958b620
@ -495,6 +495,8 @@
|
||||
title: Granite
|
||||
- local: model_doc/granitemoe
|
||||
title: GraniteMoe
|
||||
- local: model_doc/granitemoehybrid
|
||||
title: GraniteMoeHybrid
|
||||
- local: model_doc/granitemoeshared
|
||||
title: GraniteMoeShared
|
||||
- local: model_doc/helium
|
||||
|
64
docs/source/en/model_doc/granitemoehybrid.md
Normal file
64
docs/source/en/model_doc/granitemoehybrid.md
Normal file
@ -0,0 +1,64 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# GraniteMoeHybrid
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The `GraniteMoeHybrid` model builds on top of `GraniteMoeSharedModel` and `Bamba`. Its decoding layers consist of state space layers or MoE attention layers with shared experts. By default, the attention layers do not use positional encoding.
|
||||
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_path = "ibm-granite/granite-4.0-tiny-preview"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# drop device_map if running on CPU
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||
model.eval()
|
||||
|
||||
# change input text as desired
|
||||
prompt = "Write a code to find the maximum value in a list of numbers."
|
||||
|
||||
# tokenize the text
|
||||
input_tokens = tokenizer(prompt, return_tensors="pt")
|
||||
# generate output tokens
|
||||
output = model.generate(**input_tokens, max_new_tokens=100)
|
||||
# decode output tokens into text
|
||||
output = tokenizer.batch_decode(output)
|
||||
# loop over the batch to print, in this example the batch size is 1
|
||||
for i in output:
|
||||
print(i)
|
||||
```
|
||||
|
||||
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
|
||||
|
||||
|
||||
## GraniteMoeHybridConfig
|
||||
|
||||
[[autodoc]] GraniteMoeHybridConfig
|
||||
|
||||
## GraniteMoeHybridModel
|
||||
|
||||
[[autodoc]] GraniteMoeHybridModel
|
||||
- forward
|
||||
|
||||
## GraniteMoeHybridForCausalLM
|
||||
|
||||
[[autodoc]] GraniteMoeHybridForCausalLM
|
||||
- forward
|
@ -129,6 +129,7 @@ if TYPE_CHECKING:
|
||||
from .granite import *
|
||||
from .granite_speech import *
|
||||
from .granitemoe import *
|
||||
from .granitemoehybrid import *
|
||||
from .granitemoeshared import *
|
||||
from .grounding_dino import *
|
||||
from .groupvit import *
|
||||
|
@ -146,6 +146,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("granite", "GraniteConfig"),
|
||||
("granite_speech", "GraniteSpeechConfig"),
|
||||
("granitemoe", "GraniteMoeConfig"),
|
||||
("granitemoehybrid", "GraniteMoeHybridConfig"),
|
||||
("granitemoeshared", "GraniteMoeSharedConfig"),
|
||||
("granitevision", "LlavaNextConfig"),
|
||||
("graphormer", "GraphormerConfig"),
|
||||
@ -509,6 +510,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("granite", "Granite"),
|
||||
("granite_speech", "GraniteSpeech"),
|
||||
("granitemoe", "GraniteMoeMoe"),
|
||||
("granitemoehybrid", "GraniteMoeHybrid"),
|
||||
("granitemoeshared", "GraniteMoeSharedMoe"),
|
||||
("granitevision", "LLaVA-NeXT"),
|
||||
("graphormer", "Graphormer"),
|
||||
|
@ -138,6 +138,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
|
||||
("granite", "GraniteModel"),
|
||||
("granitemoe", "GraniteMoeModel"),
|
||||
("granitemoehybrid", "GraniteMoeHybridModel"),
|
||||
("granitemoeshared", "GraniteMoeSharedModel"),
|
||||
("graphormer", "GraphormerModel"),
|
||||
("grounding-dino", "GroundingDinoModel"),
|
||||
@ -558,6 +559,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("granite", "GraniteForCausalLM"),
|
||||
("granitemoe", "GraniteMoeForCausalLM"),
|
||||
("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
|
||||
("granitemoeshared", "GraniteMoeSharedForCausalLM"),
|
||||
("helium", "HeliumForCausalLM"),
|
||||
("jamba", "JambaForCausalLM"),
|
||||
|
@ -854,6 +854,7 @@ class BambaMixer(nn.Module):
|
||||
# Init cache
|
||||
if ssm_state is not None and cache_params is not None:
|
||||
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
||||
cache_params.has_previous_state = True
|
||||
|
||||
scan_output = self.norm(y, gate)
|
||||
|
||||
|
@ -651,6 +651,7 @@ class BambaMixer(nn.Module):
|
||||
# Init cache
|
||||
if ssm_state is not None and cache_params is not None:
|
||||
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
||||
cache_params.has_previous_state = True
|
||||
|
||||
scan_output = self.norm(y, gate)
|
||||
|
||||
|
@ -166,6 +166,8 @@ class GraniteMoeConfig(PretrainedConfig):
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# this model has rope embedding type, hardcoded for BC
|
||||
self.position_embedding_type = "rope"
|
||||
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
@ -13,25 +13,24 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -439,10 +438,9 @@ class GraniteMoeAttention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
@ -455,260 +453,75 @@ class GraniteMoeAttention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
cos, sin = position_embeddings if position_embeddings is not None else (None, None)
|
||||
if position_embeddings is not None:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe
|
||||
# TODO cyril: modular
|
||||
class GraniteMoeFlashAttention2(GraniteMoeAttention):
|
||||
"""
|
||||
GraniteMoe flash attention module. This module inherits from `GraniteMoeAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (GraniteMoeRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=self.scaling,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe
|
||||
# TODO cyril: modular
|
||||
class GraniteMoeSdpaAttention(GraniteMoeAttention):
|
||||
"""
|
||||
GraniteMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`GraniteMoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
# Adapted from GraniteMoeAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"GraniteMoeModel is using GraniteMoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
GRANITEMOE_ATTENTION_CLASSES = {
|
||||
"eager": GraniteMoeAttention,
|
||||
"flash_attention_2": GraniteMoeFlashAttention2,
|
||||
"sdpa": GraniteMoeSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class GraniteMoeDecoderLayer(nn.Module):
|
||||
class GraniteMoeDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: GraniteMoeConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GRANITEMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
|
||||
self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx)
|
||||
self.block_sparse_moe = GraniteMoeMoE(config)
|
||||
self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -827,13 +640,12 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GraniteMoeRMSNorm):
|
||||
@ -947,8 +759,8 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
# rope
|
||||
self.rotary_emb = GraniteMoeRotaryEmbedding(config)
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.rotary_emb = GraniteMoeRotaryEmbedding(config) if self.position_embedding_type == "rope" else None
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -1019,8 +831,10 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
position_embeddings = None
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
if self.rotary_emb is not None:
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -1032,31 +846,17 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
output_router_logits,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
output_router_logits=output_router_logits,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
output_router_logits=output_router_logits,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1265,6 +1065,7 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||
r"""
|
||||
@ -1273,6 +1074,13 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1315,8 +1123,10 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
# Only compute necessary logits
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
logits = logits / self.config.logits_scaling
|
||||
|
||||
loss = None
|
||||
|
29
src/transformers/models/granitemoehybrid/__init__.py
Normal file
29
src/transformers/models/granitemoehybrid/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_granitemoehybrid import *
|
||||
from .modeling_granitemoehybrid import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
@ -0,0 +1,256 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""GraniteMoeHybrid model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GraniteMoeHybridConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GraniteMoeHybridConfig`]. It is used to
|
||||
instantiate an GraniteMoeHybrid model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the GraniteMoeHybrid model. Defines the number of different tokens that
|
||||
can be represented by the `inputs_ids` passed when calling [`GraniteMoeHybridModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
Only relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
embedding_multiplier (`float`, *optional*, defaults to 1.0): embedding multiplier.
|
||||
logits_scaling (`float`, *optional*, defaults to 1.0): divisor for output logits.
|
||||
residual_multiplier (`float`, *optional*, defaults to 1.0): residual multiplier.
|
||||
attention_multiplier (`float`, *optional*, defaults to 1.0): attention multiplier.
|
||||
num_local_experts (`int`, *optional*, defaults to 8): total number of experts.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 2): number of experts per token.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||
allow the model to output the auxiliary loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxialiary loss coefficient
|
||||
shared_intermediate_size (`int`, *optional*, defaults to 1024): intermediate size for shared experts.
|
||||
position_embedding_type (`str`, *optional*): Positional embedding
|
||||
type to be used; defaults to None. Allowed options: `[None, "rope"]`
|
||||
layer_types (`List`, *optional*): list of strings to be used as layer types.
|
||||
Allowed choices: "mamba", "attention".
|
||||
mamba_n_heads (`int`, *optional*, defaults to 128):
|
||||
The number of mamba heads used.
|
||||
mamba_n_groups (`int`, *optional*, defaults to 1):
|
||||
The number of the mamba groups used.
|
||||
mamba_d_state (`int`, *optional*, defaults to 256):
|
||||
The dimension the mamba latent state space.
|
||||
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
|
||||
Head embedding dimension size.
|
||||
mamba_d_conv (`int`, *optional*, defaults to 4):
|
||||
The size of the mamba convolution kernel.
|
||||
mamba_expand (`int`, *optional*, defaults to 2):
|
||||
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size.
|
||||
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
||||
The chunks in which to break the sequence when doing prefill/training.
|
||||
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
|
||||
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"])
|
||||
of the mamba mixer block.
|
||||
```python
|
||||
>>> from transformers import GraniteMoeHybridModel, GraniteMoeHybridConfig
|
||||
|
||||
>>> # Initializing a GraniteMoeHybrid config
|
||||
>>> configuration = GraniteMoeHybridConfig()
|
||||
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "granitemoehybrid"
|
||||
attribute_map = {
|
||||
"layers_block_type": "layer_types",
|
||||
}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
embedding_multiplier=1.0,
|
||||
logits_scaling=1.0,
|
||||
residual_multiplier=1.0,
|
||||
attention_multiplier=1.0,
|
||||
num_local_experts=8,
|
||||
num_experts_per_tok=2,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
shared_intermediate_size=1024,
|
||||
position_embedding_type=None,
|
||||
layer_types=None,
|
||||
mamba_n_heads=128,
|
||||
mamba_n_groups=1,
|
||||
mamba_d_state=256,
|
||||
mamba_d_head="auto",
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_chunk_size=256,
|
||||
mamba_conv_bias=True,
|
||||
mamba_proj_bias=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.embedding_multiplier = embedding_multiplier
|
||||
self.logits_scaling = logits_scaling
|
||||
self.residual_multiplier = residual_multiplier
|
||||
self.attention_multiplier = attention_multiplier
|
||||
self.attention_dropout = attention_dropout
|
||||
self.num_local_experts = num_local_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.shared_intermediate_size = shared_intermediate_size
|
||||
self.position_embedding_type = position_embedding_type
|
||||
|
||||
mamba_intermediate = mamba_expand * hidden_size
|
||||
|
||||
if layer_types is not None and any(layer_type not in ["mamba", "attention"] for layer_type in layer_types):
|
||||
raise ValueError("layer_types must be a list strings in [`mamba` `attention`]")
|
||||
|
||||
if mamba_intermediate % mamba_n_heads != 0:
|
||||
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
|
||||
|
||||
# for the mamba_v2, must satisfy the following
|
||||
if mamba_d_head == "auto":
|
||||
mamba_d_head = mamba_intermediate // mamba_n_heads
|
||||
|
||||
if mamba_d_head * mamba_n_heads != mamba_intermediate:
|
||||
raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
|
||||
|
||||
self.mamba_n_heads = mamba_n_heads
|
||||
self.mamba_d_head = mamba_d_head
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
self.mamba_conv_bias = mamba_conv_bias
|
||||
self.mamba_proj_bias = mamba_proj_bias
|
||||
self.mamba_expand = mamba_expand
|
||||
self.layer_types = layer_types
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.position_embedding_type == "rope":
|
||||
rope_config_validation(self)
|
||||
|
||||
# overwrite the function to use in `HybridMambaAttentionDynamicCache`
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
return self.layer_types if self.layer_types else ["mamba"] * self.num_hidden_layers
|
||||
|
||||
|
||||
__all__ = ["GraniteMoeHybridConfig"]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,510 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ..bamba.configuration_bamba import BambaConfig
|
||||
from ..bamba.modeling_bamba import (
|
||||
BambaMixer,
|
||||
BambaRMSNormGated,
|
||||
HybridMambaAttentionDynamicCache,
|
||||
)
|
||||
from ..granitemoeshared.modeling_granitemoeshared import (
|
||||
GraniteMoeSharedAttention,
|
||||
GraniteMoeSharedDecoderLayer,
|
||||
GraniteMoeSharedForCausalLM,
|
||||
GraniteMoeSharedMLP,
|
||||
GraniteMoeSharedModel,
|
||||
GraniteMoeSharedPreTrainedModel,
|
||||
)
|
||||
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
|
||||
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
|
||||
class GraniteMoeHybridMambaLayer(BambaMixer):
|
||||
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
|
||||
super().__init__(BambaConfig(config), layer_idx)
|
||||
|
||||
|
||||
class GraniteMoeHybridRMSNormGated(BambaRMSNormGated):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps)
|
||||
|
||||
|
||||
class GraniteMoeHybridMLP(GraniteMoeSharedMLP):
|
||||
def __init__(self, config: GraniteMoeHybridConfig):
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
||||
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
self.shared_mlp = GraniteMoeHybridMLP(config)
|
||||
# Either attention or mamba will be initialized, depending on the layer type.
|
||||
self.self_attn = None
|
||||
self.mamba = None
|
||||
|
||||
if config.layers_block_type[layer_idx] == "mamba":
|
||||
self.mamba = GraniteMoeHybridMambaLayer(config, layer_idx)
|
||||
else:
|
||||
self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
|
||||
self.layer_type = config.layers_block_type[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
output_router_logits (`bool`, *optional*):
|
||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||
should not be returned during inference.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
if self.mamba is not None:
|
||||
hidden_states = self.mamba(
|
||||
hidden_states=hidden_states,
|
||||
cache_position=cache_position,
|
||||
cache_params=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
# No attention weights for state space layers
|
||||
self_attn_weights = None
|
||||
else:
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
|
||||
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (past_key_value,)
|
||||
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
GRANITEMOEHYBRID_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`GraniteMoeHybridConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare GraniteMoeHybrid Model outputting raw hidden-states without any specific head on top.",
|
||||
GRANITEMOEHYBRID_START_DOCSTRING,
|
||||
)
|
||||
class GraniteMoeHybridPreTrainedModel(GraniteMoeSharedPreTrainedModel):
|
||||
config_class = GraniteMoeHybridConfig
|
||||
_no_split_modules = ["GraniteMoeHybridDecoderLayer"]
|
||||
_is_stateful = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights()
|
||||
# Initialize Mamba modules
|
||||
if isinstance(module, (nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, GraniteMoeHybridMambaLayer):
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
elif isinstance(module, GraniteMoeHybridRMSNormGated):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
GRANITEMOEHYBRID_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||
|
||||
Two formats are allowed:
|
||||
- a [`~cache_utils.Cache`] instance;
|
||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||
cache format.
|
||||
|
||||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||
legacy cache format will be returned.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||
of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare GraniteMoeHybrid Model outputting raw hidden-states without any specific head on top.",
|
||||
GRANITEMOEHYBRID_START_DOCSTRING,
|
||||
)
|
||||
class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers.
|
||||
Each layer is a [`GraniteMoeHybridDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: GraniteMoeHybridConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: GraniteMoeHybridConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(GRANITEMOEHYBRID_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
inputs_embeds = inputs_embeds * self.embedding_multiplier
|
||||
|
||||
## overwritten because `HybridMambaAttentionDynamicCache` is needed
|
||||
if use_cache and past_key_values is None:
|
||||
logger.warning_once(
|
||||
"GraniteMoeHybrid requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. "
|
||||
"Because one was not provided, no cache will be returned."
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
position_embeddings = None
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
if self.rotary_emb is not None:
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
|
||||
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=layer_mask,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
output_router_logits=output_router_logits,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
if layer_outputs[1] is not None:
|
||||
# append attentions only of attention layers. Mamba layers return `None` as the attention weights
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if output_router_logits:
|
||||
if layer_outputs[-1] is not None:
|
||||
# append router logits only of expert layers. Regular MLP layers return `None` as the router logits
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
||||
def _update_mamba_mask(self, attention_mask, cache_position):
|
||||
"""
|
||||
No need for zeroing states when
|
||||
1. Cached forward
|
||||
2. Attending to all inputs
|
||||
"""
|
||||
mamba_mask = attention_mask
|
||||
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
|
||||
mamba_mask = None
|
||||
return mamba_mask
|
||||
|
||||
|
||||
class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: GraniteMoeHybridConfig):
|
||||
super().__init__(config)
|
||||
self.model = GraniteMoeHybridModel(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
||||
|
||||
empty_past_kv = past_key_values is None
|
||||
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if not empty_past_kv:
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or cache_position[-1] >= input_ids.shape[1] # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
else:
|
||||
past_key_values = HybridMambaAttentionDynamicCache(
|
||||
self.config, input_ids.shape[0], self.dtype, device=self.device
|
||||
)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if not empty_past_kv:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and empty_past_kv:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _supports_default_dynamic_cache(self) -> bool:
|
||||
"""
|
||||
Function overwritten as this class uses its own `HybridMambaAttentionDynamicCache`
|
||||
and do not need to initialize the Cache in advance in order to save memory
|
||||
(because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed
|
||||
for `HybridMambaAttentionDynamicCache`).
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
__all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"]
|
@ -169,6 +169,8 @@ class GraniteMoeSharedConfig(PretrainedConfig):
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# this model has rope embedding type, hardcoded for BC
|
||||
self.position_embedding_type = "rope"
|
||||
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
@ -19,7 +19,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -29,10 +29,10 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@ -300,6 +300,33 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoeShared
|
||||
# no longer copied after attention refactors
|
||||
class GraniteMoeSharedAttention(nn.Module):
|
||||
@ -343,10 +370,9 @@ class GraniteMoeSharedAttention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
@ -359,262 +385,48 @@ class GraniteMoeSharedAttention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
cos, sin = position_embeddings if position_embeddings is not None else (None, None)
|
||||
if position_embeddings is not None:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoeShared
|
||||
# TODO cyril: modular
|
||||
class GraniteMoeSharedFlashAttention2(GraniteMoeSharedAttention):
|
||||
"""
|
||||
GraniteMoeShared flash attention module. This module inherits from `GraniteMoeSharedAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (GraniteMoeSharedRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=self.scaling,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoeShared
|
||||
# TODO cyril: modular
|
||||
class GraniteMoeSharedSdpaAttention(GraniteMoeSharedAttention):
|
||||
"""
|
||||
GraniteMoeShared attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`GraniteMoeSharedAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from GraniteMoeSharedAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"GraniteMoeSharedModel is using GraniteMoeSharedSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
GRANITEMOESHARED_ATTENTION_CLASSES = {
|
||||
"eager": GraniteMoeSharedAttention,
|
||||
"flash_attention_2": GraniteMoeSharedFlashAttention2,
|
||||
"sdpa": GraniteMoeSharedSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class GraniteMoeSharedDecoderLayer(nn.Module):
|
||||
class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GRANITEMOESHARED_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config=config, layer_idx=layer_idx
|
||||
)
|
||||
|
||||
self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx)
|
||||
self.block_sparse_moe = GraniteMoeSharedMoE(config)
|
||||
self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -632,7 +444,7 @@ class GraniteMoeSharedDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -739,13 +551,12 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GraniteMoeSharedRMSNorm):
|
||||
@ -893,8 +704,8 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
# rope
|
||||
self.rotary_emb = GraniteMoeSharedRotaryEmbedding(config)
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.rotary_emb = GraniteMoeSharedRotaryEmbedding(config) if self.position_embedding_type == "rope" else None
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -965,8 +776,10 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
position_embeddings = None
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
if self.rotary_emb is not None:
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -978,31 +791,17 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
output_router_logits,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
output_router_logits=output_router_logits,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
output_router_logits=output_router_logits,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1291,6 +1090,7 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
||||
r"""
|
||||
@ -1299,6 +1099,13 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -1341,8 +1148,10 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
# Only compute necessary logits
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
logits = logits / self.config.logits_scaling
|
||||
|
||||
loss = None
|
||||
|
@ -16,7 +16,6 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -78,7 +77,7 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
|
@ -2491,6 +2491,7 @@ class GenerationTesterMixin:
|
||||
"bamba",
|
||||
"ctrl",
|
||||
"fsmt",
|
||||
"granitemoehybrid",
|
||||
"gptbigcode",
|
||||
"mega",
|
||||
"reformer",
|
||||
|
@ -47,6 +47,11 @@ if is_torch_available():
|
||||
|
||||
|
||||
class BambaModelTester:
|
||||
config_class = BambaConfig
|
||||
if is_torch_available():
|
||||
model_class = BambaModel
|
||||
for_causal_lm_class = BambaForCausalLM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
@ -118,6 +123,7 @@ class BambaModelTester:
|
||||
if self.use_labels:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
|
||||
self._update_layer_configs()
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask, token_labels
|
||||
@ -133,10 +139,12 @@ class BambaModelTester:
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
def _update_layer_configs(self):
|
||||
"""Configures hidden layers and attn layer indices if they are not set."""
|
||||
# Fix for SDPA tests, force at least 4 layers
|
||||
if self.num_hidden_layers < 4:
|
||||
self.num_hidden_layers = 4
|
||||
|
||||
if self.attn_layer_indices is None:
|
||||
d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0]
|
||||
if len(d) == 0:
|
||||
@ -144,7 +152,8 @@ class BambaModelTester:
|
||||
d = d[-1] # get the largest divisor
|
||||
self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)]
|
||||
|
||||
return BambaConfig(
|
||||
def get_config(self, **kwargs):
|
||||
return self.config_class(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
@ -164,6 +173,7 @@ class BambaModelTester:
|
||||
mamba_d_conv=self.mamba_d_conv,
|
||||
mamba_expand=self.mamba_expand,
|
||||
mamba_chunk_size=self.mamba_chunk_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
@ -173,7 +183,7 @@ class BambaModelTester:
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
model = BambaModel(config=config)
|
||||
model = self.model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
@ -187,7 +197,7 @@ class BambaModelTester:
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
model = BambaForCausalLM(config=config)
|
||||
model = self.for_causal_lm_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
@ -205,7 +215,7 @@ class BambaModelTester:
|
||||
):
|
||||
# config.is_decoder = True
|
||||
# config.add_cross_attention = True
|
||||
model = BambaForCausalLM(config=config)
|
||||
model = self.for_causal_lm_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@ -258,6 +268,7 @@ class BambaModelTester:
|
||||
|
||||
@require_torch
|
||||
class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
model_tester_class = BambaModelTester
|
||||
all_model_classes = (BambaModel, BambaForCausalLM) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
@ -276,8 +287,8 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BambaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64)
|
||||
self.model_tester = self.model_tester_class(self)
|
||||
self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
0
tests/models/granitemoehybrid/__init__.py
Normal file
0
tests/models/granitemoehybrid/__init__.py
Normal file
164
tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py
Normal file
164
tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py
Normal file
@ -0,0 +1,164 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch GraniteMoeHybrid model."""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
GraniteMoeHybridConfig,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...models.bamba.test_modeling_bamba import BambaModelTest, BambaModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
GraniteMoeHybridForCausalLM,
|
||||
GraniteMoeHybridModel,
|
||||
)
|
||||
|
||||
|
||||
class GraniteMoeHybridModelTester(BambaModelTester):
|
||||
config_class = GraniteMoeHybridConfig
|
||||
if is_torch_available():
|
||||
model_class = GraniteMoeHybridModel
|
||||
for_causal_lm_class = GraniteMoeHybridForCausalLM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
use_cache=False,
|
||||
shared_intermediate_size=174,
|
||||
layer_types=None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.shared_intermediate_size = shared_intermediate_size
|
||||
self.layer_types = layer_types
|
||||
self.use_cache = use_cache
|
||||
|
||||
def _update_layer_configs(self):
|
||||
super()._update_layer_configs()
|
||||
# GraniteMoeHybrid uses layer_types instead of attn_layer_indices
|
||||
self.layer_types = ["mamba"] * self.num_hidden_layers
|
||||
for idx in self.attn_layer_indices:
|
||||
self.layer_types[idx] = "attention"
|
||||
|
||||
def get_config(self):
|
||||
return super().get_config(
|
||||
shared_intermediate_size=self.shared_intermediate_size,
|
||||
layer_types=self.layer_types,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.TestCase):
|
||||
model_tester_class = GraniteMoeHybridModelTester
|
||||
all_model_classes = (
|
||||
(
|
||||
GraniteMoeHybridModel,
|
||||
GraniteMoeHybridForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": GraniteMoeHybridModel,
|
||||
"text-generation": GraniteMoeHybridForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
def test_config_requires_mamba_or_attention_layers(self):
|
||||
"""Ensure we can't create a config with disallowed layers."""
|
||||
with pytest.raises(ValueError):
|
||||
GraniteMoeHybridConfig(layer_types=["not allowed!"])
|
||||
|
||||
|
||||
# TODO (@alex-jw-brooks) - update this once the model(s) are out
|
||||
@unittest.skip(reason="GraniteMoeHybrid models are not yet released")
|
||||
@require_torch_gpu
|
||||
class GraniteMoeHybridIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@slow
|
||||
def test_model_logits(self):
|
||||
input_ids = [31390, 631, 4162, 30, 322, 25342, 432, 1875, 43826, 10066, 688, 225]
|
||||
|
||||
model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-tiny", device_map="auto")
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(torch.tensor([input_ids]).to(torch_device))
|
||||
|
||||
# fmt: off
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([
|
||||
[-2.9711, -2.2554, -1.0814, -1.6123, -0.8780, -1.0685, -0.6368, -1.9732, -3.3548, -2.6895, -2.3062, -2.6338]
|
||||
])
|
||||
|
||||
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)
|
||||
|
||||
# slicing logits[0, 0, 0:15]
|
||||
EXPECTED_SLICE = torch.tensor([
|
||||
[4.0662, 5.9547, 3.5803, 3.1306, 4.3211, 3.8902, 4.6438, 8.5434, 7.5865, 5.1623, 5.2240, 9.2982, 5.9094, 6.8834, 5.7551],
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
EXPECTED_SLICE.to(torch_device),
|
||||
out.logits[0, 0, :15].float(),
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"Simply put, the theory of relativity states that 1) time is relative, and 2) space is relative. The first"
|
||||
)
|
||||
prompt = "Simply put, the theory of relativity states that "
|
||||
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-4.0-tiny")
|
||||
model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-tiny", device_map="auto")
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=16, do_sample=False)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
@ -47,6 +47,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"LlamaConfig",
|
||||
"GraniteConfig",
|
||||
"GraniteMoeConfig",
|
||||
"GraniteMoeHybridConfig",
|
||||
"Qwen3MoeConfig",
|
||||
"GraniteSpeechConfig",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user