mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update other models as well just making fix-copies
This commit is contained in:
parent
7433c44376
commit
37b4ef022e
@ -43,7 +43,8 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, can_return_tuple
|
||||
from ...utils import can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_arcee import ArceeConfig
|
||||
|
||||
|
||||
@ -194,6 +195,8 @@ def eager_attention_forward(
|
||||
class ArceeAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: ArceeConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -225,7 +228,7 @@ class ArceeAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -262,6 +265,8 @@ class ArceeAttention(nn.Module):
|
||||
|
||||
|
||||
class ArceeDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: ArceeConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -278,27 +283,25 @@ class ArceeDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -306,12 +309,7 @@ class ArceeDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -321,7 +319,6 @@ class ArceePreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ArceeDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -368,7 +365,7 @@ class ArceeModel(ArceePreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -377,31 +374,13 @@ class ArceeModel(ArceePreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -426,52 +405,28 @@ class ArceeModel(ArceePreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -516,11 +471,9 @@ class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -544,12 +497,6 @@ class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -557,8 +504,6 @@ class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -609,8 +554,7 @@ class ArceeForSequenceClassification(ArceePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -626,8 +570,7 @@ class ArceeForSequenceClassification(ArceePreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -697,8 +640,6 @@ class ArceeForQuestionAnswering(ArceePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
@ -707,8 +648,6 @@ class ArceeForQuestionAnswering(ArceePreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
@ -766,8 +705,7 @@ class ArceeForTokenClassification(ArceePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -783,8 +721,7 @@ class ArceeForTokenClassification(ArceePreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -32,7 +32,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast,
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from ...utils.import_utils import is_torch_available
|
||||
from ..auto import AutoModel
|
||||
from .configuration_aria import AriaConfig, AriaTextConfig
|
||||
@ -43,9 +44,6 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class AriaTextRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
@ -501,6 +499,8 @@ def eager_attention_forward(
|
||||
class AriaTextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: AriaTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -532,7 +532,7 @@ class AriaTextAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -581,6 +581,8 @@ class AriaTextDecoderLayer(GradientCheckpointingLayer):
|
||||
Index of the layer.
|
||||
"""
|
||||
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: AriaTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -596,27 +598,25 @@ class AriaTextDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -624,12 +624,7 @@ class AriaTextDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -667,7 +662,6 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["AriaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -751,7 +745,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -760,31 +754,13 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -809,46 +785,25 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
@ -898,11 +853,9 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -926,12 +879,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -939,8 +886,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -265,6 +265,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class BambaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: BambaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -296,7 +298,7 @@ class BambaAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -34,13 +34,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_bitnet import BitNetConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class BitNetRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
@ -154,6 +152,8 @@ def eager_attention_forward(
|
||||
class BitNetAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: BitNetConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -225,6 +225,8 @@ class BitNetAttention(nn.Module):
|
||||
|
||||
|
||||
class BitNetDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: BitNetConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -241,27 +243,25 @@ class BitNetDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -269,12 +269,7 @@ class BitNetDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BitNetRotaryEmbedding(nn.Module):
|
||||
@ -318,7 +313,6 @@ class BitNetPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BitNetDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -365,7 +359,7 @@ class BitNetModel(BitNetPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -374,31 +368,13 @@ class BitNetModel(BitNetPreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -423,52 +399,28 @@ class BitNetModel(BitNetPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -513,11 +465,9 @@ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -541,12 +491,6 @@ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"User: Hey, are you conscious? Can you talk to me?Assistant: No, I'm not conscious. I'm an artificial intelligence designed to assist with information and tasks. How can I help you today?"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -554,8 +498,6 @@ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -42,13 +42,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_cohere import CohereConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CohereLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size=None, eps=1e-5, bias=False):
|
||||
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
|
||||
@ -195,6 +193,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class CohereAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -355,7 +355,6 @@ class CoherePreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["CohereDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -402,7 +401,7 @@ class CohereModel(CoherePreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -411,31 +410,13 @@ class CohereModel(CoherePreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -460,46 +441,25 @@ class CohereModel(CoherePreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
|
@ -36,6 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_cohere2 import Cohere2Config
|
||||
|
||||
|
||||
@ -172,6 +173,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class Cohere2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -334,7 +337,6 @@ class Cohere2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Cohere2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -381,7 +383,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
|
@ -280,6 +280,8 @@ def eager_attention_forward(
|
||||
class CsmAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: CsmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -311,7 +313,7 @@ class CsmAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -348,6 +350,8 @@ class CsmAttention(nn.Module):
|
||||
|
||||
|
||||
class CsmDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: CsmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -364,27 +368,25 @@ class CsmDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -392,12 +394,7 @@ class CsmDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -753,11 +750,9 @@ class CsmBackboneModel(CsmPreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
|
||||
@ -771,25 +766,9 @@ class CsmBackboneModel(CsmPreTrainedModel):
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -814,46 +793,25 @@ class CsmBackboneModel(CsmPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
|
@ -22,13 +22,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_deepseek_v3 import DeepseekV3Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class DeepseekV3RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
@ -441,6 +439,8 @@ class DeepseekV3Attention(nn.Module):
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -461,27 +461,25 @@ class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -489,12 +487,7 @@ class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -504,7 +497,6 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DeepseekV3DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -555,7 +547,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -564,31 +556,13 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -613,52 +587,28 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -703,11 +653,9 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -731,12 +679,6 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -744,8 +686,6 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -255,6 +255,8 @@ def eager_attention_forward(
|
||||
class DiaSelfAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -281,7 +283,7 @@ class DiaSelfAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -48,7 +48,8 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_diffllama import DiffLlamaConfig
|
||||
|
||||
|
||||
@ -497,6 +498,8 @@ DIFFLLAMA_ATTENTION_CLASSES = {
|
||||
|
||||
|
||||
class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -513,27 +516,25 @@ class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -541,12 +542,7 @@ class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -556,7 +552,6 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DiffLlamaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = False
|
||||
@ -642,7 +637,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -651,31 +646,13 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -700,52 +677,26 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -790,11 +741,9 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -818,12 +767,6 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -831,8 +774,6 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -896,8 +837,7 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -913,8 +853,7 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -984,8 +923,6 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
@ -994,8 +931,6 @@ class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
@ -1053,8 +988,7 @@ class DiffLlamaForTokenClassification(DiffLlamaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1070,8 +1004,7 @@ class DiffLlamaForTokenClassification(DiffLlamaPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -35,7 +35,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from .configuration_dots1 import Dots1Config
|
||||
|
||||
|
||||
@ -172,6 +172,8 @@ def eager_attention_forward(
|
||||
class Dots1Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Dots1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -360,6 +362,8 @@ class Dots1TopkRouter(nn.Module):
|
||||
|
||||
|
||||
class Dots1DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: Dots1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -381,27 +385,25 @@ class Dots1DecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -409,12 +411,7 @@ class Dots1DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -424,7 +421,6 @@ class Dots1PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Dots1DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -585,9 +581,6 @@ class Dots1Model(Dots1PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -632,11 +625,9 @@ class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -660,12 +651,6 @@ class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -673,8 +658,6 @@ class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -39,13 +39,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class Emu3RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
@ -158,6 +156,8 @@ def eager_attention_forward(
|
||||
class Emu3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Emu3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -189,7 +189,7 @@ class Emu3Attention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -226,6 +226,8 @@ class Emu3Attention(nn.Module):
|
||||
|
||||
|
||||
class Emu3DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: Emu3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -1209,7 +1211,7 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1218,31 +1220,13 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -1267,52 +1251,28 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -1358,11 +1318,9 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1386,12 +1344,6 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
||||
>>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
>>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -1399,8 +1351,6 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -330,6 +330,8 @@ def eager_attention_forward(
|
||||
class FalconH1Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: FalconH1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
@ -39,7 +39,8 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
|
||||
@ -191,6 +192,8 @@ def eager_attention_forward(
|
||||
class GemmaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -222,7 +225,7 @@ class GemmaAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -259,6 +262,8 @@ class GemmaAttention(nn.Module):
|
||||
|
||||
|
||||
class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -275,27 +280,25 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -303,12 +306,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -318,7 +316,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GemmaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -365,7 +362,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -469,9 +466,6 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -516,11 +510,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -544,12 +536,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -557,8 +543,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -622,8 +606,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -639,8 +622,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -716,8 +698,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -733,8 +714,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -41,6 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_gemma2 import Gemma2Config
|
||||
|
||||
|
||||
@ -167,6 +168,8 @@ def eager_attention_forward(
|
||||
class Gemma2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -339,7 +342,6 @@ class Gemma2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Gemma2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -386,7 +388,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -656,8 +658,7 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -673,8 +674,7 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -750,8 +750,7 @@ class Gemma2ForTokenClassification(Gemma2PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -767,8 +766,7 @@ class Gemma2ForTokenClassification(Gemma2PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -40,6 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.generic import check_model_inputs
|
||||
from ..auto import AutoModel
|
||||
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
|
||||
|
||||
@ -267,6 +268,8 @@ def eager_attention_forward(
|
||||
class Gemma3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
@ -422,7 +425,6 @@ class Gemma3PreTrainedModel(PreTrainedModel):
|
||||
"SiglipMultiheadAttentionPoolingHead",
|
||||
]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -484,7 +486,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1261,6 +1261,8 @@ def apply_rotary_pos_emb(
|
||||
class Gemma3nTextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
@ -1477,7 +1479,6 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Gemma3nDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -40,7 +40,8 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_glm import GlmConfig
|
||||
|
||||
|
||||
@ -155,6 +156,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class GlmAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -184,7 +187,7 @@ class GlmAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -276,6 +279,8 @@ class GlmRotaryEmbedding(nn.Module):
|
||||
|
||||
|
||||
class GlmDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: GlmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -292,27 +297,25 @@ class GlmDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -320,12 +323,7 @@ class GlmDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -335,7 +333,6 @@ class GlmPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GlmDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -382,7 +379,7 @@ class GlmModel(GlmPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -391,33 +388,15 @@ class GlmModel(GlmPreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
@ -440,52 +419,26 @@ class GlmModel(GlmPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -530,11 +483,9 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -558,12 +509,6 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -571,8 +516,6 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -636,8 +579,7 @@ class GlmForSequenceClassification(GlmPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -653,8 +595,7 @@ class GlmForSequenceClassification(GlmPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -730,8 +671,7 @@ class GlmForTokenClassification(GlmPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -747,8 +687,7 @@ class GlmForTokenClassification(GlmPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_glm4 import Glm4Config
|
||||
|
||||
|
||||
@ -213,6 +214,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class Glm4Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Glm4Config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -242,7 +245,7 @@ class Glm4Attention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -343,7 +346,6 @@ class Glm4PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Glm4DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -390,7 +392,7 @@ class Glm4Model(Glm4PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -399,33 +401,15 @@ class Glm4Model(Glm4PreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
@ -448,46 +432,23 @@ class Glm4Model(Glm4PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
@ -535,11 +496,9 @@ class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -563,12 +522,6 @@ class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -576,8 +529,6 @@ class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -641,8 +592,7 @@ class Glm4ForSequenceClassification(Glm4PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -658,8 +608,7 @@ class Glm4ForSequenceClassification(Glm4PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -735,8 +684,7 @@ class Glm4ForTokenClassification(Glm4PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -752,8 +700,7 @@ class Glm4ForTokenClassification(Glm4PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -26,6 +26,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_gpt_neox import GPTNeoXConfig
|
||||
|
||||
|
||||
@ -292,7 +293,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTNeoXLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -339,7 +339,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_in = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
|
@ -36,6 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_granite import GraniteConfig
|
||||
|
||||
|
||||
@ -117,6 +118,8 @@ def eager_attention_forward(
|
||||
class GraniteAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -148,7 +151,7 @@ class GraniteAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -222,6 +225,8 @@ class GraniteMLP(nn.Module):
|
||||
|
||||
|
||||
class GraniteDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: GraniteConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -305,7 +310,6 @@ class GranitePreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GraniteDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -387,7 +391,7 @@ class GraniteModel(GranitePreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
|
@ -40,7 +40,8 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_helium import HeliumConfig
|
||||
|
||||
|
||||
@ -195,6 +196,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class HeliumAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -224,7 +227,7 @@ class HeliumAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -261,6 +264,8 @@ class HeliumAttention(nn.Module):
|
||||
|
||||
|
||||
class HeliumDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -277,27 +282,25 @@ class HeliumDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -305,12 +308,7 @@ class HeliumDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -320,7 +318,6 @@ class HeliumPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["HeliumDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -367,7 +364,7 @@ class HeliumModel(HeliumPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -376,31 +373,13 @@ class HeliumModel(HeliumPreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -425,52 +404,28 @@ class HeliumModel(HeliumPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -515,11 +470,9 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -543,12 +496,6 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -556,8 +503,6 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -621,8 +566,7 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -638,8 +582,7 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -715,8 +658,7 @@ class HeliumForTokenClassification(HeliumPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -732,8 +674,7 @@ class HeliumForTokenClassification(HeliumPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -1510,8 +1510,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
|
@ -1302,8 +1302,7 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1319,8 +1318,6 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
@ -40,7 +40,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ..auto import AutoModel
|
||||
from .configuration_kyutai_speech_to_text import KyutaiSpeechToTextConfig
|
||||
|
||||
@ -1095,9 +1095,6 @@ class KyutaiSpeechToTextModel(KyutaiSpeechToTextPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -1149,11 +1146,9 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1187,12 +1182,6 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
||||
>>> output_tokens = model.generate(**inputs)
|
||||
>>> print(processor.batch_decode(output_tokens, skip_special_tokens=True))
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -1200,8 +1189,6 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -176,6 +176,8 @@ def eager_attention_forward(
|
||||
class LightGlueAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: LightGlueConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
@ -285,7 +285,6 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
@ -300,7 +299,6 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
@ -386,14 +384,14 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
cache_position: torch.Tensor = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
@ -412,7 +410,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
@ -422,8 +420,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
@ -574,7 +570,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -590,8 +586,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -670,8 +665,6 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
@ -729,7 +722,7 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -745,8 +738,7 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -342,6 +342,8 @@ def eager_attention_forward(
|
||||
class MiniMaxAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: MiniMaxConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -590,7 +592,6 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MiniMaxDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -1047,8 +1048,7 @@ class MiniMaxForSequenceClassification(MiniMaxPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1064,8 +1064,7 @@ class MiniMaxForSequenceClassification(MiniMaxPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -1141,8 +1140,7 @@ class MiniMaxForTokenClassification(MiniMaxPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1158,8 +1156,7 @@ class MiniMaxForTokenClassification(MiniMaxPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -26,7 +26,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from .configuration_mistral import MistralConfig
|
||||
|
||||
|
||||
@ -124,6 +124,8 @@ def eager_attention_forward(
|
||||
class MistralAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: MistralConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -205,6 +207,8 @@ class MistralRMSNorm(nn.Module):
|
||||
|
||||
|
||||
class MistralDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: MistralConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -219,27 +223,25 @@ class MistralDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -247,12 +249,7 @@ class MistralDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -262,7 +259,6 @@ class MistralPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MistralDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -445,9 +441,6 @@ class MistralModel(MistralPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -492,11 +485,9 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -520,12 +511,6 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -533,8 +518,6 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -592,8 +575,7 @@ class MistralForTokenClassification(MistralPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -609,8 +591,7 @@ class MistralForTokenClassification(MistralPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
@ -669,8 +650,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -686,8 +666,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
@ -236,6 +236,8 @@ def eager_attention_forward(
|
||||
class MixtralAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: MixtralConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -417,7 +419,6 @@ class MixtralPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MixtralDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -830,8 +831,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -847,8 +847,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -924,8 +923,7 @@ class MixtralForTokenClassification(MixtralPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -941,8 +939,7 @@ class MixtralForTokenClassification(MixtralPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -42,6 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_moonshine import MoonshineConfig
|
||||
|
||||
|
||||
@ -169,6 +170,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class MoonshineAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MoonshineConfig,
|
||||
@ -329,6 +332,8 @@ class MoonshineRotaryEmbedding(nn.Module):
|
||||
|
||||
|
||||
class MoonshineEncoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: MoonshineConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -351,27 +356,25 @@ class MoonshineEncoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -379,12 +382,7 @@ class MoonshineEncoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MoonshineDecoderLayer(GradientCheckpointingLayer):
|
||||
@ -678,7 +676,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1012,8 +1012,7 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1029,8 +1028,6 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -1102,8 +1099,6 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
@ -1112,8 +1107,6 @@ class NemotronForQuestionAnswering(NemotronPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
@ -1172,8 +1165,7 @@ class NemotronForTokenClassification(NemotronPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1189,8 +1181,7 @@ class NemotronForTokenClassification(NemotronPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -20,13 +20,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_olmo import OlmoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class OlmoLayerNorm(nn.Module):
|
||||
"""LayerNorm but with no learnable weight or bias."""
|
||||
|
||||
@ -133,6 +131,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
class OlmoAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: OlmoConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -210,6 +210,8 @@ class OlmoAttention(nn.Module):
|
||||
|
||||
|
||||
class OlmoDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: OlmoConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -225,27 +227,25 @@ class OlmoDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -253,12 +253,7 @@ class OlmoDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OlmoRotaryEmbedding(nn.Module):
|
||||
@ -301,7 +296,6 @@ class OlmoPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["OlmoDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -346,7 +340,7 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -355,31 +349,13 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -404,52 +380,28 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -494,11 +446,9 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -522,12 +472,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -535,8 +479,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -20,13 +20,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_olmo2 import Olmo2Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class Olmo2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
@ -124,6 +122,8 @@ def rotate_half(x):
|
||||
class Olmo2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -214,6 +214,8 @@ class Olmo2MLP(nn.Module):
|
||||
|
||||
|
||||
class Olmo2DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: Olmo2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -305,7 +307,6 @@ class Olmo2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Olmo2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -352,7 +353,7 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -361,31 +362,13 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -410,52 +393,28 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -500,11 +459,9 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -528,12 +485,6 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -541,8 +492,6 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -842,8 +842,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -859,8 +858,6 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -937,8 +934,7 @@ class PersimmonForTokenClassification(PersimmonPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -954,8 +950,7 @@ class PersimmonForTokenClassification(PersimmonPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -24,7 +24,8 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_phi import PhiConfig
|
||||
|
||||
|
||||
@ -106,6 +107,8 @@ def eager_attention_forward(
|
||||
class PhiAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: PhiConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -295,7 +298,6 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PhiDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -344,7 +346,7 @@ class PhiModel(PhiPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -442,9 +444,6 @@ class PhiModel(PhiPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -489,11 +488,9 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -517,12 +514,6 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -530,8 +521,6 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -595,8 +584,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -612,8 +600,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -689,8 +676,7 @@ class PhiForTokenClassification(PhiPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -706,8 +692,7 @@ class PhiForTokenClassification(PhiPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -41,7 +41,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from .configuration_phi3 import Phi3Config
|
||||
|
||||
|
||||
@ -234,6 +234,8 @@ class Phi3RMSNorm(nn.Module):
|
||||
|
||||
|
||||
class Phi3DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: Phi3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -316,7 +318,6 @@ class Phi3PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Phi3DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -500,9 +501,6 @@ class Phi3Model(Phi3PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -547,11 +545,9 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -575,12 +571,6 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -588,8 +578,6 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -692,8 +680,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -709,8 +696,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -786,8 +772,7 @@ class Phi3ForTokenClassification(Phi3PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -803,8 +788,7 @@ class Phi3ForTokenClassification(Phi3PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -1448,6 +1448,8 @@ class Phi4MultimodalAttention(nn.Module):
|
||||
|
||||
|
||||
class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: Phi4MultimodalConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -1622,7 +1624,6 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Phi4MultimodalDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -1433,8 +1433,7 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1450,8 +1449,6 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
@ -26,7 +26,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from .configuration_qwen2 import Qwen2Config
|
||||
|
||||
|
||||
@ -124,6 +124,8 @@ def eager_attention_forward(
|
||||
class Qwen2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Qwen2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -206,6 +208,8 @@ class Qwen2RMSNorm(nn.Module):
|
||||
|
||||
|
||||
class Qwen2DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: Qwen2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -223,27 +227,25 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -251,12 +253,7 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -266,7 +263,6 @@ class Qwen2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Qwen2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -459,9 +455,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -506,11 +499,9 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -534,12 +525,6 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -547,8 +532,6 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -612,8 +595,7 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -629,8 +611,7 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -706,8 +687,7 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -723,8 +703,7 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
@ -771,8 +750,6 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
@ -781,8 +758,6 @@ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
@ -1236,8 +1236,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1253,8 +1252,6 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -1331,8 +1328,7 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1348,8 +1344,7 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -160,6 +160,8 @@ def eager_attention_forward(
|
||||
class Qwen3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Qwen3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -232,6 +234,8 @@ class Qwen3Attention(nn.Module):
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: Qwen3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -249,27 +253,25 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -277,12 +279,7 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -292,7 +289,6 @@ class Qwen3PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Qwen3DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -532,11 +528,9 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -560,12 +554,6 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -573,8 +561,6 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -638,8 +624,7 @@ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -655,8 +640,7 @@ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -732,8 +716,7 @@ class Qwen3ForTokenClassification(Qwen3PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -749,8 +732,7 @@ class Qwen3ForTokenClassification(Qwen3PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
@ -797,8 +779,6 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
@ -807,8 +787,6 @@ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
@ -125,6 +125,8 @@ def eager_attention_forward(
|
||||
class Qwen3MoeAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -424,7 +426,6 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Qwen3MoeDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -837,8 +838,7 @@ class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -854,8 +854,7 @@ class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -931,8 +930,7 @@ class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -948,8 +946,7 @@ class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
@ -996,8 +993,6 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
@ -1006,8 +1001,6 @@ class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
@ -41,7 +41,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from .configuration_smollm3 import SmolLM3Config
|
||||
|
||||
|
||||
@ -123,6 +123,8 @@ def eager_attention_forward(
|
||||
class SmolLM3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: SmolLM3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -226,7 +228,6 @@ class SmolLM3PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SmolLM3DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -266,6 +267,8 @@ class SmolLM3MLP(nn.Module):
|
||||
|
||||
|
||||
class SmolLM3DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: SmolLM3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -283,27 +286,25 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -311,12 +312,7 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SmolLM3RotaryEmbedding(nn.Module):
|
||||
@ -489,9 +485,6 @@ class SmolLM3Model(SmolLM3PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SmolLM3ForCausalLM(SmolLM3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -536,11 +529,9 @@ class SmolLM3ForCausalLM(SmolLM3PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -564,12 +555,6 @@ class SmolLM3ForCausalLM(SmolLM3PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -577,8 +562,6 @@ class SmolLM3ForCausalLM(SmolLM3PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -642,8 +625,7 @@ class SmolLM3ForSequenceClassification(SmolLM3PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -659,8 +641,7 @@ class SmolLM3ForSequenceClassification(SmolLM3PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -736,8 +717,7 @@ class SmolLM3ForTokenClassification(SmolLM3PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -753,8 +733,7 @@ class SmolLM3ForTokenClassification(SmolLM3PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
@ -801,8 +780,6 @@ class SmolLM3ForQuestionAnswering(SmolLM3PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
@ -811,8 +788,6 @@ class SmolLM3ForQuestionAnswering(SmolLM3PreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
@ -1069,8 +1069,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1086,8 +1085,6 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -1164,8 +1161,7 @@ class StableLmForTokenClassification(StableLmPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1181,8 +1177,7 @@ class StableLmForTokenClassification(StableLmPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -44,7 +44,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from .configuration_starcoder2 import Starcoder2Config
|
||||
|
||||
|
||||
@ -143,6 +143,8 @@ def eager_attention_forward(
|
||||
class Starcoder2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -208,6 +210,8 @@ class Starcoder2Attention(nn.Module):
|
||||
|
||||
|
||||
class Starcoder2DecoderLayer(GradientCheckpointingLayer):
|
||||
return_hooks = {"hidden_states", 0}
|
||||
|
||||
def __init__(self, config: Starcoder2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -222,27 +226,25 @@ class Starcoder2DecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
)[0]
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
@ -250,12 +252,7 @@ class Starcoder2DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Starcoder2RotaryEmbedding(nn.Module):
|
||||
@ -299,7 +296,6 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Starcoder2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -449,9 +445,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -496,11 +489,9 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -524,12 +515,6 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -537,8 +522,6 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -602,8 +585,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -619,8 +601,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
@ -696,8 +677,7 @@ class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -713,8 +693,7 @@ class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
@ -205,6 +205,8 @@ def eager_attention_forward(
|
||||
class T5GemmaSelfAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -280,6 +282,8 @@ class T5GemmaSelfAttention(nn.Module):
|
||||
class T5GemmaCrossAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
return_hooks = {"attentions", 1}
|
||||
|
||||
def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -561,7 +565,6 @@ class T5GemmaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["T5GemmaBlock"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
Loading…
Reference in New Issue
Block a user