diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 480687ae7f0..eef54b02ec0 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -154,7 +154,11 @@ class OPTAttention(nn.Module): """Input shape: Batch x Time x Channel""" bsz, tgt_len, _ = hidden_states.size() - # get query proj + # Scaling is susceptible to floating point arithmetics' inprecisions + # which can lead to different results (this is dependent from model + # to model, e.g. whisper is one such case). We therefore keep the + # original order of scaling to follow the original implementation + # and enforce no scaling (1.0) in the attention call below. query_states = self.q_proj(hidden_states) * self.scaling query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) @@ -187,7 +191,7 @@ class OPTAttention(nn.Module): value_states, attention_mask, dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, + scaling=1.0, **kwargs, )