fix use cache (#3)

This commit is contained in:
Younes Belkada 2023-09-26 13:20:15 +02:00 committed by GitHub
parent 7800457d1a
commit 689f599132
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -317,7 +317,7 @@ class OptFlashAttention2(OPTAttention):
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
bsz, _, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
@ -351,13 +351,15 @@ class OptFlashAttention2(OPTAttention):
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_length = query_states.shape[1]
tgt_len = key_states.shape[-2]
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
_, query_length, _, _ = query_states.shape
attn_dropout = self.dropout if self.training else 0.0