Fix some tests (#35682)

* cohere tests

* glm tests

* cohere2 model name

* create decorator

* update

* fix cohere2 completions

* style

* style

* style

* add cuda in comments
This commit is contained in:
Cyril Vallez 2025-01-17 12:10:43 +00:00 committed by GitHub
parent 8c1b5d3782
commit ab1afd56f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 38 additions and 38 deletions

View File

@ -595,7 +595,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_flex_attn = False
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True

View File

@ -431,7 +431,7 @@ class DiffLlamaDecoderLayer(LlamaDecoderLayer):
class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
pass
_supports_flex_attn = False
class DiffLlamaModel(LlamaModel):

View File

@ -988,6 +988,17 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_torch_large_gpu(test_case, memory: float = 20):
"""Decorator marking a test that requires a CUDA GPU with more than `memory` GiB of memory."""
if torch_device != "cuda":
return unittest.skip(reason=f"test requires a CUDA GPU with more than {memory} GiB of memory")(test_case)
return unittest.skipUnless(
torch.cuda.get_device_properties(0).total_memory / 1024**3 > memory,
f"test requires a GPU with more than {memory} GiB of memory",
)(test_case)
def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case):
"""
Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled.

View File

@ -347,7 +347,7 @@ class CohereIntegrationTest(unittest.TestCase):
[[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]],
[[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]],
]
).to(torch_device)
).to(device=torch_device, dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = CohereForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(

View File

@ -26,7 +26,7 @@ from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
require_torch_large_gpu,
slow,
torch_device,
)
@ -182,7 +182,8 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
@slow
@require_torch_gpu
@require_read_token
@require_torch_large_gpu
class Cohere2IntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
@ -195,12 +196,11 @@ class Cohere2IntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@require_read_token
def test_model_bf16(self):
model_id = "CohereForAI/command-r7b-12-2024"
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
EXPECTED_TEXTS = [
"<BOS_TOKEN>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<PAD><PAD><BOS_TOKEN>Hi today I'm going to be talking about the history of the United States. The United States of America",
"<BOS_TOKEN>Hello I am doing a project for a school assignment and I need to create a website for a fictional company. I have",
"<PAD><PAD><BOS_TOKEN>Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n",
]
model = AutoModelForCausalLM.from_pretrained(
@ -215,12 +215,11 @@ class Cohere2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
def test_model_fp16(self):
model_id = "CohereForAI/command-r7b-12-2024"
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
EXPECTED_TEXTS = [
"<BOS_TOKEN>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<PAD><PAD><BOS_TOKEN>Hi today I'm going to be talking about the history of the United States. The United States of America",
"<BOS_TOKEN>Hello I am doing a project for a school assignment and I need to create a website for a fictional company. I have",
"<PAD><PAD><BOS_TOKEN>Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n",
]
model = AutoModelForCausalLM.from_pretrained(
@ -235,14 +234,13 @@ class Cohere2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
def test_model_pipeline_bf16(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Cohere2 before this PR
model_id = "CohereForAI/command-r7b-12-2024"
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hi today I'm going to be talking about the history of the United States. The United States of America",
"Hello I am doing a project for a school assignment and I need to create a website for a fictional company. I have",
"Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n",
]
model = AutoModelForCausalLM.from_pretrained(
@ -256,17 +254,14 @@ class Cohere2IntegrationTest(unittest.TestCase):
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
@require_read_token
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_model_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for Gemma2, especially in long context
model_id = "CohereForAI/command-r7b-12-2024"
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
EXPECTED_TEXTS = [
'<BOS_TOKEN>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
"<PAD><PAD><BOS_TOKEN>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
'<BOS_TOKEN>Hello I am doing a project for my school and I need to create a website for a fictional company. I have the logo and the name of the company. I need a website that is simple and easy to navigate. I need a home page, about us, services, contact us, and a gallery. I need the website to be responsive and I need it to be able to be hosted on a server. I need the website to be done in a week. I need the website to be done in HTML,',
"<PAD><PAD><BOS_TOKEN>Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n\nThis recipe is very simple and easy to make.\n\nYou will need:\n\n* 2 cups of flour\n* 1 cup of sugar\n* 1/2 cup of cocoa powder\n* 1 teaspoon of baking powder\n* 1 teaspoon of baking soda\n* 1/2 teaspoon of salt\n* 2 eggs\n* 1 cup of milk\n",
] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(
@ -280,8 +275,6 @@ class Cohere2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
@slow
@require_read_token
def test_export_static_cache(self):
if version.parse(torch.__version__) < version.parse("2.5.0"):
self.skipTest(reason="This test requires torch >= 2.5 to run.")
@ -291,16 +284,12 @@ class Cohere2IntegrationTest(unittest.TestCase):
convert_and_export_with_cache,
)
tokenizer = AutoTokenizer.from_pretrained(
"CohereForAI/command-r7b-12-2024", pad_token="<PAD>", padding_side="right"
)
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
EXPECTED_TEXT_COMPLETION = [
"Hello I am doing a project for my school and I need to know how to make a program that will take a number",
"Hello I am doing a project on the effects of social media on mental health. I have a few questions. 1. What is the relationship",
]
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>", padding_side="right")
# Load model
device = "cpu"
dtype = torch.bfloat16
@ -308,17 +297,17 @@ class Cohere2IntegrationTest(unittest.TestCase):
attn_implementation = "sdpa"
batch_size = 1
model = AutoModelForCausalLM.from_pretrained(
"CohereForAI/command-r7b-12-2024",
"CohereForAI/c4ai-command-r7b-12-2024",
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_generation_length,
max_length=30,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
"max_cache_len": 30,
},
),
)
@ -326,7 +315,7 @@ class Cohere2IntegrationTest(unittest.TestCase):
prompts = ["Hello I am doing"]
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
prompt_token_ids = prompt_tokens["input_ids"]
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
max_new_tokens = 30 - prompt_token_ids.shape[-1]
# Static Cache + export
exported_program = convert_and_export_with_cache(model)

View File

@ -23,7 +23,7 @@ from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_accelerator,
require_torch_large_gpu,
require_torch_sdpa,
slow,
torch_device,
@ -418,7 +418,7 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@slow
@require_torch_accelerator
@require_torch_large_gpu
class GlmIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
model_id = "THUDM/glm-4-9b"