mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
enable finegrained_fp8 and granite_speech cases on XPU (#38036)
* enable finegrained_fp8 cases on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> * change back to auto Signed-off-by: Yao Matrix <matrix.yao@intel.com> * rename per comments Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com> Signed-off-by: Matrix Yao <matrix.yao@intel.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
b311a3f506
commit
9b5ce556aa
@ -15,7 +15,7 @@
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -332,8 +332,10 @@ class FP8Linear(nn.Linear):
|
||||
if self.weight.element_size() > 1:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
else:
|
||||
# Context manager used to switch among the available cuda devices
|
||||
with torch.cuda.device(input.device):
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
with torch_accelerator_module.device(input.device):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
@ -343,9 +345,9 @@ class FP8Linear(nn.Linear):
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# preceding operations are ready before proceeding
|
||||
torch.cuda.synchronize()
|
||||
torch_accelerator_module.synchronize()
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output.to(dtype=input.dtype)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
|
||||
from .base import HfQuantizer
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
@ -44,9 +44,10 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
||||
"please make sure the weights are in PyTorch format."
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.")
|
||||
if not (torch.cuda.is_available() or is_torch_xpu_available()):
|
||||
raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, minor = compute_capability
|
||||
if (major < 8) or (major == 8 and minor < 9):
|
||||
|
@ -23,8 +23,9 @@ from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_torchaudio,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torchaudio_available
|
||||
|
||||
@ -195,7 +196,7 @@ class GraniteSpeechProcessorTest(unittest.TestCase):
|
||||
assert num_calculated_features == [90, 171]
|
||||
assert sum(num_expected_features) == num_audio_tokens
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_device_override(self):
|
||||
"""Ensure that we regardless of the processing device, the tensors
|
||||
produced are on the CPU.
|
||||
@ -214,7 +215,7 @@ class GraniteSpeechProcessorTest(unittest.TestCase):
|
||||
text=f"{processor.audio_token} Can you transcribe this audio?",
|
||||
audio=wav,
|
||||
return_tensors="pt",
|
||||
device="cuda",
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
assert inputs["input_features"].device.type == "cpu"
|
||||
|
@ -18,11 +18,13 @@ import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_accelerate,
|
||||
require_read_token,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_accelerator,
|
||||
require_torch_multi_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_torch_available
|
||||
|
||||
@ -34,7 +36,7 @@ if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
class FineGrainedFP8ConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@ -60,13 +62,13 @@ class FineGrainedFP8ConfigTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_read_token
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
class FP8QuantizerTest(unittest.TestCase):
|
||||
model_name = "meta-llama/Llama-3.2-1B"
|
||||
input_text = "Once upon a time"
|
||||
max_new_tokens = 10
|
||||
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
|
||||
device_map = "cuda"
|
||||
device_map = torch_device
|
||||
offload_device_map = {
|
||||
"model.embed_tokens": 0,
|
||||
"model.layers.0": 0,
|
||||
@ -103,7 +105,7 @@ class FP8QuantizerTest(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model_conversion(self):
|
||||
@ -151,7 +153,8 @@ class FP8QuantizerTest(unittest.TestCase):
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
self.assertEqual(output_tokens, self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
"""
|
||||
@ -188,11 +191,12 @@ class FP8QuantizerTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
@require_torch_multi_accelerator
|
||||
def test_quantized_model_multi_accelerator(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs
|
||||
Simple test that checks if the quantized model is working properly with multiple accelerators
|
||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs; or set ZE_AFFINITY_MASK=0,1 if you
|
||||
have more than 2 XPUs.
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
quantization_config = FineGrainedFP8Config()
|
||||
@ -204,8 +208,8 @@ class FP8QuantizerTest(unittest.TestCase):
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_save_pretrained_multi_gpu(self):
|
||||
@require_torch_multi_accelerator
|
||||
def test_save_pretrained_multi_accelerators(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
@ -245,9 +249,9 @@ class FP8QuantizerTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
class FP8LinearTest(unittest.TestCase):
|
||||
device = "cuda"
|
||||
device = torch_device
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
||||
|
Loading…
Reference in New Issue
Block a user