mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix Starcoder FA2 implementation (#28891)
This commit is contained in:
parent
64d1518cbf
commit
d9deddb4c1
@ -363,13 +363,6 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
||||
|
||||
attn_dropout = self.attn_pdrop if self.training else 0.0
|
||||
|
||||
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
|
||||
upcast = query.dtype != softmax_dtype
|
||||
softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
|
||||
softmax_scale = softmax_scale**-1
|
||||
if self.scale_attn_weights:
|
||||
softmax_scale /= self.head_dim**0.5
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
@ -393,7 +386,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
||||
value = value.to(target_dtype)
|
||||
|
||||
attn_output = self._flash_attention_forward(
|
||||
query, key, value, attention_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale
|
||||
query, key, value, attention_mask, query_length, dropout=attn_dropout
|
||||
)
|
||||
|
||||
attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
||||
|
Loading…
Reference in New Issue
Block a user