diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d41bc99eea5..0b82b17dcde 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1797,7 +1797,7 @@ class MambaCache: cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) self.conv_states[layer_idx].zero_() self.conv_states[layer_idx] += conv_state return self.conv_states[layer_idx] diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 3b4a18bb48e..d432dfa93df 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -421,6 +421,30 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_beam_sample_generate(self): pass + def test_dtype_mismatch_handled_in_cache(self): + config, input_ids, *args = self.model_tester.prepare_config_and_inputs() + model = MambaModel(config) + model.to(torch_device).to(torch.float16) + model.eval() + + # Create cache with float32 dtype + cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) + + # If code is correct, no error occurs and test passes + outputs = model( + input_ids, + cache_params=cache_params, + use_cache=True, + cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device), + ) + + self.assertIsNotNone(outputs) + self.assertIsNotNone(outputs.last_hidden_state) + self.assertEqual( + outputs.last_hidden_state.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size), + ) + @require_torch class MambaIntegrationTests(unittest.TestCase):