mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Minor llama4 fixes (#38123)
* fix wrong scaling value/default Cache init * style * fix various issues on integration tests * change expected outputs * fixup * fix config access * protect default scaling
This commit is contained in:
parent
856f034f45
commit
9cde2f5d42
@ -258,6 +258,33 @@ def eager_attention_forward(
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
|
||||||
|
def vision_eager_attention_forward(
|
||||||
|
module: nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
scaling: float,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class Llama4TextAttention(nn.Module):
|
class Llama4TextAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
@ -534,10 +561,10 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
||||||
|
|
||||||
if use_cache and past_key_values is None:
|
if use_cache and past_key_values is None:
|
||||||
if self.config.get_text_config().get("attention_chunk_size") is not None:
|
if self.config.get_text_config().attention_chunk_size is not None:
|
||||||
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||||
else:
|
else:
|
||||||
past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
@ -1099,7 +1126,7 @@ class Llama4VisionAttention(nn.Module):
|
|||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
attention_interface: Callable = eager_attention_forward
|
attention_interface: Callable = vision_eager_attention_forward
|
||||||
# flex disable because breaks on TP 8, embed is 88 not power of 2
|
# flex disable because breaks on TP 8, embed is 88 not power of 2
|
||||||
if self.config._attn_implementation not in ["eager", "flex_attention"]:
|
if self.config._attn_implementation not in ["eager", "flex_attention"]:
|
||||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
@ -1117,7 +1144,7 @@ class Llama4VisionAttention(nn.Module):
|
|||||||
value_states,
|
value_states,
|
||||||
None,
|
None,
|
||||||
dropout=0.0 if not self.training else self.attention_dropout,
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
scaling=None,
|
scaling=None, # TODO Might be enforced here for TP compatibility as scaling is not just sqrt(head_dim)
|
||||||
is_causal=False, # HAS TO BE ENFORCED
|
is_causal=False, # HAS TO BE ENFORCED
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,7 @@ if is_torch_available():
|
|||||||
@require_torch_large_gpu
|
@require_torch_large_gpu
|
||||||
@require_read_token
|
@require_read_token
|
||||||
class Llama4IntegrationTest(unittest.TestCase):
|
class Llama4IntegrationTest(unittest.TestCase):
|
||||||
model_id = "ll-re/Llama-4-17B-Omni-Instruct"
|
model_id = "meta-llama/Llama-4-Scout-17B-16E"
|
||||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
# Depending on the hardware we get different logits / generations
|
# Depending on the hardware we get different logits / generations
|
||||||
cuda_compute_capability_major_version = None
|
cuda_compute_capability_major_version = None
|
||||||
@ -48,14 +48,17 @@ class Llama4IntegrationTest(unittest.TestCase):
|
|||||||
# 8 is for A100 / A10 and 7 for T4
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
cls.model = Llama4ForConditionalGeneration.from_pretrained(
|
cls.model = Llama4ForConditionalGeneration.from_pretrained(
|
||||||
"ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32
|
"meta-llama/Llama-4-Scout-17B-16E",
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
attn_implementation="eager",
|
||||||
)
|
)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left")
|
self.processor = Llama4Processor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E", padding_side="left")
|
||||||
|
|
||||||
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||||
self.messages = [
|
self.messages_1 = [
|
||||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@ -66,27 +69,7 @@ class Llama4IntegrationTest(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_model_17b_16e_fp16(self):
|
self.messages_2 = [
|
||||||
EXPECTED_TEXT = [
|
|
||||||
"The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the",
|
|
||||||
"Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love",
|
|
||||||
]
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "Who are you?"},
|
|
||||||
]
|
|
||||||
inputs = self.processor.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
output = self.model.generate(**inputs, max_new_tokens=100)
|
|
||||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
|
||||||
|
|
||||||
print(output_text)
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXT)
|
|
||||||
|
|
||||||
def test_model_17b_16e_batch(self):
|
|
||||||
messages_2 = [
|
|
||||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@ -101,20 +84,35 @@ class Llama4IntegrationTest(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_model_17b_16e_fp16(self):
|
||||||
|
EXPECTED_TEXT = [
|
||||||
|
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white'
|
||||||
|
] # fmt: skip
|
||||||
|
|
||||||
inputs = self.processor.apply_chat_template(
|
inputs = self.processor.apply_chat_template(
|
||||||
[self.messages, messages_2],
|
self.messages_1, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
||||||
|
).to(device=torch_device, dtype=self.model.dtype)
|
||||||
|
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
print(output_text)
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXT)
|
||||||
|
|
||||||
|
def test_model_17b_16e_batch(self):
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
[self.messages_1, self.messages_2],
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
).to(torch_device)
|
).to(device=torch_device, dtype=torch.float32)
|
||||||
|
|
||||||
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
|
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white',
|
||||||
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
|
'system\n\nYou are a helpful assistant.user\n\nAre these images identical?assistant\n\nNo, these images are not identical. The first image shows a cow standing on a beach with a blue sky and a white cloud in the background.'
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
Loading…
Reference in New Issue
Block a user