mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[tests] make quanto tests device-agnostic (#36328)
* make device-agnostic * name change
This commit is contained in:
parent
678885bbbd
commit
7c5bd24ffa
@ -22,7 +22,6 @@ from transformers.testing_utils import (
|
||||
require_optimum_quanto,
|
||||
require_read_token,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -181,11 +180,11 @@ class QuantoQuantizationTest(unittest.TestCase):
|
||||
"""
|
||||
self.check_inference_correctness(self.quantized_model, "cpu")
|
||||
|
||||
def test_generate_quality_cuda(self):
|
||||
def test_generate_quality_accelerator(self):
|
||||
"""
|
||||
Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens
|
||||
Simple test to check the quality of the model on accelerators by comparing the generated tokens with the expected tokens
|
||||
"""
|
||||
self.check_inference_correctness(self.quantized_model, "cuda")
|
||||
self.check_inference_correctness(self.quantized_model, torch_device)
|
||||
|
||||
def test_quantized_model_layers(self):
|
||||
from optimum.quanto import QBitsTensor, QModuleMixin, QTensor
|
||||
@ -215,7 +214,7 @@ class QuantoQuantizationTest(unittest.TestCase):
|
||||
)
|
||||
self.quantized_model.to(0)
|
||||
self.assertEqual(
|
||||
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, "cuda"
|
||||
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, torch_device
|
||||
)
|
||||
|
||||
def test_serialization_bin(self):
|
||||
@ -430,7 +429,7 @@ class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializa
|
||||
weights = "int4"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
class QuantoQuantizationActivationTest(unittest.TestCase):
|
||||
def test_quantize_activation(self):
|
||||
quantization_config = QuantoConfig(
|
||||
@ -443,7 +442,7 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
|
||||
|
||||
|
||||
@require_optimum_quanto
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
class QuantoKVCacheQuantizationTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_read_token
|
||||
|
Loading…
Reference in New Issue
Block a user