fix Starcoder FA2 implementation (#28891)

This commit is contained in:
Sourab Mangrulkar 2024-02-07 14:10:10 +05:30 committed by GitHub
parent 64d1518cbf
commit d9deddb4c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -363,13 +363,6 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
attn_dropout = self.attn_pdrop if self.training else 0.0
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
upcast = query.dtype != softmax_dtype
softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
softmax_scale = softmax_scale**-1
if self.scale_attn_weights:
softmax_scale /= self.head_dim**0.5
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
@ -393,7 +386,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
value = value.to(target_dtype)
attn_output = self._flash_attention_forward(
query, key, value, attention_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale
query, key, value, attention_mask, query_length, dropout=attn_dropout
)
attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)