enable 2 llama UT cases on xpu (#37126)

* enable tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits and tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits_bf16 on xpu

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* switch to use Expectations

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* extract gen bits from architecture and use it

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* add cross refererence

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Yao Matrix 2025-04-07 22:02:14 +08:00 committed by GitHub
parent e7ad077012
commit 12bf24d6ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 33 deletions

View File

@ -202,9 +202,11 @@ if is_torch_available():
IS_ROCM_SYSTEM = torch.version.hip is not None IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = torch.version.xpu is not None
else: else:
IS_ROCM_SYSTEM = False IS_ROCM_SYSTEM = False
IS_CUDA_SYSTEM = False IS_CUDA_SYSTEM = False
IS_XPU_SYSTEM = False
logger = transformers_logging.get_logger(__name__) logger = transformers_logging.get_logger(__name__)
@ -3097,6 +3099,14 @@ def get_device_properties() -> DeviceProperties:
return ("rocm", major) return ("rocm", major)
else: else:
return ("cuda", major) return ("cuda", major)
elif IS_XPU_SYSTEM:
import torch
# To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
arch = torch.xpu.get_device_capability()["architecture"]
gen_mask = 0x000000FF00000000
gen = (arch & gen_mask) >> 32
return ("xpu", gen)
else: else:
return (torch_device, None) return (torch_device, None)

View File

@ -22,6 +22,7 @@ from parameterized import parameterized
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
Expectations,
cleanup, cleanup,
require_read_token, require_read_token,
require_torch, require_torch,
@ -429,16 +430,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch_accelerator @require_torch_accelerator
class LlamaIntegrationTest(unittest.TestCase): class LlamaIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def tearDown(self): def tearDown(self):
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
# some memory allocated in the cache, which means some object is not being released properly. This causes some # some memory allocated in the cache, which means some object is not being released properly. This causes some
@ -490,14 +481,17 @@ class LlamaIntegrationTest(unittest.TestCase):
# Expected mean on dim = -1 # Expected mean on dim = -1
# fmt: off # fmt: off
EXPECTED_MEAN = { expected_means = Expectations(
7: torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), {
8: torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]),
} ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]),
("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]])
})
expected_mean = expected_means.get_expectation()
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device), expected_mean.to(torch_device),
out.logits.float().mean(-1), out.logits.float().mean(-1),
atol=1e-2, atol=1e-2,
rtol=1e-2 rtol=1e-2
@ -505,15 +499,17 @@ class LlamaIntegrationTest(unittest.TestCase):
) )
# slicing logits[0, 0, 0:15] # slicing logits[0, 0, 0:15]
EXPECTED_SLICE = { expected_slices = Expectations(
7: torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), {
8: torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]),
} ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]),
("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]])
})
# fmt: on # fmt: on
expected_slice = expected_slices.get_expectation()
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device), expected_slice.to(torch_device),
out.logits[0, 0, :15].float(), out.logits[0, 0, :15].float(),
atol=1e-2, atol=1e-2,
rtol=1e-2, rtol=1e-2,
@ -534,14 +530,17 @@ class LlamaIntegrationTest(unittest.TestCase):
# fmt: off # fmt: off
# Expected mean on dim = -1 # Expected mean on dim = -1
EXPECTED_MEAN = { expected_means = Expectations(
7: torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), {
8: torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]) ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]),
} ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]),
("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]),
})
expected_mean = expected_means.get_expectation()
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device), expected_mean.to(torch_device),
out.logits.float().mean(-1), out.logits.float().mean(-1),
atol=1e-2, atol=1e-2,
rtol=1e-2 rtol=1e-2
@ -549,15 +548,18 @@ class LlamaIntegrationTest(unittest.TestCase):
) )
# slicing logits[0, 0, 0:15] # slicing logits[0, 0, 0:15]
EXPECTED_SLICE = { expected_slices = Expectations(
7: torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), {
8: torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]),
} ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]),
("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328])
})
# fmt: on # fmt: on
expected_slice = expected_slices.get_expectation()
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device), expected_slice.to(torch_device),
out.logits[0, 0, :15].float(), out.logits[0, 0, :15].float(),
atol=1e-2, atol=1e-2,
rtol=1e-2, rtol=1e-2,

View File

@ -13,14 +13,16 @@ class ExpectationsTest(unittest.TestCase):
("rocm", 8): 4, ("rocm", 8): 4,
("rocm", None): 5, ("rocm", None): 5,
("cpu", None): 6, ("cpu", None): 6,
("xpu", 3): 7,
} }
) )
def check(value, key): def check(value, key):
assert expectations.find_expectation(key) == value assert expectations.find_expectation(key) == value
# xpu has no matches so should find default expectation # npu has no matches so should find default expectation
check(1, ("xpu", None)) check(1, ("npu", None))
check(7, ("xpu", 3))
check(2, ("cuda", 8)) check(2, ("cuda", 8))
check(3, ("cuda", 7)) check(3, ("cuda", 7))
check(4, ("rocm", 9)) check(4, ("rocm", 9))