diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index ab9f8c3d853..2a8e1c25f63 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -254,8 +254,8 @@ class RecurrentGemmaSdpaAttention(nn.Module): k_out = k_out[:, :, indices] v_out = v_out[:, :, indices] - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + k_out[:, :, cache_position] = key_states.to(k_out.dtype) + v_out[:, :, cache_position] = value_states.to(v_out.dtype) self.key_states, self.value_states = k_out, v_out return k_out, v_out