This commit is contained in:
Arthur 2025-06-30 12:26:51 +02:00
parent 96aabd77c7
commit 63df15bb24
2 changed files with 7 additions and 10 deletions

View File

@ -176,7 +176,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
@ -227,7 +227,7 @@ class LlamaAttention(nn.Module):
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
@ -378,7 +378,7 @@ class LlamaModel(LlamaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -427,9 +427,6 @@ class LlamaModel(LlamaPreTrainedModel):
)
class KwargsForCausalLM(FlashAttentionKwargs, TransformersKwargs): ...
@auto_docstring
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
@ -476,7 +473,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[FlashAttentionKwargs],
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -570,7 +567,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -657,7 +654,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> QuestionAnsweringModelOutput:
outputs: BaseModelOutputWithPast = self.transformer(
input_ids,

View File

@ -854,7 +854,7 @@ def filter_out_non_signature_kwargs(extra: Optional[list] = None):
return decorator
class TransformersKwars(FlashAttentionKwargs, total=False):
class TransformersKwargs(FlashAttentionKwargs, TypedDict, total=False):
"""
Keyword arguments to be passed to the loss function