diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 96f879d31ff..515fdc4f045 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1192,12 +1192,13 @@ class MimiTransformerModel(nn.Module): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ac7ed954b96..2748a71f527 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -580,12 +580,13 @@ class MistralModel(MistralPreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 0d525407152..d56f2d2f370 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -254,12 +254,13 @@ class MistralModel(LlamaModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 23a7c663ed2..cfb5b1be8b6 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -721,12 +721,13 @@ class MixtralModel(MixtralPreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index ef705af4169..fe0c1024016 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1216,12 +1216,13 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask @@ -1522,12 +1523,13 @@ class MoshiModel(MoshiPreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index d0728df4638..a914461ad73 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -635,12 +635,13 @@ class Phi3Model(Phi3PreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index a373e723676..2d6a2a7f155 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1943,12 +1943,13 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 35c24718fb0..5002014006b 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1200,12 +1200,13 @@ class PhimoeModel(PhimoePreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8ad0facfc84..46fc9b720c8 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -593,12 +593,13 @@ class Qwen2Model(Qwen2PreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 92be4356e81..e973c5c64d9 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -2128,12 +2128,13 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask @@ -2833,12 +2834,13 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 12d98835338..18da004501b 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1354,12 +1354,13 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 7bc3606c5b2..ccd4a6b4191 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1047,12 +1047,13 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 47e54e76ae8..17cd7d5dcac 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1314,12 +1314,13 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index bc5a770a8c3..082de2f193b 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -620,12 +620,13 @@ class Qwen3Model(Qwen3PreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 6c2ac9e4611..72591887a41 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -732,12 +732,13 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 50317b78f8b..1e1cfdde775 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -584,12 +584,13 @@ class Starcoder2Model(Starcoder2PreTrainedModel): diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) - if config.get_text_config().sliding_window is not None: + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + cache_position.reshape(-1, 1) - text_config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 67da1ad857e..92f3daf8b73 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4323,6 +4323,45 @@ class ModelTesterMixin: return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + def test_sliding_window_mask(self): + """Tests that we can control the sliding window attention behavior of a model.""" + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + if not self.has_attentions: + self.skipTest(reason="Model does not support output_attentions") + + if not (hasattr(config, "sliding_window") and hasattr(config, "use_sliding_window")): + self.skipTest(reason="Model does not support sliding window mask") + + seq_len = self.model_tester.seq_length + batch_size = self.model_tester.batch_size + sliding_window = 3 # set to arbitrary small number + + sliding_mask = torch.zeros((seq_len, seq_len), dtype=torch.bool) + for i in range(seq_len): + start = max(0, i - sliding_window + 1) + sliding_mask[i, start : i + 1] = True + sliding_mask = sliding_mask.to(torch_device) + + config.sliding_window = sliding_window + inputs["attention_mask"] = torch.ones(batch_size, seq_len).to(torch.int64).to(torch_device) + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + # Set sliding window to `True` and check that all tokens beyond window size are masked + model.config.use_sliding_window = True + attentions = model(**inputs, output_attentions=True).attentions + for layer_attention in attentions: + self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item()) + + # Set sliding window to `False` while keeping `sliding_window=3` + # Check that all tokens beyond window size are not masked + model.config.use_sliding_window = False + attentions_not_sliding = model(**inputs, output_attentions=True).attentions + for layer_attention in attentions_not_sliding: + self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item()) + def test_custom_4d_attention_mask(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions")