mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Minor llama4 fixes (#38123)
* fix wrong scaling value/default Cache init * style * fix various issues on integration tests * change expected outputs * fixup * fix config access * protect default scaling
This commit is contained in:
parent
856f034f45
commit
9cde2f5d42
@ -258,6 +258,33 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
|
||||
def vision_eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Llama4TextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -534,10 +561,10 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
if self.config.get_text_config().get("attention_chunk_size") is not None:
|
||||
if self.config.get_text_config().attention_chunk_size is not None:
|
||||
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
else:
|
||||
past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -1099,7 +1126,7 @@ class Llama4VisionAttention(nn.Module):
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
attention_interface: Callable = vision_eager_attention_forward
|
||||
# flex disable because breaks on TP 8, embed is 88 not power of 2
|
||||
if self.config._attn_implementation not in ["eager", "flex_attention"]:
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -1117,7 +1144,7 @@ class Llama4VisionAttention(nn.Module):
|
||||
value_states,
|
||||
None,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=None,
|
||||
scaling=None, # TODO Might be enforced here for TP compatibility as scaling is not just sqrt(head_dim)
|
||||
is_causal=False, # HAS TO BE ENFORCED
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ if is_torch_available():
|
||||
@require_torch_large_gpu
|
||||
@require_read_token
|
||||
class Llama4IntegrationTest(unittest.TestCase):
|
||||
model_id = "ll-re/Llama-4-17B-Omni-Instruct"
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E"
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
@ -48,14 +48,17 @@ class Llama4IntegrationTest(unittest.TestCase):
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
cls.model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
"ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32
|
||||
"meta-llama/Llama-4-Scout-17B-16E",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float32,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left")
|
||||
self.processor = Llama4Processor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E", padding_side="left")
|
||||
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||
self.messages = [
|
||||
self.messages_1 = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||
{
|
||||
"role": "user",
|
||||
@ -66,27 +69,7 @@ class Llama4IntegrationTest(unittest.TestCase):
|
||||
},
|
||||
]
|
||||
|
||||
def test_model_17b_16e_fp16(self):
|
||||
EXPECTED_TEXT = [
|
||||
"The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the",
|
||||
"Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love",
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
||||
).to(torch_device)
|
||||
|
||||
output = self.model.generate(**inputs, max_new_tokens=100)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
print(output_text)
|
||||
self.assertEqual(output_text, EXPECTED_TEXT)
|
||||
|
||||
def test_model_17b_16e_batch(self):
|
||||
messages_2 = [
|
||||
self.messages_2 = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||
{
|
||||
"role": "user",
|
||||
@ -101,20 +84,35 @@ class Llama4IntegrationTest(unittest.TestCase):
|
||||
},
|
||||
]
|
||||
|
||||
def test_model_17b_16e_fp16(self):
|
||||
EXPECTED_TEXT = [
|
||||
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white'
|
||||
] # fmt: skip
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
[self.messages, messages_2],
|
||||
self.messages_1, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
||||
).to(device=torch_device, dtype=self.model.dtype)
|
||||
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
print(output_text)
|
||||
self.assertEqual(output_text, EXPECTED_TEXT)
|
||||
|
||||
def test_model_17b_16e_batch(self):
|
||||
inputs = self.processor.apply_chat_template(
|
||||
[self.messages_1, self.messages_2],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
add_generation_prompt=True,
|
||||
).to(torch_device)
|
||||
).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXTS = [
|
||||
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
|
||||
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
|
||||
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white',
|
||||
'system\n\nYou are a helpful assistant.user\n\nAre these images identical?assistant\n\nNo, these images are not identical. The first image shows a cow standing on a beach with a blue sky and a white cloud in the background.'
|
||||
] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
Loading…
Reference in New Issue
Block a user