mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Represent query_length in a different way to solve jit issue (#25164)
Fix jit trace
This commit is contained in:
parent
2a78720104
commit
d23d2c27c2
@ -154,9 +154,7 @@ class MptAttention(nn.Module):
|
||||
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
|
||||
|
||||
query_length = seq_length
|
||||
if past_key_value is not None:
|
||||
query_length += past_key_value[0].shape[2]
|
||||
query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]
|
||||
|
||||
if position_bias is not None:
|
||||
if len(position_bias.shape) != 3:
|
||||
|
Loading…
Reference in New Issue
Block a user