mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
clean autoawq cases on xpu (#38163)
* clean autoawq cases on xpu Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
parent
01ad9f4b49
commit
7f28da2850
@ -21,6 +21,7 @@ from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_accelerate,
|
||||
require_auto_awq,
|
||||
require_flash_attn,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
@ -243,7 +244,7 @@ class AwqTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
def test_quantized_model_multi_accelerator(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
"""
|
||||
@ -305,7 +306,7 @@ class AwqFusedTest(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
def _check_fused_modules(self, model):
|
||||
@ -359,6 +360,8 @@ class AwqFusedTest(unittest.TestCase):
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
||||
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
||||
)
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
def test_generation_fused(self):
|
||||
"""
|
||||
Test generation quality for fused models - single batch case
|
||||
@ -382,6 +385,8 @@ class AwqFusedTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
||||
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
|
||||
@ -433,6 +438,7 @@ class AwqFusedTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_multi_gpu
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
||||
@ -473,8 +479,9 @@ class AwqFusedTest(unittest.TestCase):
|
||||
outputs = model.generate(**inputs, max_new_tokens=12)
|
||||
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)
|
||||
|
||||
@unittest.skip(reason="Not enough GPU memory on CI runners")
|
||||
@require_flash_attn
|
||||
@require_torch_multi_gpu
|
||||
@unittest.skip(reason="Not enough GPU memory on CI runners")
|
||||
def test_generation_mixtral_fused(self):
|
||||
"""
|
||||
Text generation test for Mixtral + AWQ + fused
|
||||
|
Loading…
Reference in New Issue
Block a user