mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 18:48:24 +06:00
Add optional RMSNorm support to BitNet quantization (config + layers) (#38087)
* enable optional RMS in BitLinear * Fix naming * Import RMS from Llama using config.* * make fix-copies * ran CI loop * remove default BitNetQuantConfig values * Fix BitNetQuantConfig to be Optional * Fix config docstrings to match Optoinal * Edit docstrings to match standards --------- Co-authored-by: steinmetzc <codysteinmetz7@gmail.com> Co-authored-by: codys12 <steinmetzc@dh-mgmt4.hpc.msoe.edu> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
parent
57a79f51b2
commit
1e921a3a9c
@ -1584,7 +1584,9 @@ class TikTokenConverter:
|
|||||||
self.pattern = pattern
|
self.pattern = pattern
|
||||||
self.add_prefix_space = add_prefix_space
|
self.add_prefix_space = add_prefix_space
|
||||||
self.additional_special_tokens = (
|
self.additional_special_tokens = (
|
||||||
additional_special_tokens.keys() if type(additional_special_tokens) is dict else additional_special_tokens
|
additional_special_tokens.keys()
|
||||||
|
if isinstance(additional_special_tokens, dict)
|
||||||
|
else additional_special_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
def extract_vocab_merges_from_model(self, tiktoken_url: str):
|
def extract_vocab_merges_from_model(self, tiktoken_url: str):
|
||||||
|
@ -124,7 +124,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
class BitLinear(nn.Module):
|
class BitLinear(nn.Module):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
use_rms_norm: bool = False,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
@ -150,6 +159,13 @@ class BitLinear(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
# Optional RMSNorm (applied on the activations before quantization).
|
||||||
|
self.rms_norm = None
|
||||||
|
if use_rms_norm:
|
||||||
|
from ..models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
|
||||||
|
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
def activation_quant(self, input, num_bits=8):
|
def activation_quant(self, input, num_bits=8):
|
||||||
"""
|
"""
|
||||||
@ -180,6 +196,10 @@ class BitLinear(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
# Apply RMSNorm on the input if requested.
|
||||||
|
if self.rms_norm is not None:
|
||||||
|
input = self.rms_norm(input)
|
||||||
|
|
||||||
w = self.weight
|
w = self.weight
|
||||||
w_quant = unpack_weights(w, dtype=self.dtype)
|
w_quant = unpack_weights(w, dtype=self.dtype)
|
||||||
input_quant, input_scale = self.activation_quant(input)
|
input_quant, input_scale = self.activation_quant(input)
|
||||||
@ -245,9 +265,17 @@ class AutoBitLinear(nn.Linear):
|
|||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
online_quant: bool = False,
|
online_quant: bool = False,
|
||||||
|
use_rms_norm: bool = False,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
):
|
):
|
||||||
super().__init__(in_features, out_features, bias)
|
super().__init__(in_features, out_features, bias)
|
||||||
self.online_quant = online_quant
|
self.online_quant = online_quant
|
||||||
|
# Optional RMSNorm
|
||||||
|
self.rms_norm = None
|
||||||
|
if use_rms_norm:
|
||||||
|
from ..models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
|
||||||
|
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
|
||||||
if not online_quant:
|
if not online_quant:
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"weight_scale",
|
"weight_scale",
|
||||||
@ -271,6 +299,10 @@ class AutoBitLinear(nn.Linear):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
# Optional RMSNorm on activations prior to quantization.
|
||||||
|
if self.rms_norm is not None:
|
||||||
|
input = self.rms_norm(input)
|
||||||
|
|
||||||
if self.online_quant:
|
if self.online_quant:
|
||||||
weight = WeightQuant.apply(self.weight)
|
weight = WeightQuant.apply(self.weight)
|
||||||
else:
|
else:
|
||||||
@ -318,6 +350,8 @@ def _replace_with_bitnet_linear(
|
|||||||
device=module.weight.device,
|
device=module.weight.device,
|
||||||
dtype=module.weight.dtype,
|
dtype=module.weight.dtype,
|
||||||
online_quant=(quantization_config.quantization_mode == "online"),
|
online_quant=(quantization_config.quantization_mode == "online"),
|
||||||
|
use_rms_norm=quantization_config.use_rms_norm,
|
||||||
|
rms_norm_eps=quantization_config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
if quantization_config.quantization_mode == "offline":
|
if quantization_config.quantization_mode == "offline":
|
||||||
model._modules[name].requires_grad_(False)
|
model._modules[name].requires_grad_(False)
|
||||||
@ -328,6 +362,8 @@ def _replace_with_bitnet_linear(
|
|||||||
bias=module.bias is not None,
|
bias=module.bias is not None,
|
||||||
device=module.weight.device,
|
device=module.weight.device,
|
||||||
dtype=module.weight.dtype,
|
dtype=module.weight.dtype,
|
||||||
|
use_rms_norm=quantization_config.use_rms_norm,
|
||||||
|
rms_norm_eps=quantization_config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
model._modules[name].requires_grad_(False)
|
model._modules[name].requires_grad_(False)
|
||||||
has_been_replaced = True
|
has_been_replaced = True
|
||||||
@ -363,7 +399,7 @@ def replace_with_bitnet_linear(
|
|||||||
model (`torch.nn.Module`):
|
model (`torch.nn.Module`):
|
||||||
Input model or `torch.nn.Module` as the function is run recursively.
|
Input model or `torch.nn.Module` as the function is run recursively.
|
||||||
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
|
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
|
||||||
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
|
Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
|
||||||
for numerical stability reasons.
|
for numerical stability reasons.
|
||||||
current_key_name (`List[`str`]`, *optional*):
|
current_key_name (`List[`str`]`, *optional*):
|
||||||
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||||
|
@ -1791,6 +1791,11 @@ class BitNetQuantConfig(QuantizationConfigMixin):
|
|||||||
In `offline` mode, quantization parameters are pre-calculated *before* inference.
|
In `offline` mode, quantization parameters are pre-calculated *before* inference.
|
||||||
These parameters are then fixed and loaded into the quantized model. This
|
These parameters are then fixed and loaded into the quantized model. This
|
||||||
generally results in lower runtime overhead compared to online quantization.
|
generally results in lower runtime overhead compared to online quantization.
|
||||||
|
use_rms_norm (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to apply RMSNorm on the activations before quantization. This matches the original BitNet paper's approach
|
||||||
|
of normalizing activations before quantization/packing.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon value used in the RMSNorm layer for numerical stability.
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
Additional keyword arguments that may be used by specific quantization
|
Additional keyword arguments that may be used by specific quantization
|
||||||
backends or future versions.
|
backends or future versions.
|
||||||
@ -1801,6 +1806,8 @@ class BitNetQuantConfig(QuantizationConfigMixin):
|
|||||||
modules_to_not_convert: Optional[List] = None,
|
modules_to_not_convert: Optional[List] = None,
|
||||||
linear_class: Optional[str] = "bitlinear",
|
linear_class: Optional[str] = "bitlinear",
|
||||||
quantization_mode: Optional[str] = "offline",
|
quantization_mode: Optional[str] = "offline",
|
||||||
|
use_rms_norm: Optional[bool] = False,
|
||||||
|
rms_norm_eps: Optional[float] = 1e-6,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if linear_class not in ["bitlinear", "autobitlinear"]:
|
if linear_class not in ["bitlinear", "autobitlinear"]:
|
||||||
@ -1811,6 +1818,8 @@ class BitNetQuantConfig(QuantizationConfigMixin):
|
|||||||
self.modules_to_not_convert = modules_to_not_convert
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
self.linear_class = linear_class
|
self.linear_class = linear_class
|
||||||
self.quantization_mode = quantization_mode
|
self.quantization_mode = quantization_mode
|
||||||
|
self.use_rms_norm = use_rms_norm
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user