mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add type hints for forward functions in Gemma2 (#35034)
* feat: add gemma2 type hints * fix: mask is optional
This commit is contained in:
parent
7b5f76e32e
commit
f41d5d8f74
@ -170,7 +170,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
def eager_attention_forward(
|
||||
config: Gemma2Config,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
**_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
key_states = repeat_kv(key, config.num_key_value_groups)
|
||||
value_states = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
@ -192,7 +199,15 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
|
||||
def flash_attention_forward(
|
||||
config: Gemma2Config,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
target_dtype: torch.dtype = torch.float16,
|
||||
**_kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
if mask is not None:
|
||||
seq_len = mask.shape[1]
|
||||
query = query[:, :, :seq_len]
|
||||
@ -229,7 +244,15 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
|
||||
def flex_attention_forward(
|
||||
config: Gemma2Config,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
output_attentions: bool = False,
|
||||
**_kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
def tanh_softcap(score, b, h, q_idx, kv_idx):
|
||||
soft_cap = config.attn_logit_softcapping
|
||||
score = soft_cap * torch.tanh(score / soft_cap)
|
||||
@ -255,7 +278,14 @@ def flex_attention_forward(config, query, key, value, mask, output_attentions=Fa
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
def sdpa_attention_forward(
|
||||
config: Gemma2Config,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
**_kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
key = repeat_kv(key, config.num_key_value_groups)
|
||||
value = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
|
@ -213,7 +213,14 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
def eager_attention_forward(
|
||||
config: Gemma2Config,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
**_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
key_states = repeat_kv(key, config.num_key_value_groups)
|
||||
value_states = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
@ -235,7 +242,15 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
|
||||
def flash_attention_forward(
|
||||
config: Gemma2Config,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
target_dtype: torch.dtype = torch.float16,
|
||||
**_kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
if mask is not None:
|
||||
seq_len = mask.shape[1]
|
||||
query = query[:, :, :seq_len]
|
||||
@ -272,7 +287,15 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
|
||||
def flex_attention_forward(
|
||||
config: Gemma2Config,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
output_attentions: bool = False,
|
||||
**_kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
def tanh_softcap(score, b, h, q_idx, kv_idx):
|
||||
soft_cap = config.attn_logit_softcapping
|
||||
score = soft_cap * torch.tanh(score / soft_cap)
|
||||
@ -298,7 +321,14 @@ def flex_attention_forward(config, query, key, value, mask, output_attentions=Fa
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
def sdpa_attention_forward(
|
||||
config: Gemma2Config,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
**_kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
key = repeat_kv(key, config.num_key_value_groups)
|
||||
value = repeat_kv(value, config.num_key_value_groups)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user