mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Gemma3: fix test (#36820)
* fix test * require_read_token and public repo ids * flash-attn test uncomment * fix torchscript
This commit is contained in:
parent
068b663f90
commit
42c489f2ae
@ -1363,7 +1363,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
logits = outputs[0]
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
|
@ -557,7 +557,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
logits = outputs[0]
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
|
@ -30,6 +30,8 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
@ -355,10 +357,10 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
# @require_read_token
|
||||
@require_read_token
|
||||
class Gemma3IntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = Gemma3Processor.from_pretrained("gg-hf-g/gemma-3-4b-it", padding_side="left")
|
||||
self.processor = Gemma3Processor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
|
||||
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||
self.messages = [
|
||||
@ -376,7 +378,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_model_4b_bf16(self):
|
||||
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||
@ -397,7 +399,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_4b_batch(self):
|
||||
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||
@ -437,7 +439,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_4b_crops(self):
|
||||
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||
@ -465,12 +467,12 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
|
||||
EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nDescribe this image in detail.\nmodel\nHere's a detailed description of the image:\n\n**Overall Impression:**\n\nThe image is a close-up shot of a garden scene featuring several"] # fmt: skip
|
||||
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip
|
||||
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_4b_multiimage(self):
|
||||
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||
@ -503,7 +505,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_1b_text_only(self):
|
||||
model_id = "gg-hf-g/gemma-3-1b-it"
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
|
||||
model = Gemma3ForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
@ -518,29 +520,29 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
# TODO: raushan FA2 generates gibberish for no reason, check later
|
||||
# @require_flash_attn
|
||||
# @require_torch_gpu
|
||||
# @mark.flash_attn_test
|
||||
# def test_model_4b_flash_attn(self):
|
||||
# model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
#
|
||||
# model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
# model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
# ).to(torch_device)
|
||||
#
|
||||
# inputs = self.processor.apply_chat_template(
|
||||
# self.messages,
|
||||
# tokenize=True,
|
||||
# return_dict=True,
|
||||
# return_tensors="pt",
|
||||
# add_generation_prompt=True,
|
||||
# ).to(torch_device)
|
||||
#
|
||||
# output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
# output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
#
|
||||
# EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nPlease look out that you are what Grammy and Vi- ||.xfairesr--ith alerts themselves are||ِّ\n\n**General Note:**'] # fmt: skip
|
||||
# self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_model_4b_flash_attn(self):
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
).to(torch_device)
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
self.messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
|
||||
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||
@ -548,7 +550,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
model_id = "gg-hf-g/gemma-3-1b-it"
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
@ -576,7 +578,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
||||
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||
"""
|
||||
model_id = "gg-hf-g/gemma-3-1b-it"
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
attn_implementation = "sdpa"
|
||||
|
||||
input_text = [
|
||||
|
Loading…
Reference in New Issue
Block a user