Chore: Remove unnecessary LayerNorm, NormLayer layer abstractions.

This commit is contained in:
Korbinian Poeppel 2025-07-01 15:37:48 +02:00
parent 1bde78a312
commit 7c239aaa1b

View File

@ -29,9 +29,8 @@ from ...generation.configuration_utils import xLSTMCache
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
auto_docstring,
can_return_tuple,
is_xlstm_available,
)
from .configuration_xlstm import xLSTMConfig
@ -42,7 +41,6 @@ if is_xlstm_available():
external_xlstm = True
else:
from abc import ABC, abstractmethod
from functools import partial
from typing import Callable, Literal
@ -679,13 +677,9 @@ else:
DHHV = v.shape[-1]
c_state = (
c_initial
if c_initial is not None
else torch.zeros(B, NH, DHQK, DHHV, device=k.device, dtype=torch.float32)
)
n_state = (
n_initial if n_initial is not None else torch.zeros(B, NH, DHQK, device=k.device, dtype=torch.float32)
c_initial if c_initial is not None else torch.zeros(B, NH, DHQK, DHHV, device=k.device, dtype=torch.float32)
)
n_state = n_initial if n_initial is not None else torch.zeros(B, NH, DHQK, device=k.device, dtype=torch.float32)
m_state = m_initial if m_initial is not None else torch.zeros(B, NH, 1, device=k.device, dtype=torch.float32)
if S > 1:
@ -878,9 +872,11 @@ else:
def extra_repr(self) -> str:
return f"{self.config}"
class NormLayer(nn.Module, ABC):
"""Base class for normalization layers.
This class contains optional learnable weight and bias parameters.
class RMSNorm(nn.Module):
"""Root mean square normalization layer implementation similar
to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
It normalizes the input tensor by the root mean square of the last dimension.
Args:
num_features: The number of features in the input tensor.
@ -920,24 +916,6 @@ else:
x = x + self.bias
return x
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class RMSNorm(NormLayer):
"""Root mean square normalization layer implementation similar
to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
It normalizes the input tensor by the root mean square of the last dimension.
Args:
num_features: The number of features in the input tensor.
eps: A small value to avoid division by zero.
use_weight: Whether to use a learnable weight.
use_bias: Whether to use a learnable bias.
force_float32_reductions: Whether to force float32 reductions.
"""
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, ..., S,..., D)
# apply rms norm over the last dimension, i.e. D dimension
@ -953,40 +931,7 @@ else:
x = self._apply_weight_bias(x)
return x
class LayerNorm(NormLayer):
"""Layer normalization layer implementation similar to
https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.
The layer normalization is applied over the last dimension of the input tensor.
Args:
num_features: The number of features in the input tensor.
eps: A small value to avoid division by zero.
use_weight: Whether to use a learnable weight.
use_bias: Whether to use a learnable bias.
force_float32_reductions: Whether to force float32 reductions.
Returns:
The normalized tensor.
"""
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, ..., S,..., D)
# apply layer norm over the last dimension, i.e. D dimension
in_dtype = x.dtype
if self.force_float32_reductions:
x = x.float()
x_centered = x - x.mean(dim=-1, keepdim=True)
y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
return y.to(in_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, ..., S,..., D)
x = self._layer_normalize(x)
x = self._apply_weight_bias(x)
return x
class MultiHeadLayerNorm(LayerNorm):
class MultiHeadLayerNorm(nn.Module):
"""Multi-head version of the LayerNorm layer.
It normalizes the last dimension of the input tensor.
@ -1020,16 +965,40 @@ else:
use_bias: bool = False,
force_float32_reductions: bool = True,
):
super().__init__(
num_features=num_heads * head_dim,
eps=eps,
use_weight=use_weight,
use_bias=use_bias,
force_float32_reductions=force_float32_reductions,
)
super().__init__()
self.num_features = num_heads * head_dim
self.eps = eps
self.force_float32_reductions = force_float32_reductions
if use_weight:
self.weight = nn.Parameter(torch.ones(self.num_features))
else:
self.weight = None
if use_bias:
self.bias = nn.Parameter(torch.zeros(self.num_features))
else:
self.bias = None
self.num_heads = num_heads
self.head_dim = head_dim
def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
if self.weight is not None:
x = x * self.weight
if self.bias is not None:
x = x + self.bias
return x
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, ..., S,..., D)
# apply layer norm over the last dimension, i.e. D dimension
in_dtype = x.dtype
if self.force_float32_reductions:
x = x.float()
x_centered = x - x.mean(dim=-1, keepdim=True)
y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
return y.to(in_dtype)
def forward(
self,
x: torch.Tensor, # (B, S, NH, DH)
@ -1274,163 +1243,6 @@ class xLSTMPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_is_stateful = True
@dataclass
class xLSTMOutput(ModelOutput):
"""
Class for the xLSTM model outputs
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
cache_params (`xLSTMCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: Optional[torch.FloatTensor]
cache_params: Optional[xLSTMCache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
@dataclass
class xLSTMCausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (`xLSTMCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
cache_params: Optional[xLSTMCache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
XLSTM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`xLSTMConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
XLSTM_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
Indices of input sequence tokens in the vocabulary.
If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
`input_ids`.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
cache_params (`xLSTMCache`, *optional*):
If passed along, the model uses the previous state in all the blocks (which will give the output for the
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
use_cache (`bool`, *optional*):
If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
If `cache_params` is passed, `cache_position` should also be passed.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
def small_init_method(dim):
"""
Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
std = (2 / (5 * dim)) ** (1 / 2)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def wang_init_method(n_layers, dim):
"""
Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
"""
std = 2 / n_layers / dim ** (1 / 2)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
@add_start_docstrings(
"The bare xLSTM Model transformer outputting raw hidden-states without any specific head on top.",
XLSTM_START_DOCSTRING,
)
class xLSTMModel(xLSTMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# use embbeding_dim and num_blocks once here to make use of them
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
self.blocks = nn.ModuleList([mLSTMBlock(config) for _ in range(config.num_blocks)])
self.gradient_checkpointing = False
# actually unused, but needed in external integration
_ = (
config.add_out_norm,
config.tie_word_embeddings,
config.chunkwise_kernel,
config.sequence_kernel,
config.step_kernel,
)
self.out_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
# Initialize weights and apply final processing
self.post_init()
def _init_weights(self, module):
if self is not module:
if isinstance(module, torch.nn.Embedding):
@ -1518,18 +1330,113 @@ class xLSTMModel(xLSTMPreTrainedModel):
torch.nn.init.zeros_(block.mlstm_layer.multihead_norm.bias)
torch.nn.init.zeros_(block.mlstm_layer.out_proj.bias)
@dataclass
class xLSTMOutput(ModelOutput):
"""
Class for the xLSTM model outputs
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
cache_params (`xLSTMCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: Optional[torch.FloatTensor]
cache_params: Optional[xLSTMCache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
@dataclass
class xLSTMCausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (`xLSTMCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
cache_params: Optional[xLSTMCache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
def small_init_method(dim):
"""
Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
std = (2 / (5 * dim)) ** (1 / 2)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def wang_init_method(n_layers, dim):
"""
Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
"""
std = 2 / n_layers / dim ** (1 / 2)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
@auto_docstring
class xLSTMModel(xLSTMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# use embbeding_dim and num_blocks once here to make use of them
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
self.blocks = nn.ModuleList([mLSTMBlock(config) for _ in range(config.num_blocks)])
self.gradient_checkpointing = False
# actually unused, but needed in external integration
_ = (
config.add_out_norm,
config.tie_word_embeddings,
config.chunkwise_kernel,
config.sequence_kernel,
config.step_kernel,
)
self.out_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, new_embedding):
self.embeddings = new_embedding
@add_start_docstrings_to_model_forward(XLSTM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=xLSTMOutput,
config_class=_CONFIG_FOR_DOC,
)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -1639,13 +1546,7 @@ class xLSTMModel(xLSTMPreTrainedModel):
)
@add_start_docstrings(
"""
The xLSTM Model transformer with a language modeling head on top (linear layer with weights not tied to the input
embeddings).
""",
XLSTM_START_DOCSTRING,
)
@auto_docstring
class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
_tied_weights_keys = []
@ -1720,12 +1621,8 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
)
return model_inputs
@add_start_docstrings_to_model_forward(XLSTM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=xLSTMCausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,