mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update
This commit is contained in:
parent
96aabd77c7
commit
63df15bb24
@ -176,7 +176,7 @@ def eager_attention_forward(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
@ -227,7 +227,7 @@ class LlamaAttention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
@ -378,7 +378,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
@ -427,9 +427,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, TransformersKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -476,7 +473,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -570,7 +567,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -657,7 +654,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
input_ids,
|
||||
|
@ -854,7 +854,7 @@ def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
||||
return decorator
|
||||
|
||||
|
||||
class TransformersKwars(FlashAttentionKwargs, total=False):
|
||||
class TransformersKwargs(FlashAttentionKwargs, TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments to be passed to the loss function
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user