diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index cc4c282bf0e..d97c740042b 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -147,7 +147,7 @@ else: scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None] - # NOTE: no update in-place (igate.e. +=) as this gives error for autograd backward + # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk) # n_k update @@ -169,7 +169,7 @@ else: matQ: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK) matK: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHQK) matV: torch.Tensor, # (BATCH_SIZE, NH, SEQLEN, DHHV) - # these states must be all states up to the last chunk, igate.e. :-1 + # these states must be all states up to the last chunk, i.e. :-1 matC_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK, DHHV) vecN_states: torch.Tensor, # (BATCH_SIZE, NH, NC * DHQK) scaMinter_states: torch.Tensor, # (BATCH_SIZE, NH, NC) @@ -948,7 +948,7 @@ else: def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor: # x: (BATCH_SIZE, ..., SEQLEN,..., HD) - # apply rms norm over the last dimension, igate.e. HD dimension + # apply rms norm over the last dimension, i.e. HD dimension in_dtype = x.dtype if self.force_float32_reductions: x = x.float() @@ -1021,7 +1021,7 @@ else: def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor: # x: (BATCH_SIZE, ..., SEQLEN,..., HD) - # apply layer norm over the last dimension, igate.e. HD dimension + # apply layer norm over the last dimension, i.e. HD dimension in_dtype = x.dtype if self.force_float32_reductions: x = x.float() @@ -1502,15 +1502,14 @@ class xLSTMModel(xLSTMPreTrainedModel): hidden_states_chunk = hidden_states[ :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) ] - for igate, xlstm_block in enumerate(self.blocks): + for layer_idx, xlstm_block in enumerate(self.blocks): hidden_states_chunk, rnn_state = xlstm_block( hidden_states_chunk, - state=cache_params.rnn_state[igate], + state=cache_params.rnn_state[layer_idx], ) - for state_idx in range(len(cache_params.rnn_state[igate])): + for state_idx in range(len(cache_params.rnn_state[layer_idx])): local_rnn_state = rnn_state[state_idx] - local_rnn_state = rnn_state[state_idx] - cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state) + cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state) cache_params.rnn_state_initial = False final_state[ :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) @@ -1519,23 +1518,22 @@ class xLSTMModel(xLSTMPreTrainedModel): hidden_states = final_state else: all_hidden_states = () if output_hidden_states else None - for igate, xlstm_block in enumerate(self.blocks): + for layer_idx, xlstm_block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: hidden_states, rnn_state = self._gradient_checkpointing_func( xlstm_block.__call__, hidden_states, - cache_params.rnn_state[igate] if cache_params is not None else None, + cache_params.rnn_state[layer_idx] if cache_params is not None else None, ) else: hidden_states, rnn_state = xlstm_block( hidden_states, - state=cache_params.rnn_state[igate] if cache_params is not None else None, + state=cache_params.rnn_state[layer_idx] if cache_params is not None else None, ) if cache_params: - for state_idx in range(len(cache_params.rnn_state[igate])): + for state_idx in range(len(cache_params.rnn_state[layer_idx])): local_rnn_state = rnn_state[state_idx] - local_rnn_state = rnn_state[state_idx] - cache_params.rnn_state[igate][state_idx].copy_(local_rnn_state) + cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state) cache_params.rnn_state_initial = False if output_hidden_states: @@ -1651,7 +1649,7 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin): ) -> Union[tuple, xLSTMCausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, igate.e. you can set + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """