mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[CLeanup
] Revert SDPA attention changes that got in the static kv cache PR (#29027)
* revert unrelated changes that got in * style
This commit is contained in:
parent
7a0fccc6eb
commit
609a1767e8
@ -659,34 +659,28 @@ class MistralSdpaAttention(MistralAttention):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
past_seen_tokens = kv_seq_len - key_states.shape[-2]
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if (
|
||||
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
|
||||
): # user defined causal mask
|
||||
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
# this one liner is equivalent to the pad_unpad function
|
||||
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
|
||||
else:
|
||||
causal_mask = None
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
@ -695,9 +689,10 @@ class MistralSdpaAttention(MistralAttention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
@ -736,34 +736,28 @@ class MixtralSdpaAttention(MixtralAttention):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
past_seen_tokens = kv_seq_len - key_states.shape[-2]
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if (
|
||||
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
|
||||
): # user defined causal mask
|
||||
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
# this one liner is equivalent to the pad_unpad function
|
||||
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
|
||||
else:
|
||||
causal_mask = None
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
@ -772,9 +766,10 @@ class MixtralSdpaAttention(MixtralAttention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
@ -669,34 +669,28 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
past_seen_tokens = kv_seq_len - key_states.shape[-2]
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if (
|
||||
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
|
||||
): # user defined causal mask
|
||||
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
# this one liner is equivalent to the pad_unpad function
|
||||
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
|
||||
else:
|
||||
causal_mask = None
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
@ -705,9 +699,10 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
Loading…
Reference in New Issue
Block a user