clean autoawq cases on xpu (#38163)

* clean autoawq cases on xpu

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix 2025-05-16 19:56:43 +08:00 committed by GitHub
parent 01ad9f4b49
commit 7f28da2850
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,7 @@ from transformers.testing_utils import (
backend_empty_cache,
require_accelerate,
require_auto_awq,
require_flash_attn,
require_intel_extension_for_pytorch,
require_torch_accelerator,
require_torch_gpu,
@ -243,7 +244,7 @@ class AwqTest(unittest.TestCase):
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_accelerator
def test_quantized_model_multi_gpu(self):
def test_quantized_model_multi_accelerator(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
@ -305,7 +306,7 @@ class AwqFusedTest(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
gc.collect()
def _check_fused_modules(self, model):
@ -359,6 +360,8 @@ class AwqFusedTest(unittest.TestCase):
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
)
@require_flash_attn
@require_torch_gpu
def test_generation_fused(self):
"""
Test generation quality for fused models - single batch case
@ -382,6 +385,8 @@ class AwqFusedTest(unittest.TestCase):
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
@require_flash_attn
@require_torch_gpu
@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
@ -433,6 +438,7 @@ class AwqFusedTest(unittest.TestCase):
self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT)
@require_flash_attn
@require_torch_multi_gpu
@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
@ -473,8 +479,9 @@ class AwqFusedTest(unittest.TestCase):
outputs = model.generate(**inputs, max_new_tokens=12)
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)
@unittest.skip(reason="Not enough GPU memory on CI runners")
@require_flash_attn
@require_torch_multi_gpu
@unittest.skip(reason="Not enough GPU memory on CI runners")
def test_generation_mixtral_fused(self):
"""
Text generation test for Mixtral + AWQ + fused