mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Fix index names in modeling_xlstm.py
This commit is contained in:
parent
c7ce6a50bb
commit
5f8a399ce2
@ -147,7 +147,7 @@ else:
|
||||
|
||||
scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
|
||||
|
||||
# NOTE: no update in-place (igate.e. +=) as this gives error for autograd backward
|
||||
# NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
|
||||
matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
|
||||
|
||||
# n_k update
|
||||
@ -169,7 +169,7 @@ else:
|
||||
matQ: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
|
||||
matK: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
|
||||
matV: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHHV)
|
||||
# these states must be all states up to the last chunk, igate.e. :-1
|
||||
# these states must be all states up to the last chunk, i.e. :-1
|
||||
matC_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK, DHHV)
|
||||
vecN_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK)
|
||||
scaMinter_states: torch.Tensor, # (BATCH_SIZE, NH, NC)
|
||||
@ -948,7 +948,7 @@ else:
|
||||
|
||||
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (BATCH_SIZE, ..., SEQLEN,..., HD)
|
||||
# apply rms norm over the last dimension, igate.e. HD dimension
|
||||
# apply rms norm over the last dimension, i.e. HD dimension
|
||||
in_dtype = x.dtype
|
||||
if self.force_float32_reductions:
|
||||
x = x.float()
|
||||
@ -1021,7 +1021,7 @@ else:
|
||||
|
||||
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (BATCH_SIZE, ..., SEQLEN,..., HD)
|
||||
# apply layer norm over the last dimension, igate.e. HD dimension
|
||||
# apply layer norm over the last dimension, i.e. HD dimension
|
||||
in_dtype = x.dtype
|
||||
if self.force_float32_reductions:
|
||||
x = x.float()
|
||||
@ -1502,15 +1502,14 @@ class xLSTMModel(xLSTMPreTrainedModel):
|
||||
hidden_states_chunk = hidden_states[
|
||||
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
|
||||
]
|
||||
for igate, xlstm_block in enumerate(self.blocks):
|
||||
for layer_idx, xlstm_block in enumerate(self.blocks):
|
||||
hidden_states_chunk, rnn_state = xlstm_block(
|
||||
hidden_states_chunk,
|
||||
state=cache_params.rnn_state[igate],
|
||||
state=cache_params.rnn_state[layer_idx],
|
||||
)
|
||||
for state_idx in range(len(cache_params.rnn_state[igate])):
|
||||
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
|
||||
local_rnn_state = rnn_state[state_idx]
|
||||
local_rnn_state = rnn_state[state_idx]
|
||||
cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state)
|
||||
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
|
||||
cache_params.rnn_state_initial = False
|
||||
final_state[
|
||||
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
|
||||
@ -1519,23 +1518,22 @@ class xLSTMModel(xLSTMPreTrainedModel):
|
||||
hidden_states = final_state
|
||||
else:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for igate, xlstm_block in enumerate(self.blocks):
|
||||
for layer_idx, xlstm_block in enumerate(self.blocks):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states, rnn_state = self._gradient_checkpointing_func(
|
||||
xlstm_block.__call__,
|
||||
hidden_states,
|
||||
cache_params.rnn_state[igate] if cache_params is not None else None,
|
||||
cache_params.rnn_state[layer_idx] if cache_params is not None else None,
|
||||
)
|
||||
else:
|
||||
hidden_states, rnn_state = xlstm_block(
|
||||
hidden_states,
|
||||
state=cache_params.rnn_state[igate] if cache_params is not None else None,
|
||||
state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
|
||||
)
|
||||
if cache_params:
|
||||
for state_idx in range(len(cache_params.rnn_state[igate])):
|
||||
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
|
||||
local_rnn_state = rnn_state[state_idx]
|
||||
local_rnn_state = rnn_state[state_idx]
|
||||
cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state)
|
||||
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
|
||||
cache_params.rnn_state_initial = False
|
||||
|
||||
if output_hidden_states:
|
||||
@ -1651,7 +1649,7 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
||||
) -> Union[tuple, xLSTMCausalLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, igate.e. you can set
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user