Adding FP8 Quantization to transformers (#36026)

* first commit

* adding kernels

* fix create_quantized_param

* fix quantization logic

* end2end

* fix style

* fix imports

* fix consistency

* update

* fix style

* update

* udpate after review

* make style

* update

* update

* fix

* update

* fix docstring

* update

* update after review

* update

* fix scheme

* update

* update

* fix

* update

* fix docstring

* add source

* fix test

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Mohamed Mekkouri 2025-02-13 13:01:19 +01:00 committed by GitHub
parent c82319b493
commit efe72fe21f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1019 additions and 2 deletions

View File

@ -185,6 +185,8 @@
title: BitNet
- local: quantization/compressed_tensors
title: compressed-tensors
- local: quantization/finegrained_fp8
title: Fine-grained FP8
- local: quantization/contribute
title: Contribute new quantization method
title: Quantization Methods

View File

@ -80,3 +80,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## BitNetConfig
[[autodoc]] BitNetConfig
## FineGrainedFP8Config
[[autodoc]] FineGrainedFP8Config

View File

@ -0,0 +1,62 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Fine-grained FP8
With FP8 quantization method, you can quantize your model in FP8 (W8A8):
- the weights will be quantized in 8bit (FP8) per 2D block (e.g. weight_block_size=(128, 128)) which is inspired from the deepseek implementation
- Activations are quantized to 8 bits (FP8) per group per token, with the group value matching that of the weights in the input channels (128 by default)
It's implemented to add support for DeepSeek-V3 and DeepSeek-R1 models, you can see the paper [here](https://arxiv.org/pdf/2412.19437), and the image below explains the quantization scheme :
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/b7b3b34bf826a6423ea82ffc57ecac80c46c3c76/transformers/quantization/quantization_deepseek.png)
> [!TIP]
> You need a GPU with compute capability>=9 (e.g. H100)
Before you begin, make sure the following libraries are installed with their latest version:
```bash
pip install --upgrade accelerate torch
```
> [!TIP]
> You need to install a torch version compatible with the cuda version of your GPU.
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
```py
from transformers import FP8Config, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = FP8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".
```py
quant_path = "/path/to/save/quantized/model"
model.save_pretrained(quant_path)
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
```

View File

@ -61,7 +61,7 @@ Use the table below to help you decide which quantization method to use.
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
<Tip>
**1:** bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links.

View File

@ -1024,6 +1024,7 @@ _import_structure = {
"CompressedTensorsConfig",
"EetqConfig",
"FbgemmFp8Config",
"FineGrainedFP8Config",
"GPTQConfig",
"HiggsConfig",
"HqqConfig",
@ -6196,6 +6197,7 @@ if TYPE_CHECKING:
CompressedTensorsConfig,
EetqConfig,
FbgemmFp8Config,
FineGrainedFP8Config,
GPTQConfig,
HiggsConfig,
HqqConfig,

View File

@ -54,6 +54,7 @@ _import_structure = {
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"ggml": [
"GGUF_CONFIG_MAPPING",
@ -157,6 +158,7 @@ if TYPE_CHECKING:
)
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
from .fsdp import is_fsdp_managed_module
from .ggml import (
GGUF_CONFIG_MAPPING,

View File

@ -0,0 +1,420 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
import torch
import torch.nn as nn
import triton
import triton.language as tl
from torch.nn import functional as F
if is_accelerate_available():
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.0
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.is_contiguous()
assert x.shape[-1] % block_size == 0
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
def grid(meta):
return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul_triton(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
BLOCK_SIZE_M = 128
if M < BLOCK_SIZE_M:
BLOCK_SIZE_M = triton.next_power_of_2(M)
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
BLOCK_SIZE_K = block_k
assert block_k % BLOCK_SIZE_K == 0
BLOCK_SIZE_N = block_n
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=8,
)
return C
# Python version of the above triton function, it's much slower than the triton version, for testing
@torch.compile
def w8a8_block_fp8_matmul_compile(
input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
weight_q: torch.Tensor, # [out_features, hidden_dim]
input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
block_size: Optional[Tuple[int, int]] = None, # (M=128, N=128) for weights for example
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Performs blocked matrix multiplication with FP8 quantized matrices.
Args:
input_q: Quantized input tensor with 1x128 block quantization
weight_q: Quantized weight tensor with 128x128 block quantization
input_scale: Scaling factors for input blocks
weight_scale: Scaling factors for weight blocks
block_size: Tuple of (M, N) for weight block dimensions
output_dtype: Desired output dtype
"""
batch_size, seq_len, hidden_dim = input_q.shape if input_q.ndim == 3 else (1, input_q.shape[0], input_q.shape[1])
out_features = weight_q.shape[0]
# Reshape input for batched matmul
input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
input_scale_reshaped = input_scale.view(input_scale.shape[0], -1) # [batch*seq_len, 1]
# Calculate number of blocks
num_weight_blocks_m = out_features // block_size[0]
num_weight_blocks_n = hidden_dim // block_size[1]
output = torch.zeros((batch_size * seq_len, out_features), dtype=torch.float32, device=input_q.device)
for i in range(num_weight_blocks_m):
m_start = i * block_size[0]
m_end = m_start + block_size[0]
for j in range(num_weight_blocks_n):
n_start = j * block_size[1]
n_end = n_start + block_size[1]
# Extract current blocks
input_block = input_reshaped[:, n_start:n_end]
weight_block = weight_q[m_start:m_end, n_start:n_end]
# Get corresponding scales
curr_input_scale = input_scale_reshaped[:, j : j + 1] # [batch*seq_len, 1]
curr_weight_scale = weight_scale[i, j] # scalar
block_result = (
torch._scaled_mm(
input_block,
weight_block.t(),
scale_a=torch.tensor(1, dtype=torch.float32, device=input_q.device),
scale_b=curr_weight_scale,
out_dtype=output_dtype,
)
* curr_input_scale
)
output[:, m_start:m_end] += block_result
output = output.view(batch_size, seq_len, out_features)
return output.to(output_dtype)
class FP8Linear(nn.Module):
dtype = torch.float8_e4m3fn
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
dtype=None,
block_size: Optional[Tuple[int, int]] = None,
device=None,
activation_scheme="dynamic",
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.register_buffer(
"weight_scale_inv", torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
)
self.block_size = block_size
self.activation_scheme = activation_scheme
if bias:
self.bias = nn.Parameter(torch.empty(self.out_features))
else:
self.register_parameter("bias", None)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight.element_size() > 1:
return F.linear(input, self.weight, self.bias)
else:
# Context manager used to switch among the available cuda devices
with torch.cuda.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch.cuda.synchronize(device=input.device)
with torch.cuda.device(input.device):
output = w8a8_block_fp8_matmul_triton(
qinput,
self.weight,
scale,
self.weight_scale_inv,
self.block_size,
output_dtype=input.dtype,
)
torch.cuda.synchronize(device=input.device)
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
def _replace_with_fp8_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
"""Replace Linear layers with FP8Linear."""
if current_key_name is None:
current_key_name = []
for name, module in model.named_children():
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
current_key_name_str = ".".join(current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
model._modules[name] = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fp8_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_fp8_linear(
model,
modules_to_not_convert=None,
quantization_config=None,
):
"""Helper function to replace model layers with FP8 versions."""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fp8_linear(
model,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using fp8 but no linear modules were found in your model."
" Please double check your model architecture."
)
return model

View File

@ -24,6 +24,7 @@ from ..utils.quantization_config import (
CompressedTensorsConfig,
EetqConfig,
FbgemmFp8Config,
FineGrainedFP8Config,
GPTQConfig,
HiggsConfig,
HqqConfig,
@ -41,6 +42,7 @@ from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
@ -64,6 +66,7 @@ AUTO_QUANTIZER_MAPPING = {
"torchao": TorchAoHfQuantizer,
"bitnet": BitNetHfQuantizer,
"vptq": VptqHfQuantizer,
"fp8": FineGrainedFP8HfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@ -81,6 +84,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"torchao": TorchAoConfig,
"bitnet": BitNetConfig,
"vptq": VptqConfig,
"fp8": FineGrainedFP8Config,
}
logger = logging.get_logger(__name__)

View File

@ -0,0 +1,207 @@
import importlib
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from packaging import version
from ..utils import is_accelerate_available, is_torch_available, logging
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
if is_torch_available():
import torch
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
logger = logging.get_logger(__name__)
class FineGrainedFP8HfQuantizer(HfQuantizer):
"""
FP8 quantization implementation supporting both standard and MoE models.
Supports both e4m3fn formats based on platform.
"""
requires_parameters_quantization = True
requires_calibration = False
required_packages = ["accelerate"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_torch_available() or version.parse(importlib.metadata.version("torch")) < version.parse("2.1.0"):
raise ImportError(
"Using fp8 quantization requires torch >= 2.1.0"
"Please install the latest version of torch ( pip install --upgrade torch )"
)
if not is_accelerate_available():
raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
raise ValueError(
"Converting into FP8 weights from tf/flax weights is currently not supported, "
"please make sure the weights are in PyTorch format."
)
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.")
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
if major < 9:
raise ValueError(
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
)
device_map = kwargs.get("device_map", None)
if device_map is None:
logger.warning_once(
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
)
elif device_map is not None:
if (
not self.pre_quantized
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
"This is not supported when the model is quantized on the fly. "
"Please use a quantized checkpoint or remove the cpu/disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained")
torch_dtype = torch.float32
return torch_dtype
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
):
"""
Quantizes weights to FP8 format using Block-wise quantization
"""
from accelerate.utils import set_module_tensor_to_device
set_module_tensor_to_device(model, param_name, target_device, param_value)
module, tensor_name = get_module_from_name(model, param_name)
# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max
block_size_m, block_size_n = self.quantization_config.weight_block_size
rows, cols = param_value.shape[-2:]
if rows % block_size_m != 0 or cols % block_size_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
)
param_value_orig_shape = param_value.shape
param_value = param_value.reshape(
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
).permute(0, 1, 3, 2, 4)
# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)
# Reshape scale to match the number of blocks
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
module._buffers[tensor_name] = quantized_param.to(target_device)
module._buffers["weight_scale_inv"] = scale.to(target_device)
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
from ..integrations.finegrained_fp8 import FP8Linear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale_inv":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
modules_to_not_convert: List[str] = [],
**kwargs,
):
from ..integrations.finegrained_fp8 import replace_with_fp8_linear
self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert
if self.quantization_config.modules_to_not_convert:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
model = replace_with_fp8_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import FP8Linear
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, FP8Linear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
def is_serializable(self, safe_serialization=None):
return True
@property
def is_trainable(self) -> bool:
return False

View File

@ -21,7 +21,7 @@ import os
from dataclasses import dataclass
from enum import Enum
from inspect import Parameter, signature
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from packaging import version
@ -56,6 +56,7 @@ class QuantizationMethod(str, Enum):
FBGEMM_FP8 = "fbgemm_fp8"
TORCHAO = "torchao"
BITNET = "bitnet"
FP8 = "fp8"
class AWQLinearVersion(str, Enum):
@ -1548,3 +1549,43 @@ class BitNetConfig(QuantizationConfigMixin):
Safety checker that arguments are correct
"""
pass
@dataclass
class FineGrainedFP8Config(QuantizationConfigMixin):
"""
FineGrainedFP8Config is a configuration class for fine-grained FP8 quantization used mainly for deepseek models.
Args:
activation_scheme (`str`, *optional*, defaults to `"dynamic"`):
The scheme used for activation, the defaults and only support scheme for now is "dynamic".
weight_block_size (`typing.Tuple[int, int]`, *optional*, defaults to `(128, 128)`):
The size of the weight blocks for quantization, default is (128, 128).
modules_to_not_convert (`list`, *optional*):
A list of module names that should not be converted during quantization.
"""
def __init__(
self,
activation_scheme: str = "dynamic",
weight_block_size: Tuple[int, int] = (128, 128),
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.FP8
self.modules_to_not_convert = modules_to_not_convert
self.activation_scheme = activation_scheme
self.weight_block_size = weight_block_size
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct
"""
self.activation_scheme = self.activation_scheme.lower()
if self.activation_scheme not in ["dynamic"]:
raise ValueError(f"Activation scheme {self.activation_scheme} not supported")
if len(self.weight_block_size) != 2:
raise ValueError("weight_block_size must be a tuple of two integers")
if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0:
raise ValueError("weight_block_size must be a tuple of two positive integers")

View File

@ -0,0 +1,273 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
from transformers.testing_utils import (
require_accelerate,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class FineGrainedFP8ConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = FineGrainedFP8Config()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "fp8"}
quantization_config = FineGrainedFP8Config.from_dict(dict)
self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
self.assertEqual(dict["quant_method"], quantization_config.quant_method)
@slow
@require_accelerate
@require_torch_gpu
class FP8QuantizerTest(unittest.TestCase):
model_name = "meta-llama/Llama-3.2-1B"
input_text = "Once upon a time"
max_new_tokens = 10
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
device_map = "cuda"
offload_device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": 0,
"model.layers.2": 0,
"model.layers.3": 0,
"model.layers.4": 0,
"model.layers.5": 0,
"model.layers.6": 0,
"model.layers.7": "cpu",
"model.layers.8": "cpu",
"model.layers.9": "cpu",
"model.layers.10": "cpu",
"model.layers.11": "cpu",
"model.layers.12": "cpu",
"model.layers.13": "cpu",
"model.layers.14": "cpu",
"model.layers.15": "cpu",
"model.rotary_emb": "disk",
"model.norm": "disk",
"lm_head": 0,
}
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
cls.quantization_config = FineGrainedFP8Config()
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=cls.quantization_config
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from transformers.integrations import FP8Linear, replace_with_fp8_linear
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
quantization_config = FineGrainedFP8Config()
with init_empty_weights():
model = OPTForCausalLM(config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model = replace_with_fp8_linear(model, quantization_config=quantization_config)
nb_fp8_linear = 0
for module in model.modules():
if isinstance(module, FP8Linear):
nb_fp8_linear += 1
self.assertEqual(nb_linears - 1, nb_fp8_linear)
with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = FineGrainedFP8Config(modules_to_not_convert=["fc1"])
model = replace_with_fp8_linear(model, quantization_config=quantization_config)
nb_fp8_linear = 0
for module in model.modules():
if isinstance(module, FP8Linear):
nb_fp8_linear += 1
self.assertEqual(nb_linears - 25, nb_fp8_linear)
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_weight_and_weight_scale_inv(self):
"""
Simple test that checks if the weight and weight_scale_inv are working properly
"""
weight = self.quantized_model.model.layers[0].self_attn.q_proj.weight
weight_scale_inv = self.quantized_model.model.layers[0].self_attn.q_proj.weight_scale_inv
self.assertEqual(weight.dtype, torch.float8_e4m3fn)
self.assertEqual(weight_scale_inv.dtype, torch.float32)
self.assertEqual(weight.shape, (weight_scale_inv.shape[0] * 128, weight_scale_inv.shape[1] * 128))
def test_block_size(self):
"""
Simple test that checks if the block size is working properly
"""
self.assertEqual(self.quantized_model.config.quantization_config.weight_block_size, (128, 128))
quantization_config = FineGrainedFP8Config(weight_block_size=(32, 32))
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map=self.device_map, quantization_config=quantization_config
)
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
quantization_config = FineGrainedFP8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
self.assertTrue(set(model.hf_device_map.values()) == {0, 1})
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_quantized_model_offload(self):
"""
Simple test that checks if the quantized model returns an error when loading with cpu/disk offloaded
"""
with self.assertRaisesRegex(
ValueError, "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
):
AutoModelForCausalLM.from_pretrained(
self.model_name, device_map=self.offload_device_map, quantization_config=self.quantization_config
)
def test_save_pretrained_offload(self):
"""
Simple test that checks if the saved quantized model is working properly cpu/disk offload
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_gpu
class FP8LinearTest(unittest.TestCase):
device = "cuda"
def test_linear_preserves_shape(self):
"""
Test that FP8Linear preserves shape when in_features == out_features.
"""
from transformers.integrations import FP8Linear
linear = FP8Linear(256, 256, block_size=(128, 128), device=self.device)
x = torch.rand((1, 5, 256)).to(self.device)
x_ = linear(x)
self.assertEqual(x_.shape, x.shape)
def test_linear_with_diff_feature_size_preserves_shape(self):
"""
Test that FP8Linear generates the correct shape when in_features != out_features.
"""
from transformers.integrations import FP8Linear
linear = FP8Linear(128, 256, block_size=(128, 128), device=self.device)
x = torch.rand((1, 5, 128)).to(self.device)
x_ = linear(x)
self.assertEqual(x_.shape, (1, 5, 256))