Add optional RMSNorm support to BitNet quantization (config + layers) (#38087)

* enable optional RMS in BitLinear

* Fix naming

* Import RMS from Llama using config.*

* make fix-copies

* ran CI loop

* remove default BitNetQuantConfig values

* Fix BitNetQuantConfig to be Optional

* Fix config docstrings to match Optoinal

* Edit docstrings to match standards

---------

Co-authored-by: steinmetzc <codysteinmetz7@gmail.com>
Co-authored-by: codys12 <steinmetzc@dh-mgmt4.hpc.msoe.edu>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Codys12 2025-05-16 05:38:06 -05:00 committed by GitHub
parent 57a79f51b2
commit 1e921a3a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 3 deletions

View File

@ -1584,7 +1584,9 @@ class TikTokenConverter:
self.pattern = pattern self.pattern = pattern
self.add_prefix_space = add_prefix_space self.add_prefix_space = add_prefix_space
self.additional_special_tokens = ( self.additional_special_tokens = (
additional_special_tokens.keys() if type(additional_special_tokens) is dict else additional_special_tokens additional_special_tokens.keys()
if isinstance(additional_special_tokens, dict)
else additional_special_tokens
) )
def extract_vocab_merges_from_model(self, tiktoken_url: str): def extract_vocab_merges_from_model(self, tiktoken_url: str):

View File

@ -124,7 +124,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
class BitLinear(nn.Module): class BitLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None): def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
device=None,
dtype=None,
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
self.in_features = in_features self.in_features = in_features
@ -150,6 +159,13 @@ class BitLinear(nn.Module):
else: else:
self.bias = None self.bias = None
# Optional RMSNorm (applied on the activations before quantization).
self.rms_norm = None
if use_rms_norm:
from ..models.llama.modeling_llama import LlamaRMSNorm
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
@torch.compile @torch.compile
def activation_quant(self, input, num_bits=8): def activation_quant(self, input, num_bits=8):
""" """
@ -180,6 +196,10 @@ class BitLinear(nn.Module):
return out return out
def forward(self, input): def forward(self, input):
# Apply RMSNorm on the input if requested.
if self.rms_norm is not None:
input = self.rms_norm(input)
w = self.weight w = self.weight
w_quant = unpack_weights(w, dtype=self.dtype) w_quant = unpack_weights(w, dtype=self.dtype)
input_quant, input_scale = self.activation_quant(input) input_quant, input_scale = self.activation_quant(input)
@ -245,9 +265,17 @@ class AutoBitLinear(nn.Linear):
device=None, device=None,
dtype=None, dtype=None,
online_quant: bool = False, online_quant: bool = False,
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
): ):
super().__init__(in_features, out_features, bias) super().__init__(in_features, out_features, bias)
self.online_quant = online_quant self.online_quant = online_quant
# Optional RMSNorm
self.rms_norm = None
if use_rms_norm:
from ..models.llama.modeling_llama import LlamaRMSNorm
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
if not online_quant: if not online_quant:
self.register_buffer( self.register_buffer(
"weight_scale", "weight_scale",
@ -271,6 +299,10 @@ class AutoBitLinear(nn.Linear):
return state_dict return state_dict
def forward(self, input): def forward(self, input):
# Optional RMSNorm on activations prior to quantization.
if self.rms_norm is not None:
input = self.rms_norm(input)
if self.online_quant: if self.online_quant:
weight = WeightQuant.apply(self.weight) weight = WeightQuant.apply(self.weight)
else: else:
@ -318,6 +350,8 @@ def _replace_with_bitnet_linear(
device=module.weight.device, device=module.weight.device,
dtype=module.weight.dtype, dtype=module.weight.dtype,
online_quant=(quantization_config.quantization_mode == "online"), online_quant=(quantization_config.quantization_mode == "online"),
use_rms_norm=quantization_config.use_rms_norm,
rms_norm_eps=quantization_config.rms_norm_eps,
) )
if quantization_config.quantization_mode == "offline": if quantization_config.quantization_mode == "offline":
model._modules[name].requires_grad_(False) model._modules[name].requires_grad_(False)
@ -328,6 +362,8 @@ def _replace_with_bitnet_linear(
bias=module.bias is not None, bias=module.bias is not None,
device=module.weight.device, device=module.weight.device,
dtype=module.weight.dtype, dtype=module.weight.dtype,
use_rms_norm=quantization_config.use_rms_norm,
rms_norm_eps=quantization_config.rms_norm_eps,
) )
model._modules[name].requires_grad_(False) model._modules[name].requires_grad_(False)
has_been_replaced = True has_been_replaced = True
@ -363,7 +399,7 @@ def replace_with_bitnet_linear(
model (`torch.nn.Module`): model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively. Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons. for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*): current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of An array to track the current key of the recursion. This is used to check whether the current key (part of

View File

@ -1791,6 +1791,11 @@ class BitNetQuantConfig(QuantizationConfigMixin):
In `offline` mode, quantization parameters are pre-calculated *before* inference. In `offline` mode, quantization parameters are pre-calculated *before* inference.
These parameters are then fixed and loaded into the quantized model. This These parameters are then fixed and loaded into the quantized model. This
generally results in lower runtime overhead compared to online quantization. generally results in lower runtime overhead compared to online quantization.
use_rms_norm (`bool`, *optional*, defaults to `False`):
Whether to apply RMSNorm on the activations before quantization. This matches the original BitNet paper's approach
of normalizing activations before quantization/packing.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon value used in the RMSNorm layer for numerical stability.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments that may be used by specific quantization Additional keyword arguments that may be used by specific quantization
backends or future versions. backends or future versions.
@ -1801,6 +1806,8 @@ class BitNetQuantConfig(QuantizationConfigMixin):
modules_to_not_convert: Optional[List] = None, modules_to_not_convert: Optional[List] = None,
linear_class: Optional[str] = "bitlinear", linear_class: Optional[str] = "bitlinear",
quantization_mode: Optional[str] = "offline", quantization_mode: Optional[str] = "offline",
use_rms_norm: Optional[bool] = False,
rms_norm_eps: Optional[float] = 1e-6,
**kwargs, **kwargs,
): ):
if linear_class not in ["bitlinear", "autobitlinear"]: if linear_class not in ["bitlinear", "autobitlinear"]:
@ -1811,6 +1818,8 @@ class BitNetQuantConfig(QuantizationConfigMixin):
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
self.linear_class = linear_class self.linear_class = linear_class
self.quantization_mode = quantization_mode self.quantization_mode = quantization_mode
self.use_rms_norm = use_rms_norm
self.rms_norm_eps = rms_norm_eps
self.post_init() self.post_init()
def post_init(self): def post_init(self):