mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
enable finegrained_fp8 and granite_speech cases on XPU (#38036)
* enable finegrained_fp8 cases on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> * change back to auto Signed-off-by: Yao Matrix <matrix.yao@intel.com> * rename per comments Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com> Signed-off-by: Matrix Yao <matrix.yao@intel.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
b311a3f506
commit
9b5ce556aa
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -332,8 +332,10 @@ class FP8Linear(nn.Linear):
|
|||||||
if self.weight.element_size() > 1:
|
if self.weight.element_size() > 1:
|
||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
else:
|
else:
|
||||||
# Context manager used to switch among the available cuda devices
|
# Context manager used to switch among the available accelerators
|
||||||
with torch.cuda.device(input.device):
|
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||||
|
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||||
|
with torch_accelerator_module.device(input.device):
|
||||||
qinput, scale = act_quant(input, self.block_size[1])
|
qinput, scale = act_quant(input, self.block_size[1])
|
||||||
output = w8a8_block_fp8_matmul_triton(
|
output = w8a8_block_fp8_matmul_triton(
|
||||||
qinput,
|
qinput,
|
||||||
@ -343,9 +345,9 @@ class FP8Linear(nn.Linear):
|
|||||||
self.block_size,
|
self.block_size,
|
||||||
output_dtype=input.dtype,
|
output_dtype=input.dtype,
|
||||||
)
|
)
|
||||||
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
|
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
||||||
# preceding operations are ready before proceeding
|
# preceding operations are ready before proceeding
|
||||||
torch.cuda.synchronize()
|
torch_accelerator_module.synchronize()
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
output = output + self.bias
|
output = output + self.bias
|
||||||
return output.to(dtype=input.dtype)
|
return output.to(dtype=input.dtype)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
|
||||||
from .base import HfQuantizer
|
from .base import HfQuantizer
|
||||||
from .quantizers_utils import get_module_from_name
|
from .quantizers_utils import get_module_from_name
|
||||||
|
|
||||||
@ -44,16 +44,17 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
|||||||
"please make sure the weights are in PyTorch format."
|
"please make sure the weights are in PyTorch format."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
if not (torch.cuda.is_available() or is_torch_xpu_available()):
|
||||||
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.")
|
raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.")
|
||||||
|
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
if torch.cuda.is_available():
|
||||||
major, minor = compute_capability
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
if (major < 8) or (major == 8 and minor < 9):
|
major, minor = compute_capability
|
||||||
raise ValueError(
|
if (major < 8) or (major == 8 and minor < 9):
|
||||||
"FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
|
raise ValueError(
|
||||||
f", actual = `{major}.{minor}`"
|
"FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
|
||||||
)
|
f", actual = `{major}.{minor}`"
|
||||||
|
)
|
||||||
|
|
||||||
device_map = kwargs.get("device_map", None)
|
device_map = kwargs.get("device_map", None)
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
@ -217,7 +218,7 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
|
|||||||
|
|
||||||
config.base_model_tp_plan = text_plan
|
config.base_model_tp_plan = text_plan
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def is_serializable(self, safe_serialization=None):
|
def is_serializable(self, safe_serialization=None):
|
||||||
return True
|
return True
|
||||||
|
@ -23,8 +23,9 @@ from parameterized import parameterized
|
|||||||
from transformers import AutoTokenizer, GPT2TokenizerFast
|
from transformers import AutoTokenizer, GPT2TokenizerFast
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_accelerator,
|
||||||
require_torchaudio,
|
require_torchaudio,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import is_torchaudio_available
|
from transformers.utils import is_torchaudio_available
|
||||||
|
|
||||||
@ -195,7 +196,7 @@ class GraniteSpeechProcessorTest(unittest.TestCase):
|
|||||||
assert num_calculated_features == [90, 171]
|
assert num_calculated_features == [90, 171]
|
||||||
assert sum(num_expected_features) == num_audio_tokens
|
assert sum(num_expected_features) == num_audio_tokens
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_device_override(self):
|
def test_device_override(self):
|
||||||
"""Ensure that we regardless of the processing device, the tensors
|
"""Ensure that we regardless of the processing device, the tensors
|
||||||
produced are on the CPU.
|
produced are on the CPU.
|
||||||
@ -214,7 +215,7 @@ class GraniteSpeechProcessorTest(unittest.TestCase):
|
|||||||
text=f"{processor.audio_token} Can you transcribe this audio?",
|
text=f"{processor.audio_token} Can you transcribe this audio?",
|
||||||
audio=wav,
|
audio=wav,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
device="cuda",
|
device=torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert inputs["input_features"].device.type == "cpu"
|
assert inputs["input_features"].device.type == "cpu"
|
||||||
|
@ -18,11 +18,13 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch_gpu,
|
require_torch_accelerator,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_accelerator,
|
||||||
slow,
|
slow,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import is_accelerate_available, is_torch_available
|
from transformers.utils import is_accelerate_available, is_torch_available
|
||||||
|
|
||||||
@ -34,7 +36,7 @@ if is_accelerate_available():
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
class FineGrainedFP8ConfigTest(unittest.TestCase):
|
class FineGrainedFP8ConfigTest(unittest.TestCase):
|
||||||
def test_to_dict(self):
|
def test_to_dict(self):
|
||||||
"""
|
"""
|
||||||
@ -60,13 +62,13 @@ class FineGrainedFP8ConfigTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@require_read_token
|
@require_read_token
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
class FP8QuantizerTest(unittest.TestCase):
|
class FP8QuantizerTest(unittest.TestCase):
|
||||||
model_name = "meta-llama/Llama-3.2-1B"
|
model_name = "meta-llama/Llama-3.2-1B"
|
||||||
input_text = "Once upon a time"
|
input_text = "Once upon a time"
|
||||||
max_new_tokens = 10
|
max_new_tokens = 10
|
||||||
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
|
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
|
||||||
device_map = "cuda"
|
device_map = torch_device
|
||||||
offload_device_map = {
|
offload_device_map = {
|
||||||
"model.embed_tokens": 0,
|
"model.embed_tokens": 0,
|
||||||
"model.layers.0": 0,
|
"model.layers.0": 0,
|
||||||
@ -103,7 +105,7 @@ class FP8QuantizerTest(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def test_quantized_model_conversion(self):
|
def test_quantized_model_conversion(self):
|
||||||
@ -151,7 +153,8 @@ class FP8QuantizerTest(unittest.TestCase):
|
|||||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||||
|
|
||||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
self.assertEqual(output_tokens, self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
def test_save_pretrained(self):
|
def test_save_pretrained(self):
|
||||||
"""
|
"""
|
||||||
@ -188,11 +191,12 @@ class FP8QuantizerTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))
|
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_quantized_model_multi_gpu(self):
|
def test_quantized_model_multi_accelerator(self):
|
||||||
"""
|
"""
|
||||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
Simple test that checks if the quantized model is working properly with multiple accelerators
|
||||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs
|
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs; or set ZE_AFFINITY_MASK=0,1 if you
|
||||||
|
have more than 2 XPUs.
|
||||||
"""
|
"""
|
||||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||||
quantization_config = FineGrainedFP8Config()
|
quantization_config = FineGrainedFP8Config()
|
||||||
@ -204,8 +208,8 @@ class FP8QuantizerTest(unittest.TestCase):
|
|||||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_save_pretrained_multi_gpu(self):
|
def test_save_pretrained_multi_accelerators(self):
|
||||||
"""
|
"""
|
||||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||||
"""
|
"""
|
||||||
@ -245,9 +249,9 @@ class FP8QuantizerTest(unittest.TestCase):
|
|||||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
class FP8LinearTest(unittest.TestCase):
|
class FP8LinearTest(unittest.TestCase):
|
||||||
device = "cuda"
|
device = torch_device
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
|
||||||
|
Loading…
Reference in New Issue
Block a user