Fix index names in modeling_xlstm.py

This commit is contained in:
Korbinian Poeppel 2025-07-01 17:45:03 +02:00
parent c7ce6a50bb
commit 5f8a399ce2

View File

@ -147,7 +147,7 @@ else:
scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None] scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
# NOTE: no update in-place (igate.e. +=) as this gives error for autograd backward # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk) matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
# n_k update # n_k update
@ -169,7 +169,7 @@ else:
matQ: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK) matQ: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
matK: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK) matK: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
matV: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHHV) matV: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHHV)
# these states must be all states up to the last chunk, igate.e. :-1 # these states must be all states up to the last chunk, i.e. :-1
matC_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK, DHHV) matC_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK, DHHV)
vecN_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK) vecN_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK)
scaMinter_states: torch.Tensor, # (BATCH_SIZE, NH, NC) scaMinter_states: torch.Tensor, # (BATCH_SIZE, NH, NC)
@ -948,7 +948,7 @@ else:
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor: def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
# x: (BATCH_SIZE, ..., SEQLEN,..., HD) # x: (BATCH_SIZE, ..., SEQLEN,..., HD)
# apply rms norm over the last dimension, igate.e. HD dimension # apply rms norm over the last dimension, i.e. HD dimension
in_dtype = x.dtype in_dtype = x.dtype
if self.force_float32_reductions: if self.force_float32_reductions:
x = x.float() x = x.float()
@ -1021,7 +1021,7 @@ else:
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor: def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
# x: (BATCH_SIZE, ..., SEQLEN,..., HD) # x: (BATCH_SIZE, ..., SEQLEN,..., HD)
# apply layer norm over the last dimension, igate.e. HD dimension # apply layer norm over the last dimension, i.e. HD dimension
in_dtype = x.dtype in_dtype = x.dtype
if self.force_float32_reductions: if self.force_float32_reductions:
x = x.float() x = x.float()
@ -1502,15 +1502,14 @@ class xLSTMModel(xLSTMPreTrainedModel):
hidden_states_chunk = hidden_states[ hidden_states_chunk = hidden_states[
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
] ]
for igate, xlstm_block in enumerate(self.blocks): for layer_idx, xlstm_block in enumerate(self.blocks):
hidden_states_chunk, rnn_state = xlstm_block( hidden_states_chunk, rnn_state = xlstm_block(
hidden_states_chunk, hidden_states_chunk,
state=cache_params.rnn_state[igate], state=cache_params.rnn_state[layer_idx],
) )
for state_idx in range(len(cache_params.rnn_state[igate])): for state_idx in range(len(cache_params.rnn_state[layer_idx])):
local_rnn_state = rnn_state[state_idx] local_rnn_state = rnn_state[state_idx]
local_rnn_state = rnn_state[state_idx] cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False cache_params.rnn_state_initial = False
final_state[ final_state[
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
@ -1519,23 +1518,22 @@ class xLSTMModel(xLSTMPreTrainedModel):
hidden_states = final_state hidden_states = final_state
else: else:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for igate, xlstm_block in enumerate(self.blocks): for layer_idx, xlstm_block in enumerate(self.blocks):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
hidden_states, rnn_state = self._gradient_checkpointing_func( hidden_states, rnn_state = self._gradient_checkpointing_func(
xlstm_block.__call__, xlstm_block.__call__,
hidden_states, hidden_states,
cache_params.rnn_state[igate] if cache_params is not None else None, cache_params.rnn_state[layer_idx] if cache_params is not None else None,
) )
else: else:
hidden_states, rnn_state = xlstm_block( hidden_states, rnn_state = xlstm_block(
hidden_states, hidden_states,
state=cache_params.rnn_state[igate] if cache_params is not None else None, state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
) )
if cache_params: if cache_params:
for state_idx in range(len(cache_params.rnn_state[igate])): for state_idx in range(len(cache_params.rnn_state[layer_idx])):
local_rnn_state = rnn_state[state_idx] local_rnn_state = rnn_state[state_idx]
local_rnn_state = rnn_state[state_idx] cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False cache_params.rnn_state_initial = False
if output_hidden_states: if output_hidden_states:
@ -1651,7 +1649,7 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
) -> Union[tuple, xLSTMCausalLMOutput]: ) -> Union[tuple, xLSTMCausalLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, igate.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
""" """