mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[OPT
] Fix attention scaling (#38290)
* fix opt attention scaling * add comment to why we do this
This commit is contained in:
parent
a5a0c7b888
commit
d03a3ca692
@ -154,7 +154,11 @@ class OPTAttention(nn.Module):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
# Scaling is susceptible to floating point arithmetics' inprecisions
|
||||
# which can lead to different results (this is dependent from model
|
||||
# to model, e.g. whisper is one such case). We therefore keep the
|
||||
# original order of scaling to follow the original implementation
|
||||
# and enforce no scaling (1.0) in the attention call below.
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
@ -187,7 +191,7 @@ class OPTAttention(nn.Module):
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scaling,
|
||||
scaling=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user