Represent query_length in a different way to solve jit issue (#25164)

Fix jit trace
This commit is contained in:
jiqing-feng 2023-07-28 20:19:10 +08:00 committed by GitHub
parent 2a78720104
commit d23d2c27c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -154,9 +154,7 @@ class MptAttention(nn.Module):
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
query_length = seq_length
if past_key_value is not None:
query_length += past_key_value[0].shape[2]
query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]
if position_bias is not None:
if len(position_bias.shape) != 3: