Gemma3: fix test (#36820)

* fix test

* require_read_token and public repo ids

* flash-attn test uncomment

* fix torchscript
This commit is contained in:
Raushan Turganbay 2025-03-20 18:14:53 +01:00 committed by GitHub
parent 068b663f90
commit 42c489f2ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 35 deletions

View File

@ -1363,7 +1363,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
**lm_kwargs,
)
logits = outputs.logits
logits = outputs[0]
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues

View File

@ -557,7 +557,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
**lm_kwargs,
)
logits = outputs.logits
logits = outputs[0]
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues

View File

@ -30,6 +30,8 @@ from transformers import (
)
from transformers.testing_utils import (
cleanup,
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
slow,
@ -355,10 +357,10 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
@slow
@require_torch_gpu
# @require_read_token
@require_read_token
class Gemma3IntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = Gemma3Processor.from_pretrained("gg-hf-g/gemma-3-4b-it", padding_side="left")
self.processor = Gemma3Processor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
self.messages = [
@ -376,7 +378,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
cleanup(torch_device, gc_collect=True)
def test_model_4b_bf16(self):
model_id = "gg-hf-g/gemma-3-4b-it"
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
@ -397,7 +399,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_batch(self):
model_id = "gg-hf-g/gemma-3-4b-it"
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
@ -437,7 +439,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_crops(self):
model_id = "gg-hf-g/gemma-3-4b-it"
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
@ -465,12 +467,12 @@ class Gemma3IntegrationTest(unittest.TestCase):
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nDescribe this image in detail.\nmodel\nHere's a detailed description of the image:\n\n**Overall Impression:**\n\nThe image is a close-up shot of a garden scene featuring several"] # fmt: skip
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_multiimage(self):
model_id = "gg-hf-g/gemma-3-4b-it"
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
@ -503,7 +505,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_1b_text_only(self):
model_id = "gg-hf-g/gemma-3-1b-it"
model_id = "google/gemma-3-1b-it"
model = Gemma3ForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
@ -518,29 +520,29 @@ class Gemma3IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
# TODO: raushan FA2 generates gibberish for no reason, check later
# @require_flash_attn
# @require_torch_gpu
# @mark.flash_attn_test
# def test_model_4b_flash_attn(self):
# model_id = "gg-hf-g/gemma-3-4b-it"
#
# model = Gemma3ForConditionalGeneration.from_pretrained(
# model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
# ).to(torch_device)
#
# inputs = self.processor.apply_chat_template(
# self.messages,
# tokenize=True,
# return_dict=True,
# return_tensors="pt",
# add_generation_prompt=True,
# ).to(torch_device)
#
# output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# output_text = self.processor.batch_decode(output, skip_special_tokens=True)
#
# EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nPlease look out that you are what Grammy and Vi- ||.xfairesr--ith alerts themselves are||ِّ\n\n**General Note:**'] # fmt: skip
# self.assertEqual(output_text, EXPECTED_TEXTS)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
def test_model_4b_flash_attn(self):
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
).to(torch_device)
inputs = self.processor.apply_chat_template(
self.messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
def test_generation_beyond_sliding_window(self, attn_implementation: str):
@ -548,7 +550,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "gg-hf-g/gemma-3-1b-it"
model_id = "google/gemma-3-1b-it"
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
@ -576,7 +578,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
"""
model_id = "gg-hf-g/gemma-3-1b-it"
model_id = "google/gemma-3-1b-it"
attn_implementation = "sdpa"
input_text = [