Add type hints for forward functions in Gemma2 (#35034)

* feat: add gemma2 type hints

* fix: mask is optional
This commit is contained in:
Jacky Lee 2024-12-02 06:03:36 -08:00 committed by GitHub
parent 7b5f76e32e
commit f41d5d8f74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 8 deletions

View File

@ -170,7 +170,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
def eager_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)
@ -192,7 +199,15 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs):
return attn_output, attn_weights
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
def flash_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
target_dtype: torch.dtype = torch.float16,
**_kwargs,
) -> Tuple[torch.Tensor, None]:
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
@ -229,7 +244,15 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.
return attn_output, None
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
def flex_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
output_attentions: bool = False,
**_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def tanh_softcap(score, b, h, q_idx, kv_idx):
soft_cap = config.attn_logit_softcapping
score = soft_cap * torch.tanh(score / soft_cap)
@ -255,7 +278,14 @@ def flex_attention_forward(config, query, key, value, mask, output_attentions=Fa
return attn_output, attn_weights
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
def sdpa_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, None]:
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)

View File

@ -213,7 +213,14 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
pass
def eager_attention_forward(config, query, key, value, mask, **_kwargs):
def eager_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)
@ -235,7 +242,15 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs):
return attn_output, attn_weights
def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
def flash_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
target_dtype: torch.dtype = torch.float16,
**_kwargs,
) -> Tuple[torch.Tensor, None]:
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
@ -272,7 +287,15 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.
return attn_output, None
def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
def flex_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
output_attentions: bool = False,
**_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def tanh_softcap(score, b, h, q_idx, kv_idx):
soft_cap = config.attn_logit_softcapping
score = soft_cap * torch.tanh(score / soft_cap)
@ -298,7 +321,14 @@ def flex_attention_forward(config, query, key, value, mask, output_attentions=Fa
return attn_output, attn_weights
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
def sdpa_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, None]:
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)