diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 500910404b0..aa6d456cf4a 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -1363,7 +1363,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): **lm_kwargs, ) - logits = outputs.logits + logits = outputs[0] loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 5d8542c1d39..64ebee00a01 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -557,7 +557,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi **lm_kwargs, ) - logits = outputs.logits + logits = outputs[0] loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index a5a44330d81..06a476c69a3 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -30,6 +30,8 @@ from transformers import ( ) from transformers.testing_utils import ( cleanup, + require_flash_attn, + require_read_token, require_torch, require_torch_gpu, slow, @@ -355,10 +357,10 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte @slow @require_torch_gpu -# @require_read_token +@require_read_token class Gemma3IntegrationTest(unittest.TestCase): def setUp(self): - self.processor = Gemma3Processor.from_pretrained("gg-hf-g/gemma-3-4b-it", padding_side="left") + self.processor = Gemma3Processor.from_pretrained("google/gemma-3-4b-it", padding_side="left") url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" self.messages = [ @@ -376,7 +378,7 @@ class Gemma3IntegrationTest(unittest.TestCase): cleanup(torch_device, gc_collect=True) def test_model_4b_bf16(self): - model_id = "gg-hf-g/gemma-3-4b-it" + model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained( model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 @@ -397,7 +399,7 @@ class Gemma3IntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) def test_model_4b_batch(self): - model_id = "gg-hf-g/gemma-3-4b-it" + model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained( model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 @@ -437,7 +439,7 @@ class Gemma3IntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) def test_model_4b_crops(self): - model_id = "gg-hf-g/gemma-3-4b-it" + model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained( model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 @@ -465,12 +467,12 @@ class Gemma3IntegrationTest(unittest.TestCase): output_text = self.processor.batch_decode(output, skip_special_tokens=True) EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images - EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nDescribe this image in detail.\nmodel\nHere's a detailed description of the image:\n\n**Overall Impression:**\n\nThe image is a close-up shot of a garden scene featuring several"] # fmt: skip + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES) self.assertEqual(output_text, EXPECTED_TEXTS) def test_model_4b_multiimage(self): - model_id = "gg-hf-g/gemma-3-4b-it" + model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained( model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 @@ -503,7 +505,7 @@ class Gemma3IntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) def test_model_1b_text_only(self): - model_id = "gg-hf-g/gemma-3-1b-it" + model_id = "google/gemma-3-1b-it" model = Gemma3ForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( torch_device @@ -518,29 +520,29 @@ class Gemma3IntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) # TODO: raushan FA2 generates gibberish for no reason, check later - # @require_flash_attn - # @require_torch_gpu - # @mark.flash_attn_test - # def test_model_4b_flash_attn(self): - # model_id = "gg-hf-g/gemma-3-4b-it" - # - # model = Gemma3ForConditionalGeneration.from_pretrained( - # model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - # ).to(torch_device) - # - # inputs = self.processor.apply_chat_template( - # self.messages, - # tokenize=True, - # return_dict=True, - # return_tensors="pt", - # add_generation_prompt=True, - # ).to(torch_device) - # - # output = model.generate(**inputs, max_new_tokens=30, do_sample=False) - # output_text = self.processor.batch_decode(output, skip_special_tokens=True) - # - # EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nPlease look out that you are what Grammy and Vi- ||.xfairesr--ith alerts themselves are||ِّ\n\n**General Note:**'] # fmt: skip - # self.assertEqual(output_text, EXPECTED_TEXTS) + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + def test_model_4b_flash_attn(self): + model_id = "google/gemma-3-4b-it" + + model = Gemma3ForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ).to(torch_device) + + inputs = self.processor.apply_chat_template( + self.messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) @parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)]) def test_generation_beyond_sliding_window(self, attn_implementation: str): @@ -548,7 +550,7 @@ class Gemma3IntegrationTest(unittest.TestCase): we need to correctly slice the attention mask in all cases (because we use a HybridCache). Outputs for every attention functions should be coherent and identical. """ - model_id = "gg-hf-g/gemma-3-1b-it" + model_id = "google/gemma-3-1b-it" input_text = [ "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens @@ -576,7 +578,7 @@ class Gemma3IntegrationTest(unittest.TestCase): Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 -- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. """ - model_id = "gg-hf-g/gemma-3-1b-it" + model_id = "google/gemma-3-1b-it" attn_implementation = "sdpa" input_text = [