[OPT] Fix attention scaling (#38290)

* fix opt attention scaling

* add comment to why we do this
This commit is contained in:
Anton Vlasjuk 2025-05-26 11:02:16 +02:00 committed by GitHub
parent a5a0c7b888
commit d03a3ca692
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -154,7 +154,11 @@ class OPTAttention(nn.Module):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
# get query proj
# Scaling is susceptible to floating point arithmetics' inprecisions
# which can lead to different results (this is dependent from model
# to model, e.g. whisper is one such case). We therefore keep the
# original order of scaling to follow the original implementation
# and enforce no scaling (1.0) in the attention call below.
query_states = self.q_proj(hidden_states) * self.scaling
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
@ -187,7 +191,7 @@ class OPTAttention(nn.Module):
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
scaling=1.0,
**kwargs,
)