Fix Llama4 offset (#37414)

* add +1

* Update modeling_llama4.py
This commit is contained in:
Cyril Vallez 2025-04-10 11:40:58 +02:00 committed by GitHub
parent f5865d32a2
commit 6d8b0b3378
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -731,7 +731,6 @@ class Llama4TextModel(Llama4PreTrainedModel):
attention_chunk_size = self.config.attention_chunk_size
first_cache_position = cache_position[0]
last_cache_position = cache_position[-1]
if past_key_values is not None:
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
@ -754,7 +753,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
offsets = (first_cache_position, max(last_cache_position - key_length, 0))
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
chunked_attention_mask = make_flex_block_causal_mask(
attention_mask, self.config.attention_chunk_size, sequence_length, key_length, offsets=offsets
)
@ -780,10 +779,8 @@ class Llama4TextModel(Llama4PreTrainedModel):
batch_size=input_tensor.shape[0],
)
if full_cache_length > self.config.attention_chunk_size:
start_idx = max(last_cache_position - key_length, 0)
end_idx = last_cache_position + 1 if sequence_length > 1 else last_cache_position
# We always need a mask of at least attention_chunk_size, so we use the max here
end_idx = max(end_idx, start_idx + attention_chunk_size)
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
end_idx = start_idx + key_length
chunked_attention_mask = self.create_chunked_attention_mask(
self.config.attention_chunk_size,
start=start_idx, # same offset as with flex