mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
parent
f5865d32a2
commit
6d8b0b3378
@ -731,7 +731,6 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
attention_chunk_size = self.config.attention_chunk_size
|
||||
|
||||
first_cache_position = cache_position[0]
|
||||
last_cache_position = cache_position[-1]
|
||||
|
||||
if past_key_values is not None:
|
||||
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
|
||||
@ -754,7 +753,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
offsets = (first_cache_position, max(last_cache_position - key_length, 0))
|
||||
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
|
||||
chunked_attention_mask = make_flex_block_causal_mask(
|
||||
attention_mask, self.config.attention_chunk_size, sequence_length, key_length, offsets=offsets
|
||||
)
|
||||
@ -780,10 +779,8 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
if full_cache_length > self.config.attention_chunk_size:
|
||||
start_idx = max(last_cache_position - key_length, 0)
|
||||
end_idx = last_cache_position + 1 if sequence_length > 1 else last_cache_position
|
||||
# We always need a mask of at least attention_chunk_size, so we use the max here
|
||||
end_idx = max(end_idx, start_idx + attention_chunk_size)
|
||||
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
|
||||
end_idx = start_idx + key_length
|
||||
chunked_attention_mask = self.create_chunked_attention_mask(
|
||||
self.config.attention_chunk_size,
|
||||
start=start_idx, # same offset as with flex
|
||||
|
Loading…
Reference in New Issue
Block a user