mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
enable autoround cases on XPU (#38167)
* enable autoround cases on XPU Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
parent
0f77ca72ca
commit
34c1e29cdd
@ -3013,6 +3013,11 @@ if is_torch_available():
|
||||
"cpu": 0,
|
||||
"default": 0,
|
||||
}
|
||||
BACKEND_SYNCHRONIZE = {
|
||||
"cuda": torch.cuda.synchronize,
|
||||
"cpu": None,
|
||||
"default": None,
|
||||
}
|
||||
BACKEND_TORCH_ACCELERATOR_MODULE = {
|
||||
"cuda": torch.cuda,
|
||||
"cpu": None,
|
||||
@ -3025,6 +3030,7 @@ else:
|
||||
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
|
||||
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
|
||||
BACKEND_MEMORY_ALLOCATED = {"default": 0}
|
||||
BACKEND_SYNCHRONIZE = {"default": None}
|
||||
BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None}
|
||||
|
||||
|
||||
@ -3052,6 +3058,7 @@ if is_torch_xpu_available():
|
||||
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
|
||||
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
|
||||
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
|
||||
BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
|
||||
BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu
|
||||
|
||||
|
||||
@ -3085,6 +3092,10 @@ def backend_memory_allocated(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
|
||||
|
||||
|
||||
def backend_synchronize(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
|
||||
|
||||
|
||||
def backend_torch_accelerator_module(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE)
|
||||
|
||||
|
@ -17,11 +17,14 @@ import unittest
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_synchronize,
|
||||
require_accelerate,
|
||||
require_auto_round,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -33,7 +36,7 @@ if is_torch_available():
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_auto_round
|
||||
@require_accelerate
|
||||
class AutoRoundTest(unittest.TestCase):
|
||||
@ -50,8 +53,11 @@ class AutoRoundTest(unittest.TestCase):
|
||||
EXPECTED_OUTPUTS.add(
|
||||
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She enjoys hiking through the mountains and discovering"
|
||||
)
|
||||
EXPECTED_OUTPUTS.add(
|
||||
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She has also climbed mountains and explored caves"
|
||||
)
|
||||
|
||||
device_map = "cuda"
|
||||
device_map = torch_device
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
@ -59,7 +65,7 @@ class AutoRoundTest(unittest.TestCase):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
torch.cuda.synchronize()
|
||||
backend_synchronize(torch_device)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name, device_map=cls.device_map, torch_dtype=torch.float16
|
||||
@ -67,7 +73,7 @@ class AutoRoundTest(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model(self):
|
||||
@ -128,14 +134,15 @@ class AutoRoundTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
quantized_model.save_pretrained(tmpdirname)
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="cuda")
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=40, do_sample=False)
|
||||
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
self.assertIn(output_tokens, self.EXPECTED_OUTPUTS)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
@ -159,7 +166,7 @@ class AutoRoundTest(unittest.TestCase):
|
||||
quantization_config = AutoRoundConfig()
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto"
|
||||
model_name, device_map=torch_device, quantization_config=quantization_config, torch_dtype="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
@ -185,6 +192,7 @@ class AutoRoundTest(unittest.TestCase):
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
|
||||
|
||||
@require_torch_gpu
|
||||
def test_mixed_bits(self):
|
||||
"""
|
||||
Simple test that checks if auto-round work properly with mixed bits
|
||||
@ -203,7 +211,9 @@ class AutoRoundTest(unittest.TestCase):
|
||||
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, layer_config=layer_config)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
autoround.quantize_and_save(output_dir=tmpdirname)
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, torch_dtype=torch.float16, device_map="cuda")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
|
||||
|
Loading…
Reference in New Issue
Block a user