fix moonshine

This commit is contained in:
Arthur 2025-07-03 10:48:50 +02:00
parent 4fc83fa3a2
commit 499ae87ef7
3 changed files with 34 additions and 205 deletions

View File

@ -41,14 +41,11 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from .configuration_moonshine import MoonshineConfig
logger = logging.get_logger(__name__)
class MoonshineEncoderMLP(nn.Module):
def __init__(self, config, hidden_act):
super().__init__()
@ -419,14 +416,12 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer):
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -439,12 +434,10 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer):
)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states, _ = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
@ -454,18 +447,11 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer):
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
return hidden_states
@auto_docstring
@ -479,6 +465,14 @@ class MoonshinePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_can_record_outputs = {
"attentions": (MoonshineAttention, 1),
"cross_attentions": (
MoonshineAttention,
1,
), # The issue is that we are attaching hooks to all MoonshiAttention instances
"hidden_states": (MoonshineDecoderLayer, 0),
}
def _init_weights(self, module):
std = self.config.initializer_range
@ -545,10 +539,8 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
@can_return_tuple
def forward(
self,
input_values: Optional[torch.FloatTensor] = None,
input_values: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
r"""
@ -564,23 +556,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_values is None:
raise ValueError("You must specify input_values.")
# conv downsampling
input_values = input_values.unsqueeze(1)
hidden_states = nn.functional.tanh(self.conv1(input_values))
@ -610,38 +586,19 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# encoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = encoder_layer(
hidden_states = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last encoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@ -681,8 +638,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
@ -698,21 +653,9 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -743,11 +686,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# attention mask downsampling
if encoder_attention_mask is not None:
mask_len = encoder_hidden_states.shape[-2]
@ -757,7 +695,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
@ -769,43 +707,24 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
)
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states = decoder_layer(
hidden_states,
causal_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)

View File

@ -450,14 +450,12 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer):
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -470,12 +468,10 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer):
)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states, _ = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
@ -485,18 +481,11 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer):
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
return hidden_states
@auto_docstring
@ -510,6 +499,14 @@ class MoonshinePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_can_record_outputs = {
"attentions": (MoonshineAttention, 1),
"cross_attentions": (
MoonshineAttention,
1,
), # The issue is that we are attaching hooks to all MoonshiAttention instances
"hidden_states": (MoonshineDecoderLayer, 0),
}
def _init_weights(self, module):
std = self.config.initializer_range
@ -576,10 +573,8 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
@can_return_tuple
def forward(
self,
input_values: Optional[torch.FloatTensor] = None,
input_values: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
r"""
@ -595,23 +590,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_values is None:
raise ValueError("You must specify input_values.")
# conv downsampling
input_values = input_values.unsqueeze(1)
hidden_states = nn.functional.tanh(self.conv1(input_values))
@ -641,38 +620,19 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# encoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = encoder_layer(
hidden_states = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last encoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@ -694,8 +654,6 @@ class MoonshineDecoder(LlamaModel):
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
@ -711,21 +669,9 @@ class MoonshineDecoder(LlamaModel):
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -756,11 +702,6 @@ class MoonshineDecoder(LlamaModel):
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# attention mask downsampling
if encoder_attention_mask is not None:
mask_len = encoder_hidden_states.shape[-2]
@ -770,7 +711,7 @@ class MoonshineDecoder(LlamaModel):
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
@ -782,43 +723,24 @@ class MoonshineDecoder(LlamaModel):
)
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states = decoder_layer(
hidden_states,
causal_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)

View File

@ -252,8 +252,6 @@ class SamHQMaskDecoder(nn.Module):
Whether to use only the high-quality token output or combine with SAM output.
intermediate_embeddings (`torch.Tensor`):
Intermediate embeddings from the vision encoder for feature fusion.
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
attention_similarity (`torch.Tensor`, *optional*):
Optional tensor for attention similarity computation.
target_embedding (`torch.Tensor`, *optional*):
@ -426,20 +424,10 @@ class SamHQModel(SamModel):
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
vision_output = self.vision_encoder(
pixel_values=pixel_values,
)
vision_output = self.vision_encoder(pixel_values=pixel_values)
image_embeddings = vision_output[0]
intermediate_embeddings = vision_output[1]
return image_embeddings, intermediate_embeddings
@can_return_tuple