Fix index names in modeling_xlstm.py

This commit is contained in:
Korbinian Poeppel 2025-07-01 17:45:03 +02:00
parent c7ce6a50bb
commit 5f8a399ce2

View File

@ -147,7 +147,7 @@ else:
scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
# NOTE: no update in-place (igate.e. +=) as this gives error for autograd backward
# NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
# n_k update
@ -169,7 +169,7 @@ else:
matQ: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
matK: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
matV: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHHV)
# these states must be all states up to the last chunk, igate.e. :-1
# these states must be all states up to the last chunk, i.e. :-1
matC_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK, DHHV)
vecN_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK)
scaMinter_states: torch.Tensor, # (BATCH_SIZE, NH, NC)
@ -948,7 +948,7 @@ else:
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
# x: (BATCH_SIZE, ..., SEQLEN,..., HD)
# apply rms norm over the last dimension, igate.e. HD dimension
# apply rms norm over the last dimension, i.e. HD dimension
in_dtype = x.dtype
if self.force_float32_reductions:
x = x.float()
@ -1021,7 +1021,7 @@ else:
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
# x: (BATCH_SIZE, ..., SEQLEN,..., HD)
# apply layer norm over the last dimension, igate.e. HD dimension
# apply layer norm over the last dimension, i.e. HD dimension
in_dtype = x.dtype
if self.force_float32_reductions:
x = x.float()
@ -1502,15 +1502,14 @@ class xLSTMModel(xLSTMPreTrainedModel):
hidden_states_chunk = hidden_states[
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
]
for igate, xlstm_block in enumerate(self.blocks):
for layer_idx, xlstm_block in enumerate(self.blocks):
hidden_states_chunk, rnn_state = xlstm_block(
hidden_states_chunk,
state=cache_params.rnn_state[igate],
state=cache_params.rnn_state[layer_idx],
)
for state_idx in range(len(cache_params.rnn_state[igate])):
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
local_rnn_state = rnn_state[state_idx]
local_rnn_state = rnn_state[state_idx]
cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state)
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False
final_state[
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
@ -1519,23 +1518,22 @@ class xLSTMModel(xLSTMPreTrainedModel):
hidden_states = final_state
else:
all_hidden_states = () if output_hidden_states else None
for igate, xlstm_block in enumerate(self.blocks):
for layer_idx, xlstm_block in enumerate(self.blocks):
if self.gradient_checkpointing and self.training:
hidden_states, rnn_state = self._gradient_checkpointing_func(
xlstm_block.__call__,
hidden_states,
cache_params.rnn_state[igate] if cache_params is not None else None,
cache_params.rnn_state[layer_idx] if cache_params is not None else None,
)
else:
hidden_states, rnn_state = xlstm_block(
hidden_states,
state=cache_params.rnn_state[igate] if cache_params is not None else None,
state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
)
if cache_params:
for state_idx in range(len(cache_params.rnn_state[igate])):
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
local_rnn_state = rnn_state[state_idx]
local_rnn_state = rnn_state[state_idx]
cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state)
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False
if output_hidden_states:
@ -1651,7 +1649,7 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
) -> Union[tuple, xLSTMCausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, igate.e. you can set
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""