mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
update
This commit is contained in:
parent
96aabd77c7
commit
63df15bb24
@ -176,7 +176,7 @@ def eager_attention_forward(
|
|||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
scaling: float,
|
scaling: float,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
**kwargs,
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
):
|
):
|
||||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||||
@ -227,7 +227,7 @@ class LlamaAttention(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
@ -378,7 +378,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
@ -427,9 +427,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class KwargsForCausalLM(FlashAttentionKwargs, TransformersKwargs): ...
|
|
||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
@ -476,7 +473,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> CausalLMOutputWithPast:
|
) -> CausalLMOutputWithPast:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@ -570,7 +567,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> SequenceClassifierOutputWithPast:
|
) -> SequenceClassifierOutputWithPast:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
@ -657,7 +654,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
|||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
start_positions: Optional[torch.LongTensor] = None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions: Optional[torch.LongTensor] = None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> QuestionAnsweringModelOutput:
|
) -> QuestionAnsweringModelOutput:
|
||||||
outputs: BaseModelOutputWithPast = self.transformer(
|
outputs: BaseModelOutputWithPast = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -854,7 +854,7 @@ def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class TransformersKwars(FlashAttentionKwargs, total=False):
|
class TransformersKwargs(FlashAttentionKwargs, TypedDict, total=False):
|
||||||
"""
|
"""
|
||||||
Keyword arguments to be passed to the loss function
|
Keyword arguments to be passed to the loss function
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user