mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
🚨🚨[core] Completely rewrite the masking logic for all attentions (#37866)
* start * start having a clean 4d mask primitive * Update mask_utils.py * Update mask_utils.py * switch name * Update masking_utils.py * add a new AttentionMask tensor class * fix import * nits * fixes * use full and quandrants * general sdpa mask for all caches * style * start some tests * tests with sliding, chunked * add styling * test hybrid * Update masking_utils.py * small temp fixes * Update modeling_gemma2.py * compile compatible * Update masking_utils.py * improve * start making it more general * Update masking_utils.py * generate * make it work with flex style primitives! * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * improve * Update cache_utils.py * Update masking_utils.py * simplify - starting to look good! * Update masking_utils.py * name * Update masking_utils.py * style * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * small fix for flex * flex compile * FA2 * Update masking_utils.py * Escape for TGI/vLLM! * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * General case without cache * rename * full test on llama4 * small fix for FA2 guard with chunk * Update modeling_gemma2.py * post rebase cleanup * FA2 supports static cache! * Update modeling_flash_attention_utils.py * Update flex_attention.py * Update masking_utils.py * Update masking_utils.py * Update utils.py * override for export * Update executorch.py * Update executorch.py * Update executorch.py * Update executorch.py * Update masking_utils.py * Update masking_utils.py * output attentions * style * Update masking_utils.py * Update executorch.py * Add doicstring * Add license and put mask visualizer at the end * Update test_modeling_common.py * fix broken test * Update test_modeling_gemma.py * Update test_modeling_gemma2.py * Use fullgraph=False with FA2 * Update utils.py * change name * Update masking_utils.py * improve doc * change name * Update modeling_attn_mask_utils.py * more explicit logic based on model's property * pattern in config * extend * fixes * make it better * generalize to other test models * fix * Update masking_utils.py * fix * do not check mask equivalence if layer types are different * executorch * Update modeling_gemma2.py * Update masking_utils.py * use layer_idx instead * adjust * Update masking_utils.py * test * fix imports * Update modeling_gemma2.py * other test models * Update modeling_llama4.py * Update masking_utils.py * improve * simplify * Update masking_utils.py * typos * typo * fix * Update masking_utils.py * default DynamicCache * remove default cache * simplify * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * simplify * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * export * Update executorch.py * Update executorch.py * Update flex_attention.py * Update executorch.py * upstream to modular gemma 1 & 2 * Update modular_mistral.py * switch names * use dict * put it in the Layer directly * update copy model source for mask functions * apply so many modular (hopefully 1 shot) * use explicite dicts for make style happy * protect import * check docstring * better default in hybrid caches * qwens * Update modular_qwen2.py * simplify core logic! * Update executorch.py * qwen3 moe * Update masking_utils.py * Update masking_utils.py * simplify a lot sdpa causal skip * Update masking_utils.py * post-rebase * gemma3 finally * style * check it before * gemma3 * More general with newer torch * align gemma3 * Update utils.py * Update utils.py * Update masking_utils.py * Update test_modeling_common.py * Update flex_attention.py * Update flex_attention.py * Update flex_attention.py * test * executorch * Update test_modeling_common.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update executorch.py * Update test_modeling_common.py * fix copies * device * sdpa can be used without mask -> pass the torchscript tests in this case * Use enum for check * revert enum and add check instead * remove broken test * cohere2 * some doc & reorganize the Interface * Update tensor_parallel.py * Update tensor_parallel.py * doc and dummy * Update test_modeling_paligemma2.py * Update modeling_falcon_h1.py * Update masking_utils.py * executorch patch * style * CIs * use register in executorch * final comments! --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
parent
f8630c778c
commit
163138a911
@ -126,3 +126,43 @@ would expect from a usual Python dictionary:
|
||||
# You can also globally `register` a new function directly on it
|
||||
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
|
||||
```
|
||||
|
||||
## Attention Mask Interface
|
||||
|
||||
Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
|
||||
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
|
||||
the `AttentionInterface`:
|
||||
|
||||
```python
|
||||
from transformers import AttentionMaskInterface
|
||||
from transformers.masking_utils import sdpa_mask
|
||||
import torch
|
||||
|
||||
def my_new_sdpa_mask(*args, **kwargs):
|
||||
print("I just entered the attention mask computation")
|
||||
return sdpa_mask(*args, **kwargs)
|
||||
|
||||
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
|
||||
```
|
||||
|
||||
The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
|
||||
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
|
||||
and `attention_mask=None` will be passed along to the Attention layers.
|
||||
|
||||
The default signature of the attention mask functions is the following:
|
||||
|
||||
```python
|
||||
def custom_attention_mask(
|
||||
batch_size: int, # required arg
|
||||
cache_position: torch.Tensor, # required arg
|
||||
kv_length: int, # required arg
|
||||
kv_offset: int = 0, # required arg
|
||||
mask_function: Callable = causal_mask_function, # required arg
|
||||
attention_mask: Optional[torch.Tensor] = None, # required arg
|
||||
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
|
||||
) -> Optional[torch.Tensor]:
|
||||
```
|
||||
|
||||
It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.
|
||||
|
||||
If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).
|
@ -29,6 +29,11 @@ Most of those are only useful if you are studying the code of the models in the
|
||||
[[autodoc]] AttentionInterface
|
||||
- register
|
||||
|
||||
## Attention Mask Functions
|
||||
|
||||
[[autodoc]] AttentionMaskInterface
|
||||
- register
|
||||
|
||||
## Rotary Position Embedding Functions
|
||||
|
||||
[[autodoc]] dynamic_rope_update
|
||||
|
@ -445,6 +445,7 @@ else:
|
||||
_import_structure["modeling_outputs"] = []
|
||||
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
|
||||
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
|
||||
_import_structure["masking_utils"] = ["AttentionMaskInterface"]
|
||||
_import_structure["optimization"] = [
|
||||
"Adafactor",
|
||||
"get_constant_schedule",
|
||||
@ -914,6 +915,7 @@ if TYPE_CHECKING:
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
from .masking_utils import AttentionMaskInterface
|
||||
from .model_debugging_utils import (
|
||||
model_addition_debugger_context,
|
||||
)
|
||||
|
@ -196,6 +196,18 @@ class Cache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
query_length = cache_position.shape[0]
|
||||
past_seen_tokens = self.get_seq_length()
|
||||
kv_length = query_length + past_seen_tokens
|
||||
return kv_length, 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
@ -1084,8 +1096,6 @@ class SinkCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
is_sliding = True
|
||||
|
||||
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
||||
super().__init__()
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
@ -1390,6 +1400,16 @@ class StaticCache(Cache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
kv_length = self.get_max_cache_shape()
|
||||
return kv_length, 0
|
||||
|
||||
|
||||
class SlidingWindowCache(StaticCache):
|
||||
"""
|
||||
@ -1446,7 +1466,6 @@ class SlidingWindowCache(StaticCache):
|
||||
```
|
||||
"""
|
||||
|
||||
is_sliding = True
|
||||
is_compileable = True
|
||||
|
||||
def __init__(
|
||||
@ -1465,6 +1484,7 @@ class SlidingWindowCache(StaticCache):
|
||||
"config and it's not set to None."
|
||||
)
|
||||
max_cache_len = min(config.sliding_window, max_cache_len)
|
||||
self.sliding_window = config.sliding_window
|
||||
super().__init__(
|
||||
config=config,
|
||||
max_batch_size=max_batch_size,
|
||||
@ -1509,6 +1529,21 @@ class SlidingWindowCache(StaticCache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
query_length = cache_position.shape[0]
|
||||
first_cache_position = cache_position[0]
|
||||
# torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor
|
||||
kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
|
||||
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
|
||||
kv_length = max(query_length, self.get_max_cache_shape())
|
||||
return kv_length, kv_offset
|
||||
|
||||
|
||||
class EncoderDecoderCache(Cache):
|
||||
"""
|
||||
@ -1761,12 +1796,17 @@ class HybridCache(Cache):
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
|
||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||
self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
|
||||
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
||||
if hasattr(config, "layer_types"):
|
||||
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
|
||||
else:
|
||||
self.is_sliding = [False] * config.num_hidden_layers
|
||||
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
|
||||
self.sliding_window = min(config.sliding_window, max_cache_len)
|
||||
device = torch.device(device) if device is not None else None
|
||||
for i in range(config.num_hidden_layers):
|
||||
if layer_device_map is not None:
|
||||
@ -1775,7 +1815,7 @@ class HybridCache(Cache):
|
||||
layer_device = device
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape
|
||||
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
@ -1796,7 +1836,7 @@ class HybridCache(Cache):
|
||||
if cache_position is None:
|
||||
raise ValueError("`cache_position` must be provided for HybridCache.")
|
||||
|
||||
is_sliding_layer = self.is_sliding_list[layer_idx]
|
||||
is_sliding_layer = self.is_sliding[layer_idx]
|
||||
|
||||
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
||||
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
||||
@ -1843,6 +1883,26 @@ class HybridCache(Cache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
if self.is_sliding[layer_idx]:
|
||||
query_length = cache_position.shape[0]
|
||||
first_cache_position = cache_position[0]
|
||||
|
||||
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
|
||||
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
|
||||
local_mask_kv_length = max(query_length, self.sliding_window)
|
||||
return local_mask_kv_length, local_mask_kv_offset
|
||||
|
||||
full_mask_kv_offset = 0
|
||||
full_mask_kv_length = self.get_max_cache_shape()
|
||||
return full_mask_kv_length, full_mask_kv_offset
|
||||
|
||||
|
||||
class HybridChunkedCache(Cache):
|
||||
"""
|
||||
@ -1912,11 +1972,11 @@ class HybridChunkedCache(Cache):
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self._dtype = dtype
|
||||
|
||||
if hasattr(config.get_text_config(), "no_rope_layers"):
|
||||
self.is_sliding = config.no_rope_layers
|
||||
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
||||
if hasattr(config, "layer_types"):
|
||||
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
|
||||
else:
|
||||
layer_switch = getattr(config, "sliding_window_pattern", 2)
|
||||
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
|
||||
self.is_sliding = [False] * config.num_hidden_layers
|
||||
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
@ -1999,11 +2059,7 @@ class HybridChunkedCache(Cache):
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
if self.is_sliding[layer_idx]:
|
||||
update_fn = self._sliding_update
|
||||
else:
|
||||
update_fn = self._static_update
|
||||
|
||||
update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update
|
||||
return update_fn(
|
||||
cache_position,
|
||||
layer_idx,
|
||||
@ -2038,6 +2094,37 @@ class HybridChunkedCache(Cache):
|
||||
self.value_cache[layer_idx].zero_()
|
||||
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
if self.is_sliding[layer_idx]:
|
||||
query_length = cache_position.shape[0]
|
||||
first_cache_position = cache_position[0]
|
||||
|
||||
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
|
||||
# This is the true general case for any Cache using local attention (sliding or chunked)
|
||||
if first_cache_position >= self.sliding_window:
|
||||
# Here the Cache is already full
|
||||
local_mask_kv_length = self.sliding_window + query_length - 1
|
||||
elif (
|
||||
first_cache_position < self.sliding_window
|
||||
and first_cache_position + query_length > self.sliding_window
|
||||
):
|
||||
# Here the Cache becomes full with the new input
|
||||
local_mask_kv_length = first_cache_position + query_length
|
||||
else:
|
||||
# Here the Cache is still smaller than the local size, but we return the local size as it's static
|
||||
local_mask_kv_length = self.sliding_window
|
||||
return local_mask_kv_length, local_mask_kv_offset
|
||||
|
||||
full_mask_kv_offset = 0
|
||||
full_mask_kv_length = self.get_max_cache_shape()
|
||||
return full_mask_kv_length, full_mask_kv_offset
|
||||
|
||||
|
||||
class OffloadedHybridCache(HybridChunkedCache):
|
||||
def __init__(
|
||||
|
@ -1209,3 +1209,16 @@ if PretrainedConfig.push_to_hub.__doc__ is not None:
|
||||
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
||||
object="config", object_class="AutoConfig", object_files="configuration file"
|
||||
)
|
||||
|
||||
|
||||
ALLOWED_LAYER_TYPES = (
|
||||
"full_attention",
|
||||
"sliding_attention",
|
||||
"chunked_attention",
|
||||
)
|
||||
|
||||
|
||||
def layer_type_validation(layer_types: list[str]):
|
||||
"""Check that each entry in `layer_types` are allowed."""
|
||||
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
|
||||
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
|
||||
|
@ -46,6 +46,7 @@ from ..dynamic_module_utils import (
|
||||
)
|
||||
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ..integrations.fsdp import is_fsdp_managed_module
|
||||
from ..masking_utils import create_masks_for_generate
|
||||
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from ..tokenization_utils import ExtensionsTrie
|
||||
@ -74,6 +75,7 @@ from .candidate_generator import (
|
||||
from .configuration_utils import (
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING,
|
||||
QUANT_BACKEND_CLASSES_MAPPING,
|
||||
CompileConfig,
|
||||
GenerationConfig,
|
||||
GenerationMode,
|
||||
)
|
||||
@ -649,12 +651,22 @@ class GenerationMixin:
|
||||
causal_mask_creation_function = getattr(
|
||||
decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
||||
)
|
||||
|
||||
# If it's not defined, it means the model uses the new general mask API
|
||||
if causal_mask_creation_function is None: # can't be found
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
||||
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
||||
"writing code, see Llama for an example implementation. If you're a user, please report this "
|
||||
"issue on GitHub."
|
||||
output_attentions = kwargs.get("output_attentions", False)
|
||||
token_type_ids = getattr(model_input, "token_type_ids", None)
|
||||
# Some models may overwrite the general one
|
||||
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
|
||||
attention_mask = causal_mask_creation_function(
|
||||
config=self.config,
|
||||
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
|
||||
input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
else:
|
||||
attention_mask = causal_mask_creation_function(
|
||||
@ -3533,6 +3545,19 @@ class GenerationMixin:
|
||||
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
||||
if compile_forward:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
||||
if self.config._attn_implementation == "flash_attention_2" and getattr(
|
||||
model_kwargs.get("past_key_values"), "is_compileable", False
|
||||
):
|
||||
if generation_config.compile_config is None:
|
||||
generation_config.compile_config = CompileConfig(fullgraph=False)
|
||||
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
|
||||
elif generation_config.compile_config.fullgraph:
|
||||
logger.warning_once(
|
||||
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
|
||||
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
|
||||
)
|
||||
generation_config.compile_config.fullgraph = False
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
if generation_config.prefill_chunk_size is not None:
|
||||
|
@ -11,18 +11,21 @@
|
||||
# specific language governing permissions and limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
|
||||
from ..utils.import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import HybridCache, PreTrainedModel, StaticCache
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
||||
from ..cache_utils import DynamicCache, HybridCache, StaticCache
|
||||
from ..generation.configuration_utils import GenerationConfig
|
||||
from ..masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_ignore_causal_mask_sdpa,
|
||||
_is_torch_greater_or_equal_than_2_5,
|
||||
prepare_padding_mask,
|
||||
)
|
||||
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
||||
|
||||
|
||||
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
@ -54,19 +57,13 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
|
||||
raise ValueError("The model must have caching enabled to be performant.")
|
||||
|
||||
if not hasattr(model.config, "cache_implementation"):
|
||||
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
|
||||
# be used by default, so export will use `StaticCache` by default.
|
||||
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.")
|
||||
if not hasattr(model.config, "layer_types"):
|
||||
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
|
||||
# export will use `StaticCache` by default.
|
||||
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
|
||||
self.model = TorchExportableModuleWithStaticCache(model)
|
||||
else:
|
||||
if model.config.cache_implementation == "hybrid":
|
||||
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported cache implementation: {model.config.cache_implementation}. "
|
||||
"Please use `hybrid` or `static`."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -105,16 +102,23 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
strict(`Optional[bool]`):
|
||||
Flag to instruct `torch.export` to use `torchdynamo`.
|
||||
"""
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
self.model.model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
|
||||
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||
|
||||
return torch.export.export(
|
||||
with patch_mask_interface():
|
||||
exported_program = torch.export.export(
|
||||
self.model,
|
||||
args=(example_input_ids, example_cache_position),
|
||||
kwargs={},
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=strict if strict is not None else True,
|
||||
)
|
||||
return exported_program
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
@ -281,17 +285,6 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
|
||||
|
||||
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
||||
if self.is_causal:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones(
|
||||
self.static_cache.max_cache_len,
|
||||
self.static_cache.max_cache_len,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
)
|
||||
self.register_buffer("mask", causal_mask, persistent=False)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
|
||||
"""
|
||||
Forward pass of the module, which is compatible with the ExecuTorch runtime.
|
||||
@ -314,13 +307,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
|
||||
"""
|
||||
_, seqlen = input_ids.shape
|
||||
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
past_key_values = self.static_cache
|
||||
|
||||
outs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attn_mask,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
@ -445,18 +437,15 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Logits output from the model.
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
# Generate position_ids from cache_position
|
||||
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Create attention mask (always ones for token-by-token generation)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
|
||||
|
||||
# Forward pass with the model
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
past_key_values=self.cache,
|
||||
use_cache=True,
|
||||
@ -467,6 +456,24 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||
return outputs.logits
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_mask_interface():
|
||||
"""
|
||||
Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail
|
||||
with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles:
|
||||
Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`.
|
||||
Note that this seem to be an issue only for python<3.11.
|
||||
"""
|
||||
import transformers
|
||||
|
||||
original = transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
||||
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = ALL_MASK_ATTENTION_FUNCTIONS._global_mapping
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = original
|
||||
|
||||
|
||||
def convert_and_export_with_cache(
|
||||
model: PreTrainedModel,
|
||||
example_input_ids: Optional[torch.Tensor] = None,
|
||||
@ -493,6 +500,11 @@ def convert_and_export_with_cache(
|
||||
|
||||
import torch.export._trace
|
||||
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
with torch.no_grad():
|
||||
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
|
||||
example_input_ids = (
|
||||
@ -503,6 +515,7 @@ def convert_and_export_with_cache(
|
||||
)
|
||||
|
||||
if is_torch_greater_or_equal("2.6.0"):
|
||||
with patch_mask_interface():
|
||||
exported_program = torch.export.export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids, example_cache_position),
|
||||
@ -521,6 +534,7 @@ def convert_and_export_with_cache(
|
||||
#
|
||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
|
||||
with patch_mask_interface():
|
||||
exported_program = torch.export._trace._export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids,),
|
||||
@ -620,6 +634,7 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
|
||||
|
||||
# Export the encoder
|
||||
with torch.no_grad():
|
||||
with patch_mask_interface():
|
||||
exported_encoder = torch.export.export(
|
||||
wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
|
||||
)
|
||||
@ -642,6 +657,7 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
|
||||
|
||||
# Export the decoder
|
||||
with torch.no_grad():
|
||||
with patch_mask_interface():
|
||||
exported_decoder = torch.export.export(
|
||||
wrapped_decoder,
|
||||
(decoder_input_ids, encoder_hidden_states, cache_position),
|
||||
@ -706,3 +722,131 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
|
||||
break
|
||||
|
||||
return generated_ids
|
||||
|
||||
|
||||
def export_with_dynamic_cache(
|
||||
model: PreTrainedModel,
|
||||
example_input_ids: Optional[torch.Tensor] = None,
|
||||
example_attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Export a model with DynamicCache using `torch.export`, ensuring the exported model is compatible with `ExecuTorch`.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The pretrained model to be exported.
|
||||
example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
|
||||
example_attention_mask (`Optional[torch.Tensor]`): Example attention mask used by `torch.export`.
|
||||
|
||||
Returns:
|
||||
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
|
||||
"""
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
raise ImportError("torch >= 2.3 is required.")
|
||||
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
with torch.no_grad():
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
(),
|
||||
{
|
||||
"input_ids": example_input_ids,
|
||||
"attention_mask": example_attention_mask,
|
||||
"past_key_values": DynamicCache(),
|
||||
"use_cache": True,
|
||||
},
|
||||
strict=False,
|
||||
)
|
||||
return exported_program
|
||||
|
||||
|
||||
def sdpa_mask_without_vmap(
|
||||
batch_size: int,
|
||||
cache_position: torch.Tensor,
|
||||
kv_length: int,
|
||||
kv_offset: int = 0,
|
||||
mask_function: Optional[Callable] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
local_size: Optional[int] = None,
|
||||
allow_is_causal_skip: bool = True,
|
||||
allow_torch_fix: bool = True,
|
||||
**kwargs,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
|
||||
the element should take part in the attention computation, and False that it should not.
|
||||
|
||||
This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
The batch size of the input sequence.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
kv_length (`int`):
|
||||
The size that the key and value states will have during the attention computation.
|
||||
kv_offset (`int`, optional):
|
||||
An optional offset to indicate at which first position the key and values states will refer to.
|
||||
mask_function (`Callable`):
|
||||
The mask factory function describing the mask pattern.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
||||
local_size (`int`, optional):
|
||||
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
|
||||
to try to skip mask creation if possible.
|
||||
allow_is_causal_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
|
||||
`torch.sdpa` instead. Default to `True`.
|
||||
allow_torch_fix (`bool`, optional):
|
||||
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
|
||||
versions. We need an arg to skip it when using eager. By default `True`.
|
||||
|
||||
"""
|
||||
|
||||
q_length = cache_position.shape[0]
|
||||
# Potentially pad the 2D mask, and slice it correctly
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
||||
|
||||
# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
|
||||
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
|
||||
return None
|
||||
|
||||
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
||||
kv_arange += kv_offset
|
||||
reshaped_cache_position = cache_position.view(-1, 1)
|
||||
|
||||
# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
|
||||
# the config through kwargs anyway, so it allows to rely on it
|
||||
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
|
||||
# but this is more efficient
|
||||
sliding_window = getattr(kwargs["config"], "sliding_window", None)
|
||||
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
|
||||
|
||||
if sliding_window is not None and chunk_size is not None:
|
||||
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
|
||||
|
||||
# Simplest and most efficient way to obtain a causal mask
|
||||
causal_mask = kv_arange <= reshaped_cache_position
|
||||
# If using sliding window, add the sliding mask
|
||||
if sliding_window is not None:
|
||||
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
|
||||
causal_mask *= sliding_mask_overlay
|
||||
# If using chunk attention, add the chunked mask
|
||||
elif chunk_size is not None:
|
||||
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
|
||||
causal_mask *= chunked_mask_overlay
|
||||
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
|
||||
if padding_mask is not None:
|
||||
causal_mask = causal_mask * padding_mask[:, None, None, :]
|
||||
|
||||
# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
|
||||
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
|
||||
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
|
||||
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
|
||||
return causal_mask
|
||||
|
@ -32,14 +32,12 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_torch_flex_attn_available
|
||||
from ..utils.import_utils import _torch_version
|
||||
from ..utils.import_utils import _torch_version, is_torchdynamo_compiling
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
||||
from torch.nn.attention.flex_attention import (
|
||||
create_block_mask as create_block_causal_mask_flex,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import create_block_mask as create_block_causal_mask_flex
|
||||
|
||||
|
||||
class WrappedFlexAttention:
|
||||
@ -79,6 +77,24 @@ class WrappedFlexAttention:
|
||||
return self._compiled_flex_attention
|
||||
|
||||
|
||||
def compile_friendly_flex_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
training=False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
|
||||
# Do not use compiled version if already compiling forward (it raises issues)
|
||||
flex_attention_compiled = WrappedFlexAttention(training)() if not is_torchdynamo_compiling() else flex_attention
|
||||
return flex_attention_compiled(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
Offset = Union[torch.Tensor, int]
|
||||
|
||||
|
||||
@ -90,6 +106,10 @@ def make_flex_block_causal_mask(
|
||||
offsets: Optional[Tuple[Offset, Offset]] = None,
|
||||
) -> "BlockMask":
|
||||
"""
|
||||
IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`,
|
||||
and will be removed in a future version without warnings. New code should not use it. It is only kept here
|
||||
for BC for now, while models using it are being patched accordingly.
|
||||
|
||||
Create a block causal document mask for a batch of sequences, both packed and unpacked.
|
||||
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
||||
The resultant BlockMask is a compressed representation of the full block causal
|
||||
@ -133,7 +153,6 @@ def make_flex_block_causal_mask(
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
@ -174,24 +193,6 @@ def make_flex_block_causal_mask(
|
||||
)
|
||||
|
||||
|
||||
@torch.compiler.disable(recursive=False)
|
||||
def compile_friendly_flex_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
training=False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
|
||||
flex_attention_compiled = WrappedFlexAttention(training)()
|
||||
return flex_attention_compiled(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
|
@ -16,15 +16,15 @@ from __future__ import annotations
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
from collections.abc import MutableMapping
|
||||
from functools import partial, reduce
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_greater_or_equal, logging
|
||||
from ..utils.generic import GeneralInterface
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
||||
@ -720,20 +720,11 @@ class SequenceParallel(TensorParallelLayer):
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
|
||||
class ParallelInterface(MutableMapping):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
|
||||
"""
|
||||
|
||||
class ParallelInterface(GeneralInterface):
|
||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||
# a new instance is created (in order to locally override a given function)
|
||||
|
||||
def __init__(self):
|
||||
self._local_mapping = {}
|
||||
|
||||
ParallelInterface._global_mapping = {
|
||||
# a new instance is created (in order to locally override a given entry)
|
||||
_global_mapping = (
|
||||
{
|
||||
"colwise": ColwiseParallel(),
|
||||
"rowwise": RowwiseParallel(),
|
||||
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
|
||||
@ -746,41 +737,12 @@ class ParallelInterface(MutableMapping):
|
||||
"sequence_parallel": SequenceParallel(),
|
||||
"replicate": ReplicateParallel(),
|
||||
}
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First check if instance has a local override
|
||||
if key in self._local_mapping:
|
||||
return self._local_mapping[key]
|
||||
return self._global_mapping[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow local update of the default functions without impacting other instances
|
||||
self._local_mapping.update({key: value})
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._local_mapping[key]
|
||||
|
||||
def __iter__(self):
|
||||
# Ensure we use all keys, with the overwritten ones on top
|
||||
return iter({**self._global_mapping, **self._local_mapping})
|
||||
|
||||
def __len__(self):
|
||||
return len(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
|
||||
@classmethod
|
||||
def register(cls, key: str, value: Callable):
|
||||
cls._global_mapping.update({key: value})
|
||||
|
||||
def valid_keys(self) -> List[str]:
|
||||
return list(self.keys())
|
||||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
|
||||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
|
||||
else:
|
||||
ALL_PARALLEL_STYLES = None
|
||||
|
||||
|
||||
def convert_local_tensor_to_dtensor(
|
||||
|
1129
src/transformers/masking_utils.py
Normal file
1129
src/transformers/masking_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
|
||||
`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
|
||||
and will be removed in the future.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
|
@ -148,6 +148,12 @@ def _upad_input(
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
|
||||
# With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
|
||||
# It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
|
||||
if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
|
||||
key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
|
||||
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
||||
|
@ -27,7 +27,6 @@ import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@ -124,6 +123,7 @@ from .utils import (
|
||||
replace_return_docstrings,
|
||||
strtobool,
|
||||
)
|
||||
from .utils.generic import GeneralInterface
|
||||
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
|
||||
from .utils.import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
@ -6076,7 +6076,7 @@ def get_disk_only_shard_files(device_map, weight_map):
|
||||
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
|
||||
|
||||
|
||||
class AttentionInterface(MutableMapping):
|
||||
class AttentionInterface(GeneralInterface):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
@ -6091,36 +6091,6 @@ class AttentionInterface(MutableMapping):
|
||||
"sdpa": sdpa_attention_forward,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._local_mapping = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First check if instance has a local override
|
||||
if key in self._local_mapping:
|
||||
return self._local_mapping[key]
|
||||
return self._global_mapping[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow local update of the default functions without impacting other instances
|
||||
self._local_mapping.update({key: value})
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._local_mapping[key]
|
||||
|
||||
def __iter__(self):
|
||||
# Ensure we use all keys, with the overwritten ones on top
|
||||
return iter({**self._global_mapping, **self._local_mapping})
|
||||
|
||||
def __len__(self):
|
||||
return len(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
|
||||
@classmethod
|
||||
def register(cls, key: str, value: Callable):
|
||||
cls._global_mapping.update({key: value})
|
||||
|
||||
def valid_keys(self) -> List[str]:
|
||||
return list(self.keys())
|
||||
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
|
||||
|
@ -25,20 +25,14 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
)
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.import_utils import is_torch_available
|
||||
from ..auto import AutoModel
|
||||
from .configuration_aria import AriaConfig, AriaTextConfig
|
||||
@ -49,12 +43,6 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -818,8 +806,13 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -865,129 +858,6 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
@ -1521,61 +1391,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AriaForConditionalGeneration",
|
||||
|
@ -542,60 +542,5 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
|
||||
|
@ -757,7 +757,7 @@ class BartPreTrainedModel(PreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -827,7 +827,7 @@ class BartPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -1602,7 +1602,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1672,7 +1672,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -482,7 +482,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -552,7 +552,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -27,23 +27,17 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_bitnet import BitNetConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -425,8 +419,13 @@ class BitNetModel(BitNetPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -472,129 +471,6 @@ class BitNetModel(BitNetPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -493,7 +493,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -563,7 +563,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -482,7 +482,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -552,7 +552,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -658,7 +658,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -728,7 +728,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -1057,7 +1057,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1127,7 +1127,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -491,7 +491,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -561,7 +561,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -35,23 +35,17 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_cohere import CohereConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -462,8 +456,13 @@ class CohereModel(CoherePreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -509,129 +508,6 @@ class CohereModel(CoherePreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -19,7 +19,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
@ -119,9 +119,8 @@ class Cohere2Config(PretrainedConfig):
|
||||
The dropout ratio for the attention probabilities.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Size of the sliding window attention context.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 4):
|
||||
Pattern for the sliding window attention.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
|
||||
```python
|
||||
>>> from transformers import Cohere2Model, Cohere2Config
|
||||
@ -177,8 +176,7 @@ class Cohere2Config(PretrainedConfig):
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
sliding_window=4096,
|
||||
sliding_window_pattern=4,
|
||||
cache_implementation="hybrid",
|
||||
layer_types=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -203,10 +201,9 @@ class Cohere2Config(PretrainedConfig):
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.sliding_window = sliding_window
|
||||
self.sliding_window_pattern = sliding_window_pattern
|
||||
self.layer_types = layer_types
|
||||
# Need to specify head_dim in the config so it can be used in the attention forward functions
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.cache_implementation = cache_implementation
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
rope_config_validation(self)
|
||||
@ -219,5 +216,14 @@ class Cohere2Config(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.layer_types is None:
|
||||
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
||||
sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
__all__ = ["Cohere2Config"]
|
||||
|
@ -25,25 +25,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_cohere2 import Cohere2Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -186,6 +181,7 @@ class Cohere2Attention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
@ -199,9 +195,6 @@ class Cohere2Attention(nn.Module):
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
self.sliding_window = (
|
||||
config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -224,19 +217,9 @@ class Cohere2Attention(nn.Module):
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -284,12 +267,10 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Cohere2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Cohere2Attention(config, layer_idx)
|
||||
self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Cohere2MLP(config)
|
||||
self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
|
||||
self.config = config
|
||||
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
|
||||
self.sliding_window = config.sliding_window
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
@ -322,34 +303,6 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer):
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
"""
|
||||
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -440,7 +393,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -467,15 +420,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -485,9 +430,22 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -505,7 +463,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -531,100 +489,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: HybridCache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Flash Attention currently doesn't support static cache but Cohere2 work only with static cache.
|
||||
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
|
||||
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
|
||||
# as it doesn't cause dynamic control issues.
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return attention_mask
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if isinstance(past_key_values, (HybridCache, StaticCache)):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
@ -635,7 +499,7 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config: Cohere2Config):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Cohere2Model(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -740,88 +604,5 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten: has a special cache type, `HybridCache`
|
||||
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or cache_position[-1] >= input_ids.shape[1] # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
|
||||
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
|
||||
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
|
||||
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
|
||||
# which retriggers a capture.
|
||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
and not self.config._attn_implementation == "flash_attention_2"
|
||||
):
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_cache_shape(),
|
||||
dtype=self.lm_head.weight.dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if logits_to_keep is not None:
|
||||
model_inputs["logits_to_keep"] = logits_to_keep
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
|
||||
__all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]
|
||||
|
@ -19,8 +19,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
@ -140,9 +141,8 @@ class Cohere2Config(PretrainedConfig):
|
||||
The dropout ratio for the attention probabilities.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Size of the sliding window attention context.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 4):
|
||||
Pattern for the sliding window attention.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
|
||||
```python
|
||||
>>> from transformers import Cohere2Model, Cohere2Config
|
||||
@ -198,8 +198,7 @@ class Cohere2Config(PretrainedConfig):
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
sliding_window=4096,
|
||||
sliding_window_pattern=4,
|
||||
cache_implementation="hybrid",
|
||||
layer_types=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -224,10 +223,9 @@ class Cohere2Config(PretrainedConfig):
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.sliding_window = sliding_window
|
||||
self.sliding_window_pattern = sliding_window_pattern
|
||||
self.layer_types = layer_types
|
||||
# Need to specify head_dim in the config so it can be used in the attention forward functions
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.cache_implementation = cache_implementation
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
rope_config_validation(self)
|
||||
@ -240,6 +238,15 @@ class Cohere2Config(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.layer_types is None:
|
||||
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
||||
sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
class Cohere2RotaryEmbedding(CohereRotaryEmbedding):
|
||||
pass
|
||||
@ -261,6 +268,7 @@ class Cohere2Attention(CohereAttention, nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
@ -274,9 +282,6 @@ class Cohere2Attention(CohereAttention, nn.Module):
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
self.sliding_window = (
|
||||
config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -299,19 +304,9 @@ class Cohere2Attention(CohereAttention, nn.Module):
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"sliding_window": self.sliding_window,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -342,10 +337,7 @@ class Cohere2Attention(CohereAttention, nn.Module):
|
||||
class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
def __init__(self, config: Cohere2Config, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
self.self_attn = Cohere2Attention(config, layer_idx)
|
||||
self.config = config
|
||||
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
|
||||
self.sliding_window = config.sliding_window
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
@ -378,34 +370,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
"""
|
||||
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -451,7 +415,7 @@ class Cohere2Model(Gemma2Model):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -478,15 +442,7 @@ class Cohere2Model(Gemma2Model):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -496,9 +452,22 @@ class Cohere2Model(Gemma2Model):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -516,7 +485,7 @@ class Cohere2Model(Gemma2Model):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -544,91 +513,7 @@ class Cohere2Model(Gemma2Model):
|
||||
|
||||
|
||||
class Cohere2ForCausalLM(CohereForCausalLM):
|
||||
def __init__(self, config: Cohere2Config):
|
||||
super().__init__(config)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten: has a special cache type, `HybridCache`
|
||||
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or cache_position[-1] >= input_ids.shape[1] # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
|
||||
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
|
||||
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
|
||||
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
|
||||
# which retriggers a capture.
|
||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
and not self.config._attn_implementation == "flash_attention_2"
|
||||
):
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_cache_shape(),
|
||||
dtype=self.lm_head.weight.dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if logits_to_keep is not None:
|
||||
model_inputs["logits_to_keep"] = logits_to_keep
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]
|
||||
|
@ -29,25 +29,19 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ..auto import AutoModel
|
||||
from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
|
||||
from .generation_csm import CsmGenerationMixin
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -516,8 +510,13 @@ class CsmDepthDecoderModel(CsmPreTrainedModel):
|
||||
|
||||
inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -564,129 +563,6 @@ class CsmDepthDecoderModel(CsmPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class CsmCodebooksHead(nn.Module):
|
||||
def __init__(self, hidden_size, num_codebooks, vocab_size):
|
||||
@ -946,8 +822,13 @@ class CsmBackboneModel(CsmPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -993,129 +874,6 @@ class CsmBackboneModel(CsmPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -1477,61 +1235,6 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CsmPreTrainedModel",
|
||||
|
@ -21,6 +21,7 @@ import torch.nn as nn
|
||||
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
@ -240,8 +241,13 @@ class CsmDepthDecoderModel(LlamaModel):
|
||||
|
||||
inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -835,61 +841,6 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CsmPreTrainedModel",
|
||||
|
@ -1001,7 +1001,7 @@ class DbrxModel(DbrxPreTrainedModel):
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1071,7 +1071,7 @@ class DbrxModel(DbrxPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -15,23 +15,17 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_deepseek_v3 import DeepseekV3Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -608,8 +602,13 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -655,129 +654,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -31,7 +31,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import (
|
||||
FlashAttentionKwargs,
|
||||
_flash_attention_forward,
|
||||
@ -48,16 +48,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_diffllama import DiffLlamaConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -708,8 +702,13 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -755,129 +754,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -32,23 +32,17 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -1279,8 +1273,13 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -1326,129 +1325,6 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
@ -1857,62 +1733,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Emu3ForConditionalGeneration",
|
||||
|
@ -1212,62 +1212,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Emu3ForConditionalGeneration",
|
||||
|
@ -976,7 +976,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -45,12 +45,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils import auto_docstring, is_torchdynamo_compiling, logging, replace_return_docstrings
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
||||
from .configuration_falcon_h1 import FalconH1Config
|
||||
|
@ -19,7 +19,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -27,7 +27,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -39,16 +39,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -384,7 +378,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -422,8 +416,13 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
@ -475,129 +474,6 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -22,6 +22,7 @@ from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
@ -371,7 +372,7 @@ class GemmaModel(LlamaModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -409,8 +410,13 @@ class GemmaModel(LlamaModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
|
@ -19,7 +19,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
|
||||
|
||||
class Gemma2Config(PretrainedConfig):
|
||||
@ -78,12 +78,16 @@ class Gemma2Config(PretrainedConfig):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
||||
scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0):
|
||||
scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
|
||||
scaling factor when applying tanh softcapping on the attention scores.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma2Model, Gemma2Config
|
||||
@ -135,9 +139,9 @@ class Gemma2Config(PretrainedConfig):
|
||||
attention_dropout=0.0,
|
||||
query_pre_attn_scalar=256,
|
||||
sliding_window=4096,
|
||||
layer_types=None,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
cache_implementation="hybrid",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -166,7 +170,13 @@ class Gemma2Config(PretrainedConfig):
|
||||
self.sliding_window = sliding_window
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.cache_implementation = cache_implementation
|
||||
self.layer_types = layer_types
|
||||
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
__all__ = ["Gemma2Config"]
|
||||
|
@ -26,8 +26,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -38,17 +39,11 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_gemma2 import Gemma2Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -195,7 +190,7 @@ class Gemma2Attention(nn.Module):
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
self.attn_logit_softcapping = self.config.attn_logit_softcapping
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -218,19 +213,9 @@ class Gemma2Attention(nn.Module):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -264,7 +249,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.config = config
|
||||
self.is_sliding = not bool(layer_idx % 2)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma2MLP(config)
|
||||
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -272,7 +257,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
|
||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
@ -287,33 +271,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(attention_mask.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -441,7 +398,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -468,15 +425,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -487,9 +436,22 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
@ -516,7 +478,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -527,7 +489,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
@ -553,100 +515,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: HybridCache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
|
||||
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
|
||||
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
|
||||
# as it doesn't cause dynamic control issues.
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return attention_mask
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if isinstance(past_key_values, (HybridCache, StaticCache)):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
@ -688,7 +556,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -765,60 +633,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten: has a special cache type, `HybridCache`
|
||||
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
position_ids=position_ids,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if logits_to_keep is None:
|
||||
_ = model_inputs.pop("logits_to_keep", None)
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
and not self.config._attn_implementation == "flash_attention_2"
|
||||
):
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_cache_shape(),
|
||||
dtype=self.lm_head.weight.dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
model_inputs["attention_mask"] = attention_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
|
@ -21,13 +21,14 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import is_torch_flex_attn_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
@ -42,12 +43,6 @@ from ..gemma.modeling_gemma import (
|
||||
)
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -107,12 +102,16 @@ class Gemma2Config(PretrainedConfig):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 256): scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
||||
scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0):
|
||||
scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
|
||||
scaling factor when applying tanh softcapping on the attention scores.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma2Model, Gemma2Config
|
||||
@ -164,9 +163,9 @@ class Gemma2Config(PretrainedConfig):
|
||||
attention_dropout=0.0,
|
||||
query_pre_attn_scalar=256,
|
||||
sliding_window=4096,
|
||||
layer_types=None,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
cache_implementation="hybrid",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -195,7 +194,13 @@ class Gemma2Config(PretrainedConfig):
|
||||
self.sliding_window = sliding_window
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.cache_implementation = cache_implementation
|
||||
self.layer_types = layer_types
|
||||
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
class Gemma2RMSNorm(GemmaRMSNorm):
|
||||
@ -250,7 +255,7 @@ class Gemma2Attention(GemmaAttention):
|
||||
self.attention_dropout = self.config.attention_dropout
|
||||
self.is_causal = True
|
||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -273,19 +278,9 @@ class Gemma2Attention(GemmaAttention):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -319,7 +314,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.config = config
|
||||
self.is_sliding = not bool(layer_idx % 2)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma2MLP(config)
|
||||
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -327,7 +322,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
|
||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
@ -342,33 +336,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(attention_mask.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -414,7 +381,7 @@ class Gemma2Model(GemmaModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -441,15 +408,7 @@ class Gemma2Model(GemmaModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -460,9 +419,22 @@ class Gemma2Model(GemmaModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
@ -489,7 +461,7 @@ class Gemma2Model(GemmaModel):
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -500,7 +472,7 @@ class Gemma2Model(GemmaModel):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
@ -526,45 +498,6 @@ class Gemma2Model(GemmaModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: HybridCache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
|
||||
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
|
||||
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
|
||||
# as it doesn't cause dynamic control issues.
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return attention_mask
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if isinstance(past_key_values, (HybridCache, StaticCache)):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
def __init__(self, config):
|
||||
@ -577,7 +510,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -649,60 +582,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten: has a special cache type, `HybridCache`
|
||||
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
position_ids=position_ids,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if logits_to_keep is None:
|
||||
_ = model_inputs.pop("logits_to_keep", None)
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
and not self.config._attn_implementation == "flash_attention_2"
|
||||
):
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_cache_shape(),
|
||||
dtype=self.lm_head.weight.dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
model_inputs["attention_mask"] = attention_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
|
||||
def __init__(self, config):
|
||||
|
@ -21,7 +21,7 @@
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
from ..siglip import SiglipVisionConfig
|
||||
@ -88,13 +88,14 @@ class Gemma3TextConfig(PretrainedConfig):
|
||||
The dropout ratio for the attention probabilities.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
||||
Scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
final_logit_softcapping (`float`, *optional*):
|
||||
Scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*):
|
||||
Scaling factor when applying tanh softcapping on the attention scores.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
@ -134,8 +135,6 @@ class Gemma3TextConfig(PretrainedConfig):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings for local attention.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 6):
|
||||
Pattern for the sliding window attention.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
|
||||
@ -192,12 +191,11 @@ class Gemma3TextConfig(PretrainedConfig):
|
||||
attention_dropout=0.0,
|
||||
query_pre_attn_scalar=256,
|
||||
sliding_window=4096,
|
||||
layer_types=None,
|
||||
final_logit_softcapping=None,
|
||||
attn_logit_softcapping=None,
|
||||
cache_implementation="hybrid",
|
||||
rope_scaling=None,
|
||||
rope_local_base_freq=10_000.0,
|
||||
sliding_window_pattern=6,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -226,14 +224,21 @@ class Gemma3TextConfig(PretrainedConfig):
|
||||
self.sliding_window = sliding_window
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.cache_implementation = cache_implementation
|
||||
self.layer_types = layer_types
|
||||
|
||||
self.rope_local_base_freq = rope_local_base_freq
|
||||
# For configuring HybridCache to work with 5:1 attention pattern
|
||||
self.sliding_window_pattern = sliding_window_pattern
|
||||
self.rope_scaling = rope_scaling
|
||||
rope_config_validation(self)
|
||||
|
||||
if self.layer_types is None:
|
||||
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
||||
sliding_window_pattern = getattr(self, "sliding_window_pattern", 6)
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
class Gemma3Config(PretrainedConfig):
|
||||
r"""
|
||||
|
@ -29,32 +29,21 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..auto import AutoModel
|
||||
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -300,7 +289,7 @@ class Gemma3Attention(nn.Module):
|
||||
|
||||
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
|
||||
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
@ -351,19 +340,9 @@ class Gemma3Attention(nn.Module):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -374,9 +353,7 @@ class Gemma3Attention(nn.Module):
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
if attention_mask is not None:
|
||||
# backwards compatibility
|
||||
attention_mask = attention_mask.to(query_states)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
@ -400,14 +377,13 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma3MLP(config)
|
||||
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.is_sliding = self.self_attn.is_sliding
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
@ -423,33 +399,6 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(attention_mask.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -568,7 +517,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -595,13 +544,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -614,13 +557,22 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
@ -643,7 +595,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
hidden_states,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -655,7 +607,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
@ -681,100 +633,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: HybridCache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Flash Attention currently doesn't support static cache but Gemma3Text work only with static cache.
|
||||
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
|
||||
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
|
||||
# as it doesn't cause dynamic control issues.
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return attention_mask
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if isinstance(past_key_values, (HybridCache, StaticCache)):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
@ -818,7 +676,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -895,60 +753,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten: has a special cache type, `HybridCache`
|
||||
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
position_ids=position_ids,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if logits_to_keep is None:
|
||||
_ = model_inputs.pop("logits_to_keep", None)
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
and not self.config._attn_implementation == "flash_attention_2"
|
||||
):
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_cache_shape(),
|
||||
dtype=self.lm_head.weight.dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
model_inputs["attention_mask"] = attention_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(nn.Module):
|
||||
def __init__(self, config: Gemma3Config):
|
||||
@ -986,6 +790,22 @@ class Gemma3MultiModalProjector(nn.Module):
|
||||
return projected_vision_outputs.type_as(vision_outputs)
|
||||
|
||||
|
||||
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
"""
|
||||
# Do not return an additional mask in this case
|
||||
if token_type_ids is None:
|
||||
return None
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
# If it's 1, we need to unmask it
|
||||
return token_type_ids[batch_idx, kv_idx] == 1
|
||||
|
||||
return inner_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head.,
|
||||
@ -1012,86 +832,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
input_tensor,
|
||||
is_training: bool = False,
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
return attention_mask
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted
|
||||
# form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
elif isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
|
||||
)
|
||||
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
|
||||
# Apply bidirectional mask on images if token type ids are provided
|
||||
if token_type_ids is not None and sequence_length != 1:
|
||||
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
||||
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
||||
|
||||
# Find where a new image block starts: 1 if image and previous not image
|
||||
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||
is_image = token_type_ids == 1
|
||||
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||
|
||||
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
|
||||
same_image_mask[image_group_ids == -1] = False # remove non-image
|
||||
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
causal_mask = causal_mask.clone()
|
||||
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
||||
image_mask, 0.0
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# Then apply padding mask (will mask pad tokens)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
@ -1161,8 +901,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
@ -1202,11 +940,31 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config.get_text_config(),
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||
token_type_ids.to(cache_position.device)
|
||||
)
|
||||
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -1432,70 +1190,35 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self.model._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
def create_masks_for_generate(
|
||||
config: PretrainedConfig,
|
||||
input_embeds: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
past_key_values: Optional[Cache],
|
||||
output_attentions: bool = False,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
) -> dict:
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": config.get_text_config(),
|
||||
"input_embeds": input_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Add the token type ids mask for generate as well
|
||||
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
return create_masks_for_generate(**mask_kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -23,8 +23,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
@ -56,7 +57,7 @@ from ..siglip import SiglipVisionConfig
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3TextConfig(Gemma2Config):
|
||||
class Gemma3TextConfig(Gemma2Config, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
@ -114,13 +115,14 @@ class Gemma3TextConfig(Gemma2Config):
|
||||
The dropout ratio for the attention probabilities.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
||||
Scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
final_logit_softcapping (`float`, *optional*):
|
||||
Scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*):
|
||||
Scaling factor when applying tanh softcapping on the attention scores.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
@ -160,8 +162,6 @@ class Gemma3TextConfig(Gemma2Config):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings for local attention.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 6):
|
||||
Pattern for the sliding window attention.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
|
||||
@ -183,23 +183,74 @@ class Gemma3TextConfig(Gemma2Config):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=262_208,
|
||||
rope_theta=1_000_000.0,
|
||||
rope_scaling=None,
|
||||
rope_local_base_freq=10_000.0,
|
||||
sliding_window_pattern=6,
|
||||
hidden_size=2304,
|
||||
intermediate_size=9216,
|
||||
num_hidden_layers=26,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=4,
|
||||
head_dim=256,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
max_position_embeddings=131_072,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=1_000_000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
query_pre_attn_scalar=256,
|
||||
sliding_window=4096,
|
||||
layer_types=None,
|
||||
final_logit_softcapping=None,
|
||||
attn_logit_softcapping=None,
|
||||
**super_kwargs,
|
||||
rope_scaling=None,
|
||||
rope_local_base_freq=10_000.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(self, **super_kwargs)
|
||||
PretrainedConfig.__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_activation = hidden_activation
|
||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||
self.sliding_window = sliding_window
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.layer_types = layer_types
|
||||
|
||||
self.rope_local_base_freq = rope_local_base_freq
|
||||
# For configuring HybridCache to work with 5:1 attention pattern
|
||||
self.sliding_window_pattern = sliding_window_pattern
|
||||
self.rope_scaling = rope_scaling
|
||||
rope_config_validation(self)
|
||||
|
||||
if self.layer_types is None:
|
||||
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
|
||||
sliding_window_pattern = getattr(self, "sliding_window_pattern", 6)
|
||||
self.layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
class Gemma3Config(PretrainedConfig):
|
||||
r"""
|
||||
@ -336,7 +387,7 @@ class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
|
||||
# Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding`
|
||||
class Gemma3Attention(Gemma2Attention):
|
||||
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
||||
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
|
||||
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
|
||||
super().__init__()
|
||||
self.sliding_window = config.sliding_window if self.is_sliding else None
|
||||
@ -368,19 +419,9 @@ class Gemma3Attention(Gemma2Attention):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -391,9 +432,7 @@ class Gemma3Attention(Gemma2Attention):
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
if attention_mask is not None:
|
||||
# backwards compatibility
|
||||
attention_mask = attention_mask.to(query_states)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
@ -417,14 +456,13 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma3MLP(config)
|
||||
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.is_sliding = self.self_attn.is_sliding
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
@deprecate_kwarg("last_cache_position", version="4.53.0")
|
||||
def forward(
|
||||
@ -440,33 +478,6 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(attention_mask.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
|
||||
)
|
||||
mask_indexes += offset
|
||||
attention_mask = attention_mask[:, :, :, mask_indexes]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -557,7 +568,7 @@ class Gemma3TextModel(Gemma2Model):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -584,13 +595,7 @@ class Gemma3TextModel(Gemma2Model):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -603,13 +608,22 @@ class Gemma3TextModel(Gemma2Model):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
@ -632,7 +646,7 @@ class Gemma3TextModel(Gemma2Model):
|
||||
hidden_states,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -644,7 +658,7 @@ class Gemma3TextModel(Gemma2Model):
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
@ -716,6 +730,22 @@ class Gemma3MultiModalProjector(nn.Module):
|
||||
return projected_vision_outputs.type_as(vision_outputs)
|
||||
|
||||
|
||||
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
"""
|
||||
# Do not return an additional mask in this case
|
||||
if token_type_ids is None:
|
||||
return None
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
# If it's 1, we need to unmask it
|
||||
return token_type_ids[batch_idx, kv_idx] == 1
|
||||
|
||||
return inner_mask
|
||||
|
||||
|
||||
class Gemma3Model(PaliGemmaModel):
|
||||
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@ -731,85 +761,8 @@ class Gemma3Model(PaliGemmaModel):
|
||||
image_features = self.multi_modal_projector(vision_outputs)
|
||||
return image_features
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
input_tensor,
|
||||
is_training: bool = False,
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
return attention_mask
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted
|
||||
# form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
elif isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
|
||||
)
|
||||
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
|
||||
# Apply bidirectional mask on images if token type ids are provided
|
||||
if token_type_ids is not None and sequence_length != 1:
|
||||
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
||||
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
||||
|
||||
# Find where a new image block starts: 1 if image and previous not image
|
||||
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||
is_image = token_type_ids == 1
|
||||
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||
|
||||
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
|
||||
same_image_mask[image_group_ids == -1] = False # remove non-image
|
||||
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
causal_mask = causal_mask.clone()
|
||||
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
||||
image_mask, 0.0
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# Then apply padding mask (will mask pad tokens)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
def _update_causal_mask(self, **super_kwargs):
|
||||
raise AttributeError("We don't want to inherit it")
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@ -839,8 +792,6 @@ class Gemma3Model(PaliGemmaModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
@ -880,11 +831,31 @@ class Gemma3Model(PaliGemmaModel):
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config.get_text_config(),
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||
token_type_ids.to(cache_position.device)
|
||||
)
|
||||
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -1066,16 +1037,39 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self.model._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(self, **super_kwargs):
|
||||
raise AttributeError("We don't want to inherit it")
|
||||
|
||||
@staticmethod
|
||||
def create_masks_for_generate(
|
||||
config: PretrainedConfig,
|
||||
input_embeds: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
output_attentions: bool = False,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": config.get_text_config(),
|
||||
"input_embeds": input_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Add the token type ids mask for generate as well
|
||||
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
|
||||
|
||||
return create_masks_for_generate(**mask_kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Gemma3Config",
|
||||
|
@ -28,7 +28,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -40,16 +40,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_glm import GlmConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -443,8 +437,13 @@ class GlmModel(GlmPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -490,129 +489,6 @@ class GlmModel(GlmPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -28,7 +28,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -40,16 +40,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_glm4 import Glm4Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -451,8 +445,13 @@ class Glm4Model(Glm4PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -498,129 +497,6 @@ class Glm4Model(Glm4PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin):
|
||||
|
@ -893,60 +893,5 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"]
|
||||
|
@ -682,7 +682,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -752,7 +752,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -12,7 +12,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -24,16 +24,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_gpt_neox import GPTNeoXConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -409,8 +403,13 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
@ -479,129 +478,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -7,6 +7,7 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -349,8 +350,13 @@ class GPTNeoXModel(LlamaModel, nn.Module):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
|
@ -540,7 +540,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -610,7 +610,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -793,7 +793,6 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -863,7 +862,6 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -28,23 +28,17 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_granite import GraniteConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -446,8 +440,13 @@ class GraniteModel(GranitePreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -493,129 +492,6 @@ class GraniteModel(GranitePreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -20,6 +20,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...processing_utils import Unpack
|
||||
@ -174,8 +175,13 @@ class GraniteModel(LlamaModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -773,7 +773,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -843,7 +843,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -28,7 +28,7 @@ import torch.nn as nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -40,16 +40,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_helium import HeliumConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -428,8 +422,13 @@ class HeliumModel(HeliumPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -475,129 +474,6 @@ class HeliumModel(HeliumPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -1316,7 +1316,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
image_hidden_states=image_hidden_states,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1386,7 +1386,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -1035,61 +1035,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InternVLVisionPreTrainedModel",
|
||||
|
@ -1014,7 +1014,7 @@ class JetMoeModel(JetMoePreTrainedModel):
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1084,7 +1084,7 @@ class JetMoeModel(JetMoePreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -26,7 +26,8 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -40,18 +41,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -433,8 +426,13 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -480,129 +478,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -233,12 +233,13 @@ class Llama4TextConfig(PretrainedConfig):
|
||||
`no_rope_layer_interval` layers.
|
||||
attention_chunk_size (`int`, *optional*, defaults to 8192):
|
||||
<TODO>
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attn_temperature_tuning (`bool`, *optional*, defaults to `True`):
|
||||
Whether to dynamically scale the attention temperature for each query token based on sequence length.
|
||||
Recommended for long sequences (e.g., >32k tokens) to maintain stable output results.
|
||||
floor_scale (`int`, *optional*, defaults to 8192): TODO
|
||||
attn_scale (`int`, *optional*, defaults to 0.1): TODO
|
||||
cache_implementation (`<fill_type>`, *optional*, defaults to `"hybrid"`): <fill_docstring>
|
||||
|
||||
Example:
|
||||
"""
|
||||
@ -298,10 +299,10 @@ class Llama4TextConfig(PretrainedConfig):
|
||||
no_rope_layers=None,
|
||||
no_rope_layer_interval=4,
|
||||
attention_chunk_size=8192,
|
||||
layer_types=None,
|
||||
attn_temperature_tuning=True,
|
||||
floor_scale=8192,
|
||||
attn_scale=0.1,
|
||||
cache_implementation="hybrid",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -323,7 +324,6 @@ class Llama4TextConfig(PretrainedConfig):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = False
|
||||
self.cache_implementation = cache_implementation
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
@ -363,6 +363,13 @@ class Llama4TextConfig(PretrainedConfig):
|
||||
)
|
||||
self.attention_chunk_size = attention_chunk_size
|
||||
|
||||
self.layer_types = layer_types
|
||||
if layer_types is None:
|
||||
self.layer_types = [
|
||||
"chunked_attention" if no_rope else "full_attention" for no_rope in self.no_rope_layers
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
class Llama4Config(PretrainedConfig):
|
||||
r"""
|
||||
|
@ -24,24 +24,19 @@ import torch.nn.functional as F
|
||||
from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, HybridChunkedCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.hub_kernels import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask, create_chunked_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_llama4 import Llama4Config, Llama4TextConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -388,8 +383,9 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
self.self_attn = Llama4TextAttention(config, layer_idx)
|
||||
self.use_chunked_attention = config.attention_chunk_size is not None and bool(config.no_rope_layers[layer_idx])
|
||||
self.is_moe_layer = layer_idx in config.moe_layers
|
||||
if self.is_moe_layer: # the 128E model interleaves dense / sparse
|
||||
self.feed_forward = Llama4TextMoe(config)
|
||||
@ -399,13 +395,10 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
chunk_causal_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
@ -419,10 +412,6 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# use local attention mask for ROPE layers
|
||||
if self.use_chunked_attention and chunk_causal_mask is not None:
|
||||
attention_mask = chunk_causal_mask
|
||||
|
||||
# Self Attention
|
||||
attention_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@ -561,9 +550,6 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
if self.config.get_text_config().attention_chunk_size is not None:
|
||||
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
else:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
@ -575,9 +561,22 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask, chunk_causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, use_cache=use_cache
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
"chunked_attention": create_chunked_causal_mask(**mask_kwargs),
|
||||
}
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -596,8 +595,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
chunk_causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -609,8 +607,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
chunk_causal_mask=chunk_causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
@ -638,216 +635,6 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
@torch.compiler.disable(recursive=False) # the operations in this method are not compilable
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
chunked_attention_mask=None,
|
||||
use_cache=True,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask, attention_mask # flash does not support chunked attn TODO support flash
|
||||
return None, None
|
||||
|
||||
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
|
||||
return None, None
|
||||
|
||||
sequence_length = input_tensor.shape[1]
|
||||
cache_position = cache_position.to(self.device)
|
||||
attention_chunk_size = self.config.attention_chunk_size
|
||||
using_chunked_attention = attention_chunk_size is not None
|
||||
|
||||
first_cache_position = cache_position[0]
|
||||
|
||||
if past_key_values is not None:
|
||||
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
|
||||
else:
|
||||
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
|
||||
|
||||
if using_chunked_attention:
|
||||
cond1 = first_cache_position >= attention_chunk_size
|
||||
cond2 = (first_cache_position < attention_chunk_size) & (
|
||||
first_cache_position + sequence_length > attention_chunk_size
|
||||
)
|
||||
key_length = (
|
||||
torch.where(
|
||||
cond1,
|
||||
attention_chunk_size + sequence_length - 1,
|
||||
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
|
||||
)
|
||||
if use_cache
|
||||
else full_cache_length
|
||||
)
|
||||
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
if using_chunked_attention:
|
||||
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
|
||||
chunked_attention_mask = make_flex_block_causal_mask(
|
||||
attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets
|
||||
)
|
||||
attention_mask = make_flex_block_causal_mask(
|
||||
attention_mask,
|
||||
query_length=sequence_length,
|
||||
key_length=full_cache_length,
|
||||
offsets=(first_cache_position, 0),
|
||||
)
|
||||
return attention_mask, chunked_attention_mask
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask, chunked_attention_mask
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
target_length = max(full_cache_length, attention_chunk_size) if using_chunked_attention else full_cache_length
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
if using_chunked_attention and full_cache_length > attention_chunk_size:
|
||||
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
|
||||
end_idx = start_idx + key_length
|
||||
chunked_attention_mask = self.create_chunked_attention_mask(
|
||||
self.config.attention_chunk_size,
|
||||
start=start_idx, # same offset as with flex
|
||||
end=end_idx,
|
||||
device=device,
|
||||
)
|
||||
|
||||
local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well
|
||||
# It may be smaller than attention_chunk_size -> pad it
|
||||
requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
|
||||
if requires_padding:
|
||||
local_attention_mask = nn.functional.pad(
|
||||
local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1])
|
||||
)
|
||||
# Depending on the padding, take the query tokens from the end or the cache_position
|
||||
if not requires_padding:
|
||||
chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :]
|
||||
else:
|
||||
chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :]
|
||||
|
||||
chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1)
|
||||
chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :]
|
||||
if self.config._attn_implementation == "eager":
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and attention_mask.ndim == 4
|
||||
and not output_attentions # Only unmask for 4d masks
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
|
||||
chunked_attention_mask = chunked_attention_mask.bool()
|
||||
causal_mask = causal_mask != torch.finfo(dtype).min
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=first_cache_position,
|
||||
is_training=self.training,
|
||||
):
|
||||
causal_mask = None
|
||||
return causal_mask, chunked_attention_mask
|
||||
|
||||
def create_chunked_attention_mask(
|
||||
self, attention_chunk_size: int, start: int, end: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate the following:
|
||||
|
||||
'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ |
|
||||
'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ |
|
||||
'▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ |
|
||||
'?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ |
|
||||
|
||||
If the chunk size is 3.
|
||||
This can just be applied over the already created attention mask
|
||||
"""
|
||||
arange_vector = torch.arange(start, end, device=device)
|
||||
block_pos = torch.abs(
|
||||
arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size
|
||||
)
|
||||
token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
|
||||
mask = (block_pos == 0) & (token_pos <= 0)
|
||||
return mask.to(device)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
cache_position.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
@ -1726,61 +1513,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Llama4PreTrainedModel",
|
||||
|
@ -503,61 +503,5 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]
|
||||
|
@ -719,7 +719,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -1589,7 +1589,7 @@ class LongT5Stack(LongT5PreTrainedModel):
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1659,7 +1659,7 @@ class LongT5Stack(LongT5PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -790,7 +790,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -860,7 +860,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -484,7 +484,7 @@ class MarianPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -554,7 +554,7 @@ class MarianPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -762,7 +762,7 @@ class MBartPreTrainedModel(PreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -832,7 +832,7 @@ class MBartPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -1060,7 +1060,7 @@ class MimiTransformerModel(nn.Module):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Mimi
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Mimi
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1148,7 +1148,7 @@ class MimiTransformerModel(nn.Module):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mimi
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Mimi
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -10,10 +10,10 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -26,16 +26,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_mistral import MistralConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -403,8 +397,14 @@ class MistralModel(MistralPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
||||
causal_mask = mask_function(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -450,161 +450,6 @@ class MistralModel(MistralPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: MistralConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`MistralConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -4,13 +4,13 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, SlidingWindowCache, StaticCache
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, QuestionAnsweringModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import is_torch_flex_attn_available, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
@ -27,12 +27,6 @@ from ..llama.modeling_llama import (
|
||||
from .configuration_mistral import MistralConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
|
||||
@ -118,166 +112,107 @@ class MistralPreTrainedModel(LlamaPreTrainedModel):
|
||||
|
||||
|
||||
class MistralModel(LlamaModel):
|
||||
def __init__(self, config: MistralConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
||||
causal_mask = mask_function(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
return causal_mask
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: MistralConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`MistralConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class MistralForCausalLM(LlamaForCausalLM):
|
||||
|
@ -32,12 +32,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
|
||||
from ..auto import AutoModel
|
||||
from .configuration_mistral3 import Mistral3Config
|
||||
|
||||
@ -538,60 +533,5 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]
|
||||
|
@ -32,10 +32,10 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -48,16 +48,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_mixtral import MixtralConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -524,8 +518,14 @@ class MixtralModel(MixtralPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
||||
causal_mask = mask_function(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -591,161 +591,6 @@ class MixtralModel(MixtralPreTrainedModel):
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: MixtralConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`MixtralConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -29,6 +29,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import DynamicCache
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
||||
from ...processing_utils import Unpack
|
||||
@ -315,12 +316,6 @@ class MixtralPreTrainedModel(MistralPreTrainedModel):
|
||||
|
||||
|
||||
class MixtralModel(MistralModel):
|
||||
def __init__(self, config: MixtralConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -368,8 +363,14 @@ class MixtralModel(MistralModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
||||
causal_mask = mask_function(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -893,7 +893,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
||||
if module.is_gated:
|
||||
module.gate.data.zero_()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -963,7 +963,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -27,11 +27,8 @@ import torch.nn as nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -44,16 +41,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from .configuration_moonshine import MoonshineConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -752,8 +743,13 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -826,129 +822,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
|
@ -21,6 +21,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
@ -748,8 +749,13 @@ class MoonshineDecoder(LlamaModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -1084,7 +1084,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Moshi
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Moshi
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1172,7 +1172,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->MoshiDepth
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->MoshiDepth
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
@ -1391,7 +1391,7 @@ class MoshiModel(MoshiPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Moshi
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Moshi
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1479,7 +1479,7 @@ class MoshiModel(MoshiPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Moshi
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Moshi
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -1180,7 +1180,7 @@ class MT5Stack(MT5PreTrainedModel):
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1250,7 +1250,7 @@ class MT5Stack(MT5PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -739,7 +739,7 @@ class NemotronModel(NemotronPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -809,7 +809,7 @@ class NemotronModel(NemotronPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -13,23 +13,17 @@ import torch.nn.functional as F
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_olmo import OlmoConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -406,8 +400,13 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -453,129 +452,6 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -13,23 +13,17 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_olmo2 import Olmo2Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -412,8 +406,13 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -459,129 +458,6 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -384,7 +384,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -454,7 +454,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -566,7 +566,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -481,7 +481,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -551,7 +551,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -773,7 +773,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -843,7 +843,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -539,7 +539,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -609,7 +609,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -24,16 +24,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_phi import PhiConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -400,8 +394,13 @@ class PhiModel(PhiPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
|
||||
@ -461,129 +460,6 @@ class PhiModel(PhiPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -244,8 +245,13 @@ class PhiModel(LlamaModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
|
||||
|
@ -26,10 +26,10 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -41,16 +41,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_phi3 import Phi3Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -458,8 +452,14 @@ class Phi3Model(Phi3PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
||||
causal_mask = mask_function(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -505,161 +505,6 @@ class Phi3Model(Phi3PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: Phi3Config,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`Phi3Config`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -28,13 +28,12 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -46,16 +45,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging, torch_int
|
||||
from ...utils import auto_docstring, can_return_tuple, logging, torch_int
|
||||
from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -1766,8 +1759,14 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
||||
causal_mask = mask_function(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -1813,161 +1812,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Phi4Multimodal. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: Phi4MultimodalConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`Phi4MultimodalConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin):
|
||||
|
@ -21,11 +21,11 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPast,
|
||||
@ -1570,8 +1570,14 @@ class Phi4MultimodalModel(Phi3Model, nn.Module):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
||||
causal_mask = mask_function(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -1068,7 +1068,6 @@ class PhimoeModel(PhimoePreTrainedModel):
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Phimoe
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1156,7 +1155,6 @@ class PhimoeModel(PhimoePreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phimoe
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -1345,7 +1345,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -1415,7 +1415,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -520,7 +520,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -590,7 +590,7 @@ class PLBartPreTrainedModel(PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -900,7 +900,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
@ -970,7 +970,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Qwen2 model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
@ -110,6 +110,8 @@ class Qwen2Config(PretrainedConfig):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
@ -164,6 +166,7 @@ class Qwen2Config(PretrainedConfig):
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
layer_types=None,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
@ -174,7 +177,7 @@ class Qwen2Config(PretrainedConfig):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code
|
||||
self.sliding_window = sliding_window if self.use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
@ -195,6 +198,16 @@ class Qwen2Config(PretrainedConfig):
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention"
|
||||
if self.sliding_window is not None and i >= self.max_window_layers
|
||||
else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
|
@ -10,10 +10,10 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
@ -26,16 +26,10 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_qwen2 import Qwen2Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -143,6 +137,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -168,14 +163,6 @@ class Qwen2Attention(nn.Module):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
sliding_window = None
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -194,7 +181,7 @@ class Qwen2Attention(nn.Module):
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=sliding_window, # main diff with Llama
|
||||
sliding_window=self.sliding_window, # main diff with Llama
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -228,15 +215,13 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Qwen2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = Qwen2MLP(config)
|
||||
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_once(
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -357,6 +342,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -416,9 +402,24 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
}
|
||||
# The sliding window alternating layers are not always activated depending on the config
|
||||
if self.has_sliding_layers:
|
||||
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -435,7 +436,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
@ -463,161 +464,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: Qwen2Config,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`Qwen2Config`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -4,11 +4,15 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
@ -43,6 +47,7 @@ class Qwen2Attention(LlamaAttention):
|
||||
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -68,14 +73,6 @@ class Qwen2Attention(LlamaAttention):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
sliding_window = None
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -94,7 +91,7 @@ class Qwen2Attention(LlamaAttention):
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=sliding_window, # main diff with Llama
|
||||
sliding_window=self.sliding_window, # main diff with Llama
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -106,13 +103,7 @@ class Qwen2Attention(LlamaAttention):
|
||||
class Qwen2DecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(self, config: Qwen2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Qwen2MLP(config)
|
||||
if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_once(
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
|
||||
class Qwen2PreTrainedModel(LlamaPreTrainedModel):
|
||||
@ -120,7 +111,120 @@ class Qwen2PreTrainedModel(LlamaPreTrainedModel):
|
||||
|
||||
|
||||
class Qwen2Model(MistralModel):
|
||||
pass
|
||||
def __init__(self, config: Qwen2Config):
|
||||
super().__init__(config)
|
||||
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"output_attentions": output_attentions,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
}
|
||||
# The sliding window alternating layers are not always activated depending on the config
|
||||
if self.has_sliding_layers:
|
||||
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(LlamaForCausalLM):
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user