mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 02:28:24 +06:00
Chore: Remove unnecessary LayerNorm, NormLayer layer abstractions.
This commit is contained in:
parent
1bde78a312
commit
7c239aaa1b
@ -29,9 +29,8 @@ from ...generation.configuration_utils import xLSTMCache
|
|||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_code_sample_docstrings,
|
auto_docstring,
|
||||||
add_start_docstrings,
|
can_return_tuple,
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_xlstm_available,
|
is_xlstm_available,
|
||||||
)
|
)
|
||||||
from .configuration_xlstm import xLSTMConfig
|
from .configuration_xlstm import xLSTMConfig
|
||||||
@ -42,7 +41,6 @@ if is_xlstm_available():
|
|||||||
|
|
||||||
external_xlstm = True
|
external_xlstm = True
|
||||||
else:
|
else:
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Literal
|
from typing import Callable, Literal
|
||||||
|
|
||||||
@ -679,13 +677,9 @@ else:
|
|||||||
DHHV = v.shape[-1]
|
DHHV = v.shape[-1]
|
||||||
|
|
||||||
c_state = (
|
c_state = (
|
||||||
c_initial
|
c_initial if c_initial is not None else torch.zeros(B, NH, DHQK, DHHV, device=k.device, dtype=torch.float32)
|
||||||
if c_initial is not None
|
|
||||||
else torch.zeros(B, NH, DHQK, DHHV, device=k.device, dtype=torch.float32)
|
|
||||||
)
|
|
||||||
n_state = (
|
|
||||||
n_initial if n_initial is not None else torch.zeros(B, NH, DHQK, device=k.device, dtype=torch.float32)
|
|
||||||
)
|
)
|
||||||
|
n_state = n_initial if n_initial is not None else torch.zeros(B, NH, DHQK, device=k.device, dtype=torch.float32)
|
||||||
m_state = m_initial if m_initial is not None else torch.zeros(B, NH, 1, device=k.device, dtype=torch.float32)
|
m_state = m_initial if m_initial is not None else torch.zeros(B, NH, 1, device=k.device, dtype=torch.float32)
|
||||||
|
|
||||||
if S > 1:
|
if S > 1:
|
||||||
@ -878,9 +872,11 @@ else:
|
|||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return f"{self.config}"
|
return f"{self.config}"
|
||||||
|
|
||||||
class NormLayer(nn.Module, ABC):
|
class RMSNorm(nn.Module):
|
||||||
"""Base class for normalization layers.
|
"""Root mean square normalization layer implementation similar
|
||||||
This class contains optional learnable weight and bias parameters.
|
to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
|
||||||
|
|
||||||
|
It normalizes the input tensor by the root mean square of the last dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_features: The number of features in the input tensor.
|
num_features: The number of features in the input tensor.
|
||||||
@ -920,24 +916,6 @@ else:
|
|||||||
x = x + self.bias
|
x = x + self.bias
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class RMSNorm(NormLayer):
|
|
||||||
"""Root mean square normalization layer implementation similar
|
|
||||||
to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
|
|
||||||
|
|
||||||
It normalizes the input tensor by the root mean square of the last dimension.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_features: The number of features in the input tensor.
|
|
||||||
eps: A small value to avoid division by zero.
|
|
||||||
use_weight: Whether to use a learnable weight.
|
|
||||||
use_bias: Whether to use a learnable bias.
|
|
||||||
force_float32_reductions: Whether to force float32 reductions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# x: (B, ..., S,..., D)
|
# x: (B, ..., S,..., D)
|
||||||
# apply rms norm over the last dimension, i.e. D dimension
|
# apply rms norm over the last dimension, i.e. D dimension
|
||||||
@ -953,40 +931,7 @@ else:
|
|||||||
x = self._apply_weight_bias(x)
|
x = self._apply_weight_bias(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class LayerNorm(NormLayer):
|
class MultiHeadLayerNorm(nn.Module):
|
||||||
"""Layer normalization layer implementation similar to
|
|
||||||
https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.
|
|
||||||
|
|
||||||
The layer normalization is applied over the last dimension of the input tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_features: The number of features in the input tensor.
|
|
||||||
eps: A small value to avoid division by zero.
|
|
||||||
use_weight: Whether to use a learnable weight.
|
|
||||||
use_bias: Whether to use a learnable bias.
|
|
||||||
force_float32_reductions: Whether to force float32 reductions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The normalized tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
# x: (B, ..., S,..., D)
|
|
||||||
# apply layer norm over the last dimension, i.e. D dimension
|
|
||||||
in_dtype = x.dtype
|
|
||||||
if self.force_float32_reductions:
|
|
||||||
x = x.float()
|
|
||||||
x_centered = x - x.mean(dim=-1, keepdim=True)
|
|
||||||
y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
|
|
||||||
return y.to(in_dtype)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
# x: (B, ..., S,..., D)
|
|
||||||
x = self._layer_normalize(x)
|
|
||||||
x = self._apply_weight_bias(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class MultiHeadLayerNorm(LayerNorm):
|
|
||||||
"""Multi-head version of the LayerNorm layer.
|
"""Multi-head version of the LayerNorm layer.
|
||||||
|
|
||||||
It normalizes the last dimension of the input tensor.
|
It normalizes the last dimension of the input tensor.
|
||||||
@ -1020,16 +965,40 @@ else:
|
|||||||
use_bias: bool = False,
|
use_bias: bool = False,
|
||||||
force_float32_reductions: bool = True,
|
force_float32_reductions: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__()
|
||||||
num_features=num_heads * head_dim,
|
self.num_features = num_heads * head_dim
|
||||||
eps=eps,
|
self.eps = eps
|
||||||
use_weight=use_weight,
|
self.force_float32_reductions = force_float32_reductions
|
||||||
use_bias=use_bias,
|
|
||||||
force_float32_reductions=force_float32_reductions,
|
if use_weight:
|
||||||
)
|
self.weight = nn.Parameter(torch.ones(self.num_features))
|
||||||
|
else:
|
||||||
|
self.weight = None
|
||||||
|
|
||||||
|
if use_bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(self.num_features))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.weight is not None:
|
||||||
|
x = x * self.weight
|
||||||
|
if self.bias is not None:
|
||||||
|
x = x + self.bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x: (B, ..., S,..., D)
|
||||||
|
# apply layer norm over the last dimension, i.e. D dimension
|
||||||
|
in_dtype = x.dtype
|
||||||
|
if self.force_float32_reductions:
|
||||||
|
x = x.float()
|
||||||
|
x_centered = x - x.mean(dim=-1, keepdim=True)
|
||||||
|
y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
|
||||||
|
return y.to(in_dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor, # (B, S, NH, DH)
|
x: torch.Tensor, # (B, S, NH, DH)
|
||||||
@ -1274,163 +1243,6 @@ class xLSTMPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_is_stateful = True
|
_is_stateful = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class xLSTMOutput(ModelOutput):
|
|
||||||
"""
|
|
||||||
Class for the xLSTM model outputs
|
|
||||||
|
|
||||||
Args:
|
|
||||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
|
||||||
cache_params (`xLSTMCache`):
|
|
||||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
|
||||||
avoid providing the old `input_ids`.
|
|
||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
|
||||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
last_hidden_state: Optional[torch.FloatTensor]
|
|
||||||
cache_params: Optional[xLSTMCache] = None
|
|
||||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class xLSTMCausalLMOutput(ModelOutput):
|
|
||||||
"""
|
|
||||||
Base class for causal language model (or autoregressive) outputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
||||||
Language modeling loss (for next-token prediction).
|
|
||||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
||||||
cache_params (`xLSTMCache`):
|
|
||||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
|
||||||
avoid providing the old `input_ids`.
|
|
||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
|
||||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
|
||||||
logits: Optional[torch.FloatTensor] = None
|
|
||||||
cache_params: Optional[xLSTMCache] = None
|
|
||||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
||||||
|
|
||||||
|
|
||||||
XLSTM_START_DOCSTRING = r"""
|
|
||||||
|
|
||||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
||||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
||||||
etc.)
|
|
||||||
|
|
||||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
||||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
||||||
and behavior.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config ([`xLSTMConfig`]): Model configuration class with all the parameters of the model.
|
|
||||||
Initializing with a config file does not load the weights associated with the model, only the
|
|
||||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
XLSTM_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary.
|
|
||||||
|
|
||||||
If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
|
|
||||||
`input_ids`.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
||||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
||||||
model's internal embedding lookup matrix.
|
|
||||||
cache_params (`xLSTMCache`, *optional*):
|
|
||||||
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
|
||||||
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
||||||
The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
|
|
||||||
If `cache_params` is passed, `cache_position` should also be passed.
|
|
||||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def small_init_method(dim):
|
|
||||||
"""
|
|
||||||
Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
|
||||||
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
|
|
||||||
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
|
|
||||||
std = (2 / (5 * dim)) ** (1 / 2)
|
|
||||||
|
|
||||||
def init_(tensor):
|
|
||||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
|
||||||
|
|
||||||
return init_
|
|
||||||
|
|
||||||
|
|
||||||
def wang_init_method(n_layers, dim):
|
|
||||||
"""
|
|
||||||
Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
|
||||||
"""
|
|
||||||
std = 2 / n_layers / dim ** (1 / 2)
|
|
||||||
|
|
||||||
def init_(tensor):
|
|
||||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
|
||||||
|
|
||||||
return init_
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"The bare xLSTM Model transformer outputting raw hidden-states without any specific head on top.",
|
|
||||||
XLSTM_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class xLSTMModel(xLSTMPreTrainedModel):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
# use embbeding_dim and num_blocks once here to make use of them
|
|
||||||
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([mLSTMBlock(config) for _ in range(config.num_blocks)])
|
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
# actually unused, but needed in external integration
|
|
||||||
_ = (
|
|
||||||
config.add_out_norm,
|
|
||||||
config.tie_word_embeddings,
|
|
||||||
config.chunkwise_kernel,
|
|
||||||
config.sequence_kernel,
|
|
||||||
config.step_kernel,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.out_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
if self is not module:
|
if self is not module:
|
||||||
if isinstance(module, torch.nn.Embedding):
|
if isinstance(module, torch.nn.Embedding):
|
||||||
@ -1518,18 +1330,113 @@ class xLSTMModel(xLSTMPreTrainedModel):
|
|||||||
torch.nn.init.zeros_(block.mlstm_layer.multihead_norm.bias)
|
torch.nn.init.zeros_(block.mlstm_layer.multihead_norm.bias)
|
||||||
torch.nn.init.zeros_(block.mlstm_layer.out_proj.bias)
|
torch.nn.init.zeros_(block.mlstm_layer.out_proj.bias)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class xLSTMOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Class for the xLSTM model outputs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
cache_params (`xLSTMCache`):
|
||||||
|
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||||
|
avoid providing the old `input_ids`.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: Optional[torch.FloatTensor]
|
||||||
|
cache_params: Optional[xLSTMCache] = None
|
||||||
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class xLSTMCausalLMOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for causal language model (or autoregressive) outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||||
|
Language modeling loss (for next-token prediction).
|
||||||
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
cache_params (`xLSTMCache`):
|
||||||
|
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||||
|
avoid providing the old `input_ids`.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits: Optional[torch.FloatTensor] = None
|
||||||
|
cache_params: Optional[xLSTMCache] = None
|
||||||
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def small_init_method(dim):
|
||||||
|
"""
|
||||||
|
Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
||||||
|
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
|
||||||
|
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
|
||||||
|
std = (2 / (5 * dim)) ** (1 / 2)
|
||||||
|
|
||||||
|
def init_(tensor):
|
||||||
|
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||||
|
|
||||||
|
return init_
|
||||||
|
|
||||||
|
|
||||||
|
def wang_init_method(n_layers, dim):
|
||||||
|
"""
|
||||||
|
Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
||||||
|
"""
|
||||||
|
std = 2 / n_layers / dim ** (1 / 2)
|
||||||
|
|
||||||
|
def init_(tensor):
|
||||||
|
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||||
|
|
||||||
|
return init_
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class xLSTMModel(xLSTMPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
# use embbeding_dim and num_blocks once here to make use of them
|
||||||
|
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([mLSTMBlock(config) for _ in range(config.num_blocks)])
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
# actually unused, but needed in external integration
|
||||||
|
_ = (
|
||||||
|
config.add_out_norm,
|
||||||
|
config.tie_word_embeddings,
|
||||||
|
config.chunkwise_kernel,
|
||||||
|
config.sequence_kernel,
|
||||||
|
config.step_kernel,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embeddings
|
return self.embeddings
|
||||||
|
|
||||||
def set_input_embeddings(self, new_embedding):
|
def set_input_embeddings(self, new_embedding):
|
||||||
self.embeddings = new_embedding
|
self.embeddings = new_embedding
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(XLSTM_INPUTS_DOCSTRING)
|
@can_return_tuple
|
||||||
@add_code_sample_docstrings(
|
@auto_docstring
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
||||||
output_type=xLSTMOutput,
|
|
||||||
config_class=_CONFIG_FOR_DOC,
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -1639,13 +1546,7 @@ class xLSTMModel(xLSTMPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@auto_docstring
|
||||||
"""
|
|
||||||
The xLSTM Model transformer with a language modeling head on top (linear layer with weights not tied to the input
|
|
||||||
embeddings).
|
|
||||||
""",
|
|
||||||
XLSTM_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = []
|
_tied_weights_keys = []
|
||||||
|
|
||||||
@ -1720,12 +1621,8 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(XLSTM_INPUTS_DOCSTRING)
|
@can_return_tuple
|
||||||
@add_code_sample_docstrings(
|
@auto_docstring
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
||||||
output_type=xLSTMCausalLMOutput,
|
|
||||||
config_class=_CONFIG_FOR_DOC,
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
Loading…
Reference in New Issue
Block a user