[fix] sliding window attention mask (#38045)

* fix sliding attn

* make style

* Update tests/test_modeling_common.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* no a second throught, should default to `True` fo BC

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay 2025-05-20 11:32:19 +02:00 committed by GitHub
parent 555715f418
commit 0a52bd2403
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 93 additions and 36 deletions

View File

@ -1192,12 +1192,13 @@ class MimiTransformerModel(nn.Module):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -580,12 +580,13 @@ class MistralModel(MistralPreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -254,12 +254,13 @@ class MistralModel(LlamaModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -721,12 +721,13 @@ class MixtralModel(MixtralPreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -1216,12 +1216,13 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
@ -1522,12 +1523,13 @@ class MoshiModel(MoshiPreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -635,12 +635,13 @@ class Phi3Model(Phi3PreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -1943,12 +1943,13 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -1200,12 +1200,13 @@ class PhimoeModel(PhimoePreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -593,12 +593,13 @@ class Qwen2Model(Qwen2PreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -2128,12 +2128,13 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
@ -2833,12 +2834,13 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -1354,12 +1354,13 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -1047,12 +1047,13 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -1314,12 +1314,13 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -620,12 +620,13 @@ class Qwen3Model(Qwen3PreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -732,12 +732,13 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -584,12 +584,13 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
if config.get_text_config().sliding_window is not None:
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask

View File

@ -4323,6 +4323,45 @@ class ModelTesterMixin:
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
def test_sliding_window_mask(self):
"""Tests that we can control the sliding window attention behavior of a model."""
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if not self.has_attentions:
self.skipTest(reason="Model does not support output_attentions")
if not (hasattr(config, "sliding_window") and hasattr(config, "use_sliding_window")):
self.skipTest(reason="Model does not support sliding window mask")
seq_len = self.model_tester.seq_length
batch_size = self.model_tester.batch_size
sliding_window = 3 # set to arbitrary small number
sliding_mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)
for i in range(seq_len):
start = max(0, i - sliding_window + 1)
sliding_mask[i, start : i + 1] = True
sliding_mask = sliding_mask.to(torch_device)
config.sliding_window = sliding_window
inputs["attention_mask"] = torch.ones(batch_size, seq_len).to(torch.int64).to(torch_device)
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model.eval()
# Set sliding window to `True` and check that all tokens beyond window size are masked
model.config.use_sliding_window = True
attentions = model(**inputs, output_attentions=True).attentions
for layer_attention in attentions:
self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item())
# Set sliding window to `False` while keeping `sliding_window=3`
# Check that all tokens beyond window size are not masked
model.config.use_sliding_window = False
attentions_not_sliding = model(**inputs, output_attentions=True).attentions
for layer_attention in attentions_not_sliding:
self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item())
def test_custom_4d_attention_mask(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")