enable more test cases on xpu (#38572)

* enable glm4 integration cases on XPU, set xpu expectation for blip2

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* more

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* refine wording

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* refine test case names

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* run

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* add gemma2 and chameleon

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix review comments

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Matrix YAO <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix 2025-06-06 15:29:51 +08:00 committed by GitHub
parent 31023b6909
commit 89542fb81c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 150 additions and 72 deletions

View File

@ -25,8 +25,8 @@ _EXPECTED_OUTPUTS = [
@slow
@require_torch_gpu
@require_flash_attn
@require_torch_gpu
class TestBatchGeneration(unittest.TestCase):
@classmethod
def setUpClass(cls):

View File

@ -34,6 +34,7 @@ from transformers.models.bark.generation_configuration_bark import (
BarkSemanticGenerationConfig,
)
from transformers.testing_utils import (
backend_torch_accelerator_module,
require_flash_attn,
require_torch,
require_torch_accelerator,
@ -1306,7 +1307,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
# standard generation
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
torch_accelerator_module.empty_cache()

View File

@ -1708,10 +1708,14 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
expectations = Expectations(
{
("xpu", 3): [
[0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1],
"a woman is playing with her dog on the beach",
],
("cuda", 7): [
[0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1],
"a woman is playing with her dog on the beach",
]
],
}
)
expected_outputs = expectations.get_expectation()
@ -1729,10 +1733,14 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
expectations = Expectations(
{
("xpu", 3): [
[0, 3, 7, 152, 2515, 11389, 3523, 1],
"san francisco",
],
("cuda", 7): [
[0, 3, 7, 152, 2515, 11389, 3523, 1],
"san francisco",
]
],
}
)
expected_outputs = expectations.get_expectation()
@ -1755,10 +1763,14 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
expectations = Expectations(
{
("xpu", 3): [
[0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1],
[0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1],
],
("cuda", 7): [
[0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1],
[0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1],
]
],
}
)
expected_predictions = expectations.get_expectation()

View File

@ -420,6 +420,7 @@ class ChameleonIntegrationTest(unittest.TestCase):
# greedy generation outputs
EXPECTED_TEXT_COMPLETIONS = Expectations(
{
("xpu", 3): ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Altair. The star map is set against a black background, with the constellations visible in the night'],
("cuda", 7): ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in'],
("cuda", 8): ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot representing the position of the star Alpha Centauri. Alpha Centauri is the brightest star in the constellation Centaurus and is located'],
}
@ -457,6 +458,10 @@ class ChameleonIntegrationTest(unittest.TestCase):
# greedy generation outputs
EXPECTED_TEXT_COMPLETIONS = Expectations(
{
("xpu", 3): [
'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Altair. The star map is set against a black background, with the constellations visible in the night',
'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.',
],
("cuda", 7): [
'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and',
'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.',

View File

@ -19,7 +19,7 @@ from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_multi_gpu,
require_torch_multi_accelerator,
require_torch_sdpa,
slow,
torch_device,
@ -203,7 +203,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
@require_torch
@slow
class CohereIntegrationTest(unittest.TestCase):
@require_torch_multi_gpu
@require_torch_multi_accelerator
@require_bitsandbytes
def test_batched_4bit(self):
model_id = "CohereForAI/c4ai-command-r-v01-4bit"

View File

@ -14,7 +14,6 @@
# limitations under the License.
"""Testing suite for the PyTorch ColQwen2 model."""
import gc
import unittest
from typing import ClassVar
@ -27,7 +26,7 @@ from transformers import is_torch_available
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrieval, ColQwen2ForRetrievalOutput
from transformers.models.colqwen2.processing_colqwen2 import ColQwen2Processor
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import cleanup, require_torch, require_vision, slow, torch_device
if is_torch_available():
@ -282,8 +281,7 @@ class ColQwen2ModelIntegrationTest(unittest.TestCase):
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
cleanup(torch_device, gc_collect=True)
@slow
def test_model_integration_test(self):

View File

@ -19,7 +19,13 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import AnnotationFormatTestMixin, ImageProcessingTestMixin, prepare_image_inputs
@ -607,9 +613,9 @@ class DeformableDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessi
self.assertEqual(inputs["pixel_values"].shape, torch.Size([1, 3, 50, 50]))
@slow
@require_torch_gpu
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
@require_torch_accelerator
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_accelerator_coco_detection_annotations
def test_fast_processor_equivalence_cpu_accelerator_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt") as f:
@ -622,8 +628,8 @@ class DeformableDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessi
# 1. run processor on CPU
encoding_cpu = processor(images=image, annotations=target, return_tensors="pt", device="cpu")
# 2. run processor on GPU
encoding_gpu = processor(images=image, annotations=target, return_tensors="pt", device="cuda")
# 2. run processor on accelerator
encoding_gpu = processor(images=image, annotations=target, return_tensors="pt", device=torch_device)
# verify pixel values
self.assertEqual(encoding_cpu["pixel_values"].shape, encoding_gpu["pixel_values"].shape)
@ -665,9 +671,9 @@ class DeformableDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessi
torch.testing.assert_close(encoding_cpu["labels"][0]["size"], encoding_gpu["labels"][0]["size"].to("cpu"))
@slow
@require_torch_gpu
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_gpu_coco_panoptic_annotations
def test_fast_processor_equivalence_cpu_gpu_coco_panoptic_annotations(self):
@require_torch_accelerator
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_accelerator_coco_panoptic_annotations
def test_fast_processor_equivalence_cpu_accelerator_coco_panoptic_annotations(self):
# prepare image, target and masks_path
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_panoptic_annotations.txt") as f:
@ -684,9 +690,9 @@ class DeformableDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessi
encoding_cpu = processor(
images=image, annotations=target, masks_path=masks_path, return_tensors="pt", device="cpu"
)
# 2. run processor on GPU
# 2. run processor on accelerator
encoding_gpu = processor(
images=image, annotations=target, masks_path=masks_path, return_tensors="pt", device="cuda"
images=image, annotations=target, masks_path=masks_path, return_tensors="pt", device=torch_device
)
# verify pixel values

View File

@ -746,7 +746,7 @@ class DeformableDetrModelIntegrationTests(unittest.TestCase):
torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=1e-4, atol=1e-4)
@require_torch_accelerator
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
def test_inference_object_detection_head_equivalence_cpu_accelerator(self):
image_processor = self.default_image_processor
image = prepare_img()
encoding = image_processor(images=image, return_tensors="pt")
@ -759,7 +759,7 @@ class DeformableDetrModelIntegrationTests(unittest.TestCase):
with torch.no_grad():
cpu_outputs = model(pixel_values, pixel_mask)
# 2. run model on GPU
# 2. run model on accelerator
model.to(torch_device)
with torch.no_grad():

View File

@ -18,7 +18,14 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_torch_gpu, require_torchvision, require_vision, slow
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_torchvision,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import AnnotationFormatTestMixin, ImageProcessingTestMixin, prepare_image_inputs
@ -666,9 +673,9 @@ class DetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixi
self.assertEqual(inputs["pixel_values"].shape, torch.Size([1, 3, 50, 50]))
@slow
@require_torch_gpu
@require_torch_accelerator
@require_torchvision
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
def test_fast_processor_equivalence_cpu_accelerator_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt") as f:
@ -679,8 +686,8 @@ class DetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixi
processor = self.image_processor_list[1]()
# 1. run processor on CPU
encoding_cpu = processor(images=image, annotations=target, return_tensors="pt", device="cpu")
# 2. run processor on GPU
encoding_gpu = processor(images=image, annotations=target, return_tensors="pt", device="cuda")
# 2. run processor on accelerator
encoding_gpu = processor(images=image, annotations=target, return_tensors="pt", device=torch_device)
# verify pixel values
self.assertEqual(encoding_cpu["pixel_values"].shape, encoding_gpu["pixel_values"].shape)
@ -722,9 +729,9 @@ class DetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixi
torch.testing.assert_close(encoding_cpu["labels"][0]["size"], encoding_gpu["labels"][0]["size"].to("cpu"))
@slow
@require_torch_gpu
@require_torch_accelerator
@require_torchvision
def test_fast_processor_equivalence_cpu_gpu_coco_panoptic_annotations(self):
def test_fast_processor_equivalence_cpu_accelerator_coco_panoptic_annotations(self):
# prepare image, target and masks_path
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_panoptic_annotations.txt") as f:
@ -739,9 +746,9 @@ class DetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixi
encoding_cpu = processor(
images=image, annotations=target, masks_path=masks_path, return_tensors="pt", device="cpu"
)
# 2. run processor on GPU
# 2. run processor on accelerator
encoding_gpu = processor(
images=image, annotations=target, masks_path=masks_path, return_tensors="pt", device="cuda"
images=image, annotations=target, masks_path=masks_path, return_tensors="pt", device=torch_device
)
# verify pixel values

View File

@ -258,10 +258,14 @@ class Gemma2IntegrationTest(unittest.TestCase):
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_BATCH_TEXTS = Expectations(
{
("xpu", 3): [
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
"Hi today I'm going to be talking about the 10 most powerful characters in the Naruto series.",
],
("cuda", 8): [
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
"Hi today I'm going to be talking about the 10 most powerful characters in the Naruto series.",
]
],
}
)
EXPECTED_BATCH_TEXT = EXPECTED_BATCH_TEXTS.get_expectation()
@ -315,6 +319,9 @@ class Gemma2IntegrationTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETIONS = Expectations(
{
("xpu", 3): [
"Hello I am doing a project for my school and I need to know how to make a program that will take a number"
],
("cuda", 7): [
"Hello I am doing a project for my school and I need to know how to make a program that will take a number"
],

View File

@ -31,6 +31,7 @@ from transformers.testing_utils import (
Expectations,
cleanup,
is_flash_attn_2_available,
require_deterministic_for_xpu,
require_flash_attn,
require_read_token,
require_torch,
@ -386,6 +387,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@require_deterministic_for_xpu
def test_model_4b_bf16(self):
model_id = "google/gemma-3-4b-it"
@ -406,6 +408,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'],
("cuda", 7): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'],
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'],
}
@ -414,6 +417,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXT)
@require_torch_large_accelerator
@require_deterministic_for_xpu
def test_model_4b_batch(self):
model_id = "google/gemma-3-4b-it"
@ -450,12 +454,17 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3):
[
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and',
'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes:\n\n* **Image 1** shows a cow standing on a beach.',
],
("cuda", 7): [],
("cuda", 8):
[
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and',
'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes:\n\n* **Image 1** shows a cow standing on a beach.',
]
],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
@ -493,8 +502,9 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
("cuda", 7): [],
("cuda", 8): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.']
("cuda", 8): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
@ -502,6 +512,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXT)
@require_torch_large_accelerator
@require_deterministic_for_xpu
def test_model_4b_batch_crops(self):
model_id = "google/gemma-3-4b-it"
@ -546,11 +557,15 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): [
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a',
],
("cuda", 7): [],
("cuda", 8): [
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a',
]
],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
@ -589,13 +604,15 @@ class Gemma3IntegrationTest(unittest.TestCase):
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Chinese Arch"],
("cuda", 7): [],
("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"]
("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
self.assertEqual(output_text, EXPECTED_TEXT)
@require_deterministic_for_xpu
def test_model_1b_text_only(self):
model_id = "google/gemma-3-1b-it"
@ -610,6 +627,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'],
("cuda", 7): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'],
("cuda", 8): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'],
}
@ -641,6 +659,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'],
("cuda", 7): [],
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'],
}

View File

@ -24,6 +24,7 @@ from transformers.testing_utils import (
cleanup,
require_flash_attn,
require_torch,
require_torch_large_accelerator,
require_torch_large_gpu,
require_torch_sdpa,
slow,
@ -79,7 +80,7 @@ class Glm4ModelTest(CausalLMModelTest, unittest.TestCase):
@slow
@require_torch_large_gpu
@require_torch_large_accelerator
class Glm4IntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
model_id = "THUDM/GLM-4-9B-0414"
@ -90,6 +91,10 @@ class Glm4IntegrationTest(unittest.TestCase):
def test_model_9b_fp16(self):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): [
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
"Hi today I am going to tell you about the most common disease in the world. This disease is called diabetes",
],
("cuda", 7): [],
("cuda", 8): [
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
@ -114,6 +119,10 @@ class Glm4IntegrationTest(unittest.TestCase):
def test_model_9b_bf16(self):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): [
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
"Hi today I am going to tell you about the most common disease in the world. This disease is called diabetes",
],
("cuda", 7): [],
("cuda", 8): [
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
@ -138,6 +147,10 @@ class Glm4IntegrationTest(unittest.TestCase):
def test_model_9b_eager(self):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): [
"Hello I am doing a project on the history of the internet and I need to know what the first website was and who",
"Hi today I am going to tell you about the most common disease in the world. This disease is called diabetes",
],
("cuda", 7): [],
("cuda", 8): [
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
@ -167,6 +180,10 @@ class Glm4IntegrationTest(unittest.TestCase):
def test_model_9b_sdpa(self):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): [
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
"Hi today I am going to tell you about the most common disease in the world. This disease is called diabetes",
],
("cuda", 7): [],
("cuda", 8): [
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
@ -193,6 +210,7 @@ class Glm4IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXT)
@require_flash_attn
@require_torch_large_gpu
@pytest.mark.flash_attn_test
def test_model_9b_flash_attn(self):
EXPECTED_TEXTS = Expectations(

View File

@ -718,7 +718,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
@require_torch_accelerator
@is_flaky()
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
def test_inference_object_detection_head_equivalence_cpu_accelerator(self):
processor = self.default_processor
image = prepare_img()
text = prepare_text()
@ -730,7 +730,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
with torch.no_grad():
cpu_outputs = model(**encoding)
# 2. run model on GPU
# 2. run model on accelerator
model.to(torch_device)
encoding = encoding.to(torch_device)
with torch.no_grad():

View File

@ -18,7 +18,7 @@ import unittest
from transformers import is_torch_available
from transformers.testing_utils import (
require_read_token,
require_torch_large_gpu,
require_torch_large_accelerator,
slow,
torch_device,
)
@ -34,7 +34,7 @@ if is_torch_available():
@slow
@require_torch_large_gpu
@require_torch_large_accelerator
@require_read_token
class Llama4IntegrationTest(unittest.TestCase):
model_id = "meta-llama/Llama-4-Scout-17B-16E"

View File

@ -870,7 +870,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
self.assertListEqual([result["text_labels"] for result in results], expected_text_labels)
@require_torch_accelerator
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
def test_inference_object_detection_head_equivalence_cpu_accelerator(self):
processor = self.default_processor
image = prepare_img()
text_labels, task = prepare_text()
@ -881,7 +881,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
with torch.no_grad():
cpu_outputs = model(**encoding)
# 2. run model on GPU
# 2. run model on accelerator
model.to(torch_device)
encoding = encoding.to(torch_device)
with torch.no_grad():

View File

@ -19,10 +19,11 @@ import requests
from transformers.testing_utils import (
is_flaky,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_torchvision,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -379,10 +380,10 @@ class RtDetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
torch.testing.assert_close(encoding["labels"][1]["boxes"], expected_boxes_1, atol=1, rtol=1)
@slow
@require_torch_gpu
@require_torch_accelerator
@require_torchvision
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_accelerator_coco_detection_annotations
def test_fast_processor_equivalence_cpu_accelerator_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt") as f:
@ -393,8 +394,8 @@ class RtDetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
processor = self.image_processor_list[1]()
# 1. run processor on CPU
encoding_cpu = processor(images=image, annotations=target, return_tensors="pt", device="cpu")
# 2. run processor on GPU
encoding_gpu = processor(images=image, annotations=target, return_tensors="pt", device="cuda")
# 2. run processor on accelerator
encoding_gpu = processor(images=image, annotations=target, return_tensors="pt", device=torch_device)
# verify pixel values
self.assertEqual(encoding_cpu["pixel_values"].shape, encoding_gpu["pixel_values"].shape)

View File

@ -22,7 +22,7 @@ from PIL import Image
from transformers import is_torch_available
from transformers.testing_utils import (
cleanup,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@ -35,7 +35,7 @@ if is_torch_available():
@slow
@require_torch_gpu
@require_torch_accelerator
# @require_read_token
class ShieldGemma2IntegrationTest(unittest.TestCase):
def tearDown(self):

View File

@ -23,7 +23,11 @@ import numpy as np
from datasets import load_dataset
from transformers import WhisperFeatureExtractor
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torch_gpu
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_torch,
require_torch_accelerator,
)
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
@ -254,7 +258,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
return [x["array"] for x in speech_samples]
@require_torch_gpu
@require_torch_accelerator
@require_torch
def test_torch_integration(self):
# fmt: off
@ -303,7 +307,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
self.assertTrue(np.all(np.mean(audio) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
@require_torch_gpu
@require_torch_accelerator
@require_torch
def test_torch_integration_batch(self):
# fmt: off

View File

@ -730,7 +730,7 @@ class MixedInt8TestCpuGpu(BaseMixedInt8Test):
output_text = self.tokenizer.decode(output_parallel[0], skip_special_tokens=True)
self.assertIn(output_text, self.EXPECTED_OUTPUTS)
def test_cpu_gpu_loading_random_device_map(self):
def test_cpu_accelerator_loading_random_device_map(self):
r"""
A test to check is dispatching a model on cpu & gpu works correctly using a random `device_map`.
"""
@ -778,7 +778,7 @@ class MixedInt8TestCpuGpu(BaseMixedInt8Test):
self.check_inference_correctness(model_8bit)
def test_cpu_gpu_loading_custom_device_map(self):
def test_cpu_accelerator_loading_custom_device_map(self):
r"""
A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.
This time the device map is more organized than the test above and uses the abstraction
@ -805,7 +805,7 @@ class MixedInt8TestCpuGpu(BaseMixedInt8Test):
self.check_inference_correctness(model_8bit)
def test_cpu_gpu_disk_loading_custom_device_map(self):
def test_cpu_accelerator_disk_loading_custom_device_map(self):
r"""
A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.
This time we also add `disk` on the device_map.
@ -832,7 +832,7 @@ class MixedInt8TestCpuGpu(BaseMixedInt8Test):
self.check_inference_correctness(model_8bit)
def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):
def test_cpu_accelerator_disk_loading_custom_device_map_kwargs(self):
r"""
A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.
This time we also add `disk` on the device_map - using the kwargs directly instead of the quantization config

View File

@ -20,7 +20,7 @@ from transformers import AddedToken, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from transformers.testing_utils import (
require_gguf,
require_read_token,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@ -35,7 +35,7 @@ if is_gguf_available():
@require_gguf
@require_torch_gpu
@require_torch_accelerator
@slow
class GgufQuantizationTests(unittest.TestCase):
"""
@ -107,7 +107,7 @@ class GgufQuantizationTests(unittest.TestCase):
@require_gguf
@require_torch_gpu
@require_torch_accelerator
@slow
class GgufIntegrationTests(unittest.TestCase):
"""
@ -263,7 +263,7 @@ class GgufIntegrationTests(unittest.TestCase):
@require_gguf
@require_torch_gpu
@require_torch_accelerator
@slow
class GgufModelTests(unittest.TestCase):
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"

View File

@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, QuarkConfig
from transformers.testing_utils import (
cleanup,
is_torch_available,
require_accelerate,
require_quark,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils.import_utils import is_quark_available
@ -79,11 +80,10 @@ class QuarkTest(unittest.TestCase):
def tearDown(self):
r"""
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
TearDown function needs to be called at the end of each test to free the accelerator memory and cache, also to
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
"""
gc.collect()
torch.cuda.empty_cache()
cleanup(torch_device, gc_collect=True)
def test_memory_footprint(self):
mem_quantized = self.quantized_model.get_memory_footprint()

View File

@ -30,7 +30,7 @@ from transformers.testing_utils import (
check_json_file_has_correct_format,
is_flaky,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_vision,
slow,
torch_device,
@ -562,7 +562,7 @@ class ImageProcessingTestMixin:
self.skipTest(reason="No validation found for `preprocess` method")
@slow
@require_torch_gpu
@require_torch_accelerator
@require_vision
def test_can_compile_fast_image_processor(self):
if self.fast_image_processing_class is None:

View File

@ -26,7 +26,7 @@ from transformers import AutoVideoProcessor
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_vision,
slow,
torch_device,
@ -165,7 +165,7 @@ class VideoProcessingTestMixin:
self.assertIsNotNone(video_processor)
@slow
@require_torch_gpu
@require_torch_accelerator
@require_vision
def test_can_compile_fast_video_processor(self):
if self.fast_video_processing_class is None: