From d23d2c27c2cd5beeb79f1d54c1a2b7be646f5a47 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Fri, 28 Jul 2023 20:19:10 +0800 Subject: [PATCH] Represent query_length in a different way to solve jit issue (#25164) Fix jit trace --- src/transformers/models/mpt/modeling_mpt.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index c7157155862..e1e176568bf 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -154,9 +154,7 @@ class MptAttention(nn.Module): attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale - query_length = seq_length - if past_key_value is not None: - query_length += past_key_value[0].shape[2] + query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] if position_bias is not None: if len(position_bias.shape) != 3: