mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
update models based on qwen2
This commit is contained in:
parent
113219becd
commit
e7705c981a
@ -35,13 +35,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||||
|
from ...utils.generic import check_model_inputs
|
||||||
from .configuration_dots1 import Dots1Config
|
from .configuration_dots1 import Dots1Config
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@use_kernel_forward_from_hub("RMSNorm")
|
@use_kernel_forward_from_hub("RMSNorm")
|
||||||
class Dots1RMSNorm(nn.Module):
|
class Dots1RMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
@ -469,7 +467,7 @@ class Dots1Model(Dots1PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embed_tokens = value
|
self.embed_tokens = value
|
||||||
|
|
||||||
@can_return_tuple
|
@check_model_inputs
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -479,30 +477,12 @@ class Dots1Model(Dots1PreTrainedModel):
|
|||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
|
||||||
if not isinstance(past_key_values, (type(None), Cache)):
|
|
||||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -541,42 +521,22 @@ class Dots1Model(Dots1PreTrainedModel):
|
|||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
if output_hidden_states:
|
hidden_states = decoder_layer(
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
**flash_attn_kwargs,
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=past_key_values if use_cache else None,
|
past_key_values=past_key_values if use_cache else None,
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ from ...modeling_outputs import (
|
|||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, check_model_inputs, logging
|
||||||
from .configuration_qwen2 import Qwen2Config
|
from .configuration_qwen2 import Qwen2Config
|
||||||
|
|
||||||
|
|
||||||
@ -343,7 +343,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embed_tokens = value
|
self.embed_tokens = value
|
||||||
|
|
||||||
@can_return_tuple
|
@check_model_inputs
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -353,30 +353,12 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
|
||||||
if not isinstance(past_key_values, (type(None), Cache)):
|
|
||||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -415,42 +397,22 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
if output_hidden_states:
|
hidden_states = decoder_layer(
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
**flash_attn_kwargs,
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=past_key_values if use_cache else None,
|
past_key_values=past_key_values if use_cache else None,
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||||
|
from ...utils.generic import check_model_inputs
|
||||||
from .configuration_qwen3 import Qwen3Config
|
from .configuration_qwen3 import Qwen3Config
|
||||||
|
|
||||||
|
|
||||||
@ -369,7 +370,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embed_tokens = value
|
self.embed_tokens = value
|
||||||
|
|
||||||
@can_return_tuple
|
@check_model_inputs
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -379,30 +380,12 @@ class Qwen3Model(Qwen3PreTrainedModel):
|
|||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
|
||||||
if not isinstance(past_key_values, (type(None), Cache)):
|
|
||||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -441,42 +424,22 @@ class Qwen3Model(Qwen3PreTrainedModel):
|
|||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
if output_hidden_states:
|
hidden_states = decoder_layer(
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
**flash_attn_kwargs,
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=past_key_values if use_cache else None,
|
past_key_values=past_key_values if use_cache else None,
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||||
|
from ...utils.generic import check_model_inputs
|
||||||
from .configuration_smollm3 import SmolLM3Config
|
from .configuration_smollm3 import SmolLM3Config
|
||||||
|
|
||||||
|
|
||||||
@ -372,7 +373,7 @@ class SmolLM3Model(SmolLM3PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embed_tokens = value
|
self.embed_tokens = value
|
||||||
|
|
||||||
@can_return_tuple
|
@check_model_inputs
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -382,30 +383,12 @@ class SmolLM3Model(SmolLM3PreTrainedModel):
|
|||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
|
||||||
if not isinstance(past_key_values, (type(None), Cache)):
|
|
||||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -444,42 +427,22 @@ class SmolLM3Model(SmolLM3PreTrainedModel):
|
|||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
if output_hidden_states:
|
hidden_states = decoder_layer(
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
**flash_attn_kwargs,
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=past_key_values if use_cache else None,
|
past_key_values=past_key_values if use_cache else None,
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user