mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Adding FP8 Quantization to transformers (#36026)
* first commit * adding kernels * fix create_quantized_param * fix quantization logic * end2end * fix style * fix imports * fix consistency * update * fix style * update * udpate after review * make style * update * update * fix * update * fix docstring * update * update after review * update * fix scheme * update * update * fix * update * fix docstring * add source * fix test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
c82319b493
commit
efe72fe21f
@ -185,6 +185,8 @@
|
||||
title: BitNet
|
||||
- local: quantization/compressed_tensors
|
||||
title: compressed-tensors
|
||||
- local: quantization/finegrained_fp8
|
||||
title: Fine-grained FP8
|
||||
- local: quantization/contribute
|
||||
title: Contribute new quantization method
|
||||
title: Quantization Methods
|
||||
|
@ -80,3 +80,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
## BitNetConfig
|
||||
|
||||
[[autodoc]] BitNetConfig
|
||||
|
||||
## FineGrainedFP8Config
|
||||
|
||||
[[autodoc]] FineGrainedFP8Config
|
||||
|
62
docs/source/en/quantization/finegrained_fp8.md
Normal file
62
docs/source/en/quantization/finegrained_fp8.md
Normal file
@ -0,0 +1,62 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Fine-grained FP8
|
||||
|
||||
With FP8 quantization method, you can quantize your model in FP8 (W8A8):
|
||||
- the weights will be quantized in 8bit (FP8) per 2D block (e.g. weight_block_size=(128, 128)) which is inspired from the deepseek implementation
|
||||
- Activations are quantized to 8 bits (FP8) per group per token, with the group value matching that of the weights in the input channels (128 by default)
|
||||
|
||||
It's implemented to add support for DeepSeek-V3 and DeepSeek-R1 models, you can see the paper [here](https://arxiv.org/pdf/2412.19437), and the image below explains the quantization scheme :
|
||||
|
||||

|
||||
|
||||
> [!TIP]
|
||||
> You need a GPU with compute capability>=9 (e.g. H100)
|
||||
|
||||
Before you begin, make sure the following libraries are installed with their latest version:
|
||||
|
||||
```bash
|
||||
pip install --upgrade accelerate torch
|
||||
```
|
||||
> [!TIP]
|
||||
> You need to install a torch version compatible with the cuda version of your GPU.
|
||||
|
||||
|
||||
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
|
||||
|
||||
```py
|
||||
from transformers import FP8Config, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
quantization_config = FP8Config()
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".
|
||||
|
||||
```py
|
||||
quant_path = "/path/to/save/quantized/model"
|
||||
model.save_pretrained(quant_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
|
||||
```
|
@ -61,7 +61,7 @@ Use the table below to help you decide which quantization method to use.
|
||||
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
|
||||
|
||||
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
|
||||
<Tip>
|
||||
|
||||
**1:** bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links.
|
||||
|
@ -1024,6 +1024,7 @@ _import_structure = {
|
||||
"CompressedTensorsConfig",
|
||||
"EetqConfig",
|
||||
"FbgemmFp8Config",
|
||||
"FineGrainedFP8Config",
|
||||
"GPTQConfig",
|
||||
"HiggsConfig",
|
||||
"HqqConfig",
|
||||
@ -6196,6 +6197,7 @@ if TYPE_CHECKING:
|
||||
CompressedTensorsConfig,
|
||||
EetqConfig,
|
||||
FbgemmFp8Config,
|
||||
FineGrainedFP8Config,
|
||||
GPTQConfig,
|
||||
HiggsConfig,
|
||||
HqqConfig,
|
||||
|
@ -54,6 +54,7 @@ _import_structure = {
|
||||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
|
||||
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
|
||||
"fsdp": ["is_fsdp_managed_module"],
|
||||
"ggml": [
|
||||
"GGUF_CONFIG_MAPPING",
|
||||
@ -157,6 +158,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .eetq import replace_with_eetq_linear
|
||||
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
|
||||
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
|
||||
from .fsdp import is_fsdp_managed_module
|
||||
from .ggml import (
|
||||
GGUF_CONFIG_MAPPING,
|
||||
|
420
src/transformers/integrations/finegrained_fp8.py
Normal file
420
src/transformers/integrations/finegrained_fp8.py
Normal file
@ -0,0 +1,420 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.nn import functional as F
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||
@triton.jit
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||
s = tl.max(tl.abs(x)) / 448.0
|
||||
y = x / s
|
||||
y = y.to(y_ptr.dtype.element_ty)
|
||||
tl.store(y_ptr + offs, y)
|
||||
tl.store(s_ptr + pid, s)
|
||||
|
||||
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.is_contiguous()
|
||||
assert x.shape[-1] % block_size == 0
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
||||
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||
return y, s
|
||||
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
# Pointers to inputs and output
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
# Shape for matmul
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# Block size for block-wise quantization
|
||||
group_n,
|
||||
group_k,
|
||||
# Stride for inputs and output
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_As_m,
|
||||
stride_As_k,
|
||||
stride_Bs_k,
|
||||
stride_Bs_n,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Triton-accelerated function used to perform linear operations (dot
|
||||
product) on input tensors `A` and `B` with block-wise quantization, and
|
||||
store the result in output tensor `C`.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
As_ptrs = As + offs_am * stride_As_m
|
||||
offs_bsn = offs_bn // group_n
|
||||
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if C.dtype.element_ty == tl.bfloat16:
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
elif C.dtype.element_ty == tl.float16:
|
||||
c = accumulator.to(tl.float16)
|
||||
else:
|
||||
c = accumulator.to(tl.float32)
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_triton(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization.
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
Args:
|
||||
A: The input tensor, e.g., activation.
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization. It should
|
||||
be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
BLOCK_SIZE_M = 128
|
||||
if M < BLOCK_SIZE_M:
|
||||
BLOCK_SIZE_M = triton.next_power_of_2(M)
|
||||
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
|
||||
BLOCK_SIZE_K = block_k
|
||||
assert block_k % BLOCK_SIZE_K == 0
|
||||
BLOCK_SIZE_N = block_n
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
|
||||
|
||||
_w8a8_block_fp8_matmul[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
GROUP_SIZE_M=8,
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
# Python version of the above triton function, it's much slower than the triton version, for testing
|
||||
@torch.compile
|
||||
def w8a8_block_fp8_matmul_compile(
|
||||
input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
|
||||
weight_q: torch.Tensor, # [out_features, hidden_dim]
|
||||
input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
|
||||
weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
|
||||
block_size: Optional[Tuple[int, int]] = None, # (M=128, N=128) for weights for example
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs blocked matrix multiplication with FP8 quantized matrices.
|
||||
|
||||
Args:
|
||||
input_q: Quantized input tensor with 1x128 block quantization
|
||||
weight_q: Quantized weight tensor with 128x128 block quantization
|
||||
input_scale: Scaling factors for input blocks
|
||||
weight_scale: Scaling factors for weight blocks
|
||||
block_size: Tuple of (M, N) for weight block dimensions
|
||||
output_dtype: Desired output dtype
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim = input_q.shape if input_q.ndim == 3 else (1, input_q.shape[0], input_q.shape[1])
|
||||
out_features = weight_q.shape[0]
|
||||
|
||||
# Reshape input for batched matmul
|
||||
input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
|
||||
input_scale_reshaped = input_scale.view(input_scale.shape[0], -1) # [batch*seq_len, 1]
|
||||
# Calculate number of blocks
|
||||
num_weight_blocks_m = out_features // block_size[0]
|
||||
num_weight_blocks_n = hidden_dim // block_size[1]
|
||||
|
||||
output = torch.zeros((batch_size * seq_len, out_features), dtype=torch.float32, device=input_q.device)
|
||||
|
||||
for i in range(num_weight_blocks_m):
|
||||
m_start = i * block_size[0]
|
||||
m_end = m_start + block_size[0]
|
||||
|
||||
for j in range(num_weight_blocks_n):
|
||||
n_start = j * block_size[1]
|
||||
n_end = n_start + block_size[1]
|
||||
|
||||
# Extract current blocks
|
||||
input_block = input_reshaped[:, n_start:n_end]
|
||||
weight_block = weight_q[m_start:m_end, n_start:n_end]
|
||||
|
||||
# Get corresponding scales
|
||||
curr_input_scale = input_scale_reshaped[:, j : j + 1] # [batch*seq_len, 1]
|
||||
curr_weight_scale = weight_scale[i, j] # scalar
|
||||
|
||||
block_result = (
|
||||
torch._scaled_mm(
|
||||
input_block,
|
||||
weight_block.t(),
|
||||
scale_a=torch.tensor(1, dtype=torch.float32, device=input_q.device),
|
||||
scale_b=curr_weight_scale,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
* curr_input_scale
|
||||
)
|
||||
|
||||
output[:, m_start:m_end] += block_result
|
||||
|
||||
output = output.view(batch_size, seq_len, out_features)
|
||||
|
||||
return output.to(output_dtype)
|
||||
|
||||
|
||||
class FP8Linear(nn.Module):
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
dtype=None,
|
||||
block_size: Optional[Tuple[int, int]] = None,
|
||||
device=None,
|
||||
activation_scheme="dynamic",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer("weight", torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
|
||||
|
||||
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
|
||||
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
|
||||
self.register_buffer(
|
||||
"weight_scale_inv", torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
|
||||
)
|
||||
|
||||
self.block_size = block_size
|
||||
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.out_features))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.weight.element_size() > 1:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
else:
|
||||
# Context manager used to switch among the available cuda devices
|
||||
with torch.cuda.device(input.device):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# preceding operations are ready before proceeding
|
||||
torch.cuda.synchronize(device=input.device)
|
||||
with torch.cuda.device(input.device):
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
self.weight,
|
||||
scale,
|
||||
self.weight_scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
torch.cuda.synchronize(device=input.device)
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
def _replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""Replace Linear layers with FP8Linear."""
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
model._modules[name] = FP8Linear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
)
|
||||
has_been_replaced = True
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fp8_linear(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
|
||||
current_key_name.pop(-1)
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
quantization_config=None,
|
||||
):
|
||||
"""Helper function to replace model layers with FP8 versions."""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||
|
||||
model, has_been_replaced = _replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model using fp8 but no linear modules were found in your model."
|
||||
" Please double check your model architecture."
|
||||
)
|
||||
|
||||
return model
|
@ -24,6 +24,7 @@ from ..utils.quantization_config import (
|
||||
CompressedTensorsConfig,
|
||||
EetqConfig,
|
||||
FbgemmFp8Config,
|
||||
FineGrainedFP8Config,
|
||||
GPTQConfig,
|
||||
HiggsConfig,
|
||||
HqqConfig,
|
||||
@ -41,6 +42,7 @@ from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
|
||||
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
|
||||
from .quantizer_eetq import EetqHfQuantizer
|
||||
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
|
||||
from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
|
||||
from .quantizer_gptq import GptqHfQuantizer
|
||||
from .quantizer_higgs import HiggsHfQuantizer
|
||||
from .quantizer_hqq import HqqHfQuantizer
|
||||
@ -64,6 +66,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
"bitnet": BitNetHfQuantizer,
|
||||
"vptq": VptqHfQuantizer,
|
||||
"fp8": FineGrainedFP8HfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@ -81,6 +84,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"torchao": TorchAoConfig,
|
||||
"bitnet": BitNetConfig,
|
||||
"vptq": VptqConfig,
|
||||
"fp8": FineGrainedFP8Config,
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
207
src/transformers/quantizers/quantizer_finegrained_fp8.py
Normal file
207
src/transformers/quantizers/quantizer_finegrained_fp8.py
Normal file
@ -0,0 +1,207 @@
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
from .base import HfQuantizer
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FineGrainedFP8HfQuantizer(HfQuantizer):
|
||||
"""
|
||||
FP8 quantization implementation supporting both standard and MoE models.
|
||||
Supports both e4m3fn formats based on platform.
|
||||
"""
|
||||
|
||||
requires_parameters_quantization = True
|
||||
requires_calibration = False
|
||||
required_packages = ["accelerate"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_torch_available() or version.parse(importlib.metadata.version("torch")) < version.parse("2.1.0"):
|
||||
raise ImportError(
|
||||
"Using fp8 quantization requires torch >= 2.1.0"
|
||||
"Please install the latest version of torch ( pip install --upgrade torch )"
|
||||
)
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
|
||||
|
||||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
||||
raise ValueError(
|
||||
"Converting into FP8 weights from tf/flax weights is currently not supported, "
|
||||
"please make sure the weights are in PyTorch format."
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.")
|
||||
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, minor = compute_capability
|
||||
if major < 9:
|
||||
raise ValueError(
|
||||
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
|
||||
)
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if device_map is None:
|
||||
logger.warning_once(
|
||||
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
|
||||
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
|
||||
)
|
||||
elif device_map is not None:
|
||||
if (
|
||||
not self.pre_quantized
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
raise ValueError(
|
||||
"You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
|
||||
"This is not supported when the model is quantized on the fly. "
|
||||
"Please use a quantized checkpoint or remove the cpu/disk device from the device_map."
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained")
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Quantizes weights to FP8 format using Block-wise quantization
|
||||
"""
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
set_module_tensor_to_device(model, param_name, target_device, param_value)
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
# Get FP8 min/max values
|
||||
fp8_min = torch.finfo(torch.float8_e4m3fn).min
|
||||
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
block_size_m, block_size_n = self.quantization_config.weight_block_size
|
||||
|
||||
rows, cols = param_value.shape[-2:]
|
||||
|
||||
if rows % block_size_m != 0 or cols % block_size_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
|
||||
)
|
||||
param_value_orig_shape = param_value.shape
|
||||
|
||||
param_value = param_value.reshape(
|
||||
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
|
||||
).permute(0, 1, 3, 2, 4)
|
||||
|
||||
# Calculate scaling factor for each block
|
||||
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
|
||||
scale = fp8_max / max_abs
|
||||
scale_orig_shape = scale.shape
|
||||
scale = scale.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Quantize the weights
|
||||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
|
||||
# Reshape back to matrix shape
|
||||
quantized_param = quantized_param.reshape(param_value_orig_shape)
|
||||
|
||||
# Reshape scale to match the number of blocks
|
||||
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
|
||||
|
||||
module._buffers[tensor_name] = quantized_param.to(target_device)
|
||||
module._buffers["weight_scale_inv"] = scale.to(target_device)
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations.finegrained_fp8 import FP8Linear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if isinstance(module, FP8Linear):
|
||||
if self.pre_quantized or tensor_name == "bias":
|
||||
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expect quantized weights but got an unquantized weight")
|
||||
return False
|
||||
else:
|
||||
if tensor_name == "weight_scale_inv":
|
||||
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
device_map,
|
||||
modules_to_not_convert: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
from ..integrations.finegrained_fp8 import replace_with_fp8_linear
|
||||
|
||||
self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert
|
||||
|
||||
if self.quantization_config.modules_to_not_convert:
|
||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
|
||||
|
||||
model = replace_with_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
return model
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
from ..integrations import FP8Linear
|
||||
|
||||
not_missing_keys = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, FP8Linear):
|
||||
for missing in missing_keys:
|
||||
if (
|
||||
(name in missing or name in f"{prefix}.{missing}")
|
||||
and not missing.endswith(".weight")
|
||||
and not missing.endswith(".bias")
|
||||
):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False
|
@ -21,7 +21,7 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -56,6 +56,7 @@ class QuantizationMethod(str, Enum):
|
||||
FBGEMM_FP8 = "fbgemm_fp8"
|
||||
TORCHAO = "torchao"
|
||||
BITNET = "bitnet"
|
||||
FP8 = "fp8"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
@ -1548,3 +1549,43 @@ class BitNetConfig(QuantizationConfigMixin):
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineGrainedFP8Config(QuantizationConfigMixin):
|
||||
"""
|
||||
FineGrainedFP8Config is a configuration class for fine-grained FP8 quantization used mainly for deepseek models.
|
||||
|
||||
Args:
|
||||
activation_scheme (`str`, *optional*, defaults to `"dynamic"`):
|
||||
The scheme used for activation, the defaults and only support scheme for now is "dynamic".
|
||||
weight_block_size (`typing.Tuple[int, int]`, *optional*, defaults to `(128, 128)`):
|
||||
The size of the weight blocks for quantization, default is (128, 128).
|
||||
modules_to_not_convert (`list`, *optional*):
|
||||
A list of module names that should not be converted during quantization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
weight_block_size: Tuple[int, int] = (128, 128),
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.FP8
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.activation_scheme = activation_scheme
|
||||
self.weight_block_size = weight_block_size
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
self.activation_scheme = self.activation_scheme.lower()
|
||||
if self.activation_scheme not in ["dynamic"]:
|
||||
raise ValueError(f"Activation scheme {self.activation_scheme} not supported")
|
||||
if len(self.weight_block_size) != 2:
|
||||
raise ValueError("weight_block_size must be a tuple of two integers")
|
||||
if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0:
|
||||
raise ValueError("weight_block_size must be a tuple of two positive integers")
|
||||
|
0
tests/quantization/finegrained_fp8/__init__.py
Normal file
0
tests/quantization/finegrained_fp8/__init__.py
Normal file
273
tests/quantization/finegrained_fp8/test_fp8.py
Normal file
273
tests/quantization/finegrained_fp8/test_fp8.py
Normal file
@ -0,0 +1,273 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class FineGrainedFP8ConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
|
||||
"""
|
||||
quantization_config = FineGrainedFP8Config()
|
||||
config_to_dict = quantization_config.to_dict()
|
||||
|
||||
for key in config_to_dict:
|
||||
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
|
||||
|
||||
def test_from_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
|
||||
"""
|
||||
dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "fp8"}
|
||||
quantization_config = FineGrainedFP8Config.from_dict(dict)
|
||||
|
||||
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
|
||||
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
|
||||
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
class FP8QuantizerTest(unittest.TestCase):
|
||||
model_name = "meta-llama/Llama-3.2-1B"
|
||||
input_text = "Once upon a time"
|
||||
max_new_tokens = 10
|
||||
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
|
||||
device_map = "cuda"
|
||||
offload_device_map = {
|
||||
"model.embed_tokens": 0,
|
||||
"model.layers.0": 0,
|
||||
"model.layers.1": 0,
|
||||
"model.layers.2": 0,
|
||||
"model.layers.3": 0,
|
||||
"model.layers.4": 0,
|
||||
"model.layers.5": 0,
|
||||
"model.layers.6": 0,
|
||||
"model.layers.7": "cpu",
|
||||
"model.layers.8": "cpu",
|
||||
"model.layers.9": "cpu",
|
||||
"model.layers.10": "cpu",
|
||||
"model.layers.11": "cpu",
|
||||
"model.layers.12": "cpu",
|
||||
"model.layers.13": "cpu",
|
||||
"model.layers.14": "cpu",
|
||||
"model.layers.15": "cpu",
|
||||
"model.rotary_emb": "disk",
|
||||
"model.norm": "disk",
|
||||
"lm_head": 0,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
cls.quantization_config = FineGrainedFP8Config()
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name, device_map=cls.device_map, quantization_config=cls.quantization_config
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly
|
||||
"""
|
||||
|
||||
from transformers.integrations import FP8Linear, replace_with_fp8_linear
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
|
||||
quantization_config = FineGrainedFP8Config()
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
|
||||
nb_linears = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
nb_linears += 1
|
||||
|
||||
model = replace_with_fp8_linear(model, quantization_config=quantization_config)
|
||||
nb_fp8_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, FP8Linear):
|
||||
nb_fp8_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 1, nb_fp8_linear)
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
quantization_config = FineGrainedFP8Config(modules_to_not_convert=["fc1"])
|
||||
model = replace_with_fp8_linear(model, quantization_config=quantization_config)
|
||||
nb_fp8_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, FP8Linear):
|
||||
nb_fp8_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 25, nb_fp8_linear)
|
||||
|
||||
def test_quantized_model(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_weight_and_weight_scale_inv(self):
|
||||
"""
|
||||
Simple test that checks if the weight and weight_scale_inv are working properly
|
||||
"""
|
||||
weight = self.quantized_model.model.layers[0].self_attn.q_proj.weight
|
||||
weight_scale_inv = self.quantized_model.model.layers[0].self_attn.q_proj.weight_scale_inv
|
||||
self.assertEqual(weight.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(weight_scale_inv.dtype, torch.float32)
|
||||
self.assertEqual(weight.shape, (weight_scale_inv.shape[0] * 128, weight_scale_inv.shape[1] * 128))
|
||||
|
||||
def test_block_size(self):
|
||||
"""
|
||||
Simple test that checks if the block size is working properly
|
||||
"""
|
||||
self.assertEqual(self.quantized_model.config.quantization_config.weight_block_size, (128, 128))
|
||||
quantization_config = FineGrainedFP8Config(weight_block_size=(32, 32))
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, device_map=self.device_map, quantization_config=quantization_config
|
||||
)
|
||||
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
quantization_config = FineGrainedFP8Config()
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, device_map="auto", quantization_config=quantization_config
|
||||
)
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_save_pretrained_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
|
||||
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_quantized_model_offload(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model returns an error when loading with cpu/disk offloaded
|
||||
"""
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
|
||||
):
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, device_map=self.offload_device_map, quantization_config=self.quantization_config
|
||||
)
|
||||
|
||||
def test_save_pretrained_offload(self):
|
||||
"""
|
||||
Simple test that checks if the saved quantized model is working properly cpu/disk offload
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map)
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class FP8LinearTest(unittest.TestCase):
|
||||
device = "cuda"
|
||||
|
||||
def test_linear_preserves_shape(self):
|
||||
"""
|
||||
Test that FP8Linear preserves shape when in_features == out_features.
|
||||
"""
|
||||
from transformers.integrations import FP8Linear
|
||||
|
||||
linear = FP8Linear(256, 256, block_size=(128, 128), device=self.device)
|
||||
x = torch.rand((1, 5, 256)).to(self.device)
|
||||
|
||||
x_ = linear(x)
|
||||
self.assertEqual(x_.shape, x.shape)
|
||||
|
||||
def test_linear_with_diff_feature_size_preserves_shape(self):
|
||||
"""
|
||||
Test that FP8Linear generates the correct shape when in_features != out_features.
|
||||
"""
|
||||
from transformers.integrations import FP8Linear
|
||||
|
||||
linear = FP8Linear(128, 256, block_size=(128, 128), device=self.device)
|
||||
x = torch.rand((1, 5, 128)).to(self.device)
|
||||
|
||||
x_ = linear(x)
|
||||
self.assertEqual(x_.shape, (1, 5, 256))
|
Loading…
Reference in New Issue
Block a user