This commit is contained in:
Arthur 2025-06-30 12:26:51 +02:00
parent 96aabd77c7
commit 63df15bb24
2 changed files with 7 additions and 10 deletions

View File

@ -176,7 +176,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
scaling: float, scaling: float,
dropout: float = 0.0, dropout: float = 0.0,
**kwargs, **kwargs: Unpack[TransformersKwargs],
): ):
key_states = repeat_kv(key, module.num_key_value_groups) key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups)
@ -227,7 +227,7 @@ class LlamaAttention(nn.Module):
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
@ -378,7 +378,7 @@ class LlamaModel(LlamaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -427,9 +427,6 @@ class LlamaModel(LlamaPreTrainedModel):
) )
class KwargsForCausalLM(FlashAttentionKwargs, TransformersKwargs): ...
@auto_docstring @auto_docstring
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
@ -476,7 +473,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast: ) -> CausalLMOutputWithPast:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -570,7 +567,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
**kwargs, **kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutputWithPast: ) -> SequenceClassifierOutputWithPast:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -657,7 +654,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None, start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None,
**kwargs, **kwargs: Unpack[TransformersKwargs],
) -> QuestionAnsweringModelOutput: ) -> QuestionAnsweringModelOutput:
outputs: BaseModelOutputWithPast = self.transformer( outputs: BaseModelOutputWithPast = self.transformer(
input_ids, input_ids,

View File

@ -854,7 +854,7 @@ def filter_out_non_signature_kwargs(extra: Optional[list] = None):
return decorator return decorator
class TransformersKwars(FlashAttentionKwargs, total=False): class TransformersKwargs(FlashAttentionKwargs, TypedDict, total=False):
""" """
Keyword arguments to be passed to the loss function Keyword arguments to be passed to the loss function