diff --git a/docs/source/en/attention_interface.md b/docs/source/en/attention_interface.md index 2afb9222875..76f264d83ee 100644 --- a/docs/source/en/attention_interface.md +++ b/docs/source/en/attention_interface.md @@ -125,4 +125,44 @@ would expect from a usual Python dictionary: # You can also globally `register` a new function directly on it >>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func) -``` \ No newline at end of file +``` + +## Attention Mask Interface + +Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens +the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as +the `AttentionInterface`: + +```python +from transformers import AttentionMaskInterface +from transformers.masking_utils import sdpa_mask +import torch + +def my_new_sdpa_mask(*args, **kwargs): + print("I just entered the attention mask computation") + return sdpa_mask(*args, **kwargs) + +AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask) +``` + +The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor). +By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped +and `attention_mask=None` will be passed along to the Attention layers. + +The default signature of the attention mask functions is the following: + +```python +def custom_attention_mask( + batch_size: int, # required arg + cache_position: torch.Tensor, # required arg + kv_length: int, # required arg + kv_offset: int = 0, # required arg + mask_function: Callable = causal_mask_function, # required arg + attention_mask: Optional[torch.Tensor] = None, # required arg + **kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed +) -> Optional[torch.Tensor]: +``` + +It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation. + +If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py). \ No newline at end of file diff --git a/docs/source/en/internal/modeling_utils.md b/docs/source/en/internal/modeling_utils.md index 1c7d16ad061..11f13de081b 100644 --- a/docs/source/en/internal/modeling_utils.md +++ b/docs/source/en/internal/modeling_utils.md @@ -29,6 +29,11 @@ Most of those are only useful if you are studying the code of the models in the [[autodoc]] AttentionInterface - register +## Attention Mask Functions + +[[autodoc]] AttentionMaskInterface + - register + ## Rotary Position Embedding Functions [[autodoc]] dynamic_rope_update diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2c7db1eec71..155d0fd6d39 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -445,6 +445,7 @@ else: _import_structure["modeling_outputs"] = [] _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"] _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"] + _import_structure["masking_utils"] = ["AttentionMaskInterface"] _import_structure["optimization"] = [ "Adafactor", "get_constant_schedule", @@ -914,6 +915,7 @@ if TYPE_CHECKING: TorchExportableModuleWithStaticCache, convert_and_export_with_cache, ) + from .masking_utils import AttentionMaskInterface from .model_debugging_utils import ( model_addition_debugger_context, ) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d24edd390c2..c0bd42f2e39 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -196,6 +196,18 @@ class Cache: else: return None + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, 0 + @dataclass class CacheConfig: @@ -1084,8 +1096,6 @@ class SinkCache(Cache): ``` """ - is_sliding = True - def __init__(self, window_length: int, num_sink_tokens: int) -> None: super().__init__() self.key_cache: List[torch.Tensor] = [] @@ -1390,6 +1400,16 @@ class StaticCache(Cache): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + kv_length = self.get_max_cache_shape() + return kv_length, 0 + class SlidingWindowCache(StaticCache): """ @@ -1446,7 +1466,6 @@ class SlidingWindowCache(StaticCache): ``` """ - is_sliding = True is_compileable = True def __init__( @@ -1465,6 +1484,7 @@ class SlidingWindowCache(StaticCache): "config and it's not set to None." ) max_cache_len = min(config.sliding_window, max_cache_len) + self.sliding_window = config.sliding_window super().__init__( config=config, max_batch_size=max_batch_size, @@ -1509,6 +1529,21 @@ class SlidingWindowCache(StaticCache): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + # torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor + kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) + # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns + kv_length = max(query_length, self.get_max_cache_shape()) + return kv_length, kv_offset + class EncoderDecoderCache(Cache): """ @@ -1761,12 +1796,17 @@ class HybridCache(Cache): else config.num_key_value_heads ) - layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC - self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)] + # If the attribute does not exist in the config, fallback to a simple StaticCache + if hasattr(config, "layer_types"): + self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] + else: + self.is_sliding = [False] * config.num_hidden_layers + self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) + self.sliding_window = min(config.sliding_window, max_cache_len) device = torch.device(device) if device is not None else None for i in range(config.num_hidden_layers): if layer_device_map is not None: @@ -1775,7 +1815,7 @@ class HybridCache(Cache): layer_device = device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape + cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) @@ -1796,7 +1836,7 @@ class HybridCache(Cache): if cache_position is None: raise ValueError("`cache_position` must be provided for HybridCache.") - is_sliding_layer = self.is_sliding_list[layer_idx] + is_sliding_layer = self.is_sliding[layer_idx] # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # when the cache is initialized in the forward pass (e.g. Gemma2) @@ -1843,6 +1883,26 @@ class HybridCache(Cache): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + if self.is_sliding[layer_idx]: + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + + local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) + # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns + local_mask_kv_length = max(query_length, self.sliding_window) + return local_mask_kv_length, local_mask_kv_offset + + full_mask_kv_offset = 0 + full_mask_kv_length = self.get_max_cache_shape() + return full_mask_kv_length, full_mask_kv_offset + class HybridChunkedCache(Cache): """ @@ -1912,11 +1972,11 @@ class HybridChunkedCache(Cache): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self._dtype = dtype - if hasattr(config.get_text_config(), "no_rope_layers"): - self.is_sliding = config.no_rope_layers + # If the attribute does not exist in the config, fallback to a simple StaticCache + if hasattr(config, "layer_types"): + self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] else: - layer_switch = getattr(config, "sliding_window_pattern", 2) - self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)] + self.is_sliding = [False] * config.num_hidden_layers self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] @@ -1999,11 +2059,7 @@ class HybridChunkedCache(Cache): key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) - if self.is_sliding[layer_idx]: - update_fn = self._sliding_update - else: - update_fn = self._static_update - + update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update return update_fn( cache_position, layer_idx, @@ -2038,6 +2094,37 @@ class HybridChunkedCache(Cache): self.value_cache[layer_idx].zero_() self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + if self.is_sliding[layer_idx]: + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + + local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) + # This is the true general case for any Cache using local attention (sliding or chunked) + if first_cache_position >= self.sliding_window: + # Here the Cache is already full + local_mask_kv_length = self.sliding_window + query_length - 1 + elif ( + first_cache_position < self.sliding_window + and first_cache_position + query_length > self.sliding_window + ): + # Here the Cache becomes full with the new input + local_mask_kv_length = first_cache_position + query_length + else: + # Here the Cache is still smaller than the local size, but we return the local size as it's static + local_mask_kv_length = self.sliding_window + return local_mask_kv_length, local_mask_kv_offset + + full_mask_kv_offset = 0 + full_mask_kv_length = self.get_max_cache_shape() + return full_mask_kv_length, full_mask_kv_offset + class OffloadedHybridCache(HybridChunkedCache): def __init__( diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 18cc0ff3a5d..6e75fbfb54a 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -1209,3 +1209,16 @@ if PretrainedConfig.push_to_hub.__doc__ is not None: PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format( object="config", object_class="AutoConfig", object_files="configuration file" ) + + +ALLOWED_LAYER_TYPES = ( + "full_attention", + "sliding_attention", + "chunked_attention", +) + + +def layer_type_validation(layer_types: list[str]): + """Check that each entry in `layer_types` are allowed.""" + if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types): + raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}") diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 60859662f61..49dc4b8df72 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -46,6 +46,7 @@ from ..dynamic_module_utils import ( ) from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..integrations.fsdp import is_fsdp_managed_module +from ..masking_utils import create_masks_for_generate from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils import ExtensionsTrie @@ -74,6 +75,7 @@ from .candidate_generator import ( from .configuration_utils import ( NEED_SETUP_CACHE_CLASSES_MAPPING, QUANT_BACKEND_CLASSES_MAPPING, + CompileConfig, GenerationConfig, GenerationMode, ) @@ -649,12 +651,22 @@ class GenerationMixin: causal_mask_creation_function = getattr( decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None ) + + # If it's not defined, it means the model uses the new general mask API if causal_mask_creation_function is None: # can't be found - logger.warning_once( - f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " - "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " - "writing code, see Llama for an example implementation. If you're a user, please report this " - "issue on GitHub." + output_attentions = kwargs.get("output_attentions", False) + token_type_ids = getattr(model_input, "token_type_ids", None) + # Some models may overwrite the general one + causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate) + attention_mask = causal_mask_creation_function( + config=self.config, + # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings + input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype), + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, + token_type_ids=token_type_ids, ) else: attention_mask = causal_mask_creation_function( @@ -3533,6 +3545,19 @@ class GenerationMixin: compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) if compile_forward: os.environ["TOKENIZERS_PARALLELISM"] = "0" + # If we use FA2 and a static cache, we cannot compile with fullgraph + if self.config._attn_implementation == "flash_attention_2" and getattr( + model_kwargs.get("past_key_values"), "is_compileable", False + ): + if generation_config.compile_config is None: + generation_config.compile_config = CompileConfig(fullgraph=False) + # only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user) + elif generation_config.compile_config.fullgraph: + logger.warning_once( + "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as " + "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`." + ) + generation_config.compile_config.fullgraph = False model_forward = self.get_compiled_call(generation_config.compile_config) if generation_config.prefill_chunk_size is not None: diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index c0a3839afbb..eb17dab55af 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -11,18 +11,21 @@ # specific language governing permissions and limitations under the License. import logging -from typing import Optional +from contextlib import contextmanager +from typing import Callable, Optional import torch -from transformers.generation.configuration_utils import GenerationConfig - -from ..utils.import_utils import is_torch_available - - -if is_torch_available(): - from transformers import HybridCache, PreTrainedModel, StaticCache - from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3 +from ..cache_utils import DynamicCache, HybridCache, StaticCache +from ..generation.configuration_utils import GenerationConfig +from ..masking_utils import ( + ALL_MASK_ATTENTION_FUNCTIONS, + _ignore_causal_mask_sdpa, + _is_torch_greater_or_equal_than_2_5, + prepare_padding_mask, +) +from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3 class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): @@ -54,19 +57,13 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): if not hasattr(model.config, "use_cache") or model.config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.") - if not hasattr(model.config, "cache_implementation"): - # If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will - # be used by default, so export will use `StaticCache` by default. - logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.") + if not hasattr(model.config, "layer_types"): + # If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so + # export will use `StaticCache` by default. + logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.") self.model = TorchExportableModuleWithStaticCache(model) else: - if model.config.cache_implementation == "hybrid": - self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) - else: - raise ValueError( - f"Unsupported cache implementation: {model.config.cache_implementation}. " - "Please use `hybrid` or `static`." - ) + self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) def forward( self, @@ -105,16 +102,23 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`. """ + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + self.model.model.config._attn_implementation = "sdpa_without_vmap" + example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) - return torch.export.export( - self.model, - args=(example_input_ids, example_cache_position), - kwargs={}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + with patch_mask_interface(): + exported_program = torch.export.export( + self.model, + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + return exported_program @staticmethod def generate( @@ -281,17 +285,6 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) - self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) - if self.is_causal: - causal_mask = torch.tril( - torch.ones( - self.static_cache.max_cache_len, - self.static_cache.max_cache_len, - dtype=torch.bool, - ) - ) - self.register_buffer("mask", causal_mask, persistent=False) - def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): """ Forward pass of the module, which is compatible with the ExecuTorch runtime. @@ -314,13 +307,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. """ _, seqlen = input_ids.shape - attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None position_ids = cache_position.unsqueeze(0) past_key_values = self.static_cache outs = self.model( input_ids=input_ids, - attention_mask=attn_mask, + attention_mask=None, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, @@ -445,18 +437,15 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): Returns: torch.Tensor: Logits output from the model. """ - batch_size, seq_len = input_ids.shape + batch_size = input_ids.shape[0] # Generate position_ids from cache_position position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) - # Create attention mask (always ones for token-by-token generation) - attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device) - # Forward pass with the model outputs = self.model( input_ids=input_ids, - attention_mask=attention_mask, + attention_mask=None, position_ids=position_ids, past_key_values=self.cache, use_cache=True, @@ -467,6 +456,24 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): return outputs.logits +@contextmanager +def patch_mask_interface(): + """ + Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail + with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles: + Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`. + Note that this seem to be an issue only for python<3.11. + """ + import transformers + + original = transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS + transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = ALL_MASK_ATTENTION_FUNCTIONS._global_mapping + try: + yield + finally: + transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = original + + def convert_and_export_with_cache( model: PreTrainedModel, example_input_ids: Optional[torch.Tensor] = None, @@ -493,6 +500,11 @@ def convert_and_export_with_cache( import torch.export._trace + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + model.config._attn_implementation = "sdpa_without_vmap" + with torch.no_grad(): # TODO: The default inputs only work for text models. We need to add support for vision/audio models. example_input_ids = ( @@ -503,13 +515,14 @@ def convert_and_export_with_cache( ) if is_torch_greater_or_equal("2.6.0"): - exported_program = torch.export.export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids, example_cache_position), - kwargs={}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + with patch_mask_interface(): + exported_program = torch.export.export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) else: if dynamic_shapes is not None: logging.warning( @@ -521,13 +534,14 @@ def convert_and_export_with_cache( # # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. - exported_program = torch.export._trace._export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids,), - kwargs={"cache_position": example_cache_position}, - pre_dispatch=False, - strict=True, - ) + with patch_mask_interface(): + exported_program = torch.export._trace._export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) return exported_program @@ -620,9 +634,10 @@ class Seq2SeqLMExportableModule(torch.nn.Module): # Export the encoder with torch.no_grad(): - exported_encoder = torch.export.export( - wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True - ) + with patch_mask_interface(): + exported_encoder = torch.export.export( + wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True + ) return exported_encoder @@ -642,16 +657,17 @@ class Seq2SeqLMExportableModule(torch.nn.Module): # Export the decoder with torch.no_grad(): - exported_decoder = torch.export.export( - wrapped_decoder, - (decoder_input_ids, encoder_hidden_states, cache_position), - dynamic_shapes={ - "decoder_input_ids": None, - "encoder_hidden_states": {1: encoder_seq_len_dim}, - "cache_position": None, - }, - strict=True, - ) + with patch_mask_interface(): + exported_decoder = torch.export.export( + wrapped_decoder, + (decoder_input_ids, encoder_hidden_states, cache_position), + dynamic_shapes={ + "decoder_input_ids": None, + "encoder_hidden_states": {1: encoder_seq_len_dim}, + "cache_position": None, + }, + strict=True, + ) return exported_decoder @@ -706,3 +722,131 @@ class Seq2SeqLMExportableModule(torch.nn.Module): break return generated_ids + + +def export_with_dynamic_cache( + model: PreTrainedModel, + example_input_ids: Optional[torch.Tensor] = None, + example_attention_mask: Optional[torch.Tensor] = None, +): + """ + Export a model with DynamicCache using `torch.export`, ensuring the exported model is compatible with `ExecuTorch`. + + Args: + model (`PreTrainedModel`): The pretrained model to be exported. + example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`. + example_attention_mask (`Optional[torch.Tensor]`): Example attention mask used by `torch.export`. + + Returns: + Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. + """ + if not is_torch_greater_or_equal_than_2_3: + raise ImportError("torch >= 2.3 is required.") + + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + model.config._attn_implementation = "sdpa_without_vmap" + + with torch.no_grad(): + exported_program = torch.export.export( + model, + (), + { + "input_ids": example_input_ids, + "attention_mask": example_attention_mask, + "past_key_values": DynamicCache(), + "use_cache": True, + }, + strict=False, + ) + return exported_program + + +def sdpa_mask_without_vmap( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Optional[Callable] = None, + attention_mask: Optional[torch.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + allow_torch_fix: bool = True, + **kwargs, +) -> Optional[torch.Tensor]: + """ + Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that + the element should take part in the attention computation, and False that it should not. + + This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export. + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + local_size (`int`, optional): + The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` + to try to skip mask creation if possible. + allow_is_causal_skip (`bool`, optional): + Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in + `torch.sdpa` instead. Default to `True`. + allow_torch_fix (`bool`, optional): + Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older + versions. We need an arg to skip it when using eager. By default `True`. + + """ + + q_length = cache_position.shape[0] + # Potentially pad the 2D mask, and slice it correctly + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) + + # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument + if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size): + return None + + # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` + # but without data-dependent slicing (i.e. torch.compile friendly) + kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_arange += kv_offset + reshaped_cache_position = cache_position.view(-1, 1) + + # This is a bit hacky to know what pattern we are using, but all mask creation function actually forward + # the config through kwargs anyway, so it allows to rely on it + # Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it, + # but this is more efficient + sliding_window = getattr(kwargs["config"], "sliding_window", None) + chunk_size = getattr(kwargs["config"], "attention_chunk_size", None) + + if sliding_window is not None and chunk_size is not None: + raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`") + + # Simplest and most efficient way to obtain a causal mask + causal_mask = kv_arange <= reshaped_cache_position + # If using sliding window, add the sliding mask + if sliding_window is not None: + sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window + causal_mask *= sliding_mask_overlay + # If using chunk attention, add the chunked mask + elif chunk_size is not None: + chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size + causal_mask *= chunked_mask_overlay + + causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) + if padding_mask is not None: + causal_mask = causal_mask * padding_mask[:, None, None, :] + + # Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any + # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 + if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: + causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) + return causal_mask diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 56c35e8d195..afdaba5199d 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -32,14 +32,12 @@ import torch from packaging import version from ..utils import is_torch_flex_attn_available -from ..utils.import_utils import _torch_version +from ..utils.import_utils import _torch_version, is_torchdynamo_compiling if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask, flex_attention - from torch.nn.attention.flex_attention import ( - create_block_mask as create_block_causal_mask_flex, - ) + from torch.nn.attention.flex_attention import create_block_mask as create_block_causal_mask_flex class WrappedFlexAttention: @@ -79,6 +77,24 @@ class WrappedFlexAttention: return self._compiled_flex_attention +def compile_friendly_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + training=False, + **kwargs, +) -> torch.Tensor: + # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention + # Do not use compiled version if already compiling forward (it raises issues) + flex_attention_compiled = WrappedFlexAttention(training)() if not is_torchdynamo_compiling() else flex_attention + return flex_attention_compiled( + query, + key, + value, + **kwargs, + ) + + Offset = Union[torch.Tensor, int] @@ -90,6 +106,10 @@ def make_flex_block_causal_mask( offsets: Optional[Tuple[Offset, Offset]] = None, ) -> "BlockMask": """ + IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`, + and will be removed in a future version without warnings. New code should not use it. It is only kept here + for BC for now, while models using it are being patched accordingly. + Create a block causal document mask for a batch of sequences, both packed and unpacked. Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. The resultant BlockMask is a compressed representation of the full block causal @@ -133,7 +153,6 @@ def make_flex_block_causal_mask( """ Defines the logic of a block causal mask by combining both a standard causal mask and a block diagonal document mask. - See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` for an illustration. """ @@ -174,24 +193,6 @@ def make_flex_block_causal_mask( ) -@torch.compiler.disable(recursive=False) -def compile_friendly_flex_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - training=False, - **kwargs, -) -> torch.Tensor: - # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention - flex_attention_compiled = WrappedFlexAttention(training)() - return flex_attention_compiled( - query, - key, - value, - **kwargs, - ) - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d07a768b01a..8312891941b 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -16,15 +16,15 @@ from __future__ import annotations import operator import os import re -from collections.abc import MutableMapping from functools import partial, reduce -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist from torch import nn from ..utils import is_torch_greater_or_equal, logging +from ..utils.generic import GeneralInterface ALL_LAYERNORM_LAYERS = [nn.LayerNorm] @@ -720,20 +720,11 @@ class SequenceParallel(TensorParallelLayer): return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) -class ParallelInterface(MutableMapping): - """ - Dict-like object keeping track of allowed attention functions. You can easily add a new attention function - with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`, - it needs to declare a new instance of this class inside the `modeling_.py`, and declare it on that instance. - """ - +class ParallelInterface(GeneralInterface): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if - # a new instance is created (in order to locally override a given function) - - def __init__(self): - self._local_mapping = {} - - ParallelInterface._global_mapping = { + # a new instance is created (in order to locally override a given entry) + _global_mapping = ( + { "colwise": ColwiseParallel(), "rowwise": RowwiseParallel(), "colwise_rep": ColwiseParallel(output_layouts=Replicate()), @@ -746,41 +737,12 @@ class ParallelInterface(MutableMapping): "sequence_parallel": SequenceParallel(), "replicate": ReplicateParallel(), } - - def __getitem__(self, key): - # First check if instance has a local override - if key in self._local_mapping: - return self._local_mapping[key] - return self._global_mapping[key] - - def __setitem__(self, key, value): - # Allow local update of the default functions without impacting other instances - self._local_mapping.update({key: value}) - - def __delitem__(self, key): - del self._local_mapping[key] - - def __iter__(self): - # Ensure we use all keys, with the overwritten ones on top - return iter({**self._global_mapping, **self._local_mapping}) - - def __len__(self): - return len(self._global_mapping.keys() | self._local_mapping.keys()) - - @classmethod - def register(cls, key: str, value: Callable): - cls._global_mapping.update({key: value}) - - def valid_keys(self) -> List[str]: - return list(self.keys()) + if is_torch_greater_or_equal("2.5") and _torch_distributed_available + else {} + ) -# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones - -if is_torch_greater_or_equal("2.5") and _torch_distributed_available: - ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface() -else: - ALL_PARALLEL_STYLES = None +ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface() def convert_local_tensor_to_dtensor( diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py new file mode 100644 index 00000000000..36538882af5 --- /dev/null +++ b/src/transformers/masking_utils.py @@ -0,0 +1,1129 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F + +from .cache_utils import Cache +from .configuration_utils import PretrainedConfig +from .utils.generic import GeneralInterface +from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_torchdynamo_compiling + + +if is_torch_flex_attn_available(): + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + + +_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True) + + +def and_masks(*mask_functions: list[Callable]) -> Callable: + """Returns a mask function that is the intersection of provided mask functions""" + if not all(callable(arg) for arg in mask_functions): + raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}") + + def and_mask(batch_idx, head_idx, q_idx, kv_idx): + result = q_idx.new_ones((), dtype=torch.bool) + for mask in mask_functions: + result = result & mask(batch_idx, head_idx, q_idx, kv_idx) + return result + + return and_mask + + +def or_masks(*mask_functions: list[Callable]) -> Callable: + """Returns a mask function that is the union of provided mask functions""" + if not all(callable(arg) for arg in mask_functions): + raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}") + + def or_mask(batch_idx, head_idx, q_idx, kv_idx): + result = q_idx.new_zeros((), dtype=torch.bool) + for mask in mask_functions: + result = result | mask(batch_idx, head_idx, q_idx, kv_idx) + return result + + return or_mask + + +def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + """ + This creates a basic lower-diagonal causal mask. + """ + return kv_idx <= q_idx + + +def sliding_window_overlay(sliding_window: int) -> Callable: + """ + This is an overlay depicting a sliding window pattern. Add it on top of a causal mask for a proper sliding + window mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return kv_idx > q_idx - sliding_window + + return inner_mask + + +def chunked_overlay(chunk_size: int) -> Callable: + """ + This is an overlay depicting a chuned attention pattern. Add it on top of a causal mask for a proper chunked + attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return kv_idx // chunk_size == q_idx // chunk_size + + return inner_mask + + +def sliding_window_causal_mask_function(sliding_window: int) -> Callable: + """ + This return the mask_function function to create a sliding window mask. + """ + return and_masks(sliding_window_overlay(sliding_window), causal_mask_function) + + +def chunked_causal_mask_function(chunk_size: int) -> Callable: + """ + This return the mask_function function to create a chunked attention mask. + """ + return and_masks(chunked_overlay(chunk_size), causal_mask_function) + + +def padding_mask_function(padding_mask: torch.Tensor) -> Callable: + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because + # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not + # vectorizable on accelerator devices + return padding_mask[batch_idx, kv_idx] + + return inner_mask + + +def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable: + """ + This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, + not start and end indices. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset) + + return inner_mask + + +def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: + """ + Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over + the batch and head indices as well if `bh_indices=True`. + Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive + functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different). + + Args: + mask_function (`Callable`): + The mask_function to vmap. + bh_indices (`bool`, optional): + Whether to vmap over the batch and head indices as well, or only q and kv indices. + + Returns: + Callable: The vmapped function. + """ + # We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions + dimensions = [(None, None, None, 0), (None, None, 0, None)] + if bh_indices: + # We extend broadcasting over the [batch_idx, head_idx] dimensions + dimensions.extend([(None, 0, None, None), (0, None, None, None)]) + + for dims in dimensions: + mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) + return mask_function + + +def prepare_padding_mask( + attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True +) -> Optional[torch.Tensor]: + """ + From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing + according to the `kv_offset` if `_slice` is `True`. + """ + local_padding_mask = attention_mask + if attention_mask is not None: + # Pad it if necesary + if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0: + local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length)) + # For flex, we should not slice them, only use an offset + if _slice: + # Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`, + # but without data-dependent slicing (i.e. torch.compile friendly) + mask_indices = torch.arange(kv_length, device=local_padding_mask.device) + mask_indices += kv_offset + local_padding_mask = local_padding_mask[:, mask_indices] + return local_padding_mask + + +def _ignore_causal_mask_sdpa( + padding_mask: Optional[torch.Tensor], + query_length: int, + kv_length: int, + kv_offset: int, + local_attention_size: Optional[int] = None, +) -> bool: + """ + Detects whether the causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the 2D `padding_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is + passed). + """ + is_tracing = torch.jit.is_tracing() or isinstance(padding_mask, torch.fx.Proxy) or is_torchdynamo_compiling() + if padding_mask is not None and padding_mask.shape[-1] > kv_length: + mask_indices = torch.arange(kv_length, device=padding_mask.device) + mask_indices += kv_offset + padding_mask = padding_mask[:, mask_indices] + + # When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is + # hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True` + # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set + # `ignore_causal_mask = True` if we are not tracing + if ( + not is_tracing + # only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108 + and (query_length == 1 or kv_length == query_length) + # in this case we need to add special patterns to the mask so cannot be skipped otherwise + and (local_attention_size is None or kv_length < local_attention_size) + # In this case, we need to add padding to the mask, so cannot be skipped otherwise + and (padding_mask is None or padding_mask.all()) + ): + return True + + return False + + +def sdpa_mask_recent_torch( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + **kwargs, +) -> Optional[torch.Tensor]: + """ + Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that + the element should take part in the attention computation, and False that it should not. + This function can only be used with torch>=2.5, as the context manager is otherwise not available. + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + local_size (`int`, optional): + The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` + to try to skip mask creation if possible. + allow_is_causal_skip (`bool`, optional): + Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in + `torch.sdpa` instead. Default to `True`. + allow_torch_fix (`bool`, optional): + Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older + versions. We need an arg to skip it when using eager. By default `True`. + + + ## Creating a simple causal mask: + + To create the following causal mask: + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ■ ■ ■ ■ ⬚ + 4 ■ ■ ■ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [ True, True, True, True, False], + [ True, True, True, True, True]]]]) + ``` + + ## Creating a sliding window mask: + + To create the following sliding window mask (`sliding_window=3`): + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ⬚ ■ ■ ■ ⬚ + 4 ⬚ ⬚ ■ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3)) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, True, True, True, False], + [False, False, True, True, True]]]]) + ``` + + ## Creating a chunked attention mask + + To create the following chunked attention mask (`chunk_size=3`): + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ⬚ ⬚ ⬚ ■ ⬚ + 4 ⬚ ⬚ ⬚ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3)) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, False, False, True, False], + [False, False, False, True, True]]]]) + ``` + + """ + q_length = cache_position.shape[0] + # Potentially pad the 2D mask, and slice it correctly + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) + + # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument + if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): + return None + + # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` + # but without data-dependent slicing (i.e. torch.compile friendly) + kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_arange += kv_offset + + # Potentially add the padding 2D mask + if padding_mask is not None: + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + + batch_arange = torch.arange(batch_size, device=cache_position.device) + head_arange = torch.arange(1, device=cache_position.device) + # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from + # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it + # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices + with TransformGetItemToIndex(): + causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + + return causal_mask + + +def sdpa_mask_older_torch( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + allow_torch_fix: bool = True, + **kwargs, +) -> Optional[torch.Tensor]: + """ + NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise. + + Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that + the element should take part in the attention computation, and False that it should not. + If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend + to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does + not change the final result). + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + local_size (`int`, optional): + The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` + to try to skip mask creation if possible. + allow_is_causal_skip (`bool`, optional): + Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in + `torch.sdpa` instead. Default to `True`. + allow_torch_fix (`bool`, optional): + Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older + versions. We need an arg to skip it when using eager. By default `True`. + """ + q_length = cache_position.shape[0] + # Potentially pad the 2D mask, and slice it correctly + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) + + # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument + if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): + return None + + # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` + # but without data-dependent slicing (i.e. torch.compile friendly) + kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_arange += kv_offset + + # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, + # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow + # However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have + # `sdpa_mask_recent_torch`, as it allows more general `mask_function` + causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) + if padding_mask is not None: + causal_mask = causal_mask * padding_mask[:, None, None, :] + + # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any + # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 + if allow_torch_fix: + causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) + return causal_mask + + +# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions +# (especially mask_function indexing a tensor, such as the padding mask function) +sdpa_mask = sdpa_mask_recent_torch if is_torch_flex_attn_available() else sdpa_mask_older_torch + + +def eager_mask( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + **kwargs, +) -> torch.Tensor: + """ + Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that + the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that + it should not. + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + dtype (`torch.dtype`, optional): + The dtype to use for the mask. By default, `torch.float32`. + """ + # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf + _ = kwargs.pop("allow_is_causal_skip", None) + mask = sdpa_mask( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_function, + attention_mask=attention_mask, + allow_is_causal_skip=False, + allow_torch_fix=False, + **kwargs, + ) + min_dtype = torch.finfo(dtype).min + # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type) + mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) + return mask + + +def flash_attention_mask( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, +): + """ + Create the attention mask necesary to use FA2. Since FA2 is un-padded by definition, here we simply return + `None` if the mask is fully causal, or we return the 2D mask which will then be used to extract the seq_lens. + We just slice it in case of sliding window. + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + """ + if attention_mask is not None: + # Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing) + attention_mask = attention_mask[:, -kv_length:] + # We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2 + # (note that the attention_mask is a boolean dtype here) + if attention_mask.all(): + attention_mask = None + + return attention_mask + + +def flex_attention_mask( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> "BlockMask": + """ + Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential + for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/ + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + """ + q_length, q_offset = cache_position.shape[0], cache_position[0] + + # Potentially add the padding 2D mask + if attention_mask is not None: + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + + # Add the offsets on top (because flex interface only allows length, not start and end indices) + mask_function = add_offsets_to_mask_function(mask_function, q_offset, kv_offset) + + # Finally create the block mask + block_mask = create_block_mask( + mask_mod=mask_function, + B=batch_size, + H=None, + Q_LEN=q_length, + KV_LEN=kv_length, + device=cache_position.device, + _compile=True, + ) + return block_mask + + +class AttentionMaskInterface(GeneralInterface): + # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if + # a new instance is created (in order to locally override a given function) + _global_mapping = { + "sdpa": sdpa_mask, + "eager": eager_mask, + "flash_attention_2": flash_attention_mask, + "flex_attention": flex_attention_mask, + } + + +# Global AttentionMaskInterface shared by all models which do not need to overwrite any of the existing ones +ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface() + + +def _preprocess_mask_arguments( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, BlockMask]], + cache_position: torch.Tensor, + past_key_values: Optional[Cache], + layer_idx: Optional[int], +) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]: + """ + Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the + key-value length and offsets, and if we should early exit or not. + + Args: + config (`PretrainedConfig`): + The model config. + input_embeds (`torch.Tensor`): + The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the + batch size, query length and dtype. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). + It can also be an already prepared 4D mask, in which case it is returned as-is. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + past_key_values (`Cache`, optional): + The past key values, if we use a cache. + layer_idx (`int`, optional): + If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value + length and offset. Indeed, for hybrid caches, different layers may return different lengths. + + Returns: + early_exit (`bool`): + Whether we should early exit mask creation, and return the mask as-is. + attention_mask (`torch.Tensor` or `BlockMask` or `None`): + The attention mask to either return immediately, or to use in downstream mask creation. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`): + An offset to indicate at which first position the key and values states will refer to. + """ + # If the mask is already 4D, simply return as-is (it was already prepared, or it is custom) + if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4: + return True, attention_mask, None, None + + # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask! + if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS: + return True, None, None, None + + # Move the mask to correct device, and potentially switch dtype for efficiency + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask.to(device=cache_position.device, dtype=torch.bool) + + # If using a cache, it can give all informations about mask sizes based on seen tokens + if past_key_values is not None: + kv_length, kv_offset = past_key_values.get_mask_sizes(cache_position, layer_idx) + # Otherwise, the sizes are simply the input sizes + else: + kv_length, kv_offset = input_embeds.shape[1], 0 + + return False, attention_mask, kv_length, kv_offset + + +def _get_mask_interface(config: PretrainedConfig, output_attentions: bool = False) -> Callable: + """ + Return the mask interface (a function) to be used, based on the type of attention found in the config. + + Args: + config (`PretrainedConfig`): + The model config. + output_attentions (`bool`, optional): + Whether we return the attention scores or not. By default `False`. + """ + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + # Sdpa fallbacks to eager in the Attention modules if `output_attentions=True` + if config._attn_implementation == "sdpa" and output_attentions: + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS["eager"] + return mask_interface + + +def create_causal_mask( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cache_position: torch.Tensor, + past_key_values: Optional[Cache], + output_attentions: bool = False, + or_mask_function: Optional[Callable] = None, + and_mask_function: Optional[Callable] = None, +) -> Optional[Union[torch.Tensor, "BlockMask"]]: + """ + Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values` + has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align + to what is needed in the `modeling_xxx.py` files). + + Args: + config (`PretrainedConfig`): + The model config. + input_embeds (`torch.Tensor`): + The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the + batch size, query length and dtype. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). + It can also be an already prepared 4D mask, in which case it is returned as-is. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + past_key_values (`Cache`, optional): + The past key values, if we use a cache. + output_attentions (`bool`, optional): + Whether we return the attention scores or not. By default `False`. + or_mask_function (`Callable`, optional): + An optional mask function to combine with the causal mask function (by doing the union of both). This is + useful to easily overlay another mask on top of the causal one, for example for image tokens handling. + and_mask_function (`Callable`, optional): + An optional mask function to combine with the causal mask function (by doing the intersection of both). This is + useful to easily overlay another mask on top of the causal one, for example for image tokens handling. + """ + # If we have an HybridCache structure, here we want to create the mask for the full layers + try: + layer_idx = past_key_values.is_sliding.index(False) + except (ValueError, AttributeError): + layer_idx = 0 + + early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + ) + if early_exit: + return attention_mask + + batch_size, dtype = input_embeds.shape[0], input_embeds.dtype + mask_factory_function = causal_mask_function + mask_interface = _get_mask_interface(config, output_attentions) + + # Do not allow skip if we are compiling (this is to match BC) + # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it + allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True + + # Allow slight deviations from causal mask + if or_mask_function is not None: + if not _is_torch_greater_or_equal_than_2_5: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + mask_factory_function = or_masks(mask_factory_function, or_mask_function) + allow_is_causal_skip = False + if and_mask_function is not None: + if not _is_torch_greater_or_equal_than_2_5: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + mask_factory_function = and_masks(mask_factory_function, and_mask_function) + allow_is_causal_skip = False + + # We now create the mask + causal_mask = mask_interface( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_factory_function, + attention_mask=attention_mask, + allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa + dtype=dtype, # Additional kwarg for eager + config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + ) + return causal_mask + + +def create_sliding_window_causal_mask( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cache_position: torch.Tensor, + past_key_values: Optional[Cache], + output_attentions: bool = False, + or_mask_function: Optional[Callable] = None, + and_mask_function: Optional[Callable] = None, +) -> Optional[Union[torch.Tensor, "BlockMask"]]: + """ + Create a sliding window causal mask based on the attention implementation used (stored in the config). This type + of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this + function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the + `modeling_xxx.py` files). + + Args: + config (`PretrainedConfig`): + The model config. + input_embeds (`torch.Tensor`): + The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the + batch size, query length and dtype. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). + It can also be an already prepared 4D mask, in which case it is returned as-is. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + past_key_values (`Cache`, optional): + The past key values, if we use a cache. + output_attentions (`bool`, optional): + Whether we return the attention scores or not. By default `False`. + or_mask_function (`Callable`, optional): + An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is + useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. + and_mask_function (`Callable`, optional): + An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is + useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. + """ + # If we have an HybridCache structure, here we want to create the mask for the sliding layers + try: + layer_idx = past_key_values.is_sliding.index(True) + except (ValueError, AttributeError): + layer_idx = 0 + + early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + ) + if early_exit: + return attention_mask + + sliding_window = getattr(config, "sliding_window", None) + if sliding_window is None: + raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set") + + batch_size, dtype = input_embeds.shape[0], input_embeds.dtype + mask_factory_function = sliding_window_causal_mask_function(sliding_window) + mask_interface = _get_mask_interface(config, output_attentions) + + # Do not allow skip if we are compiling (this is to match BC) + # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it + allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True + + # Allow slight deviations from sliding causal mask + if or_mask_function is not None: + if not _is_torch_greater_or_equal_than_2_5: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + mask_factory_function = or_masks(mask_factory_function, or_mask_function) + allow_is_causal_skip = False + if and_mask_function is not None: + if not _is_torch_greater_or_equal_than_2_5: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + mask_factory_function = and_masks(mask_factory_function, and_mask_function) + allow_is_causal_skip = False + + # We now create the mask + causal_mask = mask_interface( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_factory_function, + attention_mask=attention_mask, + allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa + local_size=sliding_window, # Additional kwarg for sdpa + dtype=dtype, # Additional kwarg for eager + config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + ) + return causal_mask + + +def create_chunked_causal_mask( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cache_position: torch.Tensor, + past_key_values: Optional[Cache], + output_attentions: bool = False, + or_mask_function: Optional[Callable] = None, + and_mask_function: Optional[Callable] = None, +) -> Optional[Union[torch.Tensor, "BlockMask"]]: + """ + Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type + of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this + function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the + `modeling_xxx.py` files). + + Args: + config (`PretrainedConfig`): + The model config. + input_embeds (`torch.Tensor`): + The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the + batch size, query length and dtype. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). + It can also be an already prepared 4D mask, in which case it is returned as-is. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + past_key_values (`Cache`, optional): + The past key values, if we use a cache. + output_attentions (`bool`, optional): + Whether we return the attention scores or not. By default `False`. + or_mask_function (`Callable`, optional): + An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is + useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. + and_mask_function (`Callable`, optional): + An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is + useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. + """ + # If we have an HybridCache structure, here we want to create the mask for the sliding layers + try: + layer_idx = past_key_values.is_sliding.index(True) + except (ValueError, AttributeError): + layer_idx = 0 + + early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + ) + if early_exit: + return attention_mask + + chunk_size = getattr(config, "attention_chunk_size", None) + if chunk_size is None: + raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set") + + # Raise if using chunked attention on context too large with FA2 + if config._attn_implementation == "flash_attention_2" and kv_length + kv_offset > chunk_size: + raise ValueError( + "Flash attention 2 cannot handle chunked attention, and the key-value length is larger than the chunk size so the " + "chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model" + ) + + batch_size, dtype = input_embeds.shape[0], input_embeds.dtype + mask_factory_function = chunked_causal_mask_function(chunk_size) + mask_interface = _get_mask_interface(config, output_attentions) + + # Do not allow skip if we are compiling (this is to match BC) + # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it + allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True + + # Allow slight deviations from chunked causal mask + if or_mask_function is not None: + if not _is_torch_greater_or_equal_than_2_5: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + mask_factory_function = or_masks(mask_factory_function, or_mask_function) + allow_is_causal_skip = False + if and_mask_function is not None: + if not _is_torch_greater_or_equal_than_2_5: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + mask_factory_function = and_masks(mask_factory_function, and_mask_function) + allow_is_causal_skip = False + + # We now create the mask + causal_mask = mask_interface( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_factory_function, + attention_mask=attention_mask, + allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa + local_size=chunk_size, # Additional kwarg for sdpa + dtype=dtype, # Additional kwarg for eager + config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + ) + return causal_mask + + +LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = { + "full_attention": create_causal_mask, + "sliding_attention": create_sliding_window_causal_mask, + "chunked_attention": create_chunked_causal_mask, +} + + +def create_masks_for_generate( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cache_position: torch.Tensor, + past_key_values: Optional[Cache], + output_attentions: bool = False, + or_mask_function: Optional[Callable] = None, + and_mask_function: Optional[Callable] = None, + **kwargs, +): + """ + This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in `generate` in order + to easily create the masks in advance, when we compile the forwards with Static caches. + + Args: + config (`PretrainedConfig`): + The model config. + input_embeds (`torch.Tensor`): + The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the + batch size, query length and dtype. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). + It can also be an already prepared 4D mask, in which case it is returned as-is. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + past_key_values (`Cache`, optional): + The past key values, if we use a cache. + output_attentions (`bool`, optional): + Whether we return the attention scores or not. By default `False`. + or_mask_function (`Callable`, optional): + An optional mask function to combine with the other mask function (by doing the union of both). This is + useful to easily overlay another mask on top of the causal one, for example for image tokens handling. + and_mask_function (`Callable`, optional): + An optional mask function to combine with the other mask function (by doing the intersection of both). This is + useful to easily overlay another mask on top of the causal one, for example for image tokens handling. + """ + # The attribute reside in the text config for composite models + effective_config = config.get_text_config() + # Prepare the mask args + mask_kwargs = { + "config": effective_config, + "input_embeds": input_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + "or_mask_function": or_mask_function, + "and_mask_function": and_mask_function, + } + + # If the attribute exist, we need several masks + if hasattr(effective_config, "layer_types"): + causal_masks = {} + for layer_pattern in set(effective_config.layer_types): + causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs) + return causal_masks + # In this case, all layers are sliding + elif getattr(effective_config, "sliding_window", None) is not None: + return create_sliding_window_causal_mask(**mask_kwargs) + # In this case, all layers are chunked + elif getattr(effective_config, "attention_chunk_size", None) is not None: + return create_chunked_causal_mask(**mask_kwargs) + # All layers use standard causal attention + return create_causal_mask(**mask_kwargs) + + +# Below are utilities to pretty-print the different masks +# Print the matrix with words as row labels +GREEN = "\033[92m" +YELLOW = "\033[93m" +RESET = "\033[0m" +BLACK_SQUARE = "■" +WHITE_SQUARE = "⬚" +GREY_SQUARE = "∙" +LOW_TRIANGLE = "⬕" +UPPER_TRIANGLE = "⬔" + + +def get_style(style): + if style == "majong": + BLACK_SQUARE = "🀞" # Full block (represents "on" or active) + BLACK_SQUARE = "🀙" # Full block (represents "on" or active) + WHITE_SQUARE = "🀆" # "▒" # Light shade (represents "off" or inactive) + LOW_TRIANGLE = "🀛" # Lower left triangle (stylized indication) + UPPER_TRIANGLE = "🀛" # Upper left triangle (stylized indication) + else: + BLACK_SQUARE = "█" # Full block (represents "on" or active) + WHITE_SQUARE = "░" # "▒" # Light shade (represents "off" or inactive) + LOW_TRIANGLE = "▙" # Lower left triangle (stylized indication)) + UPPER_TRIANGLE = "▜" # Upper left triangle (stylized indication) + + return BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE + + +# LOW_TRIANGLE = UPPER_TRIANGLE = "⟍" # Upper right triangle (stylized indication) + +YELLOW_SQUARE = f"{YELLOW}{BLACK_SQUARE}{RESET}" +GREEN_SQUARE = f"{GREEN}{BLACK_SQUARE}{RESET}" + + +def tensor_to_mask_visual(original_tensor: torch.Tensor, grid_size=(20, 40), style="majong") -> str: + BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE = get_style(style) + h, w = original_tensor.shape + max_h, max_w = grid_size + if not (h < max_h and w < max_w): + # Preserve aspect ratio within max grid size + aspect_ratio = 2 * w / h + if aspect_ratio > 1: + w = max_w + h = min(max_h, max(1, round(max_w / aspect_ratio))) + else: + h = max_h + w = max(1, round(max_h * aspect_ratio)) + + # Step 1: Rescale tensor by average pooling + tensor = original_tensor.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions + tensor = F.adaptive_avg_pool2d(tensor, output_size=(h, w))[0, 0] # Remove extra dims + else: + tensor = original_tensor + + # Step 3: Build the string representation + result = [] + for i in range(h): + row = "" + for j in range(w): + if tensor[i, j] == 1: + row += BLACK_SQUARE + elif tensor[i, j] == 0: + row += WHITE_SQUARE + else: + if j > 0: + if tensor[i, j - 1] == 1: + row += LOW_TRIANGLE + elif tensor[i, j - 1] == 0: + row += UPPER_TRIANGLE + else: + row += BLACK_SQUARE if tensor[i, j] == 1 else WHITE_SQUARE + else: + row += ( + BLACK_SQUARE + if tensor[i, j] == 1 + else ( + WHITE_SQUARE + if tensor[i, j] == 0 + else (UPPER_TRIANGLE if tensor[i, j + 1] == 1 else LOW_TRIANGLE) + ) + ) + result.append(row) + + return "\n".join(result) + + +class AttentionMask(torch.Tensor): + def __new__(cls, data, style=None): + # Create a new instance of AttentionMask as a Tensor + cls.style = style + return torch.Tensor._make_subclass(cls, data, require_grad=False) + + def __init__(self, data): + # You can initialize any additional metadata here if needed + pass + + def to_string(self, grid_size=(20, 40), limit=4): + """Returns a string representation of the block mask.""" + dense_mask = self + *batch_dims, num_rows, num_cols = dense_mask.shape + total_vis = [] + + for idx, batch_idx in enumerate(itertools.product(*[range(i) for i in batch_dims])): + if idx == limit: + total_vis.append("...") + total_vis.append("To print out more, set AttentionMask.to_string(limit=N)") + total_vis.append("You can also index (AttentionMask[batch, head]) to choose a specific batch or head") + break + block_vis = tensor_to_mask_visual(dense_mask[batch_idx], grid_size=grid_size, style=self.style) + total_vis.append(block_vis) + + total_vis.append(f"torch.Tensor(shape={tuple(self.shape)}, dtype={self.dtype})") + return "\n".join(total_vis) + + def __repr__(self): + return self.to_string() + + def __str__(self): + return self.to_string() + + @classmethod + def from_tensor(cls, tensor: torch.Tensor, style: Optional[str] = None) -> "AttentionMask": + res = cls(tensor) + res.style = style + return res diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index dfdd976f015..df85a307aae 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -11,6 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general +`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now, +and will be removed in the future. +""" + from dataclasses import dataclass from typing import Optional, Union diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index c7d54dc4151..678ee983da5 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -148,6 +148,12 @@ def _upad_input( Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c52de1fcb68..97e95b4161b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -27,7 +27,6 @@ import shutil import tempfile import warnings from collections import defaultdict -from collections.abc import MutableMapping from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -124,6 +123,7 @@ from .utils import ( replace_return_docstrings, strtobool, ) +from .utils.generic import GeneralInterface from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files from .utils.import_utils import ( ENV_VARS_TRUE_VALUES, @@ -6076,7 +6076,7 @@ def get_disk_only_shard_files(device_map, weight_map): return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] -class AttentionInterface(MutableMapping): +class AttentionInterface(GeneralInterface): """ Dict-like object keeping track of allowed attention functions. You can easily add a new attention function with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`, @@ -6091,36 +6091,6 @@ class AttentionInterface(MutableMapping): "sdpa": sdpa_attention_forward, } - def __init__(self): - self._local_mapping = {} - - def __getitem__(self, key): - # First check if instance has a local override - if key in self._local_mapping: - return self._local_mapping[key] - return self._global_mapping[key] - - def __setitem__(self, key, value): - # Allow local update of the default functions without impacting other instances - self._local_mapping.update({key: value}) - - def __delitem__(self, key): - del self._local_mapping[key] - - def __iter__(self): - # Ensure we use all keys, with the overwritten ones on top - return iter({**self._global_mapping, **self._local_mapping}) - - def __len__(self): - return len(self._global_mapping.keys() | self._local_mapping.keys()) - - @classmethod - def register(cls, key: str, value: Callable): - cls._global_mapping.update({key: value}) - - def valid_keys(self) -> List[str]: - return list(self.keys()) - # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 0d516c5f1c0..cd794846275 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -25,20 +25,14 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - LossKwargs, - auto_docstring, - can_return_tuple, - is_torch_flex_attn_available, - logging, -) +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from ...utils.import_utils import is_torch_available from ..auto import AutoModel from .configuration_aria import AriaConfig, AriaTextConfig @@ -49,12 +43,6 @@ if is_torch_available(): from torch import nn -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -818,8 +806,13 @@ class AriaTextModel(AriaTextPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -865,129 +858,6 @@ class AriaTextModel(AriaTextPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1521,61 +1391,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): return model_inputs - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = [ "AriaForConditionalGeneration", diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index b7d7521ef69..a851d4d0a0f 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -542,60 +542,5 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi return model_inputs - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"] diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 3fb6c8b1c4b..01f7f19a79e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -757,7 +757,7 @@ class BartPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -827,7 +827,7 @@ class BartPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index a1d972ec1ac..4ff34b9ef25 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1602,7 +1602,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1672,7 +1672,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 900a8fe1785..d93b6f6ae2d 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -482,7 +482,7 @@ class BioGptPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -552,7 +552,7 @@ class BioGptPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 48076bfb784..e98f9ed1162 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -27,23 +27,17 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_bitnet import BitNetConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -425,8 +419,13 @@ class BitNetModel(BitNetPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -472,129 +471,6 @@ class BitNetModel(BitNetPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index ede764a513c..8eb282ac6fa 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -493,7 +493,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -563,7 +563,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index a6fa6cda00e..2f778d72939 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -482,7 +482,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -552,7 +552,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 43f9ee6c20d..bdba37a73b0 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -658,7 +658,7 @@ class BloomModel(BloomPreTrainedModel): attentions=all_self_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -728,7 +728,7 @@ class BloomModel(BloomPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 8b5d2a46067..e7a46b43cf3 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1057,7 +1057,7 @@ class ChameleonModel(ChameleonPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1127,7 +1127,7 @@ class ChameleonModel(ChameleonPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index d421432d77f..a1bef381ce7 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -491,7 +491,7 @@ class CodeGenModel(CodeGenPreTrainedModel): attentions=all_self_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -561,7 +561,7 @@ class CodeGenModel(CodeGenPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9611d92ca6d..37f698a86ec 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -35,23 +35,17 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_cohere import CohereConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -462,8 +456,13 @@ class CohereModel(CoherePreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -509,129 +508,6 @@ class CohereModel(CoherePreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/cohere2/configuration_cohere2.py b/src/transformers/models/cohere2/configuration_cohere2.py index c792ab3f827..e407fb83dfd 100644 --- a/src/transformers/models/cohere2/configuration_cohere2.py +++ b/src/transformers/models/cohere2/configuration_cohere2.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation @@ -119,9 +119,8 @@ class Cohere2Config(PretrainedConfig): The dropout ratio for the attention probabilities. sliding_window (`int`, *optional*, defaults to 4096): Size of the sliding window attention context. - sliding_window_pattern (`int`, *optional*, defaults to 4): - Pattern for the sliding window attention. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + layer_types (`list`, *optional*): + Attention pattern for each layer. ```python >>> from transformers import Cohere2Model, Cohere2Config @@ -177,8 +176,7 @@ class Cohere2Config(PretrainedConfig): attention_bias=False, attention_dropout=0.0, sliding_window=4096, - sliding_window_pattern=4, - cache_implementation="hybrid", + layer_types=None, **kwargs, ): self.vocab_size = vocab_size @@ -203,10 +201,9 @@ class Cohere2Config(PretrainedConfig): self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.sliding_window = sliding_window - self.sliding_window_pattern = sliding_window_pattern + self.layer_types = layer_types # Need to specify head_dim in the config so it can be used in the attention forward functions self.head_dim = hidden_size // num_attention_heads - self.cache_implementation = cache_implementation # Validate the correctness of rotary position embeddings parameters rope_config_validation(self) @@ -219,5 +216,14 @@ class Cohere2Config(PretrainedConfig): **kwargs, ) + if self.layer_types is None: + # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub + sliding_window_pattern = getattr(self, "sliding_window_pattern", 4) + self.layer_types = [ + "sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + __all__ = ["Cohere2Config"] diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 689b535f579..144667f1e3d 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -25,25 +25,20 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from ...utils.deprecation import deprecate_kwarg from .configuration_cohere2 import Cohere2Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -186,6 +181,7 @@ class Cohere2Attention(nn.Module): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -199,9 +195,6 @@ class Cohere2Attention(nn.Module): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.sliding_window = ( - config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None - ) def forward( self, @@ -224,19 +217,9 @@ class Cohere2Attention(nn.Module): query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -284,12 +267,10 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Cohere2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Cohere2Attention(config, layer_idx) + self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx) self.mlp = Cohere2MLP(config) self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) - self.config = config - self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 - self.sliding_window = config.sliding_window + self.attention_type = config.layer_types[layer_idx] @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( @@ -322,34 +303,6 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence """ - - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # In prefill, we may be larger than sliding window - effective_seq_len = max(cache_position.shape[0], self.sliding_window) - # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), - # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -effective_seq_len:] - # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice - # from the left, with an offset if we are beyond the sliding window - else: - min_dtype = torch.finfo(hidden_states.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - # In case we are beyond the sliding window, we need to correctly offset the mask slicing - offset = cache_position[-1] - effective_seq_len + 1 - # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = torch.clamp(offset, min=0) - # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, - # but without data-dependent slicing (i.e. torch.compile friendly) - mask_indexes = torch.arange( - min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device - ) - mask_indexes += offset - attention_mask = attention_mask[:, :, :, mask_indexes] - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -440,7 +393,7 @@ class Cohere2Model(Cohere2PreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -467,15 +420,7 @@ class Cohere2Model(Cohere2PreTrainedModel): inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - device=self.device, - ) + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -485,9 +430,22 @@ class Cohere2Model(Cohere2PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } hidden_states = inputs_embeds @@ -505,7 +463,7 @@ class Cohere2Model(Cohere2PreTrainedModel): layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, @@ -531,100 +489,6 @@ class Cohere2Model(Cohere2PreTrainedModel): attentions=all_self_attns, ) - @torch.no_grad() - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: HybridCache, - output_attentions: bool = False, - ): - # Flash Attention currently doesn't support static cache but Cohere2 work only with static cache. - # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape - # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible - # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": - return attention_mask - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, (HybridCache, StaticCache)): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -635,7 +499,7 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config: Cohere2Config): + def __init__(self, config): super().__init__(config) self.model = Cohere2Model(config) self.vocab_size = config.vocab_size @@ -740,88 +604,5 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): attentions=outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - # Overwritten: has a special cache type, `HybridCache` - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if past_key_values is not None: - if ( - inputs_embeds is not None # Exception 1 - or cache_position[-1] >= input_ids.shape[1] # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s - # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride - # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the - # batch size = 1 case, `position_ids` is already contiguous but with varying stride - # which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if ( - isinstance(past_key_values, HybridCache) - and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" - ): - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) - return model_inputs - __all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"] diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 256a597a415..792d278cc0a 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -19,8 +19,9 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from ...cache_utils import Cache, HybridCache -from ...configuration_utils import PretrainedConfig +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import rope_config_validation @@ -140,9 +141,8 @@ class Cohere2Config(PretrainedConfig): The dropout ratio for the attention probabilities. sliding_window (`int`, *optional*, defaults to 4096): Size of the sliding window attention context. - sliding_window_pattern (`int`, *optional*, defaults to 4): - Pattern for the sliding window attention. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + layer_types (`list`, *optional*): + Attention pattern for each layer. ```python >>> from transformers import Cohere2Model, Cohere2Config @@ -198,8 +198,7 @@ class Cohere2Config(PretrainedConfig): attention_bias=False, attention_dropout=0.0, sliding_window=4096, - sliding_window_pattern=4, - cache_implementation="hybrid", + layer_types=None, **kwargs, ): self.vocab_size = vocab_size @@ -224,10 +223,9 @@ class Cohere2Config(PretrainedConfig): self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.sliding_window = sliding_window - self.sliding_window_pattern = sliding_window_pattern + self.layer_types = layer_types # Need to specify head_dim in the config so it can be used in the attention forward functions self.head_dim = hidden_size // num_attention_heads - self.cache_implementation = cache_implementation # Validate the correctness of rotary position embeddings parameters rope_config_validation(self) @@ -240,6 +238,15 @@ class Cohere2Config(PretrainedConfig): **kwargs, ) + if self.layer_types is None: + # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub + sliding_window_pattern = getattr(self, "sliding_window_pattern", 4) + self.layer_types = [ + "sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + class Cohere2RotaryEmbedding(CohereRotaryEmbedding): pass @@ -261,6 +268,7 @@ class Cohere2Attention(CohereAttention, nn.Module): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -274,9 +282,6 @@ class Cohere2Attention(CohereAttention, nn.Module): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.sliding_window = ( - config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None - ) def forward( self, @@ -299,19 +304,9 @@ class Cohere2Attention(CohereAttention, nn.Module): query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -342,10 +337,7 @@ class Cohere2Attention(CohereAttention, nn.Module): class Cohere2DecoderLayer(CohereDecoderLayer): def __init__(self, config: Cohere2Config, layer_idx: int): super().__init__(config, layer_idx) - self.self_attn = Cohere2Attention(config, layer_idx) - self.config = config - self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 - self.sliding_window = config.sliding_window + self.attention_type = config.layer_types[layer_idx] @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( @@ -378,34 +370,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence """ - - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # In prefill, we may be larger than sliding window - effective_seq_len = max(cache_position.shape[0], self.sliding_window) - # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), - # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -effective_seq_len:] - # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice - # from the left, with an offset if we are beyond the sliding window - else: - min_dtype = torch.finfo(hidden_states.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - # In case we are beyond the sliding window, we need to correctly offset the mask slicing - offset = cache_position[-1] - effective_seq_len + 1 - # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = torch.clamp(offset, min=0) - # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, - # but without data-dependent slicing (i.e. torch.compile friendly) - mask_indexes = torch.arange( - min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device - ) - mask_indexes += offset - attention_mask = attention_mask[:, :, :, mask_indexes] - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -451,7 +415,7 @@ class Cohere2Model(Gemma2Model): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -478,15 +442,7 @@ class Cohere2Model(Gemma2Model): inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - device=self.device, - ) + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -496,9 +452,22 @@ class Cohere2Model(Gemma2Model): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } hidden_states = inputs_embeds @@ -516,7 +485,7 @@ class Cohere2Model(Gemma2Model): layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, @@ -544,91 +513,7 @@ class Cohere2Model(Gemma2Model): class Cohere2ForCausalLM(CohereForCausalLM): - def __init__(self, config: Cohere2Config): - super().__init__(config) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - # Overwritten: has a special cache type, `HybridCache` - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if past_key_values is not None: - if ( - inputs_embeds is not None # Exception 1 - or cache_position[-1] >= input_ids.shape[1] # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s - # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride - # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the - # batch size = 1 case, `position_ids` is already contiguous but with varying stride - # which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if ( - isinstance(past_key_values, HybridCache) - and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" - ): - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) - return model_inputs + pass __all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"] diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 58042c64abb..6f8fd7a487f 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -29,25 +29,19 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from .configuration_csm import CsmConfig, CsmDepthDecoderConfig from .generation_csm import CsmGenerationMixin -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -516,8 +510,13 @@ class CsmDepthDecoderModel(CsmPreTrainedModel): inputs_embeds = self.inputs_embeds_projector(inputs_embeds) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -564,129 +563,6 @@ class CsmDepthDecoderModel(CsmPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class CsmCodebooksHead(nn.Module): def __init__(self, hidden_size, num_codebooks, vocab_size): @@ -946,8 +822,13 @@ class CsmBackboneModel(CsmPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -993,129 +874,6 @@ class CsmBackboneModel(CsmPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - @auto_docstring( custom_intro=""" @@ -1477,61 +1235,6 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, ) - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = [ "CsmPreTrainedModel", diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 8b7a54d2765..35fdf127fcd 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -21,6 +21,7 @@ import torch.nn as nn from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -240,8 +241,13 @@ class CsmDepthDecoderModel(LlamaModel): inputs_embeds = self.inputs_embeds_projector(inputs_embeds) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -835,61 +841,6 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, ) - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = [ "CsmPreTrainedModel", diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 050476355c9..0a530e87ae1 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1001,7 +1001,7 @@ class DbrxModel(DbrxPreTrainedModel): router_logits=all_router_logits, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1071,7 +1071,7 @@ class DbrxModel(DbrxPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 8388f743f2a..b15301e2884 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -15,23 +15,17 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_deepseek_v3 import DeepseekV3Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -608,8 +602,13 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -655,129 +654,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index bc896482d3a..84df7b4d41f 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, _flash_attention_forward, @@ -48,16 +48,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_diffllama import DiffLlamaConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -708,8 +702,13 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -755,129 +754,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 61ccecd6501..3b570fd1f26 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -32,23 +32,17 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -1279,8 +1273,13 @@ class Emu3TextModel(Emu3PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -1326,129 +1325,6 @@ class Emu3TextModel(Emu3PreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1857,62 +1733,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): return model_inputs - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = [ "Emu3ForConditionalGeneration", diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index ade55f93a16..bf2e6a5efa7 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -1212,62 +1212,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): return model_inputs - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = [ "Emu3ForConditionalGeneration", diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d2d80af9fe9..df87d36242e 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -976,7 +976,7 @@ class FalconModel(FalconPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 5913a7f80bc..3a2e20e7cc1 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -45,12 +45,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - auto_docstring, - is_torchdynamo_compiling, - logging, - replace_return_docstrings, -) +from ...utils import auto_docstring, is_torchdynamo_compiling, logging, replace_return_docstrings from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_falcon_h1 import FalconH1Config diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 7344bd0da0e..897f329e56c 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -27,7 +27,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -39,16 +39,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_gemma import GemmaConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -384,7 +378,7 @@ class GemmaModel(GemmaPreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -422,8 +416,13 @@ class GemmaModel(GemmaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) # embed positions @@ -475,129 +474,6 @@ class GemmaModel(GemmaPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index fd7bae38554..1a1e8cc1c63 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional import sentencepiece as spm import torch @@ -22,6 +22,7 @@ from torch import nn from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig +from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging @@ -371,7 +372,7 @@ class GemmaModel(LlamaModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -409,8 +410,13 @@ class GemmaModel(LlamaModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) # embed positions diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index c9e66f8beac..810c10cc928 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation class Gemma2Config(PretrainedConfig): @@ -78,12 +78,16 @@ class Gemma2Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores - sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the - size of the sliding window. - final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): + in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): + scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): + scaling factor when applying tanh softcapping on the attention scores. ```python >>> from transformers import Gemma2Model, Gemma2Config @@ -135,9 +139,9 @@ class Gemma2Config(PretrainedConfig): attention_dropout=0.0, query_pre_attn_scalar=256, sliding_window=4096, + layer_types=None, final_logit_softcapping=30.0, attn_logit_softcapping=50.0, - cache_implementation="hybrid", **kwargs, ): super().__init__( @@ -166,7 +170,13 @@ class Gemma2Config(PretrainedConfig): self.sliding_window = sliding_window self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping - self.cache_implementation = cache_implementation + self.layer_types = layer_types + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) __all__ = ["Gemma2Config"] diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 04f3edbf6ec..fe5576ae1c8 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -26,8 +26,9 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -38,17 +39,11 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, can_return_tuple, logging from ...utils.deprecation import deprecate_kwarg from .configuration_gemma2 import Gemma2Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -195,7 +190,7 @@ class Gemma2Attention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.attn_logit_softcapping = self.config.attn_logit_softcapping - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None def forward( self, @@ -218,19 +213,9 @@ class Gemma2Attention(nn.Module): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - "sliding_window": self.sliding_window, - } + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -264,7 +249,7 @@ class Gemma2DecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size self.config = config - self.is_sliding = not bool(layer_idx % 2) + self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -272,7 +257,6 @@ class Gemma2DecoderLayer(nn.Module): self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.sliding_window = config.sliding_window @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( @@ -287,33 +271,6 @@ class Gemma2DecoderLayer(nn.Module): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # In prefill, we may be larger than sliding window - effective_seq_len = max(cache_position.shape[0], self.sliding_window) - # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), - # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -effective_seq_len:] - # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice - # from the left, with an offset if we are beyond the sliding window - else: - min_dtype = torch.finfo(attention_mask.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - # In case we are beyond the sliding window, we need to correctly offset the mask slicing - offset = cache_position[-1] - effective_seq_len + 1 - # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = torch.clamp(offset, min=0) - # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, - # but without data-dependent slicing (i.e. torch.compile friendly) - mask_indexes = torch.arange( - min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device - ) - mask_indexes += offset - attention_mask = attention_mask[:, :, :, mask_indexes] - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -441,7 +398,7 @@ class Gemma2Model(Gemma2PreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -468,15 +425,7 @@ class Gemma2Model(Gemma2PreTrainedModel): inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - device=self.device, - ) + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -487,9 +436,22 @@ class Gemma2Model(Gemma2PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } # embed positions hidden_states = inputs_embeds @@ -516,7 +478,7 @@ class Gemma2Model(Gemma2PreTrainedModel): partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, position_embeddings, - causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -527,7 +489,7 @@ class Gemma2Model(Gemma2PreTrainedModel): layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -553,100 +515,6 @@ class Gemma2Model(Gemma2PreTrainedModel): attentions=all_self_attns, ) - @torch.no_grad() - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: HybridCache, - output_attentions: bool = False, - ): - # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. - # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape - # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible - # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": - return attention_mask - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, (HybridCache, StaticCache)): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - @auto_docstring class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): @@ -688,7 +556,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -765,60 +633,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): attentions=outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - # Overwritten: has a special cache type, `HybridCache` - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, - ) - - if logits_to_keep is None: - _ = model_inputs.pop("logits_to_keep", None) - - if ( - isinstance(past_key_values, HybridCache) - and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" - ): - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - model_inputs["attention_mask"] = attention_mask - - return model_inputs - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 9a448370dd1..7d0b721d809 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -21,13 +21,14 @@ import torch.nn as nn import torch.utils.checkpoint from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache, StaticCache -from ...configuration_utils import PretrainedConfig +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import is_torch_flex_attn_available, logging +from ...utils import logging from ...utils.deprecation import deprecate_kwarg from ..gemma.modeling_gemma import ( GemmaAttention, @@ -42,12 +43,6 @@ from ..gemma.modeling_gemma import ( ) -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -107,12 +102,16 @@ class Gemma2Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores - sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the - size of the sliding window. - final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. - attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): + in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): + scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): + scaling factor when applying tanh softcapping on the attention scores. ```python >>> from transformers import Gemma2Model, Gemma2Config @@ -164,9 +163,9 @@ class Gemma2Config(PretrainedConfig): attention_dropout=0.0, query_pre_attn_scalar=256, sliding_window=4096, + layer_types=None, final_logit_softcapping=30.0, attn_logit_softcapping=50.0, - cache_implementation="hybrid", **kwargs, ): super().__init__( @@ -195,7 +194,13 @@ class Gemma2Config(PretrainedConfig): self.sliding_window = sliding_window self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping - self.cache_implementation = cache_implementation + self.layer_types = layer_types + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) class Gemma2RMSNorm(GemmaRMSNorm): @@ -250,7 +255,7 @@ class Gemma2Attention(GemmaAttention): self.attention_dropout = self.config.attention_dropout self.is_causal = True self.scaling = config.query_pre_attn_scalar**-0.5 - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None def forward( self, @@ -273,19 +278,9 @@ class Gemma2Attention(GemmaAttention): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - "sliding_window": self.sliding_window, - } + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -319,7 +314,7 @@ class Gemma2DecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size self.config = config - self.is_sliding = not bool(layer_idx % 2) + self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -327,7 +322,6 @@ class Gemma2DecoderLayer(nn.Module): self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.sliding_window = config.sliding_window @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( @@ -342,33 +336,6 @@ class Gemma2DecoderLayer(nn.Module): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # In prefill, we may be larger than sliding window - effective_seq_len = max(cache_position.shape[0], self.sliding_window) - # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), - # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -effective_seq_len:] - # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice - # from the left, with an offset if we are beyond the sliding window - else: - min_dtype = torch.finfo(attention_mask.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - # In case we are beyond the sliding window, we need to correctly offset the mask slicing - offset = cache_position[-1] - effective_seq_len + 1 - # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = torch.clamp(offset, min=0) - # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, - # but without data-dependent slicing (i.e. torch.compile friendly) - mask_indexes = torch.arange( - min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device - ) - mask_indexes += offset - attention_mask = attention_mask[:, :, :, mask_indexes] - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -414,7 +381,7 @@ class Gemma2Model(GemmaModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -441,15 +408,7 @@ class Gemma2Model(GemmaModel): inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - device=self.device, - ) + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -460,9 +419,22 @@ class Gemma2Model(GemmaModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } # embed positions hidden_states = inputs_embeds @@ -489,7 +461,7 @@ class Gemma2Model(GemmaModel): partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, position_embeddings, - causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -500,7 +472,7 @@ class Gemma2Model(GemmaModel): layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -526,45 +498,6 @@ class Gemma2Model(GemmaModel): attentions=all_self_attns, ) - @torch.no_grad() - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: HybridCache, - output_attentions: bool = False, - ): - # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. - # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape - # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible - # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": - return attention_mask - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, (HybridCache, StaticCache)): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - return causal_mask - class Gemma2ForCausalLM(GemmaForCausalLM): def __init__(self, config): @@ -577,7 +510,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -649,60 +582,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM): attentions=outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - # Overwritten: has a special cache type, `HybridCache` - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, - ) - - if logits_to_keep is None: - _ = model_inputs.pop("logits_to_keep", None) - - if ( - isinstance(past_key_values, HybridCache) - and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" - ): - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - model_inputs["attention_mask"] = attention_mask - - return model_inputs - class Gemma2ForSequenceClassification(GemmaForSequenceClassification): def __init__(self, config): diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index a1680b7f5aa..db2749644cc 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -21,7 +21,7 @@ # limitations under the License. from typing import Any, Dict, Optional, Union -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation from ...utils import logging from ..siglip import SiglipVisionConfig @@ -88,13 +88,14 @@ class Gemma3TextConfig(PretrainedConfig): The dropout ratio for the attention probabilities. query_pre_attn_scalar (`float`, *optional*, defaults to 256): Scaling factor used on the attention scores - sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the - size of the sliding window. + sliding_window (`int`, *optional*, defaults to 4096): + In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. final_logit_softcapping (`float`, *optional*): Scaling factor when applying tanh softcapping on the logits. attn_logit_softcapping (`float`, *optional*): Scaling factor when applying tanh softcapping on the attention scores. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value @@ -134,8 +135,6 @@ class Gemma3TextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. - sliding_window_pattern (`int`, *optional*, defaults to 6): - Pattern for the sliding window attention. ```python >>> from transformers import Gemma3TextModel, Gemma3TextConfig @@ -192,12 +191,11 @@ class Gemma3TextConfig(PretrainedConfig): attention_dropout=0.0, query_pre_attn_scalar=256, sliding_window=4096, + layer_types=None, final_logit_softcapping=None, attn_logit_softcapping=None, - cache_implementation="hybrid", rope_scaling=None, rope_local_base_freq=10_000.0, - sliding_window_pattern=6, **kwargs, ): super().__init__( @@ -226,14 +224,21 @@ class Gemma3TextConfig(PretrainedConfig): self.sliding_window = sliding_window self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping - self.cache_implementation = cache_implementation + self.layer_types = layer_types self.rope_local_base_freq = rope_local_base_freq - # For configuring HybridCache to work with 5:1 attention pattern - self.sliding_window_pattern = sliding_window_pattern self.rope_scaling = rope_scaling rope_config_validation(self) + if self.layer_types is None: + # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub + sliding_window_pattern = getattr(self, "sliding_window_pattern", 6) + self.layer_types = [ + "sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + class Gemma3Config(PretrainedConfig): r""" diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 9aeba5f18aa..122d16aafce 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -29,32 +29,21 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache, StaticCache +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - auto_docstring, - can_return_tuple, - is_torch_flex_attn_available, - is_torchdynamo_compiling, - logging, -) +from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -300,7 +289,7 @@ class Gemma3Attention(nn.Module): def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__() - self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) @@ -351,19 +340,9 @@ class Gemma3Attention(nn.Module): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - "sliding_window": self.sliding_window, - } + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -374,9 +353,7 @@ class Gemma3Attention(nn.Module): ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_mask is not None: - # backwards compatibility - attention_mask = attention_mask.to(query_states) + attn_output, attn_weights = attention_interface( self, query_states, @@ -400,14 +377,13 @@ class Gemma3DecoderLayer(nn.Module): self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma3MLP(config) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.is_sliding = self.self_attn.is_sliding - self.sliding_window = config.sliding_window @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( @@ -423,33 +399,6 @@ class Gemma3DecoderLayer(nn.Module): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # In prefill, we may be larger than sliding window - effective_seq_len = max(cache_position.shape[0], self.sliding_window) - # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), - # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -effective_seq_len:] - # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice - # from the left, with an offset if we are beyond the sliding window - else: - min_dtype = torch.finfo(attention_mask.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - # In case we are beyond the sliding window, we need to correctly offset the mask slicing - offset = cache_position[-1] - effective_seq_len + 1 - # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = torch.clamp(offset, min=0) - # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, - # but without data-dependent slicing (i.e. torch.compile friendly) - mask_indexes = torch.arange( - min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device - ) - mask_indexes += offset - attention_mask = attention_mask[:, :, :, mask_indexes] - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -568,7 +517,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -595,13 +544,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - ) + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -614,13 +557,22 @@ class Gemma3TextModel(Gemma3PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - output_attentions, - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } # embed positions hidden_states = inputs_embeds @@ -643,7 +595,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): hidden_states, position_embeddings_global, position_embeddings_local, - causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -655,7 +607,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): hidden_states, position_embeddings_global=position_embeddings_global, position_embeddings_local=position_embeddings_local, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -681,100 +633,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): attentions=all_self_attns, ) - @torch.no_grad() - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: HybridCache, - output_attentions: bool = False, - ): - # Flash Attention currently doesn't support static cache but Gemma3Text work only with static cache. - # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape - # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible - # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": - return attention_mask - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, (HybridCache, StaticCache)): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - @auto_docstring class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): @@ -818,7 +676,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -895,60 +753,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): attentions=outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - # Overwritten: has a special cache type, `HybridCache` - - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, - ) - - if logits_to_keep is None: - _ = model_inputs.pop("logits_to_keep", None) - - if ( - isinstance(past_key_values, HybridCache) - and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" - ): - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - ) - model_inputs["attention_mask"] = attention_mask - - return model_inputs - class Gemma3MultiModalProjector(nn.Module): def __init__(self, config: Gemma3Config): @@ -986,6 +790,22 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) +def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]: + """ + This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, + not start and end indices. + """ + # Do not return an additional mask in this case + if token_type_ids is None: + return None + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + # If it's 1, we need to unmask it + return token_type_ids[batch_idx, kv_idx] == 1 + + return inner_mask + + @auto_docstring( custom_intro=""" The Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head., @@ -1012,86 +832,6 @@ class Gemma3Model(Gemma3PreTrainedModel): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def _update_causal_mask( - self, - attention_mask, - token_type_ids, - past_key_values, - cache_position, - input_tensor, - is_training: bool = False, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - return attention_mask - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted - # form and requires no inversion or slicing. - return attention_mask - - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(self.dtype).min - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device - ) - - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - - # Apply bidirectional mask on images if token type ids are provided - if token_type_ids is not None and sequence_length != 1: - token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) - token_type_mask[token_type_ids == 0] = False # if text token do not change anything - - # Find where a new image block starts: 1 if image and previous not image - # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally - is_image = token_type_ids == 1 - new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] - image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 - image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) - - same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2) - same_image_mask[image_group_ids == -1] = False # remove non-image - image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool) - - causal_mask = causal_mask.clone() - causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( - image_mask, 0.0 - ) - - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - - # Then apply padding mask (will mask pad tokens) - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Projects the last hidden state from the vision model into language model space. @@ -1161,8 +901,6 @@ class Gemma3Model(Gemma3PreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - is_training = token_type_ids is not None and labels is not None - # Replace image id woth PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id @@ -1202,11 +940,31 @@ class Gemma3Model(Gemma3PreTrainedModel): image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config.get_text_config(), + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + if token_type_ids is not None and inputs_embeds.shape[1] != 1: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device) + ) + + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + outputs = self.language_model( - attention_mask=causal_mask, + attention_mask=causal_mask_mapping, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1432,70 +1190,35 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self.model._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask return model_inputs @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, + def create_masks_for_generate( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, - batch_size: int, + past_key_values: Optional[Cache], + output_attentions: bool = False, + token_type_ids: Optional[torch.Tensor] = None, **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + ) -> dict: + # Prepare mask arguments + mask_kwargs = { + "config": config.get_text_config(), + "input_embeds": input_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Add the token type ids mask for generate as well + if token_type_ids is not None and input_embeds.shape[1] != 1: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device)) - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask + return create_masks_for_generate(**mask_kwargs) __all__ = [ diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 495fe167d79..f0761d863d1 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -23,8 +23,9 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from ...cache_utils import Cache, HybridCache, StaticCache -from ...configuration_utils import PretrainedConfig +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import rope_config_validation @@ -56,7 +57,7 @@ from ..siglip import SiglipVisionConfig logger = logging.get_logger(__name__) -class Gemma3TextConfig(Gemma2Config): +class Gemma3TextConfig(Gemma2Config, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text model according to the specified arguments, defining the model architecture. Instantiating a configuration with the @@ -114,13 +115,14 @@ class Gemma3TextConfig(Gemma2Config): The dropout ratio for the attention probabilities. query_pre_attn_scalar (`float`, *optional*, defaults to 256): Scaling factor used on the attention scores - sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the - size of the sliding window. + sliding_window (`int`, *optional*, defaults to 4096): + In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. final_logit_softcapping (`float`, *optional*): Scaling factor when applying tanh softcapping on the logits. attn_logit_softcapping (`float`, *optional*): Scaling factor when applying tanh softcapping on the attention scores. - cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value @@ -160,8 +162,6 @@ class Gemma3TextConfig(Gemma2Config): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. - sliding_window_pattern (`int`, *optional*, defaults to 6): - Pattern for the sliding window attention. ```python >>> from transformers import Gemma3TextModel, Gemma3TextConfig @@ -183,23 +183,74 @@ class Gemma3TextConfig(Gemma2Config): def __init__( self, vocab_size=262_208, - rope_theta=1_000_000.0, - rope_scaling=None, - rope_local_base_freq=10_000.0, - sliding_window_pattern=6, + hidden_size=2304, + intermediate_size=9216, + num_hidden_layers=26, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", max_position_embeddings=131_072, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=1_000_000.0, + attention_bias=False, + attention_dropout=0.0, + query_pre_attn_scalar=256, + sliding_window=4096, + layer_types=None, final_logit_softcapping=None, attn_logit_softcapping=None, - **super_kwargs, + rope_scaling=None, + rope_local_base_freq=10_000.0, + **kwargs, ): - super().__init__(self, **super_kwargs) + PretrainedConfig.__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.layer_types = layer_types self.rope_local_base_freq = rope_local_base_freq - # For configuring HybridCache to work with 5:1 attention pattern - self.sliding_window_pattern = sliding_window_pattern self.rope_scaling = rope_scaling rope_config_validation(self) + if self.layer_types is None: + # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub + sliding_window_pattern = getattr(self, "sliding_window_pattern", 6) + self.layer_types = [ + "sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + class Gemma3Config(PretrainedConfig): r""" @@ -336,7 +387,7 @@ class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding): # Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding` class Gemma3Attention(Gemma2Attention): def __init__(self, config: Gemma3TextConfig, layer_idx: int): - self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" super().__init__() self.sliding_window = config.sliding_window if self.is_sliding else None @@ -368,19 +419,9 @@ class Gemma3Attention(Gemma2Attention): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - "sliding_window": self.sliding_window, - } + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -391,9 +432,7 @@ class Gemma3Attention(Gemma2Attention): ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_mask is not None: - # backwards compatibility - attention_mask = attention_mask.to(query_states) + attn_output, attn_weights = attention_interface( self, query_states, @@ -417,14 +456,13 @@ class Gemma3DecoderLayer(nn.Module): self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma3MLP(config) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.is_sliding = self.self_attn.is_sliding - self.sliding_window = config.sliding_window @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( @@ -440,33 +478,6 @@ class Gemma3DecoderLayer(nn.Module): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # In prefill, we may be larger than sliding window - effective_seq_len = max(cache_position.shape[0], self.sliding_window) - # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), - # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -effective_seq_len:] - # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice - # from the left, with an offset if we are beyond the sliding window - else: - min_dtype = torch.finfo(attention_mask.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - # In case we are beyond the sliding window, we need to correctly offset the mask slicing - offset = cache_position[-1] - effective_seq_len + 1 - # Should only be used when beyond the sliding window (i.e. offset > 0) - offset = torch.clamp(offset, min=0) - # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, - # but without data-dependent slicing (i.e. torch.compile friendly) - mask_indexes = torch.arange( - min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device - ) - mask_indexes += offset - attention_mask = attention_mask[:, :, :, mask_indexes] - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -557,7 +568,7 @@ class Gemma3TextModel(Gemma2Model): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -584,13 +595,7 @@ class Gemma3TextModel(Gemma2Model): inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None and not self.training: - batch_size, seq_len, _ = inputs_embeds.shape - past_key_values = HybridCache( - self.config, - max_batch_size=batch_size, - max_cache_len=seq_len, - dtype=inputs_embeds.dtype, - ) + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -603,13 +608,22 @@ class Gemma3TextModel(Gemma2Model): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - output_attentions, - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } # embed positions hidden_states = inputs_embeds @@ -632,7 +646,7 @@ class Gemma3TextModel(Gemma2Model): hidden_states, position_embeddings_global, position_embeddings_local, - causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -644,7 +658,7 @@ class Gemma3TextModel(Gemma2Model): hidden_states, position_embeddings_global=position_embeddings_global, position_embeddings_local=position_embeddings_local, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -716,6 +730,22 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) +def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]: + """ + This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, + not start and end indices. + """ + # Do not return an additional mask in this case + if token_type_ids is None: + return None + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + # If it's 1, we need to unmask it + return token_type_ids[batch_idx, kv_idx] == 1 + + return inner_mask + + class Gemma3Model(PaliGemmaModel): def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: """ @@ -731,85 +761,8 @@ class Gemma3Model(PaliGemmaModel): image_features = self.multi_modal_projector(vision_outputs) return image_features - def _update_causal_mask( - self, - attention_mask, - token_type_ids, - past_key_values, - cache_position, - input_tensor, - is_training: bool = False, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - return attention_mask - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted - # form and requires no inversion or slicing. - return attention_mask - - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(self.dtype).min - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device - ) - - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - - # Apply bidirectional mask on images if token type ids are provided - if token_type_ids is not None and sequence_length != 1: - token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) - token_type_mask[token_type_ids == 0] = False # if text token do not change anything - - # Find where a new image block starts: 1 if image and previous not image - # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally - is_image = token_type_ids == 1 - new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] - image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 - image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) - - same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2) - same_image_mask[image_group_ids == -1] = False # remove non-image - image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool) - - causal_mask = causal_mask.clone() - causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( - image_mask, 0.0 - ) - - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - - # Then apply padding mask (will mask pad tokens) - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask + def _update_causal_mask(self, **super_kwargs): + raise AttributeError("We don't want to inherit it") @can_return_tuple @auto_docstring @@ -839,8 +792,6 @@ class Gemma3Model(PaliGemmaModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - is_training = token_type_ids is not None and labels is not None - # Replace image id woth PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id @@ -880,11 +831,31 @@ class Gemma3Model(PaliGemmaModel): image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config.get_text_config(), + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + if token_type_ids is not None and inputs_embeds.shape[1] != 1: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device) + ) + + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + outputs = self.language_model( - attention_mask=causal_mask, + attention_mask=causal_mask_mapping, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1066,16 +1037,39 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self.model._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask return model_inputs + def _prepare_4d_causal_attention_mask_with_cache_position(self, **super_kwargs): + raise AttributeError("We don't want to inherit it") + + @staticmethod + def create_masks_for_generate( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cache_position: torch.Tensor, + past_key_values: Optional[Cache], + output_attentions: bool = False, + token_type_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # Prepare mask arguments + mask_kwargs = { + "config": config.get_text_config(), + "input_embeds": input_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Add the token type ids mask for generate as well + if token_type_ids is not None and input_embeds.shape[1] != 1: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device)) + + return create_masks_for_generate(**mask_kwargs) + __all__ = [ "Gemma3Config", diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 607a3b7f693..f3ac600e22b 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -40,16 +40,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_glm import GlmConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -443,8 +437,13 @@ class GlmModel(GlmPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -490,129 +489,6 @@ class GlmModel(GlmPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 98ee428e2c4..4525ba15018 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -40,16 +40,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_glm4 import Glm4Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -451,8 +445,13 @@ class Glm4Model(Glm4PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -498,129 +497,6 @@ class Glm4Model(Glm4PreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - @auto_docstring class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin): diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index f99598c4dcc..6da4405fad5 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -893,60 +893,5 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): return model_inputs - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"] diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index b67bd37c6b2..95de9e82d5e 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -682,7 +682,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): attentions=all_self_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -752,7 +752,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 00eaa2b325f..9c32acdb06a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -12,7 +12,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -24,16 +24,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_gpt_neox import GPTNeoXConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -409,8 +403,13 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) # Prepare head mask if needed @@ -479,129 +478,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): attentions=all_attentions, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 8e01a82b03f..70bee31b280 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -7,6 +7,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -349,8 +350,13 @@ class GPTNeoXModel(LlamaModel, nn.Module): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) # Prepare head mask if needed diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index b5c11163a61..f4a073cc4a7 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -540,7 +540,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): attentions=all_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -610,7 +610,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 9e6210604e1..093daaef193 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -793,7 +793,6 @@ class GPTJModel(GPTJPreTrainedModel): attentions=all_self_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -863,7 +862,6 @@ class GPTJModel(GPTJPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 52eea021799..fdba3f4c0eb 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -28,23 +28,17 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_granite import GraniteConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -446,8 +440,13 @@ class GraniteModel(GranitePreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -493,129 +492,6 @@ class GraniteModel(GranitePreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 1d0ff532d8a..424a0cc3fa2 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -20,6 +20,7 @@ import torch.utils.checkpoint from torch import nn from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...processing_utils import Unpack @@ -174,8 +175,13 @@ class GraniteModel(LlamaModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index df6fd0f9aae..fdd7addc450 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -773,7 +773,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): router_logits=all_router_logits, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -843,7 +843,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 5c83db0ab72..5d58ca59458 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -28,7 +28,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -40,16 +40,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_helium import HeliumConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -428,8 +422,13 @@ class HeliumModel(HeliumPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -475,129 +474,6 @@ class HeliumModel(HeliumPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 80c9af0f91e..9aabc686795 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1316,7 +1316,7 @@ class IdeficsModel(IdeficsPreTrainedModel): image_hidden_states=image_hidden_states, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1386,7 +1386,7 @@ class IdeficsModel(IdeficsPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 106aaaa3c36..26463c20091 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -1035,61 +1035,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) return model_inputs - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = [ "InternVLVisionPreTrainedModel", diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index e241898bd33..788b2066b5d 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1014,7 +1014,7 @@ class JetMoeModel(JetMoePreTrainedModel): router_logits=all_router_logits, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1084,7 +1084,7 @@ class JetMoeModel(JetMoePreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a25f05ea7c3..1718c587d94 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -26,7 +26,8 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -40,18 +41,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_llama import LlamaConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - -from ...integrations import use_kernel_forward_from_hub - - logger = logging.get_logger(__name__) @@ -433,8 +426,13 @@ class LlamaModel(LlamaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -480,129 +478,6 @@ class LlamaModel(LlamaPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/llama4/configuration_llama4.py b/src/transformers/models/llama4/configuration_llama4.py index 675bf6a5ef8..42f9442d678 100644 --- a/src/transformers/models/llama4/configuration_llama4.py +++ b/src/transformers/models/llama4/configuration_llama4.py @@ -15,7 +15,7 @@ # limitations under the License. -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...utils import logging @@ -233,12 +233,13 @@ class Llama4TextConfig(PretrainedConfig): `no_rope_layer_interval` layers. attention_chunk_size (`int`, *optional*, defaults to 8192): + layer_types (`list`, *optional*): + Attention pattern for each layer. attn_temperature_tuning (`bool`, *optional*, defaults to `True`): Whether to dynamically scale the attention temperature for each query token based on sequence length. Recommended for long sequences (e.g., >32k tokens) to maintain stable output results. floor_scale (`int`, *optional*, defaults to 8192): TODO attn_scale (`int`, *optional*, defaults to 0.1): TODO - cache_implementation (``, *optional*, defaults to `"hybrid"`): Example: """ @@ -298,10 +299,10 @@ class Llama4TextConfig(PretrainedConfig): no_rope_layers=None, no_rope_layer_interval=4, attention_chunk_size=8192, + layer_types=None, attn_temperature_tuning=True, floor_scale=8192, attn_scale=0.1, - cache_implementation="hybrid", **kwargs, ): super().__init__( @@ -323,7 +324,6 @@ class Llama4TextConfig(PretrainedConfig): self.num_attention_heads = num_attention_heads self.rope_scaling = rope_scaling self.attention_bias = False - self.cache_implementation = cache_implementation # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -363,6 +363,13 @@ class Llama4TextConfig(PretrainedConfig): ) self.attention_chunk_size = attention_chunk_size + self.layer_types = layer_types + if layer_types is None: + self.layer_types = [ + "chunked_attention" if no_rope else "full_attention" for no_rope in self.no_rope_layers + ] + layer_type_validation(self.layer_types) + class Llama4Config(PretrainedConfig): r""" diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index f8274bf1822..82caab17dba 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -24,24 +24,19 @@ import torch.nn.functional as F from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, HybridChunkedCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations.hub_kernels import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_chunked_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_llama4 import Llama4Config, Llama4TextConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - logger = logging.get_logger(__name__) @@ -388,8 +383,9 @@ class Llama4TextDecoderLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] self.self_attn = Llama4TextAttention(config, layer_idx) - self.use_chunked_attention = config.attention_chunk_size is not None and bool(config.no_rope_layers[layer_idx]) self.is_moe_layer = layer_idx in config.moe_layers if self.is_moe_layer: # the 128E model interleaves dense / sparse self.feed_forward = Llama4TextMoe(config) @@ -399,13 +395,10 @@ class Llama4TextDecoderLayer(nn.Module): self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.layer_idx = layer_idx - def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - chunk_causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, @@ -419,10 +412,6 @@ class Llama4TextDecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) - # use local attention mask for ROPE layers - if self.use_chunked_attention and chunk_causal_mask is not None: - attention_mask = chunk_causal_mask - # Self Attention attention_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, @@ -561,10 +550,7 @@ class Llama4TextModel(Llama4PreTrainedModel): inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) if use_cache and past_key_values is None: - if self.config.get_text_config().attention_chunk_size is not None: - past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1]) - else: - past_key_values = DynamicCache() + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -575,9 +561,22 @@ class Llama4TextModel(Llama4PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask, chunk_causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, use_cache=use_cache - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "chunked_attention": create_chunked_causal_mask(**mask_kwargs), + } hidden_states = inputs_embeds @@ -596,8 +595,7 @@ class Llama4TextModel(Llama4PreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, - chunk_causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -609,8 +607,7 @@ class Llama4TextModel(Llama4PreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, - chunk_causal_mask=chunk_causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -638,216 +635,6 @@ class Llama4TextModel(Llama4PreTrainedModel): attentions=all_self_attns, ) - @torch.compiler.disable(recursive=False) # the operations in this method are not compilable - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - chunked_attention_mask=None, - use_cache=True, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask, attention_mask # flash does not support chunked attn TODO support flash - return None, None - - if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]: - return None, None - - sequence_length = input_tensor.shape[1] - cache_position = cache_position.to(self.device) - attention_chunk_size = self.config.attention_chunk_size - using_chunked_attention = attention_chunk_size is not None - - first_cache_position = cache_position[0] - - if past_key_values is not None: - full_cache_length = past_key_values.get_max_cache_shape() or sequence_length - else: - full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length - - if using_chunked_attention: - cond1 = first_cache_position >= attention_chunk_size - cond2 = (first_cache_position < attention_chunk_size) & ( - first_cache_position + sequence_length > attention_chunk_size - ) - key_length = ( - torch.where( - cond1, - attention_chunk_size + sequence_length - 1, - torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), - ) - if use_cache - else full_cache_length - ) - - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - if using_chunked_attention: - offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0)) - chunked_attention_mask = make_flex_block_causal_mask( - attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets - ) - attention_mask = make_flex_block_causal_mask( - attention_mask, - query_length=sequence_length, - key_length=full_cache_length, - offsets=(first_cache_position, 0), - ) - return attention_mask, chunked_attention_mask - if isinstance(attention_mask, BlockMask): - return attention_mask, chunked_attention_mask - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - dtype, device = input_tensor.dtype, input_tensor.device - target_length = max(full_cache_length, attention_chunk_size) if using_chunked_attention else full_cache_length - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - if using_chunked_attention and full_cache_length > attention_chunk_size: - start_idx = max(first_cache_position - attention_chunk_size + 1, 0) - end_idx = start_idx + key_length - chunked_attention_mask = self.create_chunked_attention_mask( - self.config.attention_chunk_size, - start=start_idx, # same offset as with flex - end=end_idx, - device=device, - ) - - local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well - # It may be smaller than attention_chunk_size -> pad it - requires_padding = local_attention_mask.shape[-1] < attention_chunk_size - if requires_padding: - local_attention_mask = nn.functional.pad( - local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) - ) - # Depending on the padding, take the query tokens from the end or the cache_position - if not requires_padding: - chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] - else: - chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] - - chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1) - chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] - if self.config._attn_implementation == "eager": - min_dtype = torch.finfo(dtype).min - chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and attention_mask.ndim == 4 - and not output_attentions # Only unmask for 4d masks - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: - chunked_attention_mask = chunked_attention_mask.bool() - causal_mask = causal_mask != torch.finfo(dtype).min - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=first_cache_position, - is_training=self.training, - ): - causal_mask = None - return causal_mask, chunked_attention_mask - - def create_chunked_attention_mask( - self, attention_chunk_size: int, start: int, end: int, device: torch.device - ) -> torch.Tensor: - """ - Generate the following: - - 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | - '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | - '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | - 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | - '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | - '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | - - If the chunk size is 3. - This can just be applied over the already created attention mask - """ - arange_vector = torch.arange(start, end, device=device) - block_pos = torch.abs( - arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size - ) - token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) - mask = (block_pos == 0) & (token_pos <= 0) - return mask.to(device) - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - cache_position.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1726,61 +1513,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): return model_inputs - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = [ "Llama4PreTrainedModel", diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index b27624241ac..448879ec06f 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -503,61 +503,5 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): return model_inputs - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"] diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index fa92bc62236..496049e3123 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -719,7 +719,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi return model_inputs @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 7cb9aa8e948..62e696e8111 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1589,7 +1589,7 @@ class LongT5Stack(LongT5PreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1659,7 +1659,7 @@ class LongT5Stack(LongT5PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 1da696ae038..757a393a0bc 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -790,7 +790,7 @@ class M2M100PreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -860,7 +860,7 @@ class M2M100PreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index ee52dd1be06..a9a3fd353ec 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -484,7 +484,7 @@ class MarianPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -554,7 +554,7 @@ class MarianPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 1cc334db7e0..3fbb3e8b5be 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -762,7 +762,7 @@ class MBartPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -832,7 +832,7 @@ class MBartPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 515fdc4f045..4d7b92979ab 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1060,7 +1060,7 @@ class MimiTransformerModel(nn.Module): attentions=all_self_attns, ) - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Mimi + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Mimi def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1148,7 +1148,7 @@ class MimiTransformerModel(nn.Module): return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mimi + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Mimi def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 2748a71f527..92ef09fc739 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -10,10 +10,10 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -26,16 +26,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_mistral import MistralConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -403,8 +397,14 @@ class MistralModel(MistralPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -450,161 +450,6 @@ class MistralModel(MistralPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: MistralConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`MistralConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index d56f2d2f370..9dd2e051b56 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -4,13 +4,13 @@ import torch import torch.utils.checkpoint from torch import nn -from ...cache_utils import Cache, SlidingWindowCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, QuestionAnsweringModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import is_torch_flex_attn_available, logging +from ...utils import auto_docstring, can_return_tuple, logging from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -27,12 +27,6 @@ from ..llama.modeling_llama import ( from .configuration_mistral import MistralConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" @@ -118,166 +112,107 @@ class MistralPreTrainedModel(LlamaPreTrainedModel): class MistralModel(LlamaModel): - def __init__(self, config: MistralConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - - def _update_causal_mask( + @can_return_tuple + @auto_docstring + def forward( self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, past_key_values=past_key_values, + output_attentions=output_attentions, ) - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + hidden_states = inputs_embeds - return causal_mask + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: MistralConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`MistralConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) class MistralForCausalLM(LlamaForCausalLM): diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index a74a4663fec..625e1c3185e 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -32,12 +32,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - LossKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, -) +from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling from ..auto import AutoModel from .configuration_mistral3 import Mistral3Config @@ -538,60 +533,5 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) return model_inputs - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"] diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index cfb5b1be8b6..9f176a35d89 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -32,10 +32,10 @@ import torch.nn.functional as F from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -48,16 +48,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_mixtral import MixtralConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -524,8 +518,14 @@ class MixtralModel(MixtralPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -591,161 +591,6 @@ class MixtralModel(MixtralPreTrainedModel): router_logits=all_router_logits, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: MixtralConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`MixtralConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 726d2e99588..95d8defddde 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -29,6 +29,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import DynamicCache +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack @@ -315,12 +316,6 @@ class MixtralPreTrainedModel(MistralPreTrainedModel): class MixtralModel(MistralModel): - def __init__(self, config: MixtralConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -368,8 +363,14 @@ class MixtralModel(MistralModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index f1e634d9115..fcccd2b9ea6 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -893,7 +893,7 @@ class MllamaPreTrainedModel(PreTrainedModel): if module.is_gated: module.gate.data.zero_() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -963,7 +963,7 @@ class MllamaPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index d849d1945ba..edc70eafa08 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -27,11 +27,8 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_attention_mask, - _prepare_4d_attention_mask_for_sdpa, -) +from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -44,16 +41,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, can_return_tuple, logging from .configuration_moonshine import MoonshineConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -752,8 +743,13 @@ class MoonshineDecoder(MoonshinePreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -826,129 +822,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel): cross_attentions=all_cross_attentions, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - def _compute_mask_indices( shape: Tuple[int, int], diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index edb1a70279e..c22198843c4 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -21,6 +21,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -748,8 +749,13 @@ class MoonshineDecoder(LlamaModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index fe0c1024016..7e71eb2ce20 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1084,7 +1084,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): attentions=all_self_attns, ) - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Moshi + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Moshi def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1172,7 +1172,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->MoshiDepth + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->MoshiDepth def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1391,7 +1391,7 @@ class MoshiModel(MoshiPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Moshi + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Moshi def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1479,7 +1479,7 @@ class MoshiModel(MoshiPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Moshi + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Moshi def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 500b2ca4312..6b488b66d22 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1180,7 +1180,7 @@ class MT5Stack(MT5PreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1250,7 +1250,7 @@ class MT5Stack(MT5PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 3e1ce01041b..ea0f0de0456 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -739,7 +739,7 @@ class NemotronModel(NemotronPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -809,7 +809,7 @@ class NemotronModel(NemotronPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 72abff85dd5..fe4e081a3e4 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -13,23 +13,17 @@ import torch.nn.functional as F from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_olmo import OlmoConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -406,8 +400,13 @@ class OlmoModel(OlmoPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -453,129 +452,6 @@ class OlmoModel(OlmoPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 58cad693741..dd2fcdb17f5 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -13,23 +13,17 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_olmo2 import Olmo2Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -412,8 +406,13 @@ class Olmo2Model(Olmo2PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -459,129 +458,6 @@ class Olmo2Model(Olmo2PreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 1f115cb8a7b..480687ae7f0 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -384,7 +384,7 @@ class OPTDecoder(OPTPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -454,7 +454,7 @@ class OPTDecoder(OPTPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index b584a3b1318..15180b91b96 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -566,7 +566,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi return model_inputs @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 9f57f79df4e..3f59a8c9186 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -481,7 +481,7 @@ class PegasusPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -551,7 +551,7 @@ class PegasusPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index c9554255e47..04cf37a7622 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -773,7 +773,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -843,7 +843,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 2e53ce9bb4c..3f8983083e2 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -539,7 +539,7 @@ class PersimmonModel(PersimmonPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -609,7 +609,7 @@ class PersimmonModel(PersimmonPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 35df0f78cf2..4dade5e49e2 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -13,7 +13,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -24,16 +24,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_phi import PhiConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -400,8 +394,13 @@ class PhiModel(PhiPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama @@ -461,129 +460,6 @@ class PhiModel(PhiPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 5faee931e0a..6f9edaba941 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -244,8 +245,13 @@ class PhiModel(LlamaModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index a914461ad73..52cf0ef96d3 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -26,10 +26,10 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -41,16 +41,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_phi3 import Phi3Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -458,8 +452,14 @@ class Phi3Model(Phi3PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -505,161 +505,6 @@ class Phi3Model(Phi3PreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Phi3Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Phi3Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 2d6a2a7f155..6b6cef7df3e 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -28,13 +28,12 @@ import torch.nn.functional as F from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask - from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -46,16 +45,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging, torch_int +from ...utils import auto_docstring, can_return_tuple, logging, torch_int from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -1766,8 +1759,14 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -1813,161 +1812,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Phi4Multimodal. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Phi4MultimodalConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Phi4MultimodalConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - @auto_docstring class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 5b3a585bc1d..344e6c1776e 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -21,11 +21,11 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask - from ...activations import ACT2FN from ...cache_utils import DynamicCache from ...configuration_utils import PretrainedConfig +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, @@ -1570,8 +1570,14 @@ class Phi4MultimodalModel(Phi3Model, nn.Module): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 5002014006b..e81d38e2d88 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1068,7 +1068,6 @@ class PhimoeModel(PhimoePreTrainedModel): router_logits=all_router_logits, ) - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Phimoe def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1156,7 +1155,6 @@ class PhimoeModel(PhimoePreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phimoe def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 243c81a0c85..f9a5b00218d 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1345,7 +1345,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1415,7 +1415,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index c82cfb90697..614baee8bf3 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -520,7 +520,7 @@ class PLBartPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -590,7 +590,7 @@ class PLBartPreTrainedModel(PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5191a6a4901..c63b4df774b 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -900,7 +900,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -970,7 +970,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/qwen2/configuration_qwen2.py b/src/transformers/models/qwen2/configuration_qwen2.py index fde5bd502b5..f89be07bf13 100644 --- a/src/transformers/models/qwen2/configuration_qwen2.py +++ b/src/transformers/models/qwen2/configuration_qwen2.py @@ -14,7 +14,7 @@ # limitations under the License. """Qwen2 model configuration""" -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation from ...utils import logging @@ -110,6 +110,8 @@ class Qwen2Config(PretrainedConfig): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 28): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + layer_types (`list`, *optional*): + Attention pattern for each layer. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. @@ -164,6 +166,7 @@ class Qwen2Config(PretrainedConfig): use_sliding_window=False, sliding_window=4096, max_window_layers=28, + layer_types=None, attention_dropout=0.0, **kwargs, ): @@ -174,7 +177,7 @@ class Qwen2Config(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code + self.sliding_window = sliding_window if self.use_sliding_window else None self.max_window_layers = max_window_layers # for backward compatibility @@ -195,6 +198,16 @@ class Qwen2Config(PretrainedConfig): self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 46fc9b720c8..9e9b0641f0d 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -10,10 +10,10 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -26,16 +26,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_qwen2 import Qwen2Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -143,6 +137,7 @@ class Qwen2Attention(nn.Module): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None def forward( self, @@ -168,14 +163,6 @@ class Qwen2Attention(nn.Module): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - sliding_window = None - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -194,7 +181,7 @@ class Qwen2Attention(nn.Module): attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=sliding_window, # main diff with Llama + sliding_window=self.sliding_window, # main diff with Llama **kwargs, ) @@ -228,15 +215,13 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) + self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -357,6 +342,7 @@ class Qwen2Model(Qwen2PreTrainedModel): self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types # Initialize weights and apply final processing self.post_init() @@ -416,9 +402,24 @@ class Qwen2Model(Qwen2PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds @@ -435,7 +436,7 @@ class Qwen2Model(Qwen2PreTrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -463,161 +464,6 @@ class Qwen2Model(Qwen2PreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index c7ac5a6a312..5a24b425f0c 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -4,11 +4,15 @@ import torch import torch.utils.checkpoint from torch import nn -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import logging +from ...utils import auto_docstring, can_return_tuple, logging from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -43,6 +47,7 @@ class Qwen2Attention(LlamaAttention): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None def forward( self, @@ -68,14 +73,6 @@ class Qwen2Attention(LlamaAttention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - sliding_window = None - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -94,7 +91,7 @@ class Qwen2Attention(LlamaAttention): attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=sliding_window, # main diff with Llama + sliding_window=self.sliding_window, # main diff with Llama **kwargs, ) @@ -106,13 +103,7 @@ class Qwen2Attention(LlamaAttention): class Qwen2DecoderLayer(LlamaDecoderLayer): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() - self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) - self.mlp = Qwen2MLP(config) - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) + self.attention_type = config.layer_types[layer_idx] class Qwen2PreTrainedModel(LlamaPreTrainedModel): @@ -120,7 +111,120 @@ class Qwen2PreTrainedModel(LlamaPreTrainedModel): class Qwen2Model(MistralModel): - pass + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) class Qwen2ForCausalLM(LlamaForCausalLM): diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py index 132b562f9ad..06db2e83f1f 100644 --- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -372,7 +372,7 @@ class Qwen2_5OmniTextConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code + self.sliding_window = sliding_window if self.use_sliding_window else None self.max_window_layers = max_window_layers # for backward compatibility @@ -392,6 +392,7 @@ class Qwen2_5OmniTextConfig(PretrainedConfig): if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) + if self.rope_scaling is None: self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"} diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index e973c5c64d9..84e55346533 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -68,16 +68,15 @@ else: apply_rotary_emb = None +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index ce228ef73a1..a75c0fafa40 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -27,7 +27,6 @@ from torch import nn from torch.nn import Parameter from transformers.models.llama.modeling_llama import rotate_half -from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionTransformerPretrainedModel, @@ -45,6 +44,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbeddin from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( auto_docstring, @@ -251,7 +251,7 @@ class Qwen2_5OmniAudioEncoderConfig(Qwen2AudioEncoderConfig): del self.encoder_layerdrop -class Qwen2_5OmniTextConfig(Qwen2Config): +class Qwen2_5OmniTextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration @@ -362,6 +362,23 @@ class Qwen2_5OmniTextConfig(Qwen2Config): ```""" model_type = "qwen2_5_omni_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen25OmniText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, @@ -371,11 +388,52 @@ class Qwen2_5OmniTextConfig(Qwen2Config): num_hidden_layers=28, num_attention_heads=28, num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, rope_theta=1000000.0, + rope_scaling=None, + use_sliding_window=False, sliding_window=32768, - **super_kwargs, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, ): - super().__init__(**super_kwargs) + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + if self.rope_scaling is None: self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"} diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index ccd4a6b4191..a529dcdb559 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -915,7 +915,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): router_logits=all_router_logits, ) - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Qwen2Moe + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Qwen2Moe def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1003,7 +1003,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2Moe + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Qwen2Moe def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index f5e5a08cdd4..faa0f7aabb7 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1182,7 +1182,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Qwen2VL + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Qwen2VL def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1270,7 +1270,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2VL + # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Qwen2VL def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1675,7 +1675,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): return output if return_dict else output.to_tuple() @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/qwen3/configuration_qwen3.py b/src/transformers/models/qwen3/configuration_qwen3.py index 06e527ce53f..8335c798d16 100644 --- a/src/transformers/models/qwen3/configuration_qwen3.py +++ b/src/transformers/models/qwen3/configuration_qwen3.py @@ -14,7 +14,7 @@ # limitations under the License. """Qwen3 model configuration""" -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation from ...utils import logging @@ -114,6 +114,8 @@ class Qwen3Config(PretrainedConfig): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 28): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + layer_types (`list`, *optional*): + Attention pattern for each layer. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. @@ -170,6 +172,7 @@ class Qwen3Config(PretrainedConfig): use_sliding_window=False, sliding_window=4096, max_window_layers=28, + layer_types=None, attention_dropout=0.0, **kwargs, ): @@ -180,7 +183,7 @@ class Qwen3Config(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code + self.sliding_window = sliding_window if self.use_sliding_window else None self.max_window_layers = max_window_layers # for backward compatibility @@ -203,6 +206,16 @@ class Qwen3Config(PretrainedConfig): self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 082de2f193b..48eb9489be2 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -25,10 +25,10 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -41,16 +41,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_qwen3 import Qwen3Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -190,13 +184,7 @@ class Qwen3Attention(nn.Module): ) self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window - if not ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - self.sliding_window = None + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None def forward( self, @@ -253,17 +241,13 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - if ( - config.sliding_window and config._attn_implementation != "flash_attention_2" - ): # diff with Llama is this warning - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) + self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -384,6 +368,7 @@ class Qwen3Model(Qwen3PreTrainedModel): self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types # Initialize weights and apply final processing self.post_init() @@ -443,9 +428,24 @@ class Qwen3Model(Qwen3PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds @@ -462,7 +462,7 @@ class Qwen3Model(Qwen3PreTrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -490,161 +490,6 @@ class Qwen3Model(Qwen3PreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen3Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen3Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index f914faa57fe..096a0e5b9c6 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -28,16 +28,18 @@ from ...utils import LossKwargs, logging from ..gemma.modeling_gemma import GemmaMLP from ..llama.modeling_llama import ( LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaForQuestionAnswering, - LlamaForSequenceClassification, - LlamaForTokenClassification, - LlamaRMSNorm, +) +from ..qwen2.modeling_qwen2 import ( + Qwen2DecoderLayer, + Qwen2ForCausalLM, + Qwen2ForQuestionAnswering, + Qwen2ForSequenceClassification, + Qwen2ForTokenClassification, + Qwen2Model, + Qwen2RMSNorm, apply_rotary_pos_emb, eager_attention_forward, ) -from ..mistral.modeling_mistral import MistralModel from .configuration_qwen3 import Qwen3Config @@ -46,7 +48,7 @@ logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B" -class Qwen3RMSNorm(LlamaRMSNorm): +class Qwen3RMSNorm(Qwen2RMSNorm): pass @@ -59,13 +61,7 @@ class Qwen3Attention(LlamaAttention): super().__init__(config, layer_idx) self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window - if not ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - self.sliding_window = None + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None def forward( self, @@ -118,28 +114,18 @@ class Qwen3Attention(LlamaAttention): return attn_output, attn_weights -class Qwen3DecoderLayer(LlamaDecoderLayer): - def __init__(self, config: Qwen3Config, layer_idx: int): - super().__init__() - self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) - self.mlp = Qwen3MLP(config) - if ( - config.sliding_window and config._attn_implementation != "flash_attention_2" - ): # diff with Llama is this warning - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) +class Qwen3DecoderLayer(Qwen2DecoderLayer): + pass -class Qwen3Model(MistralModel): # mistral model creates sliding window +class Qwen3Model(Qwen2Model): pass class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -class Qwen3ForCausalLM(LlamaForCausalLM): +class Qwen3ForCausalLM(Qwen2ForCausalLM): def forward( self, **super_kwargs: Unpack[KwargsForCausalLM], @@ -169,15 +155,15 @@ class Qwen3ForCausalLM(LlamaForCausalLM): return super().forward(**super_kwargs) -class Qwen3ForSequenceClassification(LlamaForSequenceClassification): +class Qwen3ForSequenceClassification(Qwen2ForSequenceClassification): pass -class Qwen3ForTokenClassification(LlamaForTokenClassification): +class Qwen3ForTokenClassification(Qwen2ForTokenClassification): pass -class Qwen3ForQuestionAnswering(LlamaForQuestionAnswering): +class Qwen3ForQuestionAnswering(Qwen2ForQuestionAnswering): pass diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 72591887a41..f349f2f3d6d 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -27,10 +27,10 @@ import torch.nn.functional as F from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -43,16 +43,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_qwen3_moe import Qwen3MoeConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -155,13 +149,7 @@ class Qwen3MoeAttention(nn.Module): ) self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window - if not ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - self.sliding_window = None + self.sliding_window = getattr(config, "sliding_window", None) def forward( self, @@ -535,8 +523,14 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -602,161 +596,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): router_logits=all_router_logits, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen3Moe. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen3MoeConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen3MoeConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py index 3c673b6079b..7170164f515 100644 --- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -42,7 +42,9 @@ logger = logging.get_logger(__name__) class Qwen3MoeAttention(Qwen3Attention): # This is the main diff with qwen2Moe! - pass + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.sliding_window = getattr(config, "sliding_window", None) class Qwen3MoeMLP(nn.Module): diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 001e0b8deab..be9af304d51 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -793,7 +793,7 @@ class StableLmModel(StableLmPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -863,7 +863,7 @@ class StableLmModel(StableLmPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 1e1cfdde775..6d79e0f0f74 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -30,9 +30,9 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -44,16 +44,10 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_starcoder2 import Starcoder2Config -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -404,8 +398,14 @@ class Starcoder2Model(Starcoder2PreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -454,161 +454,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Starcoder2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Starcoder2Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Starcoder2Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index ad1ff328755..fd5840e40a7 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -27,6 +27,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -212,8 +213,14 @@ class Starcoder2Model(MistralModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index a2253026ea7..d2db7781d2f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1123,7 +1123,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): router_probs=all_router_probs, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1193,7 +1193,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index da8fe7db171..466b725bce2 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1195,7 +1195,7 @@ class T5Stack(T5PreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1265,7 +1265,7 @@ class T5Stack(T5PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index b1049851fbf..478baddc505 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1360,7 +1360,7 @@ class UdopStack(UdopPreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1430,7 +1430,7 @@ class UdopStack(UdopPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 512cd159f9a..e45d63aba7d 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -836,7 +836,7 @@ class UMT5Stack(UMT5PreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -906,7 +906,7 @@ class UMT5Stack(UMT5PreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 535c796e8ab..937d44a7817 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -648,7 +648,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi return model_inputs @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 51c4f032f4a..375278b5d85 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -460,60 +460,5 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) return model_inputs - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - __all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"] diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 2bf387de82a..e42c7ce3060 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1230,7 +1230,7 @@ class WhisperDecoder(WhisperPreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1300,7 +1300,7 @@ class WhisperDecoder(WhisperPreTrainedModel): return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b7291969ebb..e865237485a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -538,6 +538,13 @@ def convert_and_export_with_cache(*args, **kwargs): requires_backends(convert_and_export_with_cache, ["torch"]) +class AttentionMaskInterface(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + def model_addition_debugger_context(*args, **kwargs): requires_backends(model_addition_debugger_context, ["torch"]) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index e1530da235a..8fc2a857dd0 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -26,7 +26,7 @@ from contextlib import ExitStack, contextmanager from dataclasses import fields, is_dataclass from enum import Enum from functools import partial, wraps -from typing import Any, ContextManager, Optional, TypedDict +from typing import Any, Callable, ContextManager, List, Optional, TypedDict import numpy as np from packaging import version @@ -977,3 +977,44 @@ def can_return_tuple(func): return output return wrapper + + +class GeneralInterface(MutableMapping): + """ + Dict-like object keeping track of a class-wide mapping, as well as a local one. Allows to have library-wide + modifications though the class mapping, as well as local modifications in a single file with the local mapping. + """ + + # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if + # a new instance is created (in order to locally override a given function) + _global_mapping = {} + + def __init__(self): + self._local_mapping = {} + + def __getitem__(self, key): + # First check if instance has a local override + if key in self._local_mapping: + return self._local_mapping[key] + return self._global_mapping[key] + + def __setitem__(self, key, value): + # Allow local update of the default functions without impacting other instances + self._local_mapping.update({key: value}) + + def __delitem__(self, key): + del self._local_mapping[key] + + def __iter__(self): + # Ensure we use all keys, with the overwritten ones on top + return iter({**self._global_mapping, **self._local_mapping}) + + def __len__(self): + return len(self._global_mapping.keys() | self._local_mapping.keys()) + + @classmethod + def register(cls, key: str, value: Callable): + cls._global_mapping.update({key: value}) + + def valid_keys(self) -> List[str]: + return list(self.keys()) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 0a4cff3c7ff..373b3ffbe22 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -13,7 +13,6 @@ # limitations under the License. """Testing suite for the PyTorch Gemma model.""" -import tempfile import unittest import pytest @@ -23,7 +22,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_to from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( cleanup, - is_flaky, require_bitsandbytes, require_flash_attn, require_read_token, @@ -303,39 +301,45 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Gemma flash attention does not support right padding") + @require_torch_sdpa + @require_torch_accelerator + def test_sdpa_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(reason="Model does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) + + model.config._attn_implementation = "sdpa" + states_sdpa = model(dummy_input, output_hidden_states=True).hidden_states[-1] + + model.config._attn_implementation = "eager" + states_eager = model(dummy_input, output_hidden_states=True).hidden_states[-1] + + torch.testing.assert_close(states_sdpa, states_eager, atol=1e-5, rtol=1e-5) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test - @is_flaky() - @slow def test_flash_attn_2_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(reason="Model does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) + model = model_class(config).to(device=torch_device, dtype=torch.float16) + dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" - ) - model_fa.to(torch_device) + model.config._attn_implementation = "flash_attention_2" + states_sdpa = model(dummy_input, output_hidden_states=True).hidden_states[1] - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") - model.to(torch_device) + model.config._attn_implementation = "eager" + states_eager = model(dummy_input, output_hidden_states=True).hidden_states[1] - dummy_input = inputs_dict[model_class.main_input_name] - dummy_input = dummy_input.to(torch_device) - outputs = model(dummy_input, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, output_hidden_states=True) - - logits = outputs.hidden_states[-1] - logits_fa = outputs_fa.hidden_states[-1] - - # gemma flash attention 2 needs a high tolerance - assert torch.allclose(logits_fa, logits, atol=3e-3) + # Here we use higher tolerance and the output of the 2nd layer because otherwise small diffs add-up + torch.testing.assert_close(states_sdpa, states_eager, atol=1e-3, rtol=1e-3) @slow diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 6ee4e8f2327..2561875f387 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -154,6 +154,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): def test_eager_matches_fa2_generate(self): pass + @unittest.skip("Gemma2 eager/FA2 attention outputs are expected to be different") + def test_flash_attn_2_equivalence(self): + pass + @slow @require_torch_accelerator diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index efd6f4095cd..87350c04895 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -25,7 +25,6 @@ from transformers import ( AutoTokenizer, Gemma3Config, Gemma3TextConfig, - GenerationConfig, is_torch_available, ) from transformers.testing_utils import ( @@ -635,46 +634,6 @@ class Gemma3IntegrationTest(unittest.TestCase): EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip self.assertEqual(output_text, EXPECTED_COMPLETIONS) - def test_generation_beyond_sliding_window_with_generation_config(self): - """ - Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 - -- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. - """ - model_id = "google/gemma-3-1b-it" - attn_implementation = "sdpa" - - input_text = [ - "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens - "A list of colors: red, blue", # This will almost all be padding tokens - ] - tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") - inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) - - model = AutoModelForCausalLM.from_pretrained( - model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16 - ).to(torch_device) - - # Make sure prefill is larger than sliding window - input_size = inputs.input_ids.shape[-1] - self.assertGreater(input_size, model.config.sliding_window) - - generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5) - out = model.generate(**inputs, generation_config=generation_config) - - out = model.generate(**inputs, generation_config=generation_config, do_sample=False)[:, input_size:] - output_text = tokenizer.batch_decode(out) - EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip - self.assertEqual(output_text, EXPECTED_COMPLETIONS) - - # Generation works beyond sliding window - self.assertGreater(out.shape[1], model.config.sliding_window) - self.assertEqual(out.shape[1], input_size + 5) - - # Note: Auto-inheritance only works for models saved starting from 4.50.0 - model.generation_config.transformers_version = "4.49.0" - with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache - out = model.generate(**inputs, generation_config=generation_config) - def test_export_text_only_with_hybrid_cache(self): if not is_torch_greater_or_equal("2.6.0"): self.skipTest(reason="This test requires torch >= 2.6 to run.") diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index bc62c527e29..95cb5d2785b 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -26,6 +26,7 @@ from transformers import ( is_vision_available, ) from transformers.testing_utils import ( + is_flaky, require_torch, torch_device, ) @@ -381,3 +382,8 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe @unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache") def test_generate_with_static_cache(self): pass + + @pytest.mark.generate + @is_flaky + def test_generate_compile_model_forward(self): + super().test_generate_compile_model_forward() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 92f3daf8b73..2f16a86d80b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1172,25 +1172,10 @@ class ModelTesterMixin: traced_model = torch.jit.trace(model, example_inputs, check_trace=False) else: main_input = inputs[main_input_name] - - if model.config._attn_implementation == "sdpa": - trace_input = {main_input_name: main_input} - - if "attention_mask" in inputs: - trace_input["attention_mask"] = inputs["attention_mask"] - else: - self.skipTest(reason="testing SDPA without attention_mask is not supported") - - outputs = model(main_input, attention_mask=inputs["attention_mask"]) - if any(isinstance(x, Cache) for x in outputs): - continue - # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. - traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) - else: - outputs = model(main_input) - if any(isinstance(x, Cache) for x in outputs): - continue - traced_model = torch.jit.trace(model, (main_input,)) + outputs = model(main_input) + if any(isinstance(x, Cache) for x in outputs): + continue + traced_model = torch.jit.trace(model, (main_input,)) except RuntimeError: self.fail("Couldn't trace module.") @@ -3907,6 +3892,11 @@ class ModelTesterMixin: self.skipTest( "DBRX (transformers==4.40) requires a modification to support dynamic shapes with compile." ) + if getattr(config, "cache_implementation", None) == "hybrid": + self.skipTest( + "Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` " + "is a forbidden call." + ) model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: @@ -4346,18 +4336,31 @@ class ModelTesterMixin: config.sliding_window = sliding_window inputs["attention_mask"] = torch.ones(batch_size, seq_len).to(torch.int64).to(torch_device) for model_class in self.all_model_classes: - model = model_class(config).to(torch_device) - model.eval() - # Set sliding window to `True` and check that all tokens beyond window size are masked - model.config.use_sliding_window = True + config.use_sliding_window = True + config_dict = config.to_diff_dict() + if hasattr(config, "layer_types"): + del config_dict["layer_types"] + new_config = config.__class__(**config_dict) + model = model_class(new_config).to(torch_device) + model.eval() + layer_types = getattr(model.config, "layer_types", ["sliding_attention"] * config.num_hidden_layers) attentions = model(**inputs, output_attentions=True).attentions - for layer_attention in attentions: - self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item()) + for layer_attention, layer_type in zip(attentions, layer_types): + if layer_type == "sliding_attention": + self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item()) + else: + self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item()) # Set sliding window to `False` while keeping `sliding_window=3` # Check that all tokens beyond window size are not masked - model.config.use_sliding_window = False + config.use_sliding_window = False + config_dict = config.to_diff_dict() + if hasattr(config, "layer_types"): + del config_dict["layer_types"] + new_config = config.__class__(**config_dict) + model = model_class(new_config).to(torch_device) + model.eval() attentions_not_sliding = model(**inputs, output_attentions=True).attentions for layer_attention in attentions_not_sliding: self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item()) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index ea56763a65f..3d1fa7a4474 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -55,6 +55,7 @@ if is_torch_available(): convert_and_export_with_cache, pipeline, ) + from transformers.integrations.executorch import export_with_dynamic_cache TEST_CACHE_IMPLEMENTATIONS = [ @@ -593,22 +594,11 @@ class CacheExportIntegrationTest(unittest.TestCase): attention_mask = inputs.attention_mask input_ids = inputs.input_ids - past_key_values = DynamicCache() - ep = torch.export.export( - model, - (), - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "use_cache": True, - }, - strict=False, - ) + ep = export_with_dynamic_cache(model, input_ids, attention_mask) res = ep.module()( input_ids=input_ids, attention_mask=attention_mask, - past_key_values=past_key_values, + past_key_values=DynamicCache(), use_cache=True, ) self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index cea9ed693d5..238da721bf6 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -44,9 +44,11 @@ SPECIAL_CASES_TO_ALLOW = { "expert_layer_offset", "expert_layer_period", ], - "Qwen2Config": ["use_sliding_window"], + "Qwen2Config": ["use_sliding_window", "max_window_layers"], "Qwen2MoeConfig": ["use_sliding_window"], "Qwen2VLConfig": ["use_sliding_window"], + "Qwen3Config": ["max_window_layers", "use_sliding_window"], # now use `layer_types` instead + "Qwen3MoeConfig": ["max_window_layers", "use_sliding_window"], # `cache_implementation` should be in the default generation config, but we don't yet support per-model # generation configs (TODO joao) "Gemma2Config": ["tie_word_embeddings", "cache_implementation"], @@ -263,6 +265,7 @@ SPECIAL_CASES_TO_ALLOW = { "router_aux_loss_coef", "router_jitter_noise", "cache_implementation", + "attention_chunk_size", ], "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"], }