Expectations test utils (#36569)

* Add expectation classes + tests

* Use typing Union instead of |

* Use bits to track score in properties cmp method

* Add exceptions and tests + comments

* Remove compute cap minor as it is not needed currently

* Simplify. Remove Properties class

* Add example Exceptions usage

* Expectations as dict subclass

* Update example Exceptions usage

* Refactor. Improve type name. Document score fn.

* Rename to DeviceProperties.
This commit is contained in:
ivarflakstad 2025-03-18 23:39:50 +01:00 committed by GitHub
parent 179d02ffb8
commit 706703bba6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 125 additions and 19 deletions

View File

@ -32,13 +32,13 @@ import tempfile
import threading
import time
import unittest
from collections import defaultdict
from collections import UserDict, defaultdict
from collections.abc import Mapping
from dataclasses import MISSING, fields
from functools import wraps
from functools import cache, wraps
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
from unittest import mock
from unittest.mock import patch
@ -3042,3 +3042,78 @@ def cleanup(device: str, gc_collect=False):
if gc_collect:
gc.collect()
backend_empty_cache(device)
# Type definition of key used in `Expectations` class.
DeviceProperties = tuple[Union[str, None], Union[int, None]]
@cache
def get_device_properties(self) -> DeviceProperties:
"""
Get environment device properties.
"""
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
import torch
major, _ = torch.cuda.get_device_capability()
if IS_ROCM_SYSTEM:
return ("rocm", major)
else:
return ("cuda", major)
else:
return (torch_device, None)
class Expectations(UserDict[DeviceProperties, Any]):
def get_expectation(self) -> Any:
"""
Find best matching expectation based on environment device properties.
"""
return self.find_expectation(get_device_properties())
@staticmethod
def is_default(key: DeviceProperties) -> bool:
return all(p is None for p in key)
@staticmethod
def score(key: DeviceProperties, other: DeviceProperties) -> int:
"""
Returns score indicating how similar two instances of the `Properties` tuple are.
Points are calculated using bits, but documented as int.
Rules are as follows:
* Matching `type` gives 8 points.
* Semi-matching `type`, for example cuda and rocm, gives 4 points.
* Matching `major` (compute capability major version) gives 2 points.
* Default expectation (if present) gives 1 points.
"""
(device_type, major) = key
(other_device_type, other_major) = other
score = 0b0
if device_type == other_device_type:
score |= 0b1000
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
score |= 0b100
if major == other_major and other_major is not None:
score |= 0b10
if Expectations.is_default(other):
score |= 0b1
return int(score)
def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
"""
Find best matching expectation based on provided device properties.
"""
(result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
if Expectations.score(key, result_key) == 0:
raise ValueError(f"No matching expectation found for {key}")
return result
def __repr__(self):
return f"{self.data}"

View File

@ -20,12 +20,7 @@ import unittest
import pytest
from transformers import AutoTokenizer, BambaConfig, is_torch_available
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from transformers.testing_utils import Expectations, require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -503,15 +498,18 @@ class BambaModelIntegrationTest(unittest.TestCase):
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def test_simple_generate(self):
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXTS = {
# 7: "",
8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
9: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
}
expectations = Expectations(
{
(
"cuda",
8,
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
(
"rocm",
9,
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
}
)
self.model.to(torch_device)
@ -520,7 +518,8 @@ class BambaModelIntegrationTest(unittest.TestCase):
].to(torch_device)
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
output_sentence = self.tokenizer.decode(out[0, :])
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
expected = expectations.get_expectation()
self.assertEqual(output_sentence, expected)
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.cuda_compute_capability_major_version == 8:

View File

@ -0,0 +1,32 @@
import unittest
from transformers.testing_utils import Expectations
class ExpectationsTest(unittest.TestCase):
def test_expectations(self):
expectations = Expectations(
{
(None, None): 1,
("cuda", 8): 2,
("cuda", 7): 3,
("rocm", 8): 4,
("rocm", None): 5,
("cpu", None): 6,
}
)
def check(value, key):
assert expectations.find_expectation(key) == value
# xpu has no matches so should find default expectation
check(1, ("xpu", None))
check(2, ("cuda", 8))
check(3, ("cuda", 7))
check(4, ("rocm", 9))
check(4, ("rocm", None))
check(2, ("cuda", 2))
expectations = Expectations({("cuda", 8): 1})
with self.assertRaises(ValueError):
expectations.find_expectation(("xpu", None))