enable 2 llama UT cases on xpu (#37126)

* enable tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits and tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits_bf16 on xpu

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* switch to use Expectations

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* extract gen bits from architecture and use it

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* add cross refererence

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Yao Matrix 2025-04-07 22:02:14 +08:00 committed by GitHub
parent e7ad077012
commit 12bf24d6ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 33 deletions

View File

@ -202,9 +202,11 @@ if is_torch_available():
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = torch.version.xpu is not None
else:
IS_ROCM_SYSTEM = False
IS_CUDA_SYSTEM = False
IS_XPU_SYSTEM = False
logger = transformers_logging.get_logger(__name__)
@ -3097,6 +3099,14 @@ def get_device_properties() -> DeviceProperties:
return ("rocm", major)
else:
return ("cuda", major)
elif IS_XPU_SYSTEM:
import torch
# To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
arch = torch.xpu.get_device_capability()["architecture"]
gen_mask = 0x000000FF00000000
gen = (arch & gen_mask) >> 32
return ("xpu", gen)
else:
return (torch_device, None)

View File

@ -22,6 +22,7 @@ from parameterized import parameterized
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
Expectations,
cleanup,
require_read_token,
require_torch,
@ -429,16 +430,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch_accelerator
class LlamaIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def tearDown(self):
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
# some memory allocated in the cache, which means some object is not being released properly. This causes some
@ -490,14 +481,17 @@ class LlamaIntegrationTest(unittest.TestCase):
# Expected mean on dim = -1
# fmt: off
EXPECTED_MEAN = {
7: torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]),
8: torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]])
}
expected_means = Expectations(
{
("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]),
("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]),
("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]])
})
expected_mean = expected_means.get_expectation()
self.assertTrue(
torch.allclose(
EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device),
expected_mean.to(torch_device),
out.logits.float().mean(-1),
atol=1e-2,
rtol=1e-2
@ -505,15 +499,17 @@ class LlamaIntegrationTest(unittest.TestCase):
)
# slicing logits[0, 0, 0:15]
EXPECTED_SLICE = {
7: torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]),
8: torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]])
}
expected_slices = Expectations(
{
("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]),
("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]),
("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]])
})
# fmt: on
expected_slice = expected_slices.get_expectation()
self.assertTrue(
torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
expected_slice.to(torch_device),
out.logits[0, 0, :15].float(),
atol=1e-2,
rtol=1e-2,
@ -534,14 +530,17 @@ class LlamaIntegrationTest(unittest.TestCase):
# fmt: off
# Expected mean on dim = -1
EXPECTED_MEAN = {
7: torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]),
8: torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]])
}
expected_means = Expectations(
{
("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]),
("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]),
("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]),
})
expected_mean = expected_means.get_expectation()
self.assertTrue(
torch.allclose(
EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device),
expected_mean.to(torch_device),
out.logits.float().mean(-1),
atol=1e-2,
rtol=1e-2
@ -549,15 +548,18 @@ class LlamaIntegrationTest(unittest.TestCase):
)
# slicing logits[0, 0, 0:15]
EXPECTED_SLICE = {
7: torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]),
8: torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328])
}
expected_slices = Expectations(
{
("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]),
("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]),
("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328])
})
# fmt: on
expected_slice = expected_slices.get_expectation()
self.assertTrue(
torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
expected_slice.to(torch_device),
out.logits[0, 0, :15].float(),
atol=1e-2,
rtol=1e-2,

View File

@ -13,14 +13,16 @@ class ExpectationsTest(unittest.TestCase):
("rocm", 8): 4,
("rocm", None): 5,
("cpu", None): 6,
("xpu", 3): 7,
}
)
def check(value, key):
assert expectations.find_expectation(key) == value
# xpu has no matches so should find default expectation
check(1, ("xpu", None))
# npu has no matches so should find default expectation
check(1, ("npu", None))
check(7, ("xpu", 3))
check(2, ("cuda", 8))
check(3, ("cuda", 7))
check(4, ("rocm", 9))