mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[fix] sliding window attention mask (#38045)
* fix sliding attn * make style * Update tests/test_modeling_common.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * no a second throught, should default to `True` fo BC --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
555715f418
commit
0a52bd2403
@ -1192,12 +1192,13 @@ class MimiTransformerModel(nn.Module):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -580,12 +580,13 @@ class MistralModel(MistralPreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -254,12 +254,13 @@ class MistralModel(LlamaModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -721,12 +721,13 @@ class MixtralModel(MixtralPreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -1216,12 +1216,13 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
@ -1522,12 +1523,13 @@ class MoshiModel(MoshiPreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -635,12 +635,13 @@ class Phi3Model(Phi3PreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -1943,12 +1943,13 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -1200,12 +1200,13 @@ class PhimoeModel(PhimoePreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -593,12 +593,13 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -2128,12 +2128,13 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
@ -2833,12 +2834,13 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -1354,12 +1354,13 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -1047,12 +1047,13 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -1314,12 +1314,13 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -620,12 +620,13 @@ class Qwen3Model(Qwen3PreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -732,12 +732,13 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -584,12 +584,13 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
if config.get_text_config().sliding_window is not None:
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
|
@ -4323,6 +4323,45 @@ class ModelTesterMixin:
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_sliding_window_mask(self):
|
||||
"""Tests that we can control the sliding window attention behavior of a model."""
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model does not support output_attentions")
|
||||
|
||||
if not (hasattr(config, "sliding_window") and hasattr(config, "use_sliding_window")):
|
||||
self.skipTest(reason="Model does not support sliding window mask")
|
||||
|
||||
seq_len = self.model_tester.seq_length
|
||||
batch_size = self.model_tester.batch_size
|
||||
sliding_window = 3 # set to arbitrary small number
|
||||
|
||||
sliding_mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)
|
||||
for i in range(seq_len):
|
||||
start = max(0, i - sliding_window + 1)
|
||||
sliding_mask[i, start : i + 1] = True
|
||||
sliding_mask = sliding_mask.to(torch_device)
|
||||
|
||||
config.sliding_window = sliding_window
|
||||
inputs["attention_mask"] = torch.ones(batch_size, seq_len).to(torch.int64).to(torch_device)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# Set sliding window to `True` and check that all tokens beyond window size are masked
|
||||
model.config.use_sliding_window = True
|
||||
attentions = model(**inputs, output_attentions=True).attentions
|
||||
for layer_attention in attentions:
|
||||
self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item())
|
||||
|
||||
# Set sliding window to `False` while keeping `sliding_window=3`
|
||||
# Check that all tokens beyond window size are not masked
|
||||
model.config.use_sliding_window = False
|
||||
attentions_not_sliding = model(**inputs, output_attentions=True).attentions
|
||||
for layer_attention in attentions_not_sliding:
|
||||
self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item())
|
||||
|
||||
def test_custom_4d_attention_mask(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
Loading…
Reference in New Issue
Block a user