Minor llama4 fixes (#38123)

* fix wrong scaling value/default Cache init

* style

* fix various issues on integration tests

* change expected outputs

* fixup

* fix config access

* protect default scaling
This commit is contained in:
Pablo Montalvo 2025-05-20 15:15:54 +02:00 committed by GitHub
parent 856f034f45
commit 9cde2f5d42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 33 deletions

View File

@ -258,6 +258,33 @@ def eager_attention_forward(
return attn_output, attn_weights return attn_output, attn_weights
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
def vision_eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Llama4TextAttention(nn.Module): class Llama4TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
@ -534,10 +561,10 @@ class Llama4TextModel(Llama4PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
if use_cache and past_key_values is None: if use_cache and past_key_values is None:
if self.config.get_text_config().get("attention_chunk_size") is not None: if self.config.get_text_config().attention_chunk_size is not None:
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1]) past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
else: else:
past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1]) past_key_values = DynamicCache()
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -1099,7 +1126,7 @@ class Llama4VisionAttention(nn.Module):
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = vision_eager_attention_forward
# flex disable because breaks on TP 8, embed is 88 not power of 2 # flex disable because breaks on TP 8, embed is 88 not power of 2
if self.config._attn_implementation not in ["eager", "flex_attention"]: if self.config._attn_implementation not in ["eager", "flex_attention"]:
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -1117,7 +1144,7 @@ class Llama4VisionAttention(nn.Module):
value_states, value_states,
None, None,
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=None, scaling=None, # TODO Might be enforced here for TP compatibility as scaling is not just sqrt(head_dim)
is_causal=False, # HAS TO BE ENFORCED is_causal=False, # HAS TO BE ENFORCED
**kwargs, **kwargs,
) )

View File

@ -37,7 +37,7 @@ if is_torch_available():
@require_torch_large_gpu @require_torch_large_gpu
@require_read_token @require_read_token
class Llama4IntegrationTest(unittest.TestCase): class Llama4IntegrationTest(unittest.TestCase):
model_id = "ll-re/Llama-4-17B-Omni-Instruct" model_id = "meta-llama/Llama-4-Scout-17B-16E"
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4) # This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations # Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None cuda_compute_capability_major_version = None
@ -48,14 +48,17 @@ class Llama4IntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4 # 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
cls.model = Llama4ForConditionalGeneration.from_pretrained( cls.model = Llama4ForConditionalGeneration.from_pretrained(
"ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32 "meta-llama/Llama-4-Scout-17B-16E",
device_map="auto",
torch_dtype=torch.float32,
attn_implementation="eager",
) )
def setUp(self): def setUp(self):
self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left") self.processor = Llama4Processor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E", padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
self.messages = [ self.messages_1 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{ {
"role": "user", "role": "user",
@ -66,27 +69,7 @@ class Llama4IntegrationTest(unittest.TestCase):
}, },
] ]
def test_model_17b_16e_fp16(self): self.messages_2 = [
EXPECTED_TEXT = [
"The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the",
"Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love",
]
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = self.processor.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
).to(torch_device)
output = self.model.generate(**inputs, max_new_tokens=100)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
print(output_text)
self.assertEqual(output_text, EXPECTED_TEXT)
def test_model_17b_16e_batch(self):
messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{ {
"role": "user", "role": "user",
@ -101,20 +84,35 @@ class Llama4IntegrationTest(unittest.TestCase):
}, },
] ]
def test_model_17b_16e_fp16(self):
EXPECTED_TEXT = [
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white'
] # fmt: skip
inputs = self.processor.apply_chat_template( inputs = self.processor.apply_chat_template(
[self.messages, messages_2], self.messages_1, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True
).to(device=torch_device, dtype=self.model.dtype)
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
print(output_text)
self.assertEqual(output_text, EXPECTED_TEXT)
def test_model_17b_16e_batch(self):
inputs = self.processor.apply_chat_template(
[self.messages_1, self.messages_2],
tokenize=True, tokenize=True,
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
add_generation_prompt=True, add_generation_prompt=True,
).to(torch_device) ).to(device=torch_device, dtype=torch.float32)
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False) output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True) output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = [ EXPECTED_TEXTS = [
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like', 'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow" 'system\n\nYou are a helpful assistant.user\n\nAre these images identical?assistant\n\nNo, these images are not identical. The first image shows a cow standing on a beach with a blue sky and a white cloud in the background.'
] # fmt: skip ] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS) self.assertEqual(output_text, EXPECTED_TEXTS)