From dcb29eb805f2c24a37eb99a6439be181bff2445d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 4 Apr 2025 14:01:07 +0000 Subject: [PATCH 1/2] Add support for sparse `Llama4TextMoe` layer from the kernel hub --- src/transformers/integrations/hub_kernels.py | 7 +++++++ src/transformers/models/llama4/modeling_llama4.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index b2ec6b53715..3b123438603 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -26,6 +26,13 @@ try: _hub_kernels_available = True _KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = { + "Llama4TextMoe": { + "cuda": LayerRepository( + # Move to kernels-community/moe once we release. + repo_id="kernels-community/moe-new-models", + layer_name="Llama4TextMoe", + ) + }, "MultiScaleDeformableAttention": { "cuda": LayerRepository( repo_id="kernels-community/deformable-detr", diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 60a8dcabe93..31d606e87c5 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -33,6 +33,7 @@ from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin +from ...integrations.hub_kernels import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -146,6 +147,7 @@ class Llama4TextRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.eps}" +@use_kernel_forward_from_hub("Llama4TextMoe") class Llama4TextMoe(nn.Module): def __init__(self, config): super().__init__() From 373a472e939f6a1fa6a34a047c5a8c8f65131399 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Apr 2025 14:10:56 +0000 Subject: [PATCH 2/2] better merge --- src/transformers/cache_utils.py | 122 +++++++++++++++----------------- 1 file changed, 59 insertions(+), 63 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6fa96e5c8e4..3fc84e05dbc 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1663,84 +1663,77 @@ class HybridCache(Cache): max_batch_size: int, max_cache_len: Optional[int] = None, device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, + dtype: torch.dtype = torch.bfloat16, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) + self.sliding_window = getattr(config, "attention_chunk_size", 8192) + else: + self.sliding_window = config.sliding_window self.max_cache_len = max_cache_len self.max_batch_size = max_batch_size - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads - ) - layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC - self.is_sliding = torch.tensor( - [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool - ) + if hasattr(config, "no_rope_layers"): + self.is_sliding = torch.tensor([k in config.no_rope_layers for k in range(config.num_hidden_layers)]) + else: + layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 + self.is_sliding = torch.tensor( + [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool + ) + self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) + self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] + + def initialise_cache_layer(self, layer_idx, key_states): + if len(self.key_cache) > layer_idx: + return + + num_key_value_heads = key_states.shape[1] + device = key_states.device + global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) sliding_cache_shape = ( self.max_batch_size, - self.num_key_value_heads, - min(config.sliding_window, max_cache_len), + num_key_value_heads, + self.sliding_window, self.head_dim, ) - device = torch.device(device) if device is not None and isinstance(device, str) else None - for i in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[i] - else: - layer_device = device - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - if cache_position.shape[0] > max_cache_len: - k_out = key_states[:, :, -max_cache_len:, :] - v_out = value_states[:, :, -max_cache_len:, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states + cumulative_length = self.cumulative_length[layer_idx] + is_full = cumulative_length >= max_cache_len + if is_full: + full_key_states = torch.cat((self.key_cache[layer_idx][:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((self.value_cache[layer_idx][:, :, 1:, :], value_states), dim=-2) + elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: + full_key_states = torch.cat((self.key_cache[layer_idx][:, :, :cumulative_length, :], key_states), dim=-2) + full_value_states = torch.cat( + (self.value_cache[layer_idx][:, :, :cumulative_length, :], value_states), dim=-2 + ) + else: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + self.cumulative_length[layer_idx] += key_states.shape[-2] + return self.key_cache[layer_idx], self.value_cache[layer_idx] - slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) - cache_position = cache_position.clamp(0, max_cache_len - 1) - to_shift = cache_position >= max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % max_cache_len - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - return k_out, v_out + self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) + self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) + self.cumulative_length[layer_idx] += key_states.shape[-2] + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return full_key_states, full_value_states def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): k_out[:, :, cache_position] = key_states @@ -1760,7 +1753,7 @@ class HybridCache(Cache): if cache_kwargs is None: cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") - sliding_window = cache_kwargs.get("sliding_window") + self.initialise_cache_layer(layer_idx, key_states) # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # when the cache is initialized in the forward pass (e.g. Gemma2) @@ -1774,7 +1767,7 @@ class HybridCache(Cache): key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) - if sliding_window: + if self.is_sliding[layer_idx]: update_fn = self._sliding_update else: update_fn = self._static_update @@ -1801,6 +1794,8 @@ class HybridCache(Cache): "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " "Using the `layer_idx` argument is not supported." ) + if len(self.key_cache) == 0: + return 0 return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def reset(self): @@ -1809,6 +1804,7 @@ class HybridCache(Cache): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] class MambaCache: