mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Expectations test utils (#36569)
* Add expectation classes + tests * Use typing Union instead of | * Use bits to track score in properties cmp method * Add exceptions and tests + comments * Remove compute cap minor as it is not needed currently * Simplify. Remove Properties class * Add example Exceptions usage * Expectations as dict subclass * Update example Exceptions usage * Refactor. Improve type name. Document score fn. * Rename to DeviceProperties.
This commit is contained in:
parent
179d02ffb8
commit
706703bba6
@ -32,13 +32,13 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import MISSING, fields
|
from dataclasses import MISSING, fields
|
||||||
from functools import wraps
|
from functools import cache, wraps
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
|
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -3042,3 +3042,78 @@ def cleanup(device: str, gc_collect=False):
|
|||||||
if gc_collect:
|
if gc_collect:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
backend_empty_cache(device)
|
backend_empty_cache(device)
|
||||||
|
|
||||||
|
|
||||||
|
# Type definition of key used in `Expectations` class.
|
||||||
|
DeviceProperties = tuple[Union[str, None], Union[int, None]]
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_device_properties(self) -> DeviceProperties:
|
||||||
|
"""
|
||||||
|
Get environment device properties.
|
||||||
|
"""
|
||||||
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
major, _ = torch.cuda.get_device_capability()
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
return ("rocm", major)
|
||||||
|
else:
|
||||||
|
return ("cuda", major)
|
||||||
|
else:
|
||||||
|
return (torch_device, None)
|
||||||
|
|
||||||
|
|
||||||
|
class Expectations(UserDict[DeviceProperties, Any]):
|
||||||
|
def get_expectation(self) -> Any:
|
||||||
|
"""
|
||||||
|
Find best matching expectation based on environment device properties.
|
||||||
|
"""
|
||||||
|
return self.find_expectation(get_device_properties())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_default(key: DeviceProperties) -> bool:
|
||||||
|
return all(p is None for p in key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def score(key: DeviceProperties, other: DeviceProperties) -> int:
|
||||||
|
"""
|
||||||
|
Returns score indicating how similar two instances of the `Properties` tuple are.
|
||||||
|
Points are calculated using bits, but documented as int.
|
||||||
|
Rules are as follows:
|
||||||
|
* Matching `type` gives 8 points.
|
||||||
|
* Semi-matching `type`, for example cuda and rocm, gives 4 points.
|
||||||
|
* Matching `major` (compute capability major version) gives 2 points.
|
||||||
|
* Default expectation (if present) gives 1 points.
|
||||||
|
"""
|
||||||
|
(device_type, major) = key
|
||||||
|
(other_device_type, other_major) = other
|
||||||
|
|
||||||
|
score = 0b0
|
||||||
|
if device_type == other_device_type:
|
||||||
|
score |= 0b1000
|
||||||
|
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
|
||||||
|
score |= 0b100
|
||||||
|
|
||||||
|
if major == other_major and other_major is not None:
|
||||||
|
score |= 0b10
|
||||||
|
|
||||||
|
if Expectations.is_default(other):
|
||||||
|
score |= 0b1
|
||||||
|
|
||||||
|
return int(score)
|
||||||
|
|
||||||
|
def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
|
||||||
|
"""
|
||||||
|
Find best matching expectation based on provided device properties.
|
||||||
|
"""
|
||||||
|
(result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
|
||||||
|
|
||||||
|
if Expectations.score(key, result_key) == 0:
|
||||||
|
raise ValueError(f"No matching expectation found for {key}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.data}"
|
||||||
|
@ -20,12 +20,7 @@ import unittest
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import Expectations, require_torch, require_torch_gpu, slow, torch_device
|
||||||
require_torch,
|
|
||||||
require_torch_gpu,
|
|
||||||
slow,
|
|
||||||
torch_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -503,15 +498,18 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
def test_simple_generate(self):
|
def test_simple_generate(self):
|
||||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
expectations = Expectations(
|
||||||
#
|
{
|
||||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
(
|
||||||
# considering differences in hardware processing and potential deviations in generated text.
|
"cuda",
|
||||||
EXPECTED_TEXTS = {
|
8,
|
||||||
# 7: "",
|
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
|
||||||
8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
|
(
|
||||||
9: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
"rocm",
|
||||||
}
|
9,
|
||||||
|
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.model.to(torch_device)
|
self.model.to(torch_device)
|
||||||
|
|
||||||
@ -520,7 +518,8 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
|||||||
].to(torch_device)
|
].to(torch_device)
|
||||||
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
||||||
output_sentence = self.tokenizer.decode(out[0, :])
|
output_sentence = self.tokenizer.decode(out[0, :])
|
||||||
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
expected = expectations.get_expectation()
|
||||||
|
self.assertEqual(output_sentence, expected)
|
||||||
|
|
||||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||||
if self.cuda_compute_capability_major_version == 8:
|
if self.cuda_compute_capability_major_version == 8:
|
||||||
|
32
tests/utils/test_expectations.py
Normal file
32
tests/utils/test_expectations.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.testing_utils import Expectations
|
||||||
|
|
||||||
|
|
||||||
|
class ExpectationsTest(unittest.TestCase):
|
||||||
|
def test_expectations(self):
|
||||||
|
expectations = Expectations(
|
||||||
|
{
|
||||||
|
(None, None): 1,
|
||||||
|
("cuda", 8): 2,
|
||||||
|
("cuda", 7): 3,
|
||||||
|
("rocm", 8): 4,
|
||||||
|
("rocm", None): 5,
|
||||||
|
("cpu", None): 6,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def check(value, key):
|
||||||
|
assert expectations.find_expectation(key) == value
|
||||||
|
|
||||||
|
# xpu has no matches so should find default expectation
|
||||||
|
check(1, ("xpu", None))
|
||||||
|
check(2, ("cuda", 8))
|
||||||
|
check(3, ("cuda", 7))
|
||||||
|
check(4, ("rocm", 9))
|
||||||
|
check(4, ("rocm", None))
|
||||||
|
check(2, ("cuda", 2))
|
||||||
|
|
||||||
|
expectations = Expectations({("cuda", 8): 1})
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
expectations.find_expectation(("xpu", None))
|
Loading…
Reference in New Issue
Block a user