mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Chore: Remove unnecessary LayerNorm, NormLayer layer abstractions.
This commit is contained in:
parent
1bde78a312
commit
7c239aaa1b
@ -29,9 +29,8 @@ from ...generation.configuration_utils import xLSTMCache
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_xlstm_available,
|
||||
)
|
||||
from .configuration_xlstm import xLSTMConfig
|
||||
@ -42,7 +41,6 @@ if is_xlstm_available():
|
||||
|
||||
external_xlstm = True
|
||||
else:
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Callable, Literal
|
||||
|
||||
@ -679,13 +677,9 @@ else:
|
||||
DHHV = v.shape[-1]
|
||||
|
||||
c_state = (
|
||||
c_initial
|
||||
if c_initial is not None
|
||||
else torch.zeros(B, NH, DHQK, DHHV, device=k.device, dtype=torch.float32)
|
||||
)
|
||||
n_state = (
|
||||
n_initial if n_initial is not None else torch.zeros(B, NH, DHQK, device=k.device, dtype=torch.float32)
|
||||
c_initial if c_initial is not None else torch.zeros(B, NH, DHQK, DHHV, device=k.device, dtype=torch.float32)
|
||||
)
|
||||
n_state = n_initial if n_initial is not None else torch.zeros(B, NH, DHQK, device=k.device, dtype=torch.float32)
|
||||
m_state = m_initial if m_initial is not None else torch.zeros(B, NH, 1, device=k.device, dtype=torch.float32)
|
||||
|
||||
if S > 1:
|
||||
@ -878,9 +872,11 @@ else:
|
||||
def extra_repr(self) -> str:
|
||||
return f"{self.config}"
|
||||
|
||||
class NormLayer(nn.Module, ABC):
|
||||
"""Base class for normalization layers.
|
||||
This class contains optional learnable weight and bias parameters.
|
||||
class RMSNorm(nn.Module):
|
||||
"""Root mean square normalization layer implementation similar
|
||||
to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
|
||||
|
||||
It normalizes the input tensor by the root mean square of the last dimension.
|
||||
|
||||
Args:
|
||||
num_features: The number of features in the input tensor.
|
||||
@ -920,24 +916,6 @@ else:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
class RMSNorm(NormLayer):
|
||||
"""Root mean square normalization layer implementation similar
|
||||
to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
|
||||
|
||||
It normalizes the input tensor by the root mean square of the last dimension.
|
||||
|
||||
Args:
|
||||
num_features: The number of features in the input tensor.
|
||||
eps: A small value to avoid division by zero.
|
||||
use_weight: Whether to use a learnable weight.
|
||||
use_bias: Whether to use a learnable bias.
|
||||
force_float32_reductions: Whether to force float32 reductions.
|
||||
"""
|
||||
|
||||
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, ..., S,..., D)
|
||||
# apply rms norm over the last dimension, i.e. D dimension
|
||||
@ -953,40 +931,7 @@ else:
|
||||
x = self._apply_weight_bias(x)
|
||||
return x
|
||||
|
||||
class LayerNorm(NormLayer):
|
||||
"""Layer normalization layer implementation similar to
|
||||
https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.
|
||||
|
||||
The layer normalization is applied over the last dimension of the input tensor.
|
||||
|
||||
Args:
|
||||
num_features: The number of features in the input tensor.
|
||||
eps: A small value to avoid division by zero.
|
||||
use_weight: Whether to use a learnable weight.
|
||||
use_bias: Whether to use a learnable bias.
|
||||
force_float32_reductions: Whether to force float32 reductions.
|
||||
|
||||
Returns:
|
||||
The normalized tensor.
|
||||
"""
|
||||
|
||||
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, ..., S,..., D)
|
||||
# apply layer norm over the last dimension, i.e. D dimension
|
||||
in_dtype = x.dtype
|
||||
if self.force_float32_reductions:
|
||||
x = x.float()
|
||||
x_centered = x - x.mean(dim=-1, keepdim=True)
|
||||
y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
|
||||
return y.to(in_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, ..., S,..., D)
|
||||
x = self._layer_normalize(x)
|
||||
x = self._apply_weight_bias(x)
|
||||
return x
|
||||
|
||||
class MultiHeadLayerNorm(LayerNorm):
|
||||
class MultiHeadLayerNorm(nn.Module):
|
||||
"""Multi-head version of the LayerNorm layer.
|
||||
|
||||
It normalizes the last dimension of the input tensor.
|
||||
@ -1020,16 +965,40 @@ else:
|
||||
use_bias: bool = False,
|
||||
force_float32_reductions: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
num_features=num_heads * head_dim,
|
||||
eps=eps,
|
||||
use_weight=use_weight,
|
||||
use_bias=use_bias,
|
||||
force_float32_reductions=force_float32_reductions,
|
||||
)
|
||||
super().__init__()
|
||||
self.num_features = num_heads * head_dim
|
||||
self.eps = eps
|
||||
self.force_float32_reductions = force_float32_reductions
|
||||
|
||||
if use_weight:
|
||||
self.weight = nn.Parameter(torch.ones(self.num_features))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
if use_bias:
|
||||
self.bias = nn.Parameter(torch.zeros(self.num_features))
|
||||
else:
|
||||
self.bias = None
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.weight is not None:
|
||||
x = x * self.weight
|
||||
if self.bias is not None:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, ..., S,..., D)
|
||||
# apply layer norm over the last dimension, i.e. D dimension
|
||||
in_dtype = x.dtype
|
||||
if self.force_float32_reductions:
|
||||
x = x.float()
|
||||
x_centered = x - x.mean(dim=-1, keepdim=True)
|
||||
y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
|
||||
return y.to(in_dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (B, S, NH, DH)
|
||||
@ -1274,163 +1243,6 @@ class xLSTMPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_is_stateful = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class xLSTMOutput(ModelOutput):
|
||||
"""
|
||||
Class for the xLSTM model outputs
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
cache_params (`xLSTMCache`):
|
||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||
avoid providing the old `input_ids`.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor]
|
||||
cache_params: Optional[xLSTMCache] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class xLSTMCausalLMOutput(ModelOutput):
|
||||
"""
|
||||
Base class for causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
cache_params (`xLSTMCache`):
|
||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||
avoid providing the old `input_ids`.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
cache_params: Optional[xLSTMCache] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
XLSTM_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`xLSTMConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
XLSTM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
|
||||
`input_ids`.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
cache_params (`xLSTMCache`, *optional*):
|
||||
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
||||
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
|
||||
If `cache_params` is passed, `cache_position` should also be passed.
|
||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
"""
|
||||
|
||||
|
||||
def small_init_method(dim):
|
||||
"""
|
||||
Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
||||
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
|
||||
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
|
||||
std = (2 / (5 * dim)) ** (1 / 2)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def wang_init_method(n_layers, dim):
|
||||
"""
|
||||
Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
||||
"""
|
||||
std = 2 / n_layers / dim ** (1 / 2)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare xLSTM Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XLSTM_START_DOCSTRING,
|
||||
)
|
||||
class xLSTMModel(xLSTMPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# use embbeding_dim and num_blocks once here to make use of them
|
||||
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
|
||||
|
||||
self.blocks = nn.ModuleList([mLSTMBlock(config) for _ in range(config.num_blocks)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# actually unused, but needed in external integration
|
||||
_ = (
|
||||
config.add_out_norm,
|
||||
config.tie_word_embeddings,
|
||||
config.chunkwise_kernel,
|
||||
config.sequence_kernel,
|
||||
config.step_kernel,
|
||||
)
|
||||
|
||||
self.out_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _init_weights(self, module):
|
||||
if self is not module:
|
||||
if isinstance(module, torch.nn.Embedding):
|
||||
@ -1518,18 +1330,113 @@ class xLSTMModel(xLSTMPreTrainedModel):
|
||||
torch.nn.init.zeros_(block.mlstm_layer.multihead_norm.bias)
|
||||
torch.nn.init.zeros_(block.mlstm_layer.out_proj.bias)
|
||||
|
||||
|
||||
@dataclass
|
||||
class xLSTMOutput(ModelOutput):
|
||||
"""
|
||||
Class for the xLSTM model outputs
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
cache_params (`xLSTMCache`):
|
||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||
avoid providing the old `input_ids`.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor]
|
||||
cache_params: Optional[xLSTMCache] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class xLSTMCausalLMOutput(ModelOutput):
|
||||
"""
|
||||
Base class for causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
cache_params (`xLSTMCache`):
|
||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||
avoid providing the old `input_ids`.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
cache_params: Optional[xLSTMCache] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
def small_init_method(dim):
|
||||
"""
|
||||
Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
||||
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
|
||||
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
|
||||
std = (2 / (5 * dim)) ** (1 / 2)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def wang_init_method(n_layers, dim):
|
||||
"""
|
||||
Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
||||
"""
|
||||
std = 2 / n_layers / dim ** (1 / 2)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class xLSTMModel(xLSTMPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# use embbeding_dim and num_blocks once here to make use of them
|
||||
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
|
||||
|
||||
self.blocks = nn.ModuleList([mLSTMBlock(config) for _ in range(config.num_blocks)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# actually unused, but needed in external integration
|
||||
_ = (
|
||||
config.add_out_norm,
|
||||
config.tie_word_embeddings,
|
||||
config.chunkwise_kernel,
|
||||
config.sequence_kernel,
|
||||
config.step_kernel,
|
||||
)
|
||||
|
||||
self.out_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def set_input_embeddings(self, new_embedding):
|
||||
self.embeddings = new_embedding
|
||||
|
||||
@add_start_docstrings_to_model_forward(XLSTM_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=xLSTMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -1639,13 +1546,7 @@ class xLSTMModel(xLSTMPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The xLSTM Model transformer with a language modeling head on top (linear layer with weights not tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
XLSTM_START_DOCSTRING,
|
||||
)
|
||||
@auto_docstring
|
||||
class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = []
|
||||
|
||||
@ -1720,12 +1621,8 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(XLSTM_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=xLSTMCausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
Loading…
Reference in New Issue
Block a user