🚨🚨[core] Completely rewrite the masking logic for all attentions (#37866)

* start

* start having a clean 4d mask primitive

* Update mask_utils.py

* Update mask_utils.py

* switch name

* Update masking_utils.py

* add a new AttentionMask tensor class

* fix import

* nits

* fixes

* use full and quandrants

* general sdpa mask for all caches

* style

* start some tests

* tests with sliding, chunked

* add styling

* test hybrid

* Update masking_utils.py

* small temp fixes

* Update modeling_gemma2.py

* compile compatible

* Update masking_utils.py

* improve

* start making it more general

* Update masking_utils.py

* generate

* make it work with flex style primitives!

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* improve

* Update cache_utils.py

* Update masking_utils.py

* simplify - starting to look good!

* Update masking_utils.py

* name

* Update masking_utils.py

* style

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* small fix for flex

* flex compile

* FA2

* Update masking_utils.py

* Escape for TGI/vLLM!

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* General case without cache

* rename

* full test on llama4

* small fix for FA2 guard with chunk

* Update modeling_gemma2.py

* post rebase cleanup

* FA2 supports static cache!

* Update modeling_flash_attention_utils.py

* Update flex_attention.py

* Update masking_utils.py

* Update masking_utils.py

* Update utils.py

* override for export

* Update executorch.py

* Update executorch.py

* Update executorch.py

* Update executorch.py

* Update masking_utils.py

* Update masking_utils.py

* output attentions

* style

* Update masking_utils.py

* Update executorch.py

* Add doicstring

* Add license and put mask visualizer at the end

* Update test_modeling_common.py

* fix broken test

* Update test_modeling_gemma.py

* Update test_modeling_gemma2.py

* Use fullgraph=False with FA2

* Update utils.py

* change name

* Update masking_utils.py

* improve doc

* change name

* Update modeling_attn_mask_utils.py

* more explicit logic based on model's property

* pattern in config

* extend

* fixes

* make it better

* generalize to other test models

* fix

* Update masking_utils.py

* fix

* do not check mask equivalence if layer types are different

* executorch

* Update modeling_gemma2.py

* Update masking_utils.py

* use layer_idx instead

* adjust

* Update masking_utils.py

* test

* fix imports

* Update modeling_gemma2.py

* other test models

* Update modeling_llama4.py

* Update masking_utils.py

* improve

* simplify

* Update masking_utils.py

* typos

* typo

* fix

* Update masking_utils.py

* default DynamicCache

* remove default cache

* simplify

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* simplify

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* export

* Update executorch.py

* Update executorch.py

* Update flex_attention.py

* Update executorch.py

* upstream to modular gemma 1 & 2

* Update modular_mistral.py

* switch names

* use dict

* put it in the Layer directly

* update copy model source for mask functions

* apply so many modular (hopefully 1 shot)

* use explicite dicts for make style happy

* protect import

* check docstring

* better default in hybrid caches

* qwens

* Update modular_qwen2.py

* simplify core logic!

* Update executorch.py

* qwen3 moe

* Update masking_utils.py

* Update masking_utils.py

* simplify a lot sdpa causal skip

* Update masking_utils.py

* post-rebase

* gemma3 finally

* style

* check it before

* gemma3

* More general with newer torch

* align gemma3

* Update utils.py

* Update utils.py

* Update masking_utils.py

* Update test_modeling_common.py

* Update flex_attention.py

* Update flex_attention.py

* Update flex_attention.py

* test

* executorch

* Update test_modeling_common.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update executorch.py

* Update test_modeling_common.py

* fix copies

* device

* sdpa can be used without mask -> pass the torchscript tests in this case

* Use enum for check

* revert enum and add check instead

* remove broken test

* cohere2

* some doc & reorganize the Interface

* Update tensor_parallel.py

* Update tensor_parallel.py

* doc and dummy

* Update test_modeling_paligemma2.py

* Update modeling_falcon_h1.py

* Update masking_utils.py

* executorch patch

* style

* CIs

* use register in executorch

* final comments!

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Cyril Vallez 2025-05-22 11:38:26 +02:00 committed by GitHub
parent f8630c778c
commit 163138a911
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
129 changed files with 2976 additions and 6800 deletions

View File

@ -125,4 +125,44 @@ would expect from a usual Python dictionary:
# You can also globally `register` a new function directly on it # You can also globally `register` a new function directly on it
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func) >>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
``` ```
## Attention Mask Interface
Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
the `AttentionInterface`:
```python
from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask
import torch
def my_new_sdpa_mask(*args, **kwargs):
print("I just entered the attention mask computation")
return sdpa_mask(*args, **kwargs)
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
```
The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
and `attention_mask=None` will be passed along to the Attention layers.
The default signature of the attention mask functions is the following:
```python
def custom_attention_mask(
batch_size: int, # required arg
cache_position: torch.Tensor, # required arg
kv_length: int, # required arg
kv_offset: int = 0, # required arg
mask_function: Callable = causal_mask_function, # required arg
attention_mask: Optional[torch.Tensor] = None, # required arg
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
) -> Optional[torch.Tensor]:
```
It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.
If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).

View File

@ -29,6 +29,11 @@ Most of those are only useful if you are studying the code of the models in the
[[autodoc]] AttentionInterface [[autodoc]] AttentionInterface
- register - register
## Attention Mask Functions
[[autodoc]] AttentionMaskInterface
- register
## Rotary Position Embedding Functions ## Rotary Position Embedding Functions
[[autodoc]] dynamic_rope_update [[autodoc]] dynamic_rope_update

View File

@ -445,6 +445,7 @@ else:
_import_structure["modeling_outputs"] = [] _import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"] _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"] _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
_import_structure["masking_utils"] = ["AttentionMaskInterface"]
_import_structure["optimization"] = [ _import_structure["optimization"] = [
"Adafactor", "Adafactor",
"get_constant_schedule", "get_constant_schedule",
@ -914,6 +915,7 @@ if TYPE_CHECKING:
TorchExportableModuleWithStaticCache, TorchExportableModuleWithStaticCache,
convert_and_export_with_cache, convert_and_export_with_cache,
) )
from .masking_utils import AttentionMaskInterface
from .model_debugging_utils import ( from .model_debugging_utils import (
model_addition_debugger_context, model_addition_debugger_context,
) )

View File

@ -196,6 +196,18 @@ class Cache:
else: else:
return None return None
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
query_length = cache_position.shape[0]
past_seen_tokens = self.get_seq_length()
kv_length = query_length + past_seen_tokens
return kv_length, 0
@dataclass @dataclass
class CacheConfig: class CacheConfig:
@ -1084,8 +1096,6 @@ class SinkCache(Cache):
``` ```
""" """
is_sliding = True
def __init__(self, window_length: int, num_sink_tokens: int) -> None: def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__() super().__init__()
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
@ -1390,6 +1400,16 @@ class StaticCache(Cache):
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
kv_length = self.get_max_cache_shape()
return kv_length, 0
class SlidingWindowCache(StaticCache): class SlidingWindowCache(StaticCache):
""" """
@ -1446,7 +1466,6 @@ class SlidingWindowCache(StaticCache):
``` ```
""" """
is_sliding = True
is_compileable = True is_compileable = True
def __init__( def __init__(
@ -1465,6 +1484,7 @@ class SlidingWindowCache(StaticCache):
"config and it's not set to None." "config and it's not set to None."
) )
max_cache_len = min(config.sliding_window, max_cache_len) max_cache_len = min(config.sliding_window, max_cache_len)
self.sliding_window = config.sliding_window
super().__init__( super().__init__(
config=config, config=config,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
@ -1509,6 +1529,21 @@ class SlidingWindowCache(StaticCache):
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
# torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor
kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
kv_length = max(query_length, self.get_max_cache_shape())
return kv_length, kv_offset
class EncoderDecoderCache(Cache): class EncoderDecoderCache(Cache):
""" """
@ -1761,12 +1796,17 @@ class HybridCache(Cache):
else config.num_key_value_heads else config.num_key_value_heads
) )
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC # If the attribute does not exist in the config, fallback to a simple StaticCache
self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)] if hasattr(config, "layer_types"):
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
else:
self.is_sliding = [False] * config.num_hidden_layers
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
self.sliding_window = min(config.sliding_window, max_cache_len)
device = torch.device(device) if device is not None else None device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
if layer_device_map is not None: if layer_device_map is not None:
@ -1775,7 +1815,7 @@ class HybridCache(Cache):
layer_device = device layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. # breaks when updating the cache.
cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
@ -1796,7 +1836,7 @@ class HybridCache(Cache):
if cache_position is None: if cache_position is None:
raise ValueError("`cache_position` must be provided for HybridCache.") raise ValueError("`cache_position` must be provided for HybridCache.")
is_sliding_layer = self.is_sliding_list[layer_idx] is_sliding_layer = self.is_sliding[layer_idx]
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2) # when the cache is initialized in the forward pass (e.g. Gemma2)
@ -1843,6 +1883,26 @@ class HybridCache(Cache):
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
if self.is_sliding[layer_idx]:
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
local_mask_kv_length = max(query_length, self.sliding_window)
return local_mask_kv_length, local_mask_kv_offset
full_mask_kv_offset = 0
full_mask_kv_length = self.get_max_cache_shape()
return full_mask_kv_length, full_mask_kv_offset
class HybridChunkedCache(Cache): class HybridChunkedCache(Cache):
""" """
@ -1912,11 +1972,11 @@ class HybridChunkedCache(Cache):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self._dtype = dtype self._dtype = dtype
if hasattr(config.get_text_config(), "no_rope_layers"): # If the attribute does not exist in the config, fallback to a simple StaticCache
self.is_sliding = config.no_rope_layers if hasattr(config, "layer_types"):
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
else: else:
layer_switch = getattr(config, "sliding_window_pattern", 2) self.is_sliding = [False] * config.num_hidden_layers
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
@ -1999,11 +2059,7 @@ class HybridChunkedCache(Cache):
key_states = key_states.to(k_out.dtype) key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype) value_states = value_states.to(v_out.dtype)
if self.is_sliding[layer_idx]: update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update
update_fn = self._sliding_update
else:
update_fn = self._static_update
return update_fn( return update_fn(
cache_position, cache_position,
layer_idx, layer_idx,
@ -2038,6 +2094,37 @@ class HybridChunkedCache(Cache):
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
if self.is_sliding[layer_idx]:
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is the true general case for any Cache using local attention (sliding or chunked)
if first_cache_position >= self.sliding_window:
# Here the Cache is already full
local_mask_kv_length = self.sliding_window + query_length - 1
elif (
first_cache_position < self.sliding_window
and first_cache_position + query_length > self.sliding_window
):
# Here the Cache becomes full with the new input
local_mask_kv_length = first_cache_position + query_length
else:
# Here the Cache is still smaller than the local size, but we return the local size as it's static
local_mask_kv_length = self.sliding_window
return local_mask_kv_length, local_mask_kv_offset
full_mask_kv_offset = 0
full_mask_kv_length = self.get_max_cache_shape()
return full_mask_kv_length, full_mask_kv_offset
class OffloadedHybridCache(HybridChunkedCache): class OffloadedHybridCache(HybridChunkedCache):
def __init__( def __init__(

View File

@ -1209,3 +1209,16 @@ if PretrainedConfig.push_to_hub.__doc__ is not None:
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format( PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
object="config", object_class="AutoConfig", object_files="configuration file" object="config", object_class="AutoConfig", object_files="configuration file"
) )
ALLOWED_LAYER_TYPES = (
"full_attention",
"sliding_attention",
"chunked_attention",
)
def layer_type_validation(layer_types: list[str]):
"""Check that each entry in `layer_types` are allowed."""
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")

View File

@ -46,6 +46,7 @@ from ..dynamic_module_utils import (
) )
from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module from ..integrations.fsdp import is_fsdp_managed_module
from ..masking_utils import create_masks_for_generate
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..pytorch_utils import isin_mps_friendly from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie from ..tokenization_utils import ExtensionsTrie
@ -74,6 +75,7 @@ from .candidate_generator import (
from .configuration_utils import ( from .configuration_utils import (
NEED_SETUP_CACHE_CLASSES_MAPPING, NEED_SETUP_CACHE_CLASSES_MAPPING,
QUANT_BACKEND_CLASSES_MAPPING, QUANT_BACKEND_CLASSES_MAPPING,
CompileConfig,
GenerationConfig, GenerationConfig,
GenerationMode, GenerationMode,
) )
@ -649,12 +651,22 @@ class GenerationMixin:
causal_mask_creation_function = getattr( causal_mask_creation_function = getattr(
decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
) )
# If it's not defined, it means the model uses the new general mask API
if causal_mask_creation_function is None: # can't be found if causal_mask_creation_function is None: # can't be found
logger.warning_once( output_attentions = kwargs.get("output_attentions", False)
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " token_type_ids = getattr(model_input, "token_type_ids", None)
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " # Some models may overwrite the general one
"writing code, see Llama for an example implementation. If you're a user, please report this " causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
"issue on GitHub." attention_mask = causal_mask_creation_function(
config=self.config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
token_type_ids=token_type_ids,
) )
else: else:
attention_mask = causal_mask_creation_function( attention_mask = causal_mask_creation_function(
@ -3533,6 +3545,19 @@ class GenerationMixin:
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward: if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0" os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2" and getattr(
model_kwargs.get("past_key_values"), "is_compileable", False
):
if generation_config.compile_config is None:
generation_config.compile_config = CompileConfig(fullgraph=False)
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
elif generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config) model_forward = self.get_compiled_call(generation_config.compile_config)
if generation_config.prefill_chunk_size is not None: if generation_config.prefill_chunk_size is not None:

View File

@ -11,18 +11,21 @@
# specific language governing permissions and limitations under the License. # specific language governing permissions and limitations under the License.
import logging import logging
from typing import Optional from contextlib import contextmanager
from typing import Callable, Optional
import torch import torch
from transformers.generation.configuration_utils import GenerationConfig from ..cache_utils import DynamicCache, HybridCache, StaticCache
from ..generation.configuration_utils import GenerationConfig
from ..utils.import_utils import is_torch_available from ..masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_ignore_causal_mask_sdpa,
if is_torch_available(): _is_torch_greater_or_equal_than_2_5,
from transformers import HybridCache, PreTrainedModel, StaticCache prepare_padding_mask,
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3 )
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
@ -54,19 +57,13 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
if not hasattr(model.config, "use_cache") or model.config.use_cache is False: if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
raise ValueError("The model must have caching enabled to be performant.") raise ValueError("The model must have caching enabled to be performant.")
if not hasattr(model.config, "cache_implementation"): if not hasattr(model.config, "layer_types"):
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will # If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
# be used by default, so export will use `StaticCache` by default. # export will use `StaticCache` by default.
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.") logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
self.model = TorchExportableModuleWithStaticCache(model) self.model = TorchExportableModuleWithStaticCache(model)
else: else:
if model.config.cache_implementation == "hybrid": self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
else:
raise ValueError(
f"Unsupported cache implementation: {model.config.cache_implementation}. "
"Please use `hybrid` or `static`."
)
def forward( def forward(
self, self,
@ -105,16 +102,23 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
strict(`Optional[bool]`): strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`. Flag to instruct `torch.export` to use `torchdynamo`.
""" """
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
self.model.model.config._attn_implementation = "sdpa_without_vmap"
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
return torch.export.export( with patch_mask_interface():
self.model, exported_program = torch.export.export(
args=(example_input_ids, example_cache_position), self.model,
kwargs={}, args=(example_input_ids, example_cache_position),
dynamic_shapes=dynamic_shapes, kwargs={},
strict=strict if strict is not None else True, dynamic_shapes=dynamic_shapes,
) strict=strict if strict is not None else True,
)
return exported_program
@staticmethod @staticmethod
def generate( def generate(
@ -281,17 +285,6 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal:
causal_mask = torch.tril(
torch.ones(
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
)
self.register_buffer("mask", causal_mask, persistent=False)
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
""" """
Forward pass of the module, which is compatible with the ExecuTorch runtime. Forward pass of the module, which is compatible with the ExecuTorch runtime.
@ -314,13 +307,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
""" """
_, seqlen = input_ids.shape _, seqlen = input_ids.shape
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache past_key_values = self.static_cache
outs = self.model( outs = self.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attn_mask, attention_mask=None,
position_ids=position_ids, position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -445,18 +437,15 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
Returns: Returns:
torch.Tensor: Logits output from the model. torch.Tensor: Logits output from the model.
""" """
batch_size, seq_len = input_ids.shape batch_size = input_ids.shape[0]
# Generate position_ids from cache_position # Generate position_ids from cache_position
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
# Create attention mask (always ones for token-by-token generation)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
# Forward pass with the model # Forward pass with the model
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=None,
position_ids=position_ids, position_ids=position_ids,
past_key_values=self.cache, past_key_values=self.cache,
use_cache=True, use_cache=True,
@ -467,6 +456,24 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
return outputs.logits return outputs.logits
@contextmanager
def patch_mask_interface():
"""
Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail
with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles:
Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`.
Note that this seem to be an issue only for python<3.11.
"""
import transformers
original = transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = ALL_MASK_ATTENTION_FUNCTIONS._global_mapping
try:
yield
finally:
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = original
def convert_and_export_with_cache( def convert_and_export_with_cache(
model: PreTrainedModel, model: PreTrainedModel,
example_input_ids: Optional[torch.Tensor] = None, example_input_ids: Optional[torch.Tensor] = None,
@ -493,6 +500,11 @@ def convert_and_export_with_cache(
import torch.export._trace import torch.export._trace
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"
with torch.no_grad(): with torch.no_grad():
# TODO: The default inputs only work for text models. We need to add support for vision/audio models. # TODO: The default inputs only work for text models. We need to add support for vision/audio models.
example_input_ids = ( example_input_ids = (
@ -503,13 +515,14 @@ def convert_and_export_with_cache(
) )
if is_torch_greater_or_equal("2.6.0"): if is_torch_greater_or_equal("2.6.0"):
exported_program = torch.export.export( with patch_mask_interface():
TorchExportableModuleWithStaticCache(model), exported_program = torch.export.export(
args=(example_input_ids, example_cache_position), TorchExportableModuleWithStaticCache(model),
kwargs={}, args=(example_input_ids, example_cache_position),
dynamic_shapes=dynamic_shapes, kwargs={},
strict=strict if strict is not None else True, dynamic_shapes=dynamic_shapes,
) strict=strict if strict is not None else True,
)
else: else:
if dynamic_shapes is not None: if dynamic_shapes is not None:
logging.warning( logging.warning(
@ -521,13 +534,14 @@ def convert_and_export_with_cache(
# #
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
exported_program = torch.export._trace._export( with patch_mask_interface():
TorchExportableModuleWithStaticCache(model), exported_program = torch.export._trace._export(
args=(example_input_ids,), TorchExportableModuleWithStaticCache(model),
kwargs={"cache_position": example_cache_position}, args=(example_input_ids,),
pre_dispatch=False, kwargs={"cache_position": example_cache_position},
strict=True, pre_dispatch=False,
) strict=True,
)
return exported_program return exported_program
@ -620,9 +634,10 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
# Export the encoder # Export the encoder
with torch.no_grad(): with torch.no_grad():
exported_encoder = torch.export.export( with patch_mask_interface():
wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True exported_encoder = torch.export.export(
) wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
)
return exported_encoder return exported_encoder
@ -642,16 +657,17 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
# Export the decoder # Export the decoder
with torch.no_grad(): with torch.no_grad():
exported_decoder = torch.export.export( with patch_mask_interface():
wrapped_decoder, exported_decoder = torch.export.export(
(decoder_input_ids, encoder_hidden_states, cache_position), wrapped_decoder,
dynamic_shapes={ (decoder_input_ids, encoder_hidden_states, cache_position),
"decoder_input_ids": None, dynamic_shapes={
"encoder_hidden_states": {1: encoder_seq_len_dim}, "decoder_input_ids": None,
"cache_position": None, "encoder_hidden_states": {1: encoder_seq_len_dim},
}, "cache_position": None,
strict=True, },
) strict=True,
)
return exported_decoder return exported_decoder
@ -706,3 +722,131 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
break break
return generated_ids return generated_ids
def export_with_dynamic_cache(
model: PreTrainedModel,
example_input_ids: Optional[torch.Tensor] = None,
example_attention_mask: Optional[torch.Tensor] = None,
):
"""
Export a model with DynamicCache using `torch.export`, ensuring the exported model is compatible with `ExecuTorch`.
Args:
model (`PreTrainedModel`): The pretrained model to be exported.
example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
example_attention_mask (`Optional[torch.Tensor]`): Example attention mask used by `torch.export`.
Returns:
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
"""
if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"
with torch.no_grad():
exported_program = torch.export.export(
model,
(),
{
"input_ids": example_input_ids,
"attention_mask": example_attention_mask,
"past_key_values": DynamicCache(),
"use_cache": True,
},
strict=False,
)
return exported_program
def sdpa_mask_without_vmap(
batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Optional[Callable] = None,
attention_mask: Optional[torch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_torch_fix: bool = True,
**kwargs,
) -> Optional[torch.Tensor]:
"""
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
the element should take part in the attention computation, and False that it should not.
This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
kv_offset (`int`, optional):
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
local_size (`int`, optional):
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
to try to skip mask creation if possible.
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`torch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
"""
q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
return None
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
reshaped_cache_position = cache_position.view(-1, 1)
# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
# the config through kwargs anyway, so it allows to rely on it
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
# but this is more efficient
sliding_window = getattr(kwargs["config"], "sliding_window", None)
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
if sliding_window is not None and chunk_size is not None:
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
# Simplest and most efficient way to obtain a causal mask
causal_mask = kv_arange <= reshaped_cache_position
# If using sliding window, add the sliding mask
if sliding_window is not None:
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
causal_mask *= sliding_mask_overlay
# If using chunk attention, add the chunked mask
elif chunk_size is not None:
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
causal_mask *= chunked_mask_overlay
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
if padding_mask is not None:
causal_mask = causal_mask * padding_mask[:, None, None, :]
# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask

View File

@ -32,14 +32,12 @@ import torch
from packaging import version from packaging import version
from ..utils import is_torch_flex_attn_available from ..utils import is_torch_flex_attn_available
from ..utils.import_utils import _torch_version from ..utils.import_utils import _torch_version, is_torchdynamo_compiling
if is_torch_flex_attn_available(): if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask, flex_attention from torch.nn.attention.flex_attention import BlockMask, flex_attention
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import create_block_mask as create_block_causal_mask_flex
create_block_mask as create_block_causal_mask_flex,
)
class WrappedFlexAttention: class WrappedFlexAttention:
@ -79,6 +77,24 @@ class WrappedFlexAttention:
return self._compiled_flex_attention return self._compiled_flex_attention
def compile_friendly_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
training=False,
**kwargs,
) -> torch.Tensor:
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
# Do not use compiled version if already compiling forward (it raises issues)
flex_attention_compiled = WrappedFlexAttention(training)() if not is_torchdynamo_compiling() else flex_attention
return flex_attention_compiled(
query,
key,
value,
**kwargs,
)
Offset = Union[torch.Tensor, int] Offset = Union[torch.Tensor, int]
@ -90,6 +106,10 @@ def make_flex_block_causal_mask(
offsets: Optional[Tuple[Offset, Offset]] = None, offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask": ) -> "BlockMask":
""" """
IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`,
and will be removed in a future version without warnings. New code should not use it. It is only kept here
for BC for now, while models using it are being patched accordingly.
Create a block causal document mask for a batch of sequences, both packed and unpacked. Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal The resultant BlockMask is a compressed representation of the full block causal
@ -133,7 +153,6 @@ def make_flex_block_causal_mask(
""" """
Defines the logic of a block causal mask by combining both a standard causal mask Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask. and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration. for an illustration.
""" """
@ -174,24 +193,6 @@ def make_flex_block_causal_mask(
) )
@torch.compiler.disable(recursive=False)
def compile_friendly_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
training=False,
**kwargs,
) -> torch.Tensor:
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
flex_attention_compiled = WrappedFlexAttention(training)()
return flex_attention_compiled(
query,
key,
value,
**kwargs,
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
""" """
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,

View File

@ -16,15 +16,15 @@ from __future__ import annotations
import operator import operator
import os import os
import re import re
from collections.abc import MutableMapping
from functools import partial, reduce from functools import partial, reduce
from typing import Callable, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from ..utils import is_torch_greater_or_equal, logging from ..utils import is_torch_greater_or_equal, logging
from ..utils.generic import GeneralInterface
ALL_LAYERNORM_LAYERS = [nn.LayerNorm] ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
@ -720,20 +720,11 @@ class SequenceParallel(TensorParallelLayer):
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
class ParallelInterface(MutableMapping): class ParallelInterface(GeneralInterface):
"""
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
"""
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function) # a new instance is created (in order to locally override a given entry)
_global_mapping = (
def __init__(self): {
self._local_mapping = {}
ParallelInterface._global_mapping = {
"colwise": ColwiseParallel(), "colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(), "rowwise": RowwiseParallel(),
"colwise_rep": ColwiseParallel(output_layouts=Replicate()), "colwise_rep": ColwiseParallel(output_layouts=Replicate()),
@ -746,41 +737,12 @@ class ParallelInterface(MutableMapping):
"sequence_parallel": SequenceParallel(), "sequence_parallel": SequenceParallel(),
"replicate": ReplicateParallel(), "replicate": ReplicateParallel(),
} }
if is_torch_greater_or_equal("2.5") and _torch_distributed_available
def __getitem__(self, key): else {}
# First check if instance has a local override )
if key in self._local_mapping:
return self._local_mapping[key]
return self._global_mapping[key]
def __setitem__(self, key, value):
# Allow local update of the default functions without impacting other instances
self._local_mapping.update({key: value})
def __delitem__(self, key):
del self._local_mapping[key]
def __iter__(self):
# Ensure we use all keys, with the overwritten ones on top
return iter({**self._global_mapping, **self._local_mapping})
def __len__(self):
return len(self._global_mapping.keys() | self._local_mapping.keys())
@classmethod
def register(cls, key: str, value: Callable):
cls._global_mapping.update({key: value})
def valid_keys(self) -> List[str]:
return list(self.keys())
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
else:
ALL_PARALLEL_STYLES = None
def convert_local_tensor_to_dtensor( def convert_local_tensor_to_dtensor(

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
and will be removed in the future.
"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union

View File

@ -148,6 +148,12 @@ def _upad_input(
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
""" """
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
# With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
# It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)

View File

@ -27,7 +27,6 @@ import shutil
import tempfile import tempfile
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from collections.abc import MutableMapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@ -124,6 +123,7 @@ from .utils import (
replace_return_docstrings, replace_return_docstrings,
strtobool, strtobool,
) )
from .utils.generic import GeneralInterface
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import ( from .utils.import_utils import (
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,
@ -6076,7 +6076,7 @@ def get_disk_only_shard_files(device_map, weight_map):
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
class AttentionInterface(MutableMapping): class AttentionInterface(GeneralInterface):
""" """
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`, with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
@ -6091,36 +6091,6 @@ class AttentionInterface(MutableMapping):
"sdpa": sdpa_attention_forward, "sdpa": sdpa_attention_forward,
} }
def __init__(self):
self._local_mapping = {}
def __getitem__(self, key):
# First check if instance has a local override
if key in self._local_mapping:
return self._local_mapping[key]
return self._global_mapping[key]
def __setitem__(self, key, value):
# Allow local update of the default functions without impacting other instances
self._local_mapping.update({key: value})
def __delitem__(self, key):
del self._local_mapping[key]
def __iter__(self):
# Ensure we use all keys, with the overwritten ones on top
return iter({**self._global_mapping, **self._local_mapping})
def __len__(self):
return len(self._global_mapping.keys() | self._local_mapping.keys())
@classmethod
def register(cls, key: str, value: Callable):
cls._global_mapping.update({key: value})
def valid_keys(self) -> List[str]:
return list(self.keys())
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()

View File

@ -25,20 +25,14 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
LossKwargs,
auto_docstring,
can_return_tuple,
is_torch_flex_attn_available,
logging,
)
from ...utils.import_utils import is_torch_available from ...utils.import_utils import is_torch_available
from ..auto import AutoModel from ..auto import AutoModel
from .configuration_aria import AriaConfig, AriaTextConfig from .configuration_aria import AriaConfig, AriaTextConfig
@ -49,12 +43,6 @@ if is_torch_available():
from torch import nn from torch import nn
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -818,8 +806,13 @@ class AriaTextModel(AriaTextPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -865,129 +858,6 @@ class AriaTextModel(AriaTextPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@ -1521,61 +1391,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
return model_inputs return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [ __all__ = [
"AriaForConditionalGeneration", "AriaForConditionalGeneration",

View File

@ -542,60 +542,5 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
return model_inputs return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"] __all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]

View File

@ -757,7 +757,7 @@ class BartPreTrainedModel(PreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -827,7 +827,7 @@ class BartPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -1602,7 +1602,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1672,7 +1672,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -482,7 +482,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -552,7 +552,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -27,23 +27,17 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_bitnet import BitNetConfig from .configuration_bitnet import BitNetConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -425,8 +419,13 @@ class BitNetModel(BitNetPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -472,129 +471,6 @@ class BitNetModel(BitNetPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -493,7 +493,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -563,7 +563,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -482,7 +482,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -552,7 +552,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -658,7 +658,7 @@ class BloomModel(BloomPreTrainedModel):
attentions=all_self_attentions, attentions=all_self_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -728,7 +728,7 @@ class BloomModel(BloomPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -1057,7 +1057,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1127,7 +1127,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -491,7 +491,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
attentions=all_self_attentions, attentions=all_self_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -561,7 +561,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -35,23 +35,17 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_cohere import CohereConfig from .configuration_cohere import CohereConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -462,8 +456,13 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -509,129 +508,6 @@ class CohereModel(CoherePreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
@ -119,9 +119,8 @@ class Cohere2Config(PretrainedConfig):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
sliding_window (`int`, *optional*, defaults to 4096): sliding_window (`int`, *optional*, defaults to 4096):
Size of the sliding window attention context. Size of the sliding window attention context.
sliding_window_pattern (`int`, *optional*, defaults to 4): layer_types (`list`, *optional*):
Pattern for the sliding window attention. Attention pattern for each layer.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
```python ```python
>>> from transformers import Cohere2Model, Cohere2Config >>> from transformers import Cohere2Model, Cohere2Config
@ -177,8 +176,7 @@ class Cohere2Config(PretrainedConfig):
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
sliding_window=4096, sliding_window=4096,
sliding_window_pattern=4, layer_types=None,
cache_implementation="hybrid",
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -203,10 +201,9 @@ class Cohere2Config(PretrainedConfig):
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.sliding_window_pattern = sliding_window_pattern self.layer_types = layer_types
# Need to specify head_dim in the config so it can be used in the attention forward functions # Need to specify head_dim in the config so it can be used in the attention forward functions
self.head_dim = hidden_size // num_attention_heads self.head_dim = hidden_size // num_attention_heads
self.cache_implementation = cache_implementation
# Validate the correctness of rotary position embeddings parameters # Validate the correctness of rotary position embeddings parameters
rope_config_validation(self) rope_config_validation(self)
@ -219,5 +216,14 @@ class Cohere2Config(PretrainedConfig):
**kwargs, **kwargs,
) )
if self.layer_types is None:
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
self.layer_types = [
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
__all__ = ["Cohere2Config"] __all__ = ["Cohere2Config"]

View File

@ -25,25 +25,20 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg from ...utils.deprecation import deprecate_kwarg
from .configuration_cohere2 import Cohere2Config from .configuration_cohere2 import Cohere2Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -186,6 +181,7 @@ class Cohere2Attention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.is_causal = True self.is_causal = True
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
self.q_proj = nn.Linear( self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
@ -199,9 +195,6 @@ class Cohere2Attention(nn.Module):
self.o_proj = nn.Linear( self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
) )
self.sliding_window = (
config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None
)
def forward( def forward(
self, self,
@ -224,19 +217,9 @@ class Cohere2Attention(nn.Module):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = { cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -284,12 +267,10 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Cohere2Config, layer_idx: int): def __init__(self, config: Cohere2Config, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Cohere2Attention(config, layer_idx) self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
self.mlp = Cohere2MLP(config) self.mlp = Cohere2MLP(config)
self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.config = config self.attention_type = config.layer_types[layer_idx]
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0") @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
@ -322,34 +303,6 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer):
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
""" """
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -440,7 +393,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@ -467,15 +420,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training: if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape past_key_values = DynamicCache()
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -485,9 +430,22 @@ class Cohere2Model(Cohere2PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions if not isinstance(causal_mask_mapping := attention_mask, dict):
) # Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -505,7 +463,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=causal_mask, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@ -531,100 +489,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
@torch.no_grad()
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: HybridCache,
output_attentions: bool = False,
):
# Flash Attention currently doesn't support static cache but Cohere2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
return attention_mask
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@ -635,7 +499,7 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
_tp_plan = {"lm_head": "colwise_rep"} _tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: Cohere2Config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model = Cohere2Model(config) self.model = Cohere2Model(config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
@ -740,88 +604,5 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten: has a special cache type, `HybridCache`
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
)
if logits_to_keep is not None:
model_inputs["logits_to_keep"] = logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
__all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"] __all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]

View File

@ -19,8 +19,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
@ -140,9 +141,8 @@ class Cohere2Config(PretrainedConfig):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
sliding_window (`int`, *optional*, defaults to 4096): sliding_window (`int`, *optional*, defaults to 4096):
Size of the sliding window attention context. Size of the sliding window attention context.
sliding_window_pattern (`int`, *optional*, defaults to 4): layer_types (`list`, *optional*):
Pattern for the sliding window attention. Attention pattern for each layer.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
```python ```python
>>> from transformers import Cohere2Model, Cohere2Config >>> from transformers import Cohere2Model, Cohere2Config
@ -198,8 +198,7 @@ class Cohere2Config(PretrainedConfig):
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
sliding_window=4096, sliding_window=4096,
sliding_window_pattern=4, layer_types=None,
cache_implementation="hybrid",
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -224,10 +223,9 @@ class Cohere2Config(PretrainedConfig):
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.sliding_window_pattern = sliding_window_pattern self.layer_types = layer_types
# Need to specify head_dim in the config so it can be used in the attention forward functions # Need to specify head_dim in the config so it can be used in the attention forward functions
self.head_dim = hidden_size // num_attention_heads self.head_dim = hidden_size // num_attention_heads
self.cache_implementation = cache_implementation
# Validate the correctness of rotary position embeddings parameters # Validate the correctness of rotary position embeddings parameters
rope_config_validation(self) rope_config_validation(self)
@ -240,6 +238,15 @@ class Cohere2Config(PretrainedConfig):
**kwargs, **kwargs,
) )
if self.layer_types is None:
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
self.layer_types = [
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
class Cohere2RotaryEmbedding(CohereRotaryEmbedding): class Cohere2RotaryEmbedding(CohereRotaryEmbedding):
pass pass
@ -261,6 +268,7 @@ class Cohere2Attention(CohereAttention, nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.is_causal = True self.is_causal = True
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
self.q_proj = nn.Linear( self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
@ -274,9 +282,6 @@ class Cohere2Attention(CohereAttention, nn.Module):
self.o_proj = nn.Linear( self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
) )
self.sliding_window = (
config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None
)
def forward( def forward(
self, self,
@ -299,19 +304,9 @@ class Cohere2Attention(CohereAttention, nn.Module):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = { cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -342,10 +337,7 @@ class Cohere2Attention(CohereAttention, nn.Module):
class Cohere2DecoderLayer(CohereDecoderLayer): class Cohere2DecoderLayer(CohereDecoderLayer):
def __init__(self, config: Cohere2Config, layer_idx: int): def __init__(self, config: Cohere2Config, layer_idx: int):
super().__init__(config, layer_idx) super().__init__(config, layer_idx)
self.self_attn = Cohere2Attention(config, layer_idx) self.attention_type = config.layer_types[layer_idx]
self.config = config
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0") @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
@ -378,34 +370,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
""" """
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -451,7 +415,7 @@ class Cohere2Model(Gemma2Model):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@ -478,15 +442,7 @@ class Cohere2Model(Gemma2Model):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training: if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape past_key_values = DynamicCache()
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -496,9 +452,22 @@ class Cohere2Model(Gemma2Model):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions if not isinstance(causal_mask_mapping := attention_mask, dict):
) # Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -516,7 +485,7 @@ class Cohere2Model(Gemma2Model):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=causal_mask, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@ -544,91 +513,7 @@ class Cohere2Model(Gemma2Model):
class Cohere2ForCausalLM(CohereForCausalLM): class Cohere2ForCausalLM(CohereForCausalLM):
def __init__(self, config: Cohere2Config): pass
super().__init__(config)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten: has a special cache type, `HybridCache`
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
)
if logits_to_keep is not None:
model_inputs["logits_to_keep"] = logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
__all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"] __all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]

View File

@ -29,25 +29,19 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, logging
from ..auto import AutoModel from ..auto import AutoModel
from .configuration_csm import CsmConfig, CsmDepthDecoderConfig from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
from .generation_csm import CsmGenerationMixin from .generation_csm import CsmGenerationMixin
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -516,8 +510,13 @@ class CsmDepthDecoderModel(CsmPreTrainedModel):
inputs_embeds = self.inputs_embeds_projector(inputs_embeds) inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -564,129 +563,6 @@ class CsmDepthDecoderModel(CsmPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class CsmCodebooksHead(nn.Module): class CsmCodebooksHead(nn.Module):
def __init__(self, hidden_size, num_codebooks, vocab_size): def __init__(self, hidden_size, num_codebooks, vocab_size):
@ -946,8 +822,13 @@ class CsmBackboneModel(CsmPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -993,129 +874,6 @@ class CsmBackboneModel(CsmPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@auto_docstring( @auto_docstring(
custom_intro=""" custom_intro="""
@ -1477,61 +1235,6 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
) )
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [ __all__ = [
"CsmPreTrainedModel", "CsmPreTrainedModel",

View File

@ -21,6 +21,7 @@ import torch.nn as nn
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
@ -240,8 +241,13 @@ class CsmDepthDecoderModel(LlamaModel):
inputs_embeds = self.inputs_embeds_projector(inputs_embeds) inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -835,61 +841,6 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
) )
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [ __all__ = [
"CsmPreTrainedModel", "CsmPreTrainedModel",

View File

@ -1001,7 +1001,7 @@ class DbrxModel(DbrxPreTrainedModel):
router_logits=all_router_logits, router_logits=all_router_logits,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1071,7 +1071,7 @@ class DbrxModel(DbrxPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -15,23 +15,17 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_deepseek_v3 import DeepseekV3Config from .configuration_deepseek_v3 import DeepseekV3Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -608,8 +602,13 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -655,129 +654,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -31,7 +31,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import ( from ...modeling_flash_attention_utils import (
FlashAttentionKwargs, FlashAttentionKwargs,
_flash_attention_forward, _flash_attention_forward,
@ -48,16 +48,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_diffllama import DiffLlamaConfig from .configuration_diffllama import DiffLlamaConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -708,8 +702,13 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -755,129 +754,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -32,23 +32,17 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -1279,8 +1273,13 @@ class Emu3TextModel(Emu3PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -1326,129 +1325,6 @@ class Emu3TextModel(Emu3PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@ -1857,62 +1733,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
return model_inputs return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [ __all__ = [
"Emu3ForConditionalGeneration", "Emu3ForConditionalGeneration",

View File

@ -1212,62 +1212,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
return model_inputs return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [ __all__ = [
"Emu3ForConditionalGeneration", "Emu3ForConditionalGeneration",

View File

@ -976,7 +976,7 @@ class FalconModel(FalconPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -45,12 +45,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import auto_docstring, is_torchdynamo_compiling, logging, replace_return_docstrings
auto_docstring,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg from ...utils.deprecation import deprecate_kwarg
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_falcon_h1 import FalconH1Config from .configuration_falcon_h1 import FalconH1Config

View File

@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -27,7 +27,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -39,16 +39,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_gemma import GemmaConfig from .configuration_gemma import GemmaConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -384,7 +378,7 @@ class GemmaModel(GemmaPreTrainedModel):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@ -422,8 +416,13 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
# embed positions # embed positions
@ -475,129 +474,6 @@ class GemmaModel(GemmaPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional
import sentencepiece as spm import sentencepiece as spm
import torch import torch
@ -22,6 +22,7 @@ from torch import nn
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...masking_utils import create_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast
from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
@ -371,7 +372,7 @@ class GemmaModel(LlamaModel):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@ -409,8 +410,13 @@ class GemmaModel(LlamaModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
# embed positions # embed positions

View File

@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig, layer_type_validation
class Gemma2Config(PretrainedConfig): class Gemma2Config(PretrainedConfig):
@ -78,12 +78,16 @@ class Gemma2Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores query_pre_attn_scalar (`float`, *optional*, defaults to 256):
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the scaling factor used on the attention scores
size of the sliding window. sliding_window (`int`, *optional*, defaults to 4096):
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. layer_types (`list`, *optional*):
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. Attention pattern for each layer.
final_logit_softcapping (`float`, *optional*, defaults to 30.0):
scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
scaling factor when applying tanh softcapping on the attention scores.
```python ```python
>>> from transformers import Gemma2Model, Gemma2Config >>> from transformers import Gemma2Model, Gemma2Config
@ -135,9 +139,9 @@ class Gemma2Config(PretrainedConfig):
attention_dropout=0.0, attention_dropout=0.0,
query_pre_attn_scalar=256, query_pre_attn_scalar=256,
sliding_window=4096, sliding_window=4096,
layer_types=None,
final_logit_softcapping=30.0, final_logit_softcapping=30.0,
attn_logit_softcapping=50.0, attn_logit_softcapping=50.0,
cache_implementation="hybrid",
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -166,7 +170,13 @@ class Gemma2Config(PretrainedConfig):
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
__all__ = ["Gemma2Config"] __all__ = ["Gemma2Config"]

View File

@ -26,8 +26,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -38,17 +39,11 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg from ...utils.deprecation import deprecate_kwarg
from .configuration_gemma2 import Gemma2Config from .configuration_gemma2 import Gemma2Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -195,7 +190,7 @@ class Gemma2Attention(nn.Module):
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
) )
self.attn_logit_softcapping = self.config.attn_logit_softcapping self.attn_logit_softcapping = self.config.attn_logit_softcapping
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
def forward( def forward(
self, self,
@ -218,19 +213,9 @@ class Gemma2Attention(nn.Module):
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -264,7 +249,7 @@ class Gemma2DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.config = config self.config = config
self.is_sliding = not bool(layer_idx % 2) self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config) self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -272,7 +257,6 @@ class Gemma2DecoderLayer(nn.Module):
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0") @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
@ -287,33 +271,6 @@ class Gemma2DecoderLayer(nn.Module):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(attention_mask.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -441,7 +398,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@ -468,15 +425,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training: if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape past_key_values = DynamicCache()
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -487,9 +436,22 @@ class Gemma2Model(Gemma2PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions if not isinstance(causal_mask_mapping := attention_mask, dict):
) # Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -516,7 +478,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
partial(decoder_layer.__call__, **flash_attn_kwargs), partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
position_embeddings, position_embeddings,
causal_mask, causal_mask_mapping[decoder_layer.attention_type],
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -527,7 +489,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=causal_mask, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -553,100 +515,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
@torch.no_grad()
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: HybridCache,
output_attentions: bool = False,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
return attention_mask
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@auto_docstring @auto_docstring
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
@ -688,7 +556,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
@ -765,60 +633,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten: has a special cache type, `HybridCache`
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs["attention_mask"] = attention_mask
return model_inputs
@auto_docstring( @auto_docstring(
custom_intro=""" custom_intro="""

View File

@ -21,13 +21,14 @@ import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import is_torch_flex_attn_available, logging from ...utils import logging
from ...utils.deprecation import deprecate_kwarg from ...utils.deprecation import deprecate_kwarg
from ..gemma.modeling_gemma import ( from ..gemma.modeling_gemma import (
GemmaAttention, GemmaAttention,
@ -42,12 +43,6 @@ from ..gemma.modeling_gemma import (
) )
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -107,12 +102,16 @@ class Gemma2Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores query_pre_attn_scalar (`float`, *optional*, defaults to 256):
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the scaling factor used on the attention scores
size of the sliding window. sliding_window (`int`, *optional*, defaults to 4096):
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. layer_types (`list`, *optional*):
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. Attention pattern for each layer.
final_logit_softcapping (`float`, *optional*, defaults to 30.0):
scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
scaling factor when applying tanh softcapping on the attention scores.
```python ```python
>>> from transformers import Gemma2Model, Gemma2Config >>> from transformers import Gemma2Model, Gemma2Config
@ -164,9 +163,9 @@ class Gemma2Config(PretrainedConfig):
attention_dropout=0.0, attention_dropout=0.0,
query_pre_attn_scalar=256, query_pre_attn_scalar=256,
sliding_window=4096, sliding_window=4096,
layer_types=None,
final_logit_softcapping=30.0, final_logit_softcapping=30.0,
attn_logit_softcapping=50.0, attn_logit_softcapping=50.0,
cache_implementation="hybrid",
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -195,7 +194,13 @@ class Gemma2Config(PretrainedConfig):
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
class Gemma2RMSNorm(GemmaRMSNorm): class Gemma2RMSNorm(GemmaRMSNorm):
@ -250,7 +255,7 @@ class Gemma2Attention(GemmaAttention):
self.attention_dropout = self.config.attention_dropout self.attention_dropout = self.config.attention_dropout
self.is_causal = True self.is_causal = True
self.scaling = config.query_pre_attn_scalar**-0.5 self.scaling = config.query_pre_attn_scalar**-0.5
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
def forward( def forward(
self, self,
@ -273,19 +278,9 @@ class Gemma2Attention(GemmaAttention):
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -319,7 +314,7 @@ class Gemma2DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.config = config self.config = config
self.is_sliding = not bool(layer_idx % 2) self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config) self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -327,7 +322,6 @@ class Gemma2DecoderLayer(nn.Module):
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0") @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
@ -342,33 +336,6 @@ class Gemma2DecoderLayer(nn.Module):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(attention_mask.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -414,7 +381,7 @@ class Gemma2Model(GemmaModel):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@ -441,15 +408,7 @@ class Gemma2Model(GemmaModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training: if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape past_key_values = DynamicCache()
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -460,9 +419,22 @@ class Gemma2Model(GemmaModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions if not isinstance(causal_mask_mapping := attention_mask, dict):
) # Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -489,7 +461,7 @@ class Gemma2Model(GemmaModel):
partial(decoder_layer.__call__, **flash_attn_kwargs), partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states, hidden_states,
position_embeddings, position_embeddings,
causal_mask, causal_mask_mapping[decoder_layer.attention_type],
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -500,7 +472,7 @@ class Gemma2Model(GemmaModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=causal_mask, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -526,45 +498,6 @@ class Gemma2Model(GemmaModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
@torch.no_grad()
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: HybridCache,
output_attentions: bool = False,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
return attention_mask
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
return causal_mask
class Gemma2ForCausalLM(GemmaForCausalLM): class Gemma2ForCausalLM(GemmaForCausalLM):
def __init__(self, config): def __init__(self, config):
@ -577,7 +510,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
@ -649,60 +582,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten: has a special cache type, `HybridCache`
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs["attention_mask"] = attention_mask
return model_inputs
class Gemma2ForSequenceClassification(GemmaForSequenceClassification): class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
def __init__(self, config): def __init__(self, config):

View File

@ -21,7 +21,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
from ..siglip import SiglipVisionConfig from ..siglip import SiglipVisionConfig
@ -88,13 +88,14 @@ class Gemma3TextConfig(PretrainedConfig):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 256): query_pre_attn_scalar (`float`, *optional*, defaults to 256):
Scaling factor used on the attention scores Scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the sliding_window (`int`, *optional*, defaults to 4096):
size of the sliding window. In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.
layer_types (`list`, *optional*):
Attention pattern for each layer.
final_logit_softcapping (`float`, *optional*): final_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the logits. Scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*): attn_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the attention scores. Scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
@ -134,8 +135,6 @@ class Gemma3TextConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0): rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention. The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
```python ```python
>>> from transformers import Gemma3TextModel, Gemma3TextConfig >>> from transformers import Gemma3TextModel, Gemma3TextConfig
@ -192,12 +191,11 @@ class Gemma3TextConfig(PretrainedConfig):
attention_dropout=0.0, attention_dropout=0.0,
query_pre_attn_scalar=256, query_pre_attn_scalar=256,
sliding_window=4096, sliding_window=4096,
layer_types=None,
final_logit_softcapping=None, final_logit_softcapping=None,
attn_logit_softcapping=None, attn_logit_softcapping=None,
cache_implementation="hybrid",
rope_scaling=None, rope_scaling=None,
rope_local_base_freq=10_000.0, rope_local_base_freq=10_000.0,
sliding_window_pattern=6,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -226,14 +224,21 @@ class Gemma3TextConfig(PretrainedConfig):
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation self.layer_types = layer_types
self.rope_local_base_freq = rope_local_base_freq self.rope_local_base_freq = rope_local_base_freq
# For configuring HybridCache to work with 5:1 attention pattern
self.sliding_window_pattern = sliding_window_pattern
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
rope_config_validation(self) rope_config_validation(self)
if self.layer_types is None:
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
sliding_window_pattern = getattr(self, "sliding_window_pattern", 6)
self.layer_types = [
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
class Gemma3Config(PretrainedConfig): class Gemma3Config(PretrainedConfig):
r""" r"""

View File

@ -29,32 +29,21 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
ModelOutput,
auto_docstring,
can_return_tuple,
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
)
from ...utils.deprecation import deprecate_kwarg from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel from ..auto import AutoModel
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -300,7 +289,7 @@ class Gemma3Attention(nn.Module):
def __init__(self, config: Gemma3TextConfig, layer_idx: int): def __init__(self, config: Gemma3TextConfig, layer_idx: int):
super().__init__() super().__init__()
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
@ -351,19 +340,9 @@ class Gemma3Attention(nn.Module):
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -374,9 +353,7 @@ class Gemma3Attention(nn.Module):
) )
else: else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if attention_mask is not None:
# backwards compatibility
attention_mask = attention_mask.to(query_states)
attn_output, attn_weights = attention_interface( attn_output, attn_weights = attention_interface(
self, self,
query_states, query_states,
@ -400,14 +377,13 @@ class Gemma3DecoderLayer(nn.Module):
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma3MLP(config) self.mlp = Gemma3MLP(config)
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0") @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
@ -423,33 +399,6 @@ class Gemma3DecoderLayer(nn.Module):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(attention_mask.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -568,7 +517,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@ -595,13 +544,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training: if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape past_key_values = DynamicCache()
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -614,13 +557,22 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, if not isinstance(causal_mask_mapping := attention_mask, dict):
inputs_embeds, # Prepare mask arguments
cache_position, mask_kwargs = {
past_key_values, "config": self.config,
output_attentions, "input_embeds": inputs_embeds,
) "attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -643,7 +595,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
hidden_states, hidden_states,
position_embeddings_global, position_embeddings_global,
position_embeddings_local, position_embeddings_local,
causal_mask, causal_mask_mapping[decoder_layer.attention_type],
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -655,7 +607,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
hidden_states, hidden_states,
position_embeddings_global=position_embeddings_global, position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local, position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -681,100 +633,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
@torch.no_grad()
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: HybridCache,
output_attentions: bool = False,
):
# Flash Attention currently doesn't support static cache but Gemma3Text work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
return attention_mask
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@auto_docstring @auto_docstring
class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
@ -818,7 +676,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
@ -895,60 +753,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten: has a special cache type, `HybridCache`
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs["attention_mask"] = attention_mask
return model_inputs
class Gemma3MultiModalProjector(nn.Module): class Gemma3MultiModalProjector(nn.Module):
def __init__(self, config: Gemma3Config): def __init__(self, config: Gemma3Config):
@ -986,6 +790,22 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs) return projected_vision_outputs.type_as(vision_outputs)
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
"""
# Do not return an additional mask in this case
if token_type_ids is None:
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1, we need to unmask it
return token_type_ids[batch_idx, kv_idx] == 1
return inner_mask
@auto_docstring( @auto_docstring(
custom_intro=""" custom_intro="""
The Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head., The Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head.,
@ -1012,86 +832,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value) self.language_model.set_input_embeddings(value)
def _update_causal_mask(
self,
attention_mask,
token_type_ids,
past_key_values,
cache_position,
input_tensor,
is_training: bool = False,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
return attention_mask
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted
# form and requires no inversion or slicing.
return attention_mask
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(self.dtype).min
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
return attention_mask
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
# Apply bidirectional mask on images if token type ids are provided
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
# Find where a new image block starts: 1 if image and previous not image
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
is_image = token_type_ids == 1
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
same_image_mask[image_group_ids == -1] = False # remove non-image
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
image_mask, 0.0
)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
# Then apply padding mask (will mask pad tokens)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
""" """
Projects the last hidden state from the vision model into language model space. Projects the last hidden state from the vision model into language model space.
@ -1161,8 +901,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors # Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size: if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id special_image_mask = input_ids == self.config.image_token_id
@ -1202,11 +940,31 @@ class Gemma3Model(Gemma3PreTrainedModel):
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training if not isinstance(causal_mask_mapping := attention_mask, dict):
) # Prepare mask arguments
mask_kwargs = {
"config": self.config.get_text_config(),
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
)
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
outputs = self.language_model( outputs = self.language_model(
attention_mask=causal_mask, attention_mask=causal_mask_mapping,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@ -1432,70 +1190,35 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0: if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs return model_inputs
@staticmethod @staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position( def create_masks_for_generate(
attention_mask: torch.Tensor, config: PretrainedConfig,
sequence_length: int, input_embeds: torch.Tensor,
target_length: int, attention_mask: Optional[torch.Tensor],
dtype: torch.dtype,
cache_position: torch.Tensor, cache_position: torch.Tensor,
batch_size: int, past_key_values: Optional[Cache],
output_attentions: bool = False,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
): ) -> dict:
""" # Prepare mask arguments
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape mask_kwargs = {
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. "config": config.get_text_config(),
"input_embeds": input_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
Args: return create_masks_for_generate(**mask_kwargs)
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [ __all__ = [

View File

@ -23,8 +23,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
@ -56,7 +57,7 @@ from ..siglip import SiglipVisionConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Gemma3TextConfig(Gemma2Config): class Gemma3TextConfig(Gemma2Config, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
@ -114,13 +115,14 @@ class Gemma3TextConfig(Gemma2Config):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 256): query_pre_attn_scalar (`float`, *optional*, defaults to 256):
Scaling factor used on the attention scores Scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the sliding_window (`int`, *optional*, defaults to 4096):
size of the sliding window. In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.
layer_types (`list`, *optional*):
Attention pattern for each layer.
final_logit_softcapping (`float`, *optional*): final_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the logits. Scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*): attn_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the attention scores. Scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
@ -160,8 +162,6 @@ class Gemma3TextConfig(Gemma2Config):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0): rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention. The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
```python ```python
>>> from transformers import Gemma3TextModel, Gemma3TextConfig >>> from transformers import Gemma3TextModel, Gemma3TextConfig
@ -183,23 +183,74 @@ class Gemma3TextConfig(Gemma2Config):
def __init__( def __init__(
self, self,
vocab_size=262_208, vocab_size=262_208,
rope_theta=1_000_000.0, hidden_size=2304,
rope_scaling=None, intermediate_size=9216,
rope_local_base_freq=10_000.0, num_hidden_layers=26,
sliding_window_pattern=6, num_attention_heads=8,
num_key_value_heads=4,
head_dim=256,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=131_072, max_position_embeddings=131_072,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=1_000_000.0,
attention_bias=False,
attention_dropout=0.0,
query_pre_attn_scalar=256,
sliding_window=4096,
layer_types=None,
final_logit_softcapping=None, final_logit_softcapping=None,
attn_logit_softcapping=None, attn_logit_softcapping=None,
**super_kwargs, rope_scaling=None,
rope_local_base_freq=10_000.0,
**kwargs,
): ):
super().__init__(self, **super_kwargs) PretrainedConfig.__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.query_pre_attn_scalar = query_pre_attn_scalar
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types
self.rope_local_base_freq = rope_local_base_freq self.rope_local_base_freq = rope_local_base_freq
# For configuring HybridCache to work with 5:1 attention pattern
self.sliding_window_pattern = sliding_window_pattern
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
rope_config_validation(self) rope_config_validation(self)
if self.layer_types is None:
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
sliding_window_pattern = getattr(self, "sliding_window_pattern", 6)
self.layer_types = [
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
class Gemma3Config(PretrainedConfig): class Gemma3Config(PretrainedConfig):
r""" r"""
@ -336,7 +387,7 @@ class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
# Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding` # Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding`
class Gemma3Attention(Gemma2Attention): class Gemma3Attention(Gemma2Attention):
def __init__(self, config: Gemma3TextConfig, layer_idx: int): def __init__(self, config: Gemma3TextConfig, layer_idx: int):
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
super().__init__() super().__init__()
self.sliding_window = config.sliding_window if self.is_sliding else None self.sliding_window = config.sliding_window if self.is_sliding else None
@ -368,19 +419,9 @@ class Gemma3Attention(Gemma2Attention):
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -391,9 +432,7 @@ class Gemma3Attention(Gemma2Attention):
) )
else: else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if attention_mask is not None:
# backwards compatibility
attention_mask = attention_mask.to(query_states)
attn_output, attn_weights = attention_interface( attn_output, attn_weights = attention_interface(
self, self,
query_states, query_states,
@ -417,14 +456,13 @@ class Gemma3DecoderLayer(nn.Module):
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma3MLP(config) self.mlp = Gemma3MLP(config)
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window
@deprecate_kwarg("last_cache_position", version="4.53.0") @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward( def forward(
@ -440,33 +478,6 @@ class Gemma3DecoderLayer(nn.Module):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(attention_mask.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -557,7 +568,7 @@ class Gemma3TextModel(Gemma2Model):
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@ -584,13 +595,7 @@ class Gemma3TextModel(Gemma2Model):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training: if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape past_key_values = DynamicCache()
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -603,13 +608,22 @@ class Gemma3TextModel(Gemma2Model):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, if not isinstance(causal_mask_mapping := attention_mask, dict):
inputs_embeds, # Prepare mask arguments
cache_position, mask_kwargs = {
past_key_values, "config": self.config,
output_attentions, "input_embeds": inputs_embeds,
) "attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -632,7 +646,7 @@ class Gemma3TextModel(Gemma2Model):
hidden_states, hidden_states,
position_embeddings_global, position_embeddings_global,
position_embeddings_local, position_embeddings_local,
causal_mask, causal_mask_mapping[decoder_layer.attention_type],
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -644,7 +658,7 @@ class Gemma3TextModel(Gemma2Model):
hidden_states, hidden_states,
position_embeddings_global=position_embeddings_global, position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local, position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -716,6 +730,22 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs) return projected_vision_outputs.type_as(vision_outputs)
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
"""
# Do not return an additional mask in this case
if token_type_ids is None:
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1, we need to unmask it
return token_type_ids[batch_idx, kv_idx] == 1
return inner_mask
class Gemma3Model(PaliGemmaModel): class Gemma3Model(PaliGemmaModel):
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
""" """
@ -731,85 +761,8 @@ class Gemma3Model(PaliGemmaModel):
image_features = self.multi_modal_projector(vision_outputs) image_features = self.multi_modal_projector(vision_outputs)
return image_features return image_features
def _update_causal_mask( def _update_causal_mask(self, **super_kwargs):
self, raise AttributeError("We don't want to inherit it")
attention_mask,
token_type_ids,
past_key_values,
cache_position,
input_tensor,
is_training: bool = False,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
return attention_mask
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted
# form and requires no inversion or slicing.
return attention_mask
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(self.dtype).min
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
return attention_mask
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
# Apply bidirectional mask on images if token type ids are provided
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
# Find where a new image block starts: 1 if image and previous not image
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
is_image = token_type_ids == 1
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
same_image_mask[image_group_ids == -1] = False # remove non-image
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
image_mask, 0.0
)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
# Then apply padding mask (will mask pad tokens)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@can_return_tuple @can_return_tuple
@auto_docstring @auto_docstring
@ -839,8 +792,6 @@ class Gemma3Model(PaliGemmaModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors # Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size: if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id special_image_mask = input_ids == self.config.image_token_id
@ -880,11 +831,31 @@ class Gemma3Model(PaliGemmaModel):
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training if not isinstance(causal_mask_mapping := attention_mask, dict):
) # Prepare mask arguments
mask_kwargs = {
"config": self.config.get_text_config(),
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
)
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
outputs = self.language_model( outputs = self.language_model(
attention_mask=causal_mask, attention_mask=causal_mask_mapping,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@ -1066,16 +1037,39 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0: if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs return model_inputs
def _prepare_4d_causal_attention_mask_with_cache_position(self, **super_kwargs):
raise AttributeError("We don't want to inherit it")
@staticmethod
def create_masks_for_generate(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
output_attentions: bool = False,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# Prepare mask arguments
mask_kwargs = {
"config": config.get_text_config(),
"input_embeds": input_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
return create_masks_for_generate(**mask_kwargs)
__all__ = [ __all__ = [
"Gemma3Config", "Gemma3Config",

View File

@ -28,7 +28,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -40,16 +40,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_glm import GlmConfig from .configuration_glm import GlmConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -443,8 +437,13 @@ class GlmModel(GlmPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -490,129 +489,6 @@ class GlmModel(GlmPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -28,7 +28,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -40,16 +40,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_glm4 import Glm4Config from .configuration_glm4 import Glm4Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -451,8 +445,13 @@ class Glm4Model(Glm4PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -498,129 +497,6 @@ class Glm4Model(Glm4PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@auto_docstring @auto_docstring
class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin): class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin):

View File

@ -893,60 +893,5 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
return model_inputs return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"] __all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"]

View File

@ -682,7 +682,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
attentions=all_self_attentions, attentions=all_self_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -752,7 +752,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -12,7 +12,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -24,16 +24,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_gpt_neox import GPTNeoXConfig from .configuration_gpt_neox import GPTNeoXConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -409,8 +403,13 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
# Prepare head mask if needed # Prepare head mask if needed
@ -479,129 +478,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
attentions=all_attentions, attentions=all_attentions,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -7,6 +7,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -349,8 +350,13 @@ class GPTNeoXModel(LlamaModel, nn.Module):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
# Prepare head mask if needed # Prepare head mask if needed

View File

@ -540,7 +540,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
attentions=all_attentions, attentions=all_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -610,7 +610,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -793,7 +793,6 @@ class GPTJModel(GPTJPreTrainedModel):
attentions=all_self_attentions, attentions=all_self_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -863,7 +862,6 @@ class GPTJModel(GPTJPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -28,23 +28,17 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_granite import GraniteConfig from .configuration_granite import GraniteConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -446,8 +440,13 @@ class GraniteModel(GranitePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -493,129 +492,6 @@ class GraniteModel(GranitePreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -20,6 +20,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -174,8 +175,13 @@ class GraniteModel(LlamaModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds

View File

@ -773,7 +773,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
router_logits=all_router_logits, router_logits=all_router_logits,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -843,7 +843,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -28,7 +28,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -40,16 +40,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_helium import HeliumConfig from .configuration_helium import HeliumConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -428,8 +422,13 @@ class HeliumModel(HeliumPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -475,129 +474,6 @@ class HeliumModel(HeliumPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -1316,7 +1316,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
image_hidden_states=image_hidden_states, image_hidden_states=image_hidden_states,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1386,7 +1386,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -1035,61 +1035,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return model_inputs return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [ __all__ = [
"InternVLVisionPreTrainedModel", "InternVLVisionPreTrainedModel",

View File

@ -1014,7 +1014,7 @@ class JetMoeModel(JetMoePreTrainedModel):
router_logits=all_router_logits, router_logits=all_router_logits,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1084,7 +1084,7 @@ class JetMoeModel(JetMoePreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -26,7 +26,8 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -40,18 +41,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_llama import LlamaConfig from .configuration_llama import LlamaConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
from ...integrations import use_kernel_forward_from_hub
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -433,8 +426,13 @@ class LlamaModel(LlamaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -480,129 +478,6 @@ class LlamaModel(LlamaPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...utils import logging from ...utils import logging
@ -233,12 +233,13 @@ class Llama4TextConfig(PretrainedConfig):
`no_rope_layer_interval` layers. `no_rope_layer_interval` layers.
attention_chunk_size (`int`, *optional*, defaults to 8192): attention_chunk_size (`int`, *optional*, defaults to 8192):
<TODO> <TODO>
layer_types (`list`, *optional*):
Attention pattern for each layer.
attn_temperature_tuning (`bool`, *optional*, defaults to `True`): attn_temperature_tuning (`bool`, *optional*, defaults to `True`):
Whether to dynamically scale the attention temperature for each query token based on sequence length. Whether to dynamically scale the attention temperature for each query token based on sequence length.
Recommended for long sequences (e.g., >32k tokens) to maintain stable output results. Recommended for long sequences (e.g., >32k tokens) to maintain stable output results.
floor_scale (`int`, *optional*, defaults to 8192): TODO floor_scale (`int`, *optional*, defaults to 8192): TODO
attn_scale (`int`, *optional*, defaults to 0.1): TODO attn_scale (`int`, *optional*, defaults to 0.1): TODO
cache_implementation (`<fill_type>`, *optional*, defaults to `"hybrid"`): <fill_docstring>
Example: Example:
""" """
@ -298,10 +299,10 @@ class Llama4TextConfig(PretrainedConfig):
no_rope_layers=None, no_rope_layers=None,
no_rope_layer_interval=4, no_rope_layer_interval=4,
attention_chunk_size=8192, attention_chunk_size=8192,
layer_types=None,
attn_temperature_tuning=True, attn_temperature_tuning=True,
floor_scale=8192, floor_scale=8192,
attn_scale=0.1, attn_scale=0.1,
cache_implementation="hybrid",
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -323,7 +324,6 @@ class Llama4TextConfig(PretrainedConfig):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = False self.attention_bias = False
self.cache_implementation = cache_implementation
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
num_key_value_heads = num_attention_heads num_key_value_heads = num_attention_heads
@ -363,6 +363,13 @@ class Llama4TextConfig(PretrainedConfig):
) )
self.attention_chunk_size = attention_chunk_size self.attention_chunk_size = attention_chunk_size
self.layer_types = layer_types
if layer_types is None:
self.layer_types = [
"chunked_attention" if no_rope else "full_attention" for no_rope in self.no_rope_layers
]
layer_type_validation(self.layer_types)
class Llama4Config(PretrainedConfig): class Llama4Config(PretrainedConfig):
r""" r"""

View File

@ -24,24 +24,19 @@ import torch.nn.functional as F
from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, HybridChunkedCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations.hub_kernels import use_kernel_forward_from_hub from ...integrations.hub_kernels import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask, create_chunked_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_llama4 import Llama4Config, Llama4TextConfig from .configuration_llama4 import Llama4Config, Llama4TextConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -388,8 +383,9 @@ class Llama4TextDecoderLayer(nn.Module):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Llama4TextAttention(config, layer_idx) self.self_attn = Llama4TextAttention(config, layer_idx)
self.use_chunked_attention = config.attention_chunk_size is not None and bool(config.no_rope_layers[layer_idx])
self.is_moe_layer = layer_idx in config.moe_layers self.is_moe_layer = layer_idx in config.moe_layers
if self.is_moe_layer: # the 128E model interleaves dense / sparse if self.is_moe_layer: # the 128E model interleaves dense / sparse
self.feed_forward = Llama4TextMoe(config) self.feed_forward = Llama4TextMoe(config)
@ -399,13 +395,10 @@ class Llama4TextDecoderLayer(nn.Module):
self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layer_idx = layer_idx
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
chunk_causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
@ -419,10 +412,6 @@ class Llama4TextDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# use local attention mask for ROPE layers
if self.use_chunked_attention and chunk_causal_mask is not None:
attention_mask = chunk_causal_mask
# Self Attention # Self Attention
attention_states, self_attn_weights = self.self_attn( attention_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
@ -561,10 +550,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
if use_cache and past_key_values is None: if use_cache and past_key_values is None:
if self.config.get_text_config().attention_chunk_size is not None: past_key_values = DynamicCache()
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
else:
past_key_values = DynamicCache()
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -575,9 +561,22 @@ class Llama4TextModel(Llama4PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask, chunk_causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, use_cache=use_cache if not isinstance(causal_mask_mapping := attention_mask, dict):
) # Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"chunked_attention": create_chunked_causal_mask(**mask_kwargs),
}
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -596,8 +595,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
causal_mask, causal_mask_mapping[decoder_layer.attention_type],
chunk_causal_mask,
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -609,8 +607,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
chunk_causal_mask=chunk_causal_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -638,216 +635,6 @@ class Llama4TextModel(Llama4PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
@torch.compiler.disable(recursive=False) # the operations in this method are not compilable
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
chunked_attention_mask=None,
use_cache=True,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask, attention_mask # flash does not support chunked attn TODO support flash
return None, None
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
return None, None
sequence_length = input_tensor.shape[1]
cache_position = cache_position.to(self.device)
attention_chunk_size = self.config.attention_chunk_size
using_chunked_attention = attention_chunk_size is not None
first_cache_position = cache_position[0]
if past_key_values is not None:
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
else:
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
if using_chunked_attention:
cond1 = first_cache_position >= attention_chunk_size
cond2 = (first_cache_position < attention_chunk_size) & (
first_cache_position + sequence_length > attention_chunk_size
)
key_length = (
torch.where(
cond1,
attention_chunk_size + sequence_length - 1,
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
)
if use_cache
else full_cache_length
)
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
if using_chunked_attention:
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
chunked_attention_mask = make_flex_block_causal_mask(
attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets
)
attention_mask = make_flex_block_causal_mask(
attention_mask,
query_length=sequence_length,
key_length=full_cache_length,
offsets=(first_cache_position, 0),
)
return attention_mask, chunked_attention_mask
if isinstance(attention_mask, BlockMask):
return attention_mask, chunked_attention_mask
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
dtype, device = input_tensor.dtype, input_tensor.device
target_length = max(full_cache_length, attention_chunk_size) if using_chunked_attention else full_cache_length
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if using_chunked_attention and full_cache_length > attention_chunk_size:
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
end_idx = start_idx + key_length
chunked_attention_mask = self.create_chunked_attention_mask(
self.config.attention_chunk_size,
start=start_idx, # same offset as with flex
end=end_idx,
device=device,
)
local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well
# It may be smaller than attention_chunk_size -> pad it
requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
if requires_padding:
local_attention_mask = nn.functional.pad(
local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1])
)
# Depending on the padding, take the query tokens from the end or the cache_position
if not requires_padding:
chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :]
else:
chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :]
chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1)
chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :]
if self.config._attn_implementation == "eager":
min_dtype = torch.finfo(dtype).min
chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and attention_mask.ndim == 4
and not output_attentions # Only unmask for 4d masks
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
chunked_attention_mask = chunked_attention_mask.bool()
causal_mask = causal_mask != torch.finfo(dtype).min
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=first_cache_position,
is_training=self.training,
):
causal_mask = None
return causal_mask, chunked_attention_mask
def create_chunked_attention_mask(
self, attention_chunk_size: int, start: int, end: int, device: torch.device
) -> torch.Tensor:
"""
Generate the following:
'What' : 0 |
'▁is' : 1 |
'▁ch' : 2 |
'unked' : 3 |
'▁attention': 4 |
'?' : 5 |
If the chunk size is 3.
This can just be applied over the already created attention mask
"""
arange_vector = torch.arange(start, end, device=device)
block_pos = torch.abs(
arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size
)
token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
mask = (block_pos == 0) & (token_pos <= 0)
return mask.to(device)
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
cache_position.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@ -1726,61 +1513,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
return model_inputs return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [ __all__ = [
"Llama4PreTrainedModel", "Llama4PreTrainedModel",

View File

@ -503,61 +503,5 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return model_inputs return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"] __all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]

View File

@ -719,7 +719,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
return model_inputs return model_inputs
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -1589,7 +1589,7 @@ class LongT5Stack(LongT5PreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1659,7 +1659,7 @@ class LongT5Stack(LongT5PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -790,7 +790,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
module.bias.data.zero_() module.bias.data.zero_()
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -860,7 +860,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -484,7 +484,7 @@ class MarianPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
module.bias.data.zero_() module.bias.data.zero_()
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -554,7 +554,7 @@ class MarianPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -762,7 +762,7 @@ class MBartPreTrainedModel(PreTrainedModel):
} }
return dummy_inputs return dummy_inputs
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -832,7 +832,7 @@ class MBartPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -1060,7 +1060,7 @@ class MimiTransformerModel(nn.Module):
attentions=all_self_attns, attentions=all_self_attns,
) )
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Mimi # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Mimi
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1148,7 +1148,7 @@ class MimiTransformerModel(nn.Module):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mimi # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Mimi
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -10,10 +10,10 @@ import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -26,16 +26,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_mistral import MistralConfig from .configuration_mistral import MistralConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -403,8 +397,14 @@ class MistralModel(MistralPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -450,161 +450,6 @@ class MistralModel(MistralPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
config: MistralConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`MistralConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -4,13 +4,13 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, QuestionAnsweringModelOutput from ...modeling_outputs import BaseModelOutputWithPast, QuestionAnsweringModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import is_torch_flex_attn_available, logging from ...utils import auto_docstring, can_return_tuple, logging
from ..llama.modeling_llama import ( from ..llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaDecoderLayer, LlamaDecoderLayer,
@ -27,12 +27,6 @@ from ..llama.modeling_llama import (
from .configuration_mistral import MistralConfig from .configuration_mistral import MistralConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
@ -118,166 +112,107 @@ class MistralPreTrainedModel(LlamaPreTrainedModel):
class MistralModel(LlamaModel): class MistralModel(LlamaModel):
def __init__(self, config: MistralConfig): @can_return_tuple
super().__init__(config) @auto_docstring
self.layers = nn.ModuleList( def forward(
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], input_ids: Optional[torch.LongTensor] = None,
input_tensor: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
cache_position: torch.Tensor, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Cache, past_key_values: Optional[Cache] = None,
output_attentions: bool = False, inputs_embeds: Optional[torch.FloatTensor] = None,
): use_cache: Optional[bool] = None,
if self.config._attn_implementation == "flash_attention_2": output_attentions: Optional[bool] = None,
if attention_mask is not None and past_key_values is not None: output_hidden_states: Optional[bool] = None,
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] cache_position: Optional[torch.LongTensor] = None,
if is_padding_right: **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
raise ValueError( ) -> BaseModelOutputWithPast:
"You are attempting to perform batched generation with padding_side='right'" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " output_hidden_states = (
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
if attention_mask is not None and 0.0 in attention_mask: use_cache = use_cache if use_cache is not None else self.config.use_cache
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in if (input_ids is None) ^ (inputs_embeds is not None):
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.gradient_checkpointing and self.training and use_cache:
if ( logger.warning_once(
self.config._attn_implementation == "sdpa" "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
and not (using_static_cache or using_sliding_window_cache) )
and not output_attentions use_cache = False
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype = input_tensor.dtype # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
min_dtype = torch.finfo(dtype).min if not isinstance(past_key_values, (type(None), Cache)):
sequence_length = input_tensor.shape[1] raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache: if inputs_embeds is None:
target_length = past_key_values.get_max_cache_shape() inputs_embeds = self.embed_tokens(input_ids)
# DynamicCache or no cache
else: if use_cache and past_key_values is None:
target_length = ( past_key_values = DynamicCache()
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor) if cache_position is None:
else past_seen_tokens + sequence_length + 1 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D). if position_ids is None:
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( position_ids = cache_position.unsqueeze(0)
attention_mask,
sequence_length=sequence_length, mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
target_length=target_length, causal_mask = mask_function(
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config, config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values, past_key_values=past_key_values,
output_attentions=output_attentions,
) )
if ( hidden_states = inputs_embeds
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask # create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
@staticmethod # decoder layers
def _prepare_4d_causal_attention_mask_with_cache_position( all_hidden_states = () if output_hidden_states else None
attention_mask: torch.Tensor, all_self_attns = () if output_attentions else None
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
config: MistralConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args: for decoder_layer in self.layers[: self.config.num_hidden_layers]:
attention_mask (`torch.Tensor`): if output_hidden_states:
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. all_hidden_states += (hidden_states,)
sequence_length (`int`):
The sequence length being processed. layer_outputs = decoder_layer(
target_length (`int`): hidden_states,
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. attention_mask=causal_mask,
dtype (`torch.dtype`): position_ids=position_ids,
The dtype to use for the 4D attention mask. past_key_value=past_key_values,
cache_position (`torch.Tensor`): output_attentions=output_attentions,
Indices depicting the position of the input sequence tokens in the sequence. use_cache=use_cache,
batch_size (`torch.Tensor`): cache_position=cache_position,
Batch size. position_embeddings=position_embeddings,
config (`MistralConfig`): **flash_attn_kwargs,
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
) )
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1 hidden_states = layer_outputs[0]
)
text_config = config.get_text_config() if output_attentions:
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: all_self_attns += (layer_outputs[1],)
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not hidden_states = self.norm(hidden_states)
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( # add hidden states from the last decoder layer
cache_position.reshape(-1, 1) - text_config.sliding_window if output_hidden_states:
) all_hidden_states += (hidden_states,)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask return BaseModelOutputWithPast(
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) last_hidden_state=hidden_states,
if attention_mask is not None: past_key_values=past_key_values if use_cache else None,
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit hidden_states=all_hidden_states,
if attention_mask.shape[-1] > target_length: attentions=all_self_attns,
attention_mask = attention_mask[:, :target_length] )
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class MistralForCausalLM(LlamaForCausalLM): class MistralForCausalLM(LlamaForCausalLM):

View File

@ -32,12 +32,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
LossKwargs,
auto_docstring,
can_return_tuple,
is_torchdynamo_compiling,
)
from ..auto import AutoModel from ..auto import AutoModel
from .configuration_mistral3 import Mistral3Config from .configuration_mistral3 import Mistral3Config
@ -538,60 +533,5 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
return model_inputs return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"] __all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]

View File

@ -32,10 +32,10 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -48,16 +48,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_mixtral import MixtralConfig from .configuration_mixtral import MixtralConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -524,8 +518,14 @@ class MixtralModel(MixtralPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -591,161 +591,6 @@ class MixtralModel(MixtralPreTrainedModel):
router_logits=all_router_logits, router_logits=all_router_logits,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
config: MixtralConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`MixtralConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -29,6 +29,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import DynamicCache from ...cache_utils import DynamicCache
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...processing_utils import Unpack from ...processing_utils import Unpack
@ -315,12 +316,6 @@ class MixtralPreTrainedModel(MistralPreTrainedModel):
class MixtralModel(MistralModel): class MixtralModel(MistralModel):
def __init__(self, config: MixtralConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@ -368,8 +363,14 @@ class MixtralModel(MistralModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds

View File

@ -893,7 +893,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
if module.is_gated: if module.is_gated:
module.gate.data.zero_() module.gate.data.zero_()
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -963,7 +963,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -27,11 +27,8 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import ( from ...masking_utils import create_causal_mask
AttentionMaskConverter, from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -44,16 +41,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import auto_docstring, can_return_tuple, logging
from .configuration_moonshine import MoonshineConfig from .configuration_moonshine import MoonshineConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -752,8 +743,13 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -826,129 +822,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def _compute_mask_indices( def _compute_mask_indices(
shape: Tuple[int, int], shape: Tuple[int, int],

View File

@ -21,6 +21,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
@ -748,8 +749,13 @@ class MoonshineDecoder(LlamaModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds

View File

@ -1084,7 +1084,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
attentions=all_self_attns, attentions=all_self_attns,
) )
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Moshi # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Moshi
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1172,7 +1172,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->MoshiDepth # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->MoshiDepth
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,
@ -1391,7 +1391,7 @@ class MoshiModel(MoshiPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Moshi # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Moshi
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1479,7 +1479,7 @@ class MoshiModel(MoshiPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Moshi # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Moshi
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -1180,7 +1180,7 @@ class MT5Stack(MT5PreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1250,7 +1250,7 @@ class MT5Stack(MT5PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -739,7 +739,7 @@ class NemotronModel(NemotronPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -809,7 +809,7 @@ class NemotronModel(NemotronPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -13,23 +13,17 @@ import torch.nn.functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_olmo import OlmoConfig from .configuration_olmo import OlmoConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -406,8 +400,13 @@ class OlmoModel(OlmoPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -453,129 +452,6 @@ class OlmoModel(OlmoPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -13,23 +13,17 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_olmo2 import Olmo2Config from .configuration_olmo2 import Olmo2Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -412,8 +406,13 @@ class Olmo2Model(Olmo2PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -459,129 +458,6 @@ class Olmo2Model(Olmo2PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -384,7 +384,7 @@ class OPTDecoder(OPTPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -454,7 +454,7 @@ class OPTDecoder(OPTPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -566,7 +566,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
return model_inputs return model_inputs
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -481,7 +481,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
module.bias.data.zero_() module.bias.data.zero_()
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -551,7 +551,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -773,7 +773,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
module.bias.data.zero_() module.bias.data.zero_()
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -843,7 +843,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -539,7 +539,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -609,7 +609,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -13,7 +13,7 @@ import torch.nn as nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -24,16 +24,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_phi import PhiConfig from .configuration_phi import PhiConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -400,8 +394,13 @@ class PhiModel(PhiPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
@ -461,129 +460,6 @@ class PhiModel(PhiPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -244,8 +245,13 @@ class PhiModel(LlamaModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = create_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama

View File

@ -26,10 +26,10 @@ import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -41,16 +41,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_phi3 import Phi3Config from .configuration_phi3 import Phi3Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -458,8 +452,14 @@ class Phi3Model(Phi3PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -505,161 +505,6 @@ class Phi3Model(Phi3PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
config: Phi3Config,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Phi3Config`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -28,13 +28,12 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -46,16 +45,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging, torch_int from ...utils import auto_docstring, can_return_tuple, logging, torch_int
from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -1766,8 +1759,14 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -1813,161 +1812,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Phi4Multimodal. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
config: Phi4MultimodalConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Phi4MultimodalConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@auto_docstring @auto_docstring
class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin):

View File

@ -21,11 +21,11 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import DynamicCache from ...cache_utils import DynamicCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPast, BaseModelOutputWithPast,
@ -1570,8 +1570,14 @@ class Phi4MultimodalModel(Phi3Model, nn.Module):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds

View File

@ -1068,7 +1068,6 @@ class PhimoeModel(PhimoePreTrainedModel):
router_logits=all_router_logits, router_logits=all_router_logits,
) )
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Phimoe
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1156,7 +1155,6 @@ class PhimoeModel(PhimoePreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phimoe
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -1345,7 +1345,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -1415,7 +1415,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -520,7 +520,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
module.bias.data.zero_() module.bias.data.zero_()
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -590,7 +590,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -900,7 +900,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: Union[torch.Tensor, "BlockMask"], attention_mask: Union[torch.Tensor, "BlockMask"],
@ -970,7 +970,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
return causal_mask return causal_mask
@staticmethod @staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position( def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
sequence_length: int, sequence_length: int,

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Qwen2 model configuration""" """Qwen2 model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@ -110,6 +110,8 @@ class Qwen2Config(PretrainedConfig):
Sliding window attention (SWA) window size. If not specified, will default to `4096`. Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28): max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
layer_types (`list`, *optional*):
Attention pattern for each layer.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
@ -164,6 +166,7 @@ class Qwen2Config(PretrainedConfig):
use_sliding_window=False, use_sliding_window=False,
sliding_window=4096, sliding_window=4096,
max_window_layers=28, max_window_layers=28,
layer_types=None,
attention_dropout=0.0, attention_dropout=0.0,
**kwargs, **kwargs,
): ):
@ -174,7 +177,7 @@ class Qwen2Config(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code self.sliding_window = sliding_window if self.use_sliding_window else None
self.max_window_layers = max_window_layers self.max_window_layers = max_window_layers
# for backward compatibility # for backward compatibility
@ -195,6 +198,16 @@ class Qwen2Config(PretrainedConfig):
self.rope_scaling["rope_type"] = self.rope_scaling["type"] self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self) rope_config_validation(self)
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
super().__init__( super().__init__(
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,

View File

@ -10,10 +10,10 @@ import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -26,16 +26,10 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_qwen2 import Qwen2Config from .configuration_qwen2 import Qwen2Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -143,6 +137,7 @@ class Qwen2Attention(nn.Module):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
def forward( def forward(
self, self,
@ -168,14 +163,6 @@ class Qwen2Attention(nn.Module):
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -194,7 +181,7 @@ class Qwen2Attention(nn.Module):
attention_mask, attention_mask,
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
sliding_window=sliding_window, # main diff with Llama sliding_window=self.sliding_window, # main diff with Llama
**kwargs, **kwargs,
) )
@ -228,15 +215,13 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen2Config, layer_idx: int): def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
self.mlp = Qwen2MLP(config) self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.use_sliding_window and config._attn_implementation != "flash_attention_2": self.attention_type = config.layer_types[layer_idx]
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
def forward( def forward(
self, self,
@ -357,6 +342,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.rotary_emb = Qwen2RotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@ -416,9 +402,24 @@ class Qwen2Model(Qwen2PreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( # It may already have been prepared by e.g. `generate`
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions if not isinstance(causal_mask_mapping := attention_mask, dict):
) # Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -435,7 +436,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask, attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -463,161 +464,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
attentions=all_self_attns, attentions=all_self_attns,
) )
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2Config,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen2Config`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
text_config = config.get_text_config()
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - text_config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

View File

@ -4,11 +4,15 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from ...cache_utils import Cache from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import logging from ...utils import auto_docstring, can_return_tuple, logging
from ..llama.modeling_llama import ( from ..llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaDecoderLayer, LlamaDecoderLayer,
@ -43,6 +47,7 @@ class Qwen2Attention(LlamaAttention):
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
def forward( def forward(
self, self,
@ -68,14 +73,6 @@ class Qwen2Attention(LlamaAttention):
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -94,7 +91,7 @@ class Qwen2Attention(LlamaAttention):
attention_mask, attention_mask,
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
sliding_window=sliding_window, # main diff with Llama sliding_window=self.sliding_window, # main diff with Llama
**kwargs, **kwargs,
) )
@ -106,13 +103,7 @@ class Qwen2Attention(LlamaAttention):
class Qwen2DecoderLayer(LlamaDecoderLayer): class Qwen2DecoderLayer(LlamaDecoderLayer):
def __init__(self, config: Qwen2Config, layer_idx: int): def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__() super().__init__()
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.attention_type = config.layer_types[layer_idx]
self.mlp = Qwen2MLP(config)
if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
class Qwen2PreTrainedModel(LlamaPreTrainedModel): class Qwen2PreTrainedModel(LlamaPreTrainedModel):
@ -120,7 +111,120 @@ class Qwen2PreTrainedModel(LlamaPreTrainedModel):
class Qwen2Model(MistralModel): class Qwen2Model(MistralModel):
pass def __init__(self, config: Qwen2Config):
super().__init__(config)
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
if not isinstance(past_key_values, (type(None), Cache)):
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"output_attentions": output_attentions,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Qwen2ForCausalLM(LlamaForCausalLM): class Qwen2ForCausalLM(LlamaForCausalLM):

Some files were not shown because too many files have changed in this diff Show More