enable finegrained_fp8 and granite_speech cases on XPU (#38036)

* enable finegrained_fp8 cases on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* change back to auto

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* rename per comments

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Yao Matrix 2025-05-14 16:58:40 +08:00 committed by GitHub
parent b311a3f506
commit 9b5ce556aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 34 deletions

View File

@ -15,7 +15,7 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..utils import is_accelerate_available, is_torch_available, logging from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
if is_torch_available(): if is_torch_available():
@ -332,8 +332,10 @@ class FP8Linear(nn.Linear):
if self.weight.element_size() > 1: if self.weight.element_size() > 1:
return F.linear(input, self.weight, self.bias) return F.linear(input, self.weight, self.bias)
else: else:
# Context manager used to switch among the available cuda devices # Context manager used to switch among the available accelerators
with torch.cuda.device(input.device): device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
with torch_accelerator_module.device(input.device):
qinput, scale = act_quant(input, self.block_size[1]) qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton( output = w8a8_block_fp8_matmul_triton(
qinput, qinput,
@ -343,9 +345,9 @@ class FP8Linear(nn.Linear):
self.block_size, self.block_size,
output_dtype=input.dtype, output_dtype=input.dtype,
) )
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding # preceding operations are ready before proceeding
torch.cuda.synchronize() torch_accelerator_module.synchronize()
if self.bias is not None: if self.bias is not None:
output = output + self.bias output = output + self.bias
return output.to(dtype=input.dtype) return output.to(dtype=input.dtype)

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ..utils import is_accelerate_available, is_torch_available, logging from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
from .base import HfQuantizer from .base import HfQuantizer
from .quantizers_utils import get_module_from_name from .quantizers_utils import get_module_from_name
@ -44,9 +44,10 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
"please make sure the weights are in PyTorch format." "please make sure the weights are in PyTorch format."
) )
if not torch.cuda.is_available(): if not (torch.cuda.is_available() or is_torch_xpu_available()):
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.") raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.")
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability() compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability major, minor = compute_capability
if (major < 8) or (major == 8 and minor < 9): if (major < 8) or (major == 8 and minor < 9):

View File

@ -23,8 +23,9 @@ from parameterized import parameterized
from transformers import AutoTokenizer, GPT2TokenizerFast from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers.testing_utils import ( from transformers.testing_utils import (
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_torchaudio, require_torchaudio,
torch_device,
) )
from transformers.utils import is_torchaudio_available from transformers.utils import is_torchaudio_available
@ -195,7 +196,7 @@ class GraniteSpeechProcessorTest(unittest.TestCase):
assert num_calculated_features == [90, 171] assert num_calculated_features == [90, 171]
assert sum(num_expected_features) == num_audio_tokens assert sum(num_expected_features) == num_audio_tokens
@require_torch_gpu @require_torch_accelerator
def test_device_override(self): def test_device_override(self):
"""Ensure that we regardless of the processing device, the tensors """Ensure that we regardless of the processing device, the tensors
produced are on the CPU. produced are on the CPU.
@ -214,7 +215,7 @@ class GraniteSpeechProcessorTest(unittest.TestCase):
text=f"{processor.audio_token} Can you transcribe this audio?", text=f"{processor.audio_token} Can you transcribe this audio?",
audio=wav, audio=wav,
return_tensors="pt", return_tensors="pt",
device="cuda", device=torch_device,
) )
assert inputs["input_features"].device.type == "cpu" assert inputs["input_features"].device.type == "cpu"

View File

@ -18,11 +18,13 @@ import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache,
require_accelerate, require_accelerate,
require_read_token, require_read_token,
require_torch_gpu, require_torch_accelerator,
require_torch_multi_gpu, require_torch_multi_accelerator,
slow, slow,
torch_device,
) )
from transformers.utils import is_accelerate_available, is_torch_available from transformers.utils import is_accelerate_available, is_torch_available
@ -34,7 +36,7 @@ if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
@require_torch_gpu @require_torch_accelerator
class FineGrainedFP8ConfigTest(unittest.TestCase): class FineGrainedFP8ConfigTest(unittest.TestCase):
def test_to_dict(self): def test_to_dict(self):
""" """
@ -60,13 +62,13 @@ class FineGrainedFP8ConfigTest(unittest.TestCase):
@slow @slow
@require_accelerate @require_accelerate
@require_read_token @require_read_token
@require_torch_gpu @require_torch_accelerator
class FP8QuantizerTest(unittest.TestCase): class FP8QuantizerTest(unittest.TestCase):
model_name = "meta-llama/Llama-3.2-1B" model_name = "meta-llama/Llama-3.2-1B"
input_text = "Once upon a time" input_text = "Once upon a time"
max_new_tokens = 10 max_new_tokens = 10
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich." EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
device_map = "cuda" device_map = torch_device
offload_device_map = { offload_device_map = {
"model.embed_tokens": 0, "model.embed_tokens": 0,
"model.layers.0": 0, "model.layers.0": 0,
@ -103,7 +105,7 @@ class FP8QuantizerTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
gc.collect() gc.collect()
def test_quantized_model_conversion(self): def test_quantized_model_conversion(self):
@ -151,7 +153,8 @@ class FP8QuantizerTest(unittest.TestCase):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
self.assertEqual(output_tokens, self.EXPECTED_OUTPUT)
def test_save_pretrained(self): def test_save_pretrained(self):
""" """
@ -188,11 +191,12 @@ class FP8QuantizerTest(unittest.TestCase):
) )
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32)) self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))
@require_torch_multi_gpu @require_torch_multi_accelerator
def test_quantized_model_multi_gpu(self): def test_quantized_model_multi_accelerator(self):
""" """
Simple test that checks if the quantized model is working properly with multiple GPUs Simple test that checks if the quantized model is working properly with multiple accelerators
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs; or set ZE_AFFINITY_MASK=0,1 if you
have more than 2 XPUs.
""" """
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
quantization_config = FineGrainedFP8Config() quantization_config = FineGrainedFP8Config()
@ -204,8 +208,8 @@ class FP8QuantizerTest(unittest.TestCase):
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu @require_torch_multi_accelerator
def test_save_pretrained_multi_gpu(self): def test_save_pretrained_multi_accelerators(self):
""" """
Simple test that checks if the quantized model is working properly after being saved and loaded Simple test that checks if the quantized model is working properly after being saved and loaded
""" """
@ -245,9 +249,9 @@ class FP8QuantizerTest(unittest.TestCase):
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_gpu @require_torch_accelerator
class FP8LinearTest(unittest.TestCase): class FP8LinearTest(unittest.TestCase):
device = "cuda" device = torch_device
@unittest.skipIf( @unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9, torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,