Fix Gemma2IntegrationTest (#38492)

* fix

* fix

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* skip-ci

* update

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-06-02 22:45:09 +02:00 committed by GitHub
parent 1094dd34f7
commit ccc859620a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 18 deletions

View File

@ -116,6 +116,7 @@ from .utils import (
is_peft_available,
is_phonemizer_available,
is_pretty_midi_available,
is_psutil_available,
is_pyctcdecode_available,
is_pytesseract_available,
is_pytest_available,
@ -1053,6 +1054,19 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_large_cpu_ram(test_case, memory: float = 80):
"""Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
if not is_psutil_available():
return test_case
import psutil
return unittest.skipUnless(
psutil.virtual_memory().total / 1024**3 > memory,
f"test requires a machine with more than {memory} GiB of CPU RAM memory",
)(test_case)
def require_torch_large_gpu(test_case, memory: float = 20):
"""Decorator marking a test that requires a CUDA GPU with more than `memory` GiB of memory."""
if torch_device != "cuda":

View File

@ -24,6 +24,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
Expectations,
is_flash_attn_2_available,
require_flash_attn,
require_read_token,
require_torch,
@ -282,6 +283,9 @@ class Cohere2IntegrationTest(unittest.TestCase):
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
self.skipTest("FlashAttention2 is required for this test.")
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")

View File

@ -23,13 +23,17 @@ from pytest import mark
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
Expectations,
cleanup,
is_flash_attn_2_available,
require_flash_attn,
require_large_cpu_ram,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_large_accelerator,
require_torch_large_gpu,
slow,
tooslow,
torch_device,
)
@ -177,7 +181,13 @@ class Gemma2ModelTest(CausalLMModelTest, unittest.TestCase):
class Gemma2IntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
@tooslow
def setUp(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@require_torch_large_accelerator
@require_read_token
def test_model_9b_bf16(self):
model_id = "google/gemma-2-9b"
@ -198,7 +208,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
@tooslow
@require_torch_large_accelerator
@require_read_token
def test_model_9b_fp16(self):
model_id = "google/gemma-2-9b"
@ -220,7 +230,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
@tooslow
@require_torch_large_accelerator
def test_model_9b_pipeline_bf16(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-9b"
@ -246,10 +256,15 @@ class Gemma2IntegrationTest(unittest.TestCase):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-2b"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
"Hi today I'm going to be talking about the 10 best anime of all time.\n\n1",
]
EXPECTED_BATCH_TEXTS = Expectations(
{
("cuda", 8): [
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
"Hi today I'm going to be talking about the 10 most powerful characters in the Naruto series.",
]
}
)
EXPECTED_BATCH_TEXT = EXPECTED_BATCH_TEXTS.get_expectation()
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
@ -259,21 +274,20 @@ class Gemma2IntegrationTest(unittest.TestCase):
output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
self.assertEqual(output[0][0]["generated_text"], EXPECTED_BATCH_TEXT[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_BATCH_TEXT[1])
@require_read_token
@require_flash_attn
@require_torch_gpu
@require_torch_large_gpu
@mark.flash_attn_test
@slow
@tooslow
def test_model_9b_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic composed of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the",
] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(
@ -299,9 +313,17 @@ class Gemma2IntegrationTest(unittest.TestCase):
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = [
"Hello I am doing a project for my school and I need to know how to make a program that will take a number",
]
EXPECTED_TEXT_COMPLETIONS = Expectations(
{
("cuda", 7): [
"Hello I am doing a project for my school and I need to know how to make a program that will take a number"
],
("cuda", 8): [
"Hello I am doing a project for my class and I am having trouble with the code. I am trying to make a"
],
}
)
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]
@ -343,6 +365,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
@slow
@require_read_token
@require_large_cpu_ram
def test_export_hybrid_cache(self):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
from transformers.pytorch_utils import is_torch_greater_or_equal
@ -379,8 +402,8 @@ class Gemma2IntegrationTest(unittest.TestCase):
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
self.assertEqual(export_generated_text, eager_generated_text)
@require_torch_large_accelerator
@require_read_token
@tooslow
def test_model_9b_bf16_flex_attention(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
@ -407,6 +430,8 @@ class Gemma2IntegrationTest(unittest.TestCase):
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
self.skipTest("FlashAttention2 is required for this test.")
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")