Expectation fixes and added AMD expectations (#38729)

This commit is contained in:
Rémi Ouazan 2025-06-13 16:14:58 +02:00 committed by GitHub
parent e39172ecab
commit 9ff246db00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 311 additions and 146 deletions

View File

@ -3237,7 +3237,9 @@ def cleanup(device: str, gc_collect=False):
# Type definition of key used in `Expectations` class.
DeviceProperties = tuple[Union[str, None], Union[int, None]]
DeviceProperties = tuple[Optional[str], Optional[int], Optional[int]]
# Helper type. Makes creating instances of `Expectations` smoother.
PackedDeviceProperties = tuple[Optional[str], Union[None, int, tuple[int, int]]]
@cache
@ -3248,11 +3250,11 @@ def get_device_properties() -> DeviceProperties:
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
import torch
major, _ = torch.cuda.get_device_capability()
major, minor = torch.cuda.get_device_capability()
if IS_ROCM_SYSTEM:
return ("rocm", major)
return ("rocm", major, minor)
else:
return ("cuda", major)
return ("cuda", major, minor)
elif IS_XPU_SYSTEM:
import torch
@ -3260,58 +3262,81 @@ def get_device_properties() -> DeviceProperties:
arch = torch.xpu.get_device_capability()["architecture"]
gen_mask = 0x000000FF00000000
gen = (arch & gen_mask) >> 32
return ("xpu", gen)
return ("xpu", gen, None)
else:
return (torch_device, None)
return (torch_device, None, None)
class Expectations(UserDict[DeviceProperties, Any]):
def unpack_device_properties(
properties: Optional[PackedDeviceProperties] = None,
) -> DeviceProperties:
"""
Unpack a `PackedDeviceProperties` tuple into consistently formatted `DeviceProperties` tuple. If properties is None, it is fetched.
"""
if properties is None:
return get_device_properties()
device_type, major_minor = properties
if major_minor is None:
major, minor = None, None
elif isinstance(major_minor, int):
major, minor = major_minor, None
else:
major, minor = major_minor
return device_type, major, minor
class Expectations(UserDict[PackedDeviceProperties, Any]):
def get_expectation(self) -> Any:
"""
Find best matching expectation based on environment device properties.
"""
return self.find_expectation(get_device_properties())
@staticmethod
def is_default(key: DeviceProperties) -> bool:
return all(p is None for p in key)
def unpacked(self) -> list[tuple[DeviceProperties, Any]]:
return [(unpack_device_properties(k), v) for k, v in self.data.items()]
@staticmethod
def score(key: DeviceProperties, other: DeviceProperties) -> int:
def is_default(properties: DeviceProperties) -> bool:
return all(p is None for p in properties)
@staticmethod
def score(properties: DeviceProperties, other: DeviceProperties) -> float:
"""
Returns score indicating how similar two instances of the `Properties` tuple are.
Points are calculated using bits, but documented as int.
Rules are as follows:
* Matching `type` gives 8 points.
* Semi-matching `type`, for example cuda and rocm, gives 4 points.
* Matching `major` (compute capability major version) gives 2 points.
* Default expectation (if present) gives 1 points.
* Matching `type` adds one point, semi-matching `type` adds half a point (e.g. cuda and rocm).
* If types match, matching `major` adds another point, and then matching `minor` adds another.
* Default expectation (if present) is worth 0.1 point to distinguish it from a straight-up zero.
"""
(device_type, major) = key
(other_device_type, other_major) = other
device_type, major, minor = properties
other_device_type, other_major, other_minor = other
score = 0b0
if device_type == other_device_type:
score |= 0b1000
score = 0
# Matching device type, maybe major and minor
if device_type is not None and device_type == other_device_type:
score += 1
if major is not None and major == other_major:
score += 1
if minor is not None and minor == other_minor:
score += 1
# Semi-matching device type
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
score |= 0b100
if major == other_major and other_major is not None:
score |= 0b10
score = 0.5
# Default expectation
if Expectations.is_default(other):
score |= 0b1
score = 0.1
return int(score)
return score
def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
def find_expectation(self, properties: DeviceProperties = (None, None, None)) -> Any:
"""
Find best matching expectation based on provided device properties.
"""
(result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
(result_key, result) = max(self.unpacked(), key=lambda x: Expectations.score(properties, x[0]))
if Expectations.score(key, result_key) == 0:
raise ValueError(f"No matching expectation found for {key}")
if Expectations.score(properties, result_key) == 0:
raise ValueError(f"No matching expectation found for {properties}")
return result

View File

@ -347,7 +347,8 @@ class AyaVisionIntegrationTest(unittest.TestCase):
@classmethod
def get_model(cls):
# Use 4-bit on T4
load_in_4bit = get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8
device_type, major, _ = get_device_properties()
load_in_4bit = (device_type == "cuda") and (major < 8)
torch_dtype = None if load_in_4bit else torch.float16
if cls.model is None:

View File

@ -27,6 +27,7 @@ from transformers import (
is_torch_available,
)
from transformers.testing_utils import (
DeviceProperties,
Expectations,
get_device_properties,
require_deterministic_for_xpu,
@ -594,7 +595,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
tokenizer = None
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
device_properties = None
device_properties: DeviceProperties = (None, None, None)
@classmethod
def setUpClass(cls):
@ -637,7 +638,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
self.assertEqual(output_sentence, expected)
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.device_properties == ("cuda", 8):
if self.device_properties[0] == "cuda" and self.device_properties[1] == 8:
with torch.no_grad():
logits = self.model(input_ids=input_ids, logits_to_keep=40).logits
@ -690,7 +691,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
self.assertEqual(output_sentences[1], EXPECTED_TEXT[1])
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.device_properties == ("cuda", 8):
if self.device_properties[0] == "cuda" and self.device_properties[1] == 8:
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits

View File

@ -21,6 +21,7 @@ from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
DeviceProperties,
Expectations,
cleanup,
get_device_properties,
@ -108,7 +109,7 @@ class GemmaIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
# Depending on the hardware we get different logits / generations
device_properties = None
device_properties: DeviceProperties = (None, None, None)
@classmethod
def setUpClass(cls):
@ -241,7 +242,7 @@ class GemmaIntegrationTest(unittest.TestCase):
@require_read_token
def test_model_7b_fp16(self):
if self.device_properties == ("cuda", 7):
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
model_id = "google/gemma-7b"
@ -262,7 +263,7 @@ class GemmaIntegrationTest(unittest.TestCase):
@require_read_token
def test_model_7b_bf16(self):
if self.device_properties == ("cuda", 7):
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
model_id = "google/gemma-7b"
@ -293,7 +294,7 @@ class GemmaIntegrationTest(unittest.TestCase):
@require_read_token
def test_model_7b_fp16_static_cache(self):
if self.device_properties == ("cuda", 7):
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
model_id = "google/gemma-7b"

View File

@ -19,6 +19,7 @@ import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer, GlmConfig, is_torch_available
from transformers.testing_utils import (
Expectations,
require_flash_attn,
require_torch,
require_torch_large_accelerator,
@ -118,10 +119,17 @@ class GlmIntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_9b_eager(self):
EXPECTED_TEXTS = [
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
]
expected_texts = Expectations({
("cuda", None): [
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
"Hi today I am going to show you how to make a simple and easy to make a DIY paper flower.",
],
("rocm", (9, 5)) : [
"Hello I am doing a project on the history of the internetSolution:\n\nStep 1: Introduction\nThe history of the",
"Hi today I am going to show you how to make a simple and easy to make a paper airplane. First",
]
}) # fmt: skip
EXPECTED_TEXTS = expected_texts.get_expectation()
model = AutoModelForCausalLM.from_pretrained(
self.model_id,

View File

@ -821,6 +821,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
expected_outputs = Expectations(
{
("rocm", None): 'Today is a nice day and we can do this again."\n\nDana said that she will',
("rocm", (9, 5)): "Today is a nice day and if you don't know anything about the state of play during your holiday",
("cuda", None): "Today is a nice day and if you don't know anything about the state of play during your holiday",
}
) # fmt: skip

View File

@ -17,6 +17,7 @@ import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer, HeliumConfig, is_torch_available
from transformers.testing_utils import (
Expectations,
require_read_token,
require_torch,
slow,
@ -83,9 +84,13 @@ class HeliumIntegrationTest(unittest.TestCase):
@require_read_token
def test_model_2b(self):
model_id = "kyutai/helium-1-preview"
EXPECTED_TEXTS = [
"Hello, today is a great day to start a new project. I have been working on a new project for a while now and I have"
]
expected_texts = Expectations(
{
("rocm", (9, 5)): ["Hello, today is a great day to start a new project. I have been working on a new project for a while now, and I"],
("cuda", None): ["Hello, today is a great day to start a new project. I have been working on a new project for a while now and I have"],
}
) # fmt: skip
EXPECTED_TEXTS = expected_texts.get_expectation()
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, revision="refs/pr/1").to(
torch_device

View File

@ -30,6 +30,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
Expectations,
cleanup,
require_bitsandbytes,
require_flash_attn,
@ -621,8 +622,14 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
self.assertEqual(generated_texts[0], expected_generated_text)
expected_generated_texts = Expectations(
{
("cuda", None): "In this image, we see the Statue of Liberty, the Hudson River,",
("rocm", (9, 5)): "In this image, we see the Statue of Liberty, the New York City",
}
)
EXPECTED_GENERATED_TEXT = expected_generated_texts.get_expectation()
self.assertEqual(generated_texts[0], EXPECTED_GENERATED_TEXT)
@slow
@require_bitsandbytes

View File

@ -537,6 +537,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
{
("xpu", 3): "The man is performing a volley.",
("cuda", 7): "The man is performing a forehand shot.",
("rocm", (9, 5)): "The man is performing a volley shot.",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()

View File

@ -21,6 +21,7 @@ import pytest
from transformers import AutoTokenizer, JambaConfig, is_torch_available
from transformers.testing_utils import (
DeviceProperties,
Expectations,
get_device_properties,
require_bitsandbytes,
@ -557,7 +558,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
tokenizer = None
# This variable is used to determine which acclerator are we using for our runners (e.g. A10 or T4)
# Depending on the hardware we get different logits / generations
device_properties = None
device_properties: DeviceProperties = (None, None, None)
@classmethod
def setUpClass(cls):
@ -595,7 +596,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
self.assertEqual(output_sentence, expected_sentence)
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.device_properties == ("cuda", 8):
if self.device_properties[0] == "cuda" and self.device_properties[1] == 8:
with torch.no_grad():
logits = self.model(input_ids=input_ids).logits
@ -638,7 +639,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
self.assertEqual(output_sentences[1], expected_sentences[1])
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.device_properties == ("cuda", 8):
if self.device_properties[0] == "cuda" and self.device_properties[1] == 8:
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits

View File

@ -541,16 +541,24 @@ class JanusIntegrationTest(unittest.TestCase):
# fmt: off
expected_tokens = Expectations(
{
("rocm", None): [10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770,
12353, 5481, 11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376,
13219, 6007, 14292, 12696, 10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335,
6135, 2316, 15423, 311, 5460, 12218, 14172, 8583, 14577, 3648
],
("cuda", None): [4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548,
1820, 1465, 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146,
10417, 1951, 7713, 14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676
],
("rocm", None): [
10367, 1380, 4841, 15155, 1224, 16361, 15834, 13722, 15258, 8321, 10496, 14532, 8770, 12353, 5481,
11484, 2585, 8587, 3201, 14292, 3356, 2037, 3077, 6107, 3758, 2572, 9376, 13219, 6007, 14292, 12696,
10666, 10046, 13483, 8282, 9101, 5208, 4260, 13886, 13335, 6135, 2316, 15423, 311, 5460, 12218,
14172, 8583, 14577, 3648
],
("rocm", (9, 5)): [
4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465,
13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305,
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
897, 4044, 1762, 4676
],
("cuda", None): [
4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465,
13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305,
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
897, 4044, 1762, 4676
],
}
)
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)

View File

@ -113,7 +113,7 @@ class LlamaIntegrationTest(unittest.TestCase):
"""
# diff on `EXPECTED_TEXT`:
# 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results.
EXPECTED_TEXT = (
expected_base_text = (
"Tell me about the french revolution. The french revolution was a period of radical political and social "
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
@ -122,6 +122,13 @@ class LlamaIntegrationTest(unittest.TestCase):
"demanded greater representation and eventually broke away to form the National Assembly. This marked "
"the beginning of the end of the absolute monarchy and the rise of the middle class.\n"
)
expected_texts = Expectations(
{
("rocm", (9, 5)): expected_base_text.replace("political and social", "social and political"),
("cuda", None): expected_base_text,
}
) # fmt: skip
EXPECTED_TEXT = expected_texts.get_expectation()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model = LlamaForCausalLM.from_pretrained(

View File

@ -341,7 +341,11 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
expected_decoded_texts = Expectations({
("cuda", None): "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly,",
("rocm", (9, 5)): "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. First, the",
}) # fmt: skip
EXPECTED_DECODED_TEXT = expected_decoded_texts.get_expectation()
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
@ -397,12 +401,28 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
expected_decoded_texts = Expectations(
{
("cuda", None): [
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring "
"with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, "
"you",
"USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat "
"is located on",
],
("rocm", (9, 5)): [
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring "
"with me? ASSISTANT: When visiting this serene location, which features a wooden pier overlooking a "
"lake, you should",
"USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat "
"is located on",
],
}
)
EXPECTED_DECODED_TEXT = expected_decoded_texts.get_expectation()
decoded_output = processor.batch_decode(output, skip_special_tokens=True)
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
@ -433,6 +453,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along',
'USER: \nWhat is this?\nASSISTANT: Cats',
],
("rocm", (9, 5)): [
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this dock on a lake, there are several things to be cautious about and items to",
"USER: \nWhat is this?\nASSISTANT: This is a picture of two cats lying on a couch.",
],
}
) # fmt: skip
EXPECTED_DECODED_TEXT = EXPECTED_DECODED_TEXTS.get_expectation()
@ -467,12 +491,28 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
expected_decoded_texts = Expectations(
{
("cuda", None): [
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring "
"with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a "
"body of water",
"USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat "
"sleeping on a bed.",
],
("rocm", (9, 5)): [
"USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring "
"with me?\nASSISTANT: When visiting this place, which is a pier or dock overlooking a lake, you should "
"be",
"USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat "
"sleeping on a bed.",
],
}
)
EXPECTED_DECODED_TEXT = expected_decoded_texts.get_expectation()
decoded_output = processor.batch_decode(output, skip_special_tokens=True)
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
@slow
@require_torch

View File

@ -21,6 +21,7 @@ from packaging import version
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
from transformers.testing_utils import (
DeviceProperties,
Expectations,
backend_empty_cache,
cleanup,
@ -114,7 +115,7 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase):
class MistralIntegrationTest(unittest.TestCase):
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
# Depending on the hardware we get different logits / generations
device_properties = None
device_properties: DeviceProperties = (None, None, None)
@classmethod
def setUpClass(cls):
@ -279,7 +280,7 @@ class MistralIntegrationTest(unittest.TestCase):
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
if self.device_properties == ("cuda", 7):
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
NUM_TOKENS_TO_GENERATE = 40

View File

@ -20,7 +20,6 @@ import pytest
from transformers import MixtralConfig, is_torch_available
from transformers.testing_utils import (
Expectations,
get_device_properties,
require_flash_attn,
require_torch,
require_torch_accelerator,
@ -142,14 +141,6 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase):
@require_torch
class MixtralIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
device_properties = None
@classmethod
def setUpClass(cls):
cls.device_properties = get_device_properties()
@slow
@require_torch_accelerator
def test_small_model_logits(self):

View File

@ -445,7 +445,11 @@ class MptIntegrationTests(unittest.TestCase):
)
input_text = "Hello"
expected_output = "Hello, I'm a new user of the forum. I have a question about the \"Solaris"
expected_outputs = Expectations({
("cuda", None): "Hello, I'm a new user of the forum. I have a question about the \"Solaris",
("rocm", (9, 5)): "Hello, I'm a newbie to the forum. I have a question about the \"B\" in",
}) # fmt: off
expected_output = expected_outputs.get_expectation()
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=20)
@ -463,19 +467,12 @@ class MptIntegrationTests(unittest.TestCase):
)
input_text = "Hello"
expected_outputs = Expectations(
{
(
"xpu",
3,
): "Hello and welcome to the first ever episode of the new and improved, and hopefully improved, podcast.\n",
("cuda", 7): "Hello and welcome to the first episode of the new podcast, The Frugal Feminist.\n",
(
"cuda",
8,
): "Hello and welcome to the first day of the new release countdown for the month of May!\nToday",
}
)
expected_outputs = Expectations({
("rocm", (9, 5)): "Hello and welcome to the first day of the new release at The Stamp Man!\nToday we are",
("xpu", 3): "Hello and welcome to the first ever episode of the new and improved, and hopefully improved, podcast.\n",
("cuda", 7): "Hello and welcome to the first episode of the new podcast, The Frugal Feminist.\n",
("cuda", 8): "Hello and welcome to the first day of the new release countdown for the month of May!\nToday",
}) # fmt: off
expected_output = expected_outputs.get_expectation()
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
@ -510,6 +507,10 @@ class MptIntegrationTests(unittest.TestCase):
"Hello my name is Tiffany and I am a mother of two beautiful children. I have been a nanny for the",
"Today I am going at the gym and then I am going to go to the grocery store. I am going to buy some food and some",
],
("rocm", (9, 5)): [
"Hello my name is Jasmine and I am a very sweet and loving dog. I am a very playful dog and I",
"Today I am going at the gym and then I am going to go to the mall. I am going to buy a new pair of jeans",
],
}
)
expected_output = expected_outputs.get_expectation()
@ -535,9 +536,10 @@ class MptIntegrationTests(unittest.TestCase):
{
("xpu", 3): torch.Tensor([-0.2090, -0.2061, -0.1465]),
("cuda", 7): torch.Tensor([-0.2520, -0.2178, -0.1953]),
# TODO: This is quite a bit off, check BnB
("rocm", (9, 5)): torch.Tensor([-0.3008, -0.1309, -0.1562]),
}
)
expected_slice = expected_slices.get_expectation().to(torch_device, torch.bfloat16)
predicted_slice = outputs.hidden_states[-1][0, 0, :3]
torch.testing.assert_close(expected_slice, predicted_slice, rtol=1e-3, atol=1e-3)

View File

@ -1041,7 +1041,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
(device_type, major) = get_device_properties()
device_type, major, _ = get_device_properties()
if device_type == "cuda" and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:

View File

@ -1041,7 +1041,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
(device_type, major) = get_device_properties()
device_type, major, _ = get_device_properties()
if device_type == "cuda" and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:

View File

@ -27,6 +27,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
Expectations,
cleanup,
require_read_token,
require_torch,
@ -590,7 +591,13 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
expected_decoded_texts = Expectations(
{
("rocm", (9, 5)): "detect shoe\n<loc0051><loc0309><loc0708><loc0644> shoe",
("cuda", None): "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe",
}
) # fmt: skip
EXPECTED_DECODED_TEXT = expected_decoded_texts.get_expectation()
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
def test_paligemma_index_error_bug(self):

View File

@ -19,6 +19,7 @@ import unittest
from transformers import Phi3Config, StaticCache, is_torch_available
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.testing_utils import (
Expectations,
require_torch,
slow,
torch_device,
@ -352,9 +353,14 @@ class Phi3IntegrationTest(unittest.TestCase):
model_id = "microsoft/Phi-4-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = [
"You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. A 45-year-old patient with a 10-year history of type 2 diabetes mellitus, who is currently on metformin and a SGLT2 inhibitor, presents with a 2-year history"
]
expected_text_completions = Expectations(
{
("rocm", (9, 5)): ["You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. A 45-year-old patient with a 10-year history of type 2 diabetes mellitus presents with a 2-year history of progressive, non-healing, and painful, 2.5 cm"],
("cuda", None): ["You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user. A 45-year-old patient with a 10-year history of type 2 diabetes mellitus, who is currently on metformin and a SGLT2 inhibitor, presents with a 2-year history"],
}
) # fmt: skip
EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]

View File

@ -22,6 +22,7 @@ from packaging import version
from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
Expectations,
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
@ -250,9 +251,17 @@ class Qwen2IntegrationTest(unittest.TestCase):
qwen_model = "Qwen/Qwen2-0.5B"
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = [
"My favourite condiment is 100% natural, organic, gluten free, vegan, and free from preservatives. I"
]
expected_text_completions = Expectations({
("cuda", None): [
"My favourite condiment is 100% natural, organic, gluten free, vegan, and free from preservatives. I"
],
("rocm", (9, 5)): [
"My favourite condiment is 100% natural, organic, gluten free, vegan, and vegetarian. I love to use"
]
}) # fmt: off
EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]

View File

@ -22,6 +22,7 @@ from packaging import version
from transformers import AutoTokenizer, Qwen3Config, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
Expectations,
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
@ -246,10 +247,18 @@ class Qwen3IntegrationTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
if version.parse(torch.__version__) == version.parse("2.7.0"):
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."]
cuda_expectation = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."]
else:
strict = True
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
cuda_expectation = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
expected_text_completions = Expectations(
{
("rocm", (9, 5)): ["My favourite condiment is 100% plain, unflavoured, and unadulterated."],
("cuda", None): cuda_expectation,
}
) # fmt: skip
EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"

View File

@ -23,6 +23,7 @@ from pytest import mark
from transformers import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
from transformers.testing_utils import (
Expectations,
is_flaky,
require_flash_attn,
require_torch,
@ -760,17 +761,19 @@ class Siglip2ModelIntegrationTest(unittest.TestCase):
# verify the logits values
# fmt: off
expected_logits_per_text = torch.tensor(
[
[ 1.0195, -0.0280, -1.4468],
[ -4.5395, -6.2269, -1.5667],
[ 4.1757, 5.0358, 3.5159],
[ 9.4264, 10.1879, 6.3353],
[ 2.4409, 3.1058, 4.5491],
[-12.3230, -13.7355, -13.4632],
expected_logits_per_texts = Expectations({
("cuda", None): [
[ 1.0195, -0.0280, -1.4468], [ -4.5395, -6.2269, -1.5667], [ 4.1757, 5.0358, 3.5159],
[ 9.4264, 10.1879, 6.3353], [ 2.4409, 3.1058, 4.5491], [-12.3230, -13.7355, -13.4632],
[ 1.1520, 1.1687, -1.9647],
]
).to(torch_device)
],
("rocm", (9, 5)): [
[ 1.0236, -0.0376, -1.4464], [ -4.5358, -6.2235, -1.5628], [ 4.1708, 5.0334, 3.5187],
[ 9.4241, 10.1828, 6.3366], [ 2.4371, 3.1062, 4.5530], [-12.3173, -13.7240, -13.4580],
[ 1.1502, 1.1716, -1.9623]
],
})
EXPECTED_LOGITS_PER_TEXT = torch.tensor(expected_logits_per_texts.get_expectation()).to(torch_device)
# fmt: on
torch.testing.assert_close(outputs.logits_per_text, expected_logits_per_text, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(outputs.logits_per_text, EXPECTED_LOGITS_PER_TEXT, rtol=1e-3, atol=1e-3)

View File

@ -22,6 +22,7 @@ import numpy as np
from transformers import PretrainedConfig, VitsConfig
from transformers.testing_utils import (
Expectations,
is_flaky,
is_torch_available,
require_torch,
@ -454,13 +455,21 @@ class VitsModelIntegrationTests(unittest.TestCase):
self.assertEqual(outputs.waveform.shape, (1, 87040))
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
expected_logits = Expectations({
("cuda", None): [
0.0101, 0.0318, 0.0489, 0.0627, 0.0728, 0.0865, 0.1053, 0.1279,
0.1514, 0.1703, 0.1827, 0.1829, 0.1694, 0.1509, 0.1332, 0.1188,
0.1066, 0.0978, 0.0936, 0.0867, 0.0724, 0.0493, 0.0197, -0.0141,
-0.0501, -0.0817, -0.1065, -0.1223, -0.1311, -0.1339
],
("rocm", (9, 5)): [
0.0097, 0.0315, 0.0486, 0.0626, 0.0728, 0.0865, 0.1053, 0.1279,
0.1515, 0.1703, 0.1827, 0.1829, 0.1694, 0.1509, 0.1333, 0.1189,
0.1066, 0.0978, 0.0937, 0.0868, 0.0726, 0.0496, 0.0200, -0.0138,
-0.0500, -0.0817, -0.1067, -0.1225, -0.1313, -0.1340
]
).to(torch.float16)
})
EXPECTED_LOGITS = torch.tensor(expected_logits.get_expectation(), dtype=torch.float16)
# fmt: on
torch.testing.assert_close(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)

View File

@ -17,7 +17,9 @@ import unittest
from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import (
Expectations,
cleanup,
is_torch_greater_or_equal,
require_torch,
require_torch_accelerator,
require_torch_fp16,
@ -422,13 +424,21 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STRS = [
# torch 2.6
"Today is a nice day and the water is still cold. We just stopped off for some fresh coffee. This place looks like a",
# torch 2.7
"Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today.",
]
self.assertIn(output_str, EXPECTED_OUTPUT_STRS)
if is_torch_greater_or_equal("2.7.0"):
cuda_expectation = (
"Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today."
)
else:
cuda_expectation = "Today is a nice day and the water is still cold. We just stopped off for some fresh coffee. This place looks like a"
expected_output_strings = Expectations(
{
("rocm", (9, 5)): "Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today.",
("cuda", None): cuda_expectation,
}
) # fmt: skip
EXPECTED_OUTPUT_STR = expected_output_strings.get_expectation()
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
@require_torch_accelerator
@require_torch_fp16

View File

@ -26,6 +26,7 @@ from transformers import (
)
from transformers.pipelines import MaskGenerationPipeline
from transformers.testing_utils import (
Expectations,
is_pipeline_test,
nested_simplify,
require_tf,
@ -120,6 +121,11 @@ class MaskGenerationPipelineTests(unittest.TestCase):
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
# fmt: off
last_output = Expectations({
("cuda", None): {'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871},
("rocm", (9, 5)): {'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8872}
}).get_expectation()
self.assertEqual(
nested_simplify(new_outupt, decimals=4),
[
@ -152,7 +158,7 @@ class MaskGenerationPipelineTests(unittest.TestCase):
{'mask': {'hash': '7b9e8ddb73', 'shape': (480, 640)}, 'scores': 0.8986},
{'mask': {'hash': 'cd24047c8a', 'shape': (480, 640)}, 'scores': 0.8984},
{'mask': {'hash': '6943e6bcbd', 'shape': (480, 640)}, 'scores': 0.8873},
{'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871}
last_output
],
)
# fmt: on

View File

@ -62,7 +62,7 @@ class AwqConfigTest(unittest.TestCase):
# Only cuda and xpu devices can run this function
support_llm_awq = False
device_type, major = get_device_properties()
device_type, major, _ = get_device_properties()
if device_type == "cuda" and major >= 8:
support_llm_awq = True
elif device_type == "xpu":

View File

@ -552,7 +552,8 @@ class TorchAoSerializationFP8AcceleratorTest(TorchAoSerializationTest):
# called only once for all test in this class
@classmethod
def setUpClass(cls):
if get_device_properties()[0] == "cuda" and get_device_properties()[1] < 9:
device_type, major, minor = get_device_properties()
if device_type == "cuda" and major < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Float8WeightOnlyConfig
@ -573,7 +574,8 @@ class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
# called only once for all test in this class
@classmethod
def setUpClass(cls):
if get_device_properties()[0] == "cuda" and get_device_properties()[1] < 9:
device_type, major, minor = get_device_properties()
if device_type == "cuda" and major < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Int8DynamicActivationInt4WeightConfig

View File

@ -3775,7 +3775,7 @@ class ModelTesterMixin:
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
(device_type, major) = get_device_properties()
device_type, major, minor = get_device_properties()
if device_type == "cuda" and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:
@ -3823,7 +3823,7 @@ class ModelTesterMixin:
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
(device_type, major) = get_device_properties()
device_type, major, minor = get_device_properties()
if device_type == "cuda" and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:

View File

@ -5,6 +5,8 @@ from transformers.testing_utils import Expectations
class ExpectationsTest(unittest.TestCase):
def test_expectations(self):
# We use the expectations below to make sure the right expectations are found for the right devices.
# Each value is just a unique ID.
expectations = Expectations(
{
(None, None): 1,
@ -17,18 +19,20 @@ class ExpectationsTest(unittest.TestCase):
}
)
def check(value, key):
assert expectations.find_expectation(key) == value
def check(expected_id, device_prop):
found_id = expectations.find_expectation(device_prop)
assert found_id == expected_id, f"Expected {expected_id} for {device_prop}, found {found_id}"
# npu has no matches so should find default expectation
check(1, ("npu", None))
check(7, ("xpu", 3))
check(2, ("cuda", 8))
check(3, ("cuda", 7))
check(4, ("rocm", 9))
check(4, ("rocm", None))
check(2, ("cuda", 2))
check(1, ("npu", None, None))
check(7, ("xpu", 3, None))
check(2, ("cuda", 8, None))
check(3, ("cuda", 7, None))
check(4, ("rocm", 9, None))
check(4, ("rocm", None, None))
check(2, ("cuda", 2, None))
# We also test that if there is no default excpectation and no match is found, a ValueError is raised.
expectations = Expectations({("cuda", 8): 1})
with self.assertRaises(ValueError):
expectations.find_expectation(("xpu", None))