mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix use cache (#3)
This commit is contained in:
parent
7800457d1a
commit
689f599132
@ -317,7 +317,7 @@ class OptFlashAttention2(OPTAttention):
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
bsz, _, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
@ -351,13 +351,15 @@ class OptFlashAttention2(OPTAttention):
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
query_length = query_states.shape[1]
|
||||
tgt_len = key_states.shape[-2]
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
|
||||
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
|
||||
_, query_length, _, _ = query_states.shape
|
||||
|
||||
attn_dropout = self.dropout if self.training else 0.0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user