mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix Gemma2IntegrationTest
(#38492)
* fix * fix * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * skip-ci * update * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
1094dd34f7
commit
ccc859620a
@ -116,6 +116,7 @@ from .utils import (
|
||||
is_peft_available,
|
||||
is_phonemizer_available,
|
||||
is_pretty_midi_available,
|
||||
is_psutil_available,
|
||||
is_pyctcdecode_available,
|
||||
is_pytesseract_available,
|
||||
is_pytest_available,
|
||||
@ -1053,6 +1054,19 @@ def require_torch_gpu(test_case):
|
||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||
|
||||
|
||||
def require_large_cpu_ram(test_case, memory: float = 80):
|
||||
"""Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
|
||||
if not is_psutil_available():
|
||||
return test_case
|
||||
|
||||
import psutil
|
||||
|
||||
return unittest.skipUnless(
|
||||
psutil.virtual_memory().total / 1024**3 > memory,
|
||||
f"test requires a machine with more than {memory} GiB of CPU RAM memory",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_large_gpu(test_case, memory: float = 20):
|
||||
"""Decorator marking a test that requires a CUDA GPU with more than `memory` GiB of memory."""
|
||||
if torch_device != "cuda":
|
||||
|
@ -24,6 +24,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
is_flash_attn_2_available,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
@ -282,6 +283,9 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
|
||||
self.skipTest("FlashAttention2 is required for this test.")
|
||||
|
||||
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
|
||||
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")
|
||||
|
||||
|
@ -23,13 +23,17 @@ from pytest import mark
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
is_flash_attn_2_available,
|
||||
require_flash_attn,
|
||||
require_large_cpu_ram,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_large_accelerator,
|
||||
require_torch_large_gpu,
|
||||
slow,
|
||||
tooslow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@ -177,7 +181,13 @@ class Gemma2ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
class Gemma2IntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
|
||||
@tooslow
|
||||
def setUp(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@require_torch_large_accelerator
|
||||
@require_read_token
|
||||
def test_model_9b_bf16(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
@ -198,7 +208,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@tooslow
|
||||
@require_torch_large_accelerator
|
||||
@require_read_token
|
||||
def test_model_9b_fp16(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
@ -220,7 +230,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
@tooslow
|
||||
@require_torch_large_accelerator
|
||||
def test_model_9b_pipeline_bf16(self):
|
||||
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
|
||||
model_id = "google/gemma-2-9b"
|
||||
@ -246,10 +256,15 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
|
||||
model_id = "google/gemma-2-2b"
|
||||
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
|
||||
"Hi today I'm going to be talking about the 10 best anime of all time.\n\n1",
|
||||
]
|
||||
EXPECTED_BATCH_TEXTS = Expectations(
|
||||
{
|
||||
("cuda", 8): [
|
||||
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
|
||||
"Hi today I'm going to be talking about the 10 most powerful characters in the Naruto series.",
|
||||
]
|
||||
}
|
||||
)
|
||||
EXPECTED_BATCH_TEXT = EXPECTED_BATCH_TEXTS.get_expectation()
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
@ -259,21 +274,20 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
|
||||
output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)
|
||||
|
||||
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
|
||||
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
|
||||
self.assertEqual(output[0][0]["generated_text"], EXPECTED_BATCH_TEXT[0])
|
||||
self.assertEqual(output[1][0]["generated_text"], EXPECTED_BATCH_TEXT[1])
|
||||
|
||||
@require_read_token
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_torch_large_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@tooslow
|
||||
def test_model_9b_flash_attn(self):
|
||||
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
|
||||
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
|
||||
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic composed of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the",
|
||||
] # fmt: skip
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -299,9 +313,17 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Hello I am doing a project for my school and I need to know how to make a program that will take a number",
|
||||
]
|
||||
EXPECTED_TEXT_COMPLETIONS = Expectations(
|
||||
{
|
||||
("cuda", 7): [
|
||||
"Hello I am doing a project for my school and I need to know how to make a program that will take a number"
|
||||
],
|
||||
("cuda", 8): [
|
||||
"Hello I am doing a project for my class and I am having trouble with the code. I am trying to make a"
|
||||
],
|
||||
}
|
||||
)
|
||||
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
|
||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||
"input_ids"
|
||||
].shape[-1]
|
||||
@ -343,6 +365,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
@require_large_cpu_ram
|
||||
def test_export_hybrid_cache(self):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal
|
||||
@ -379,8 +402,8 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
|
||||
self.assertEqual(export_generated_text, eager_generated_text)
|
||||
|
||||
@require_torch_large_accelerator
|
||||
@require_read_token
|
||||
@tooslow
|
||||
def test_model_9b_bf16_flex_attention(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
@ -407,6 +430,8 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
|
||||
self.skipTest("FlashAttention2 is required for this test.")
|
||||
|
||||
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
|
||||
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")
|
||||
|
Loading…
Reference in New Issue
Block a user