From 706703bba6c920b10aa7e7ee8163b06a8a03c450 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 18 Mar 2025 23:39:50 +0100 Subject: [PATCH] Expectations test utils (#36569) * Add expectation classes + tests * Use typing Union instead of | * Use bits to track score in properties cmp method * Add exceptions and tests + comments * Remove compute cap minor as it is not needed currently * Simplify. Remove Properties class * Add example Exceptions usage * Expectations as dict subclass * Update example Exceptions usage * Refactor. Improve type name. Document score fn. * Rename to DeviceProperties. --- src/transformers/testing_utils.py | 81 ++++++++++++++++++++++- tests/models/bamba/test_modeling_bamba.py | 31 +++++---- tests/utils/test_expectations.py | 32 +++++++++ 3 files changed, 125 insertions(+), 19 deletions(-) create mode 100644 tests/utils/test_expectations.py diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index d75c2d778fb..e2811ae9f10 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -32,13 +32,13 @@ import tempfile import threading import time import unittest -from collections import defaultdict +from collections import UserDict, defaultdict from collections.abc import Mapping from dataclasses import MISSING, fields -from functools import wraps +from functools import cache, wraps from io import StringIO from pathlib import Path -from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union from unittest import mock from unittest.mock import patch @@ -3042,3 +3042,78 @@ def cleanup(device: str, gc_collect=False): if gc_collect: gc.collect() backend_empty_cache(device) + + +# Type definition of key used in `Expectations` class. +DeviceProperties = tuple[Union[str, None], Union[int, None]] + + +@cache +def get_device_properties(self) -> DeviceProperties: + """ + Get environment device properties. + """ + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + import torch + + major, _ = torch.cuda.get_device_capability() + if IS_ROCM_SYSTEM: + return ("rocm", major) + else: + return ("cuda", major) + else: + return (torch_device, None) + + +class Expectations(UserDict[DeviceProperties, Any]): + def get_expectation(self) -> Any: + """ + Find best matching expectation based on environment device properties. + """ + return self.find_expectation(get_device_properties()) + + @staticmethod + def is_default(key: DeviceProperties) -> bool: + return all(p is None for p in key) + + @staticmethod + def score(key: DeviceProperties, other: DeviceProperties) -> int: + """ + Returns score indicating how similar two instances of the `Properties` tuple are. + Points are calculated using bits, but documented as int. + Rules are as follows: + * Matching `type` gives 8 points. + * Semi-matching `type`, for example cuda and rocm, gives 4 points. + * Matching `major` (compute capability major version) gives 2 points. + * Default expectation (if present) gives 1 points. + """ + (device_type, major) = key + (other_device_type, other_major) = other + + score = 0b0 + if device_type == other_device_type: + score |= 0b1000 + elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]: + score |= 0b100 + + if major == other_major and other_major is not None: + score |= 0b10 + + if Expectations.is_default(other): + score |= 0b1 + + return int(score) + + def find_expectation(self, key: DeviceProperties = (None, None)) -> Any: + """ + Find best matching expectation based on provided device properties. + """ + (result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0])) + + if Expectations.score(key, result_key) == 0: + raise ValueError(f"No matching expectation found for {key}") + + return result + + def __repr__(self): + return f"{self.data}" diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 3cbc239a5e2..ca0f4c95ff5 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -20,12 +20,7 @@ import unittest import pytest from transformers import AutoTokenizer, BambaConfig, is_torch_available -from transformers.testing_utils import ( - require_torch, - require_torch_gpu, - slow, - torch_device, -) +from transformers.testing_utils import Expectations, require_torch, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -503,15 +498,18 @@ class BambaModelIntegrationTest(unittest.TestCase): cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] def test_simple_generate(self): - # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. - # - # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, - # considering differences in hardware processing and potential deviations in generated text. - EXPECTED_TEXTS = { - # 7: "", - 8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.", - 9: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", - } + expectations = Expectations( + { + ( + "cuda", + 8, + ): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.", + ( + "rocm", + 9, + ): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", + } + ) self.model.to(torch_device) @@ -520,7 +518,8 @@ class BambaModelIntegrationTest(unittest.TestCase): ].to(torch_device) out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10) output_sentence = self.tokenizer.decode(out[0, :]) - self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) + expected = expectations.get_expectation() + self.assertEqual(output_sentence, expected) # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist if self.cuda_compute_capability_major_version == 8: diff --git a/tests/utils/test_expectations.py b/tests/utils/test_expectations.py new file mode 100644 index 00000000000..b4372d262ed --- /dev/null +++ b/tests/utils/test_expectations.py @@ -0,0 +1,32 @@ +import unittest + +from transformers.testing_utils import Expectations + + +class ExpectationsTest(unittest.TestCase): + def test_expectations(self): + expectations = Expectations( + { + (None, None): 1, + ("cuda", 8): 2, + ("cuda", 7): 3, + ("rocm", 8): 4, + ("rocm", None): 5, + ("cpu", None): 6, + } + ) + + def check(value, key): + assert expectations.find_expectation(key) == value + + # xpu has no matches so should find default expectation + check(1, ("xpu", None)) + check(2, ("cuda", 8)) + check(3, ("cuda", 7)) + check(4, ("rocm", 9)) + check(4, ("rocm", None)) + check(2, ("cuda", 2)) + + expectations = Expectations({("cuda", 8): 1}) + with self.assertRaises(ValueError): + expectations.find_expectation(("xpu", None))