mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
Fix index names in modeling_xlstm.py
This commit is contained in:
parent
c7ce6a50bb
commit
5f8a399ce2
@ -147,7 +147,7 @@ else:
|
|||||||
|
|
||||||
scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
|
scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
|
||||||
|
|
||||||
# NOTE: no update in-place (igate.e. +=) as this gives error for autograd backward
|
# NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
|
||||||
matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
|
matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
|
||||||
|
|
||||||
# n_k update
|
# n_k update
|
||||||
@ -169,7 +169,7 @@ else:
|
|||||||
matQ: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
|
matQ: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
|
||||||
matK: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
|
matK: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK)
|
||||||
matV: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHHV)
|
matV: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHHV)
|
||||||
# these states must be all states up to the last chunk, igate.e. :-1
|
# these states must be all states up to the last chunk, i.e. :-1
|
||||||
matC_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK, DHHV)
|
matC_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK, DHHV)
|
||||||
vecN_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK)
|
vecN_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK)
|
||||||
scaMinter_states: torch.Tensor, # (BATCH_SIZE, NH, NC)
|
scaMinter_states: torch.Tensor, # (BATCH_SIZE, NH, NC)
|
||||||
@ -948,7 +948,7 @@ else:
|
|||||||
|
|
||||||
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# x: (BATCH_SIZE, ..., SEQLEN,..., HD)
|
# x: (BATCH_SIZE, ..., SEQLEN,..., HD)
|
||||||
# apply rms norm over the last dimension, igate.e. HD dimension
|
# apply rms norm over the last dimension, i.e. HD dimension
|
||||||
in_dtype = x.dtype
|
in_dtype = x.dtype
|
||||||
if self.force_float32_reductions:
|
if self.force_float32_reductions:
|
||||||
x = x.float()
|
x = x.float()
|
||||||
@ -1021,7 +1021,7 @@ else:
|
|||||||
|
|
||||||
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# x: (BATCH_SIZE, ..., SEQLEN,..., HD)
|
# x: (BATCH_SIZE, ..., SEQLEN,..., HD)
|
||||||
# apply layer norm over the last dimension, igate.e. HD dimension
|
# apply layer norm over the last dimension, i.e. HD dimension
|
||||||
in_dtype = x.dtype
|
in_dtype = x.dtype
|
||||||
if self.force_float32_reductions:
|
if self.force_float32_reductions:
|
||||||
x = x.float()
|
x = x.float()
|
||||||
@ -1502,15 +1502,14 @@ class xLSTMModel(xLSTMPreTrainedModel):
|
|||||||
hidden_states_chunk = hidden_states[
|
hidden_states_chunk = hidden_states[
|
||||||
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
|
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
|
||||||
]
|
]
|
||||||
for igate, xlstm_block in enumerate(self.blocks):
|
for layer_idx, xlstm_block in enumerate(self.blocks):
|
||||||
hidden_states_chunk, rnn_state = xlstm_block(
|
hidden_states_chunk, rnn_state = xlstm_block(
|
||||||
hidden_states_chunk,
|
hidden_states_chunk,
|
||||||
state=cache_params.rnn_state[igate],
|
state=cache_params.rnn_state[layer_idx],
|
||||||
)
|
)
|
||||||
for state_idx in range(len(cache_params.rnn_state[igate])):
|
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
|
||||||
local_rnn_state = rnn_state[state_idx]
|
local_rnn_state = rnn_state[state_idx]
|
||||||
local_rnn_state = rnn_state[state_idx]
|
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
|
||||||
cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state)
|
|
||||||
cache_params.rnn_state_initial = False
|
cache_params.rnn_state_initial = False
|
||||||
final_state[
|
final_state[
|
||||||
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
|
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
|
||||||
@ -1519,23 +1518,22 @@ class xLSTMModel(xLSTMPreTrainedModel):
|
|||||||
hidden_states = final_state
|
hidden_states = final_state
|
||||||
else:
|
else:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for igate, xlstm_block in enumerate(self.blocks):
|
for layer_idx, xlstm_block in enumerate(self.blocks):
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
hidden_states, rnn_state = self._gradient_checkpointing_func(
|
hidden_states, rnn_state = self._gradient_checkpointing_func(
|
||||||
xlstm_block.__call__,
|
xlstm_block.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cache_params.rnn_state[igate] if cache_params is not None else None,
|
cache_params.rnn_state[layer_idx] if cache_params is not None else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states, rnn_state = xlstm_block(
|
hidden_states, rnn_state = xlstm_block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
state=cache_params.rnn_state[igate] if cache_params is not None else None,
|
state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
|
||||||
)
|
)
|
||||||
if cache_params:
|
if cache_params:
|
||||||
for state_idx in range(len(cache_params.rnn_state[igate])):
|
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
|
||||||
local_rnn_state = rnn_state[state_idx]
|
local_rnn_state = rnn_state[state_idx]
|
||||||
local_rnn_state = rnn_state[state_idx]
|
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
|
||||||
cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state)
|
|
||||||
cache_params.rnn_state_initial = False
|
cache_params.rnn_state_initial = False
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@ -1651,7 +1649,7 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
|||||||
) -> Union[tuple, xLSTMCausalLMOutput]:
|
) -> Union[tuple, xLSTMCausalLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, igate.e. you can set
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user