diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index 5a249343eef..42a371e3681 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -29,9 +29,8 @@ from ...generation.configuration_utils import xLSTMCache from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, + auto_docstring, + can_return_tuple, is_xlstm_available, ) from .configuration_xlstm import xLSTMConfig @@ -42,7 +41,6 @@ if is_xlstm_available(): external_xlstm = True else: - from abc import ABC, abstractmethod from functools import partial from typing import Callable, Literal @@ -679,13 +677,9 @@ else: DHHV = v.shape[-1] c_state = ( - c_initial - if c_initial is not None - else torch.zeros(B, NH, DHQK, DHHV, device=k.device, dtype=torch.float32) - ) - n_state = ( - n_initial if n_initial is not None else torch.zeros(B, NH, DHQK, device=k.device, dtype=torch.float32) + c_initial if c_initial is not None else torch.zeros(B, NH, DHQK, DHHV, device=k.device, dtype=torch.float32) ) + n_state = n_initial if n_initial is not None else torch.zeros(B, NH, DHQK, device=k.device, dtype=torch.float32) m_state = m_initial if m_initial is not None else torch.zeros(B, NH, 1, device=k.device, dtype=torch.float32) if S > 1: @@ -878,9 +872,11 @@ else: def extra_repr(self) -> str: return f"{self.config}" - class NormLayer(nn.Module, ABC): - """Base class for normalization layers. - This class contains optional learnable weight and bias parameters. + class RMSNorm(nn.Module): + """Root mean square normalization layer implementation similar + to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html. + + It normalizes the input tensor by the root mean square of the last dimension. Args: num_features: The number of features in the input tensor. @@ -920,24 +916,6 @@ else: x = x + self.bias return x - @abstractmethod - def forward(self, x: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - class RMSNorm(NormLayer): - """Root mean square normalization layer implementation similar - to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html. - - It normalizes the input tensor by the root mean square of the last dimension. - - Args: - num_features: The number of features in the input tensor. - eps: A small value to avoid division by zero. - use_weight: Whether to use a learnable weight. - use_bias: Whether to use a learnable bias. - force_float32_reductions: Whether to force float32 reductions. - """ - def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor: # x: (B, ..., S,..., D) # apply rms norm over the last dimension, i.e. D dimension @@ -953,40 +931,7 @@ else: x = self._apply_weight_bias(x) return x - class LayerNorm(NormLayer): - """Layer normalization layer implementation similar to - https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html. - - The layer normalization is applied over the last dimension of the input tensor. - - Args: - num_features: The number of features in the input tensor. - eps: A small value to avoid division by zero. - use_weight: Whether to use a learnable weight. - use_bias: Whether to use a learnable bias. - force_float32_reductions: Whether to force float32 reductions. - - Returns: - The normalized tensor. - """ - - def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor: - # x: (B, ..., S,..., D) - # apply layer norm over the last dimension, i.e. D dimension - in_dtype = x.dtype - if self.force_float32_reductions: - x = x.float() - x_centered = x - x.mean(dim=-1, keepdim=True) - y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps) - return y.to(in_dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x: (B, ..., S,..., D) - x = self._layer_normalize(x) - x = self._apply_weight_bias(x) - return x - - class MultiHeadLayerNorm(LayerNorm): + class MultiHeadLayerNorm(nn.Module): """Multi-head version of the LayerNorm layer. It normalizes the last dimension of the input tensor. @@ -1020,16 +965,40 @@ else: use_bias: bool = False, force_float32_reductions: bool = True, ): - super().__init__( - num_features=num_heads * head_dim, - eps=eps, - use_weight=use_weight, - use_bias=use_bias, - force_float32_reductions=force_float32_reductions, - ) + super().__init__() + self.num_features = num_heads * head_dim + self.eps = eps + self.force_float32_reductions = force_float32_reductions + + if use_weight: + self.weight = nn.Parameter(torch.ones(self.num_features)) + else: + self.weight = None + + if use_bias: + self.bias = nn.Parameter(torch.zeros(self.num_features)) + else: + self.bias = None self.num_heads = num_heads self.head_dim = head_dim + def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor: + if self.weight is not None: + x = x * self.weight + if self.bias is not None: + x = x + self.bias + return x + + def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, ..., S,..., D) + # apply layer norm over the last dimension, i.e. D dimension + in_dtype = x.dtype + if self.force_float32_reductions: + x = x.float() + x_centered = x - x.mean(dim=-1, keepdim=True) + y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps) + return y.to(in_dtype) + def forward( self, x: torch.Tensor, # (B, S, NH, DH) @@ -1274,163 +1243,6 @@ class xLSTMPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _is_stateful = True - -@dataclass -class xLSTMOutput(ModelOutput): - """ - Class for the xLSTM model outputs - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - cache_params (`xLSTMCache`): - The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to - avoid providing the old `input_ids`. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - - """ - - last_hidden_state: Optional[torch.FloatTensor] - cache_params: Optional[xLSTMCache] = None - hidden_states: Optional[tuple[torch.FloatTensor]] = None - - -@dataclass -class xLSTMCausalLMOutput(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`xLSTMCache`): - The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to - avoid providing the old `input_ids`. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - cache_params: Optional[xLSTMCache] = None - hidden_states: Optional[tuple[torch.FloatTensor]] = None - - -XLSTM_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`xLSTMConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -XLSTM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - Indices of input sequence tokens in the vocabulary. - - If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - cache_params (`xLSTMCache`, *optional*): - If passed along, the model uses the previous state in all the blocks (which will give the output for the - `input_ids` provided as if the model add `state_input_ids + input_ids` as context). - use_cache (`bool`, *optional*): - If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - The position of the current input in the cache. This is used to ensure that the cache is correctly updated. - If `cache_params` is passed, `cache_position` should also be passed. - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) -""" - - -def small_init_method(dim): - """ - Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py - Fills the input Tensor with values according to the method described in Transformers without Tears: Improving - the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution.""" - std = (2 / (5 * dim)) ** (1 / 2) - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - -def wang_init_method(n_layers, dim): - """ - Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py - """ - std = 2 / n_layers / dim ** (1 / 2) - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - -@add_start_docstrings( - "The bare xLSTM Model transformer outputting raw hidden-states without any specific head on top.", - XLSTM_START_DOCSTRING, -) -class xLSTMModel(xLSTMPreTrainedModel): - def __init__(self, config): - super().__init__(config) - # use embbeding_dim and num_blocks once here to make use of them - self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim) - - self.blocks = nn.ModuleList([mLSTMBlock(config) for _ in range(config.num_blocks)]) - - self.gradient_checkpointing = False - # actually unused, but needed in external integration - _ = ( - config.add_out_norm, - config.tie_word_embeddings, - config.chunkwise_kernel, - config.sequence_kernel, - config.step_kernel, - ) - - self.out_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) - # Initialize weights and apply final processing - self.post_init() - def _init_weights(self, module): if self is not module: if isinstance(module, torch.nn.Embedding): @@ -1518,18 +1330,113 @@ class xLSTMModel(xLSTMPreTrainedModel): torch.nn.init.zeros_(block.mlstm_layer.multihead_norm.bias) torch.nn.init.zeros_(block.mlstm_layer.out_proj.bias) + +@dataclass +class xLSTMOutput(ModelOutput): + """ + Class for the xLSTM model outputs + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`xLSTMCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + + """ + + last_hidden_state: Optional[torch.FloatTensor] + cache_params: Optional[xLSTMCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class xLSTMCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`xLSTMCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[xLSTMCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + + +def small_init_method(dim): + """ + Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py + Fills the input Tensor with values according to the method described in Transformers without Tears: Improving + the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution.""" + std = (2 / (5 * dim)) ** (1 / 2) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def wang_init_method(n_layers, dim): + """ + Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py + """ + std = 2 / n_layers / dim ** (1 / 2) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +@auto_docstring +class xLSTMModel(xLSTMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + # use embbeding_dim and num_blocks once here to make use of them + self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim) + + self.blocks = nn.ModuleList([mLSTMBlock(config) for _ in range(config.num_blocks)]) + + self.gradient_checkpointing = False + # actually unused, but needed in external integration + _ = ( + config.add_out_norm, + config.tie_word_embeddings, + config.chunkwise_kernel, + config.sequence_kernel, + config.step_kernel, + ) + + self.out_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + # Initialize weights and apply final processing + self.post_init() + def get_input_embeddings(self): return self.embeddings def set_input_embeddings(self, new_embedding): self.embeddings = new_embedding - @add_start_docstrings_to_model_forward(XLSTM_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=xLSTMOutput, - config_class=_CONFIG_FOR_DOC, - ) + @can_return_tuple + @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1639,13 +1546,7 @@ class xLSTMModel(xLSTMPreTrainedModel): ) -@add_start_docstrings( - """ - The xLSTM Model transformer with a language modeling head on top (linear layer with weights not tied to the input - embeddings). - """, - XLSTM_START_DOCSTRING, -) +@auto_docstring class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin): _tied_weights_keys = [] @@ -1720,12 +1621,8 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin): ) return model_inputs - @add_start_docstrings_to_model_forward(XLSTM_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=xLSTMCausalLMOutput, - config_class=_CONFIG_FOR_DOC, - ) + @can_return_tuple + @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None,