enable autoround cases on XPU (#38167)

* enable autoround cases on XPU

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix 2025-05-16 17:08:35 +08:00 committed by GitHub
parent 0f77ca72ca
commit 34c1e29cdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 10 deletions

View File

@ -3013,6 +3013,11 @@ if is_torch_available():
"cpu": 0,
"default": 0,
}
BACKEND_SYNCHRONIZE = {
"cuda": torch.cuda.synchronize,
"cpu": None,
"default": None,
}
BACKEND_TORCH_ACCELERATOR_MODULE = {
"cuda": torch.cuda,
"cpu": None,
@ -3025,6 +3030,7 @@ else:
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
BACKEND_MEMORY_ALLOCATED = {"default": 0}
BACKEND_SYNCHRONIZE = {"default": None}
BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None}
@ -3052,6 +3058,7 @@ if is_torch_xpu_available():
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu
@ -3085,6 +3092,10 @@ def backend_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
def backend_synchronize(device: str):
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
def backend_torch_accelerator_module(device: str):
return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE)

View File

@ -17,11 +17,14 @@ import unittest
from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer
from transformers.testing_utils import (
backend_empty_cache,
backend_synchronize,
require_accelerate,
require_auto_round,
require_intel_extension_for_pytorch,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
)
@ -33,7 +36,7 @@ if is_torch_available():
@slow
@require_torch_gpu
@require_torch_accelerator
@require_auto_round
@require_accelerate
class AutoRoundTest(unittest.TestCase):
@ -50,8 +53,11 @@ class AutoRoundTest(unittest.TestCase):
EXPECTED_OUTPUTS.add(
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She enjoys hiking through the mountains and discovering"
)
EXPECTED_OUTPUTS.add(
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She has also climbed mountains and explored caves"
)
device_map = "cuda"
device_map = torch_device
# called only once for all test in this class
@classmethod
@ -59,7 +65,7 @@ class AutoRoundTest(unittest.TestCase):
"""
Setup quantized model
"""
torch.cuda.synchronize()
backend_synchronize(torch_device)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, torch_dtype=torch.float16
@ -67,7 +73,7 @@ class AutoRoundTest(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
gc.collect()
def test_quantized_model(self):
@ -128,14 +134,15 @@ class AutoRoundTest(unittest.TestCase):
)
quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="cuda")
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
self.assertIn(output_tokens, self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
@ -159,7 +166,7 @@ class AutoRoundTest(unittest.TestCase):
quantization_config = AutoRoundConfig()
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto"
model_name, device_map=torch_device, quantization_config=quantization_config, torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -185,6 +192,7 @@ class AutoRoundTest(unittest.TestCase):
inputs = tokenizer(text, return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
@require_torch_gpu
def test_mixed_bits(self):
"""
Simple test that checks if auto-round work properly with mixed bits
@ -203,7 +211,9 @@ class AutoRoundTest(unittest.TestCase):
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, layer_config=layer_config)
with tempfile.TemporaryDirectory() as tmpdirname:
autoround.quantize_and_save(output_dir=tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, torch_dtype=torch.float16, device_map="cuda")
model = AutoModelForCausalLM.from_pretrained(
tmpdirname, torch_dtype=torch.float16, device_map=torch_device
)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])