Fix Mamba slow path bug with dtype mismatch. (#32691)

* Fix Mamba slow path bug with dtype mismatch.

* Update test_modeling_mamba.py

* Improve style.

* Fix issue with cache position of dtype mismatch test.

* Change test for slow path.

* Revert changes.

* Switch to buggy code and add test to catch it.

* Fix the dtype mismatch bug and add test code to verify it.

* Fix minor bug with test.

* Fix incorrect dtype of model output.

* Fix incorrect dtype of cache.

* Fix incorrect dtype of ssm cache.

* Fix incorrect dtype of conv state.

* Remove assertion for ssm state.

* Add assertion for conv state dtype.

* Fix all issues with dtype mismatch test.
This commit is contained in:
Adibvafa Fallahpour 2024-10-01 03:28:40 -04:00 committed by GitHub
parent 570c89625b
commit c269c5c74d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 1 deletions

View File

@ -1797,7 +1797,7 @@ class MambaCache:
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]

View File

@ -421,6 +421,30 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_beam_sample_generate(self):
pass
def test_dtype_mismatch_handled_in_cache(self):
config, input_ids, *args = self.model_tester.prepare_config_and_inputs()
model = MambaModel(config)
model.to(torch_device).to(torch.float16)
model.eval()
# Create cache with float32 dtype
cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device)
# If code is correct, no error occurs and test passes
outputs = model(
input_ids,
cache_params=cache_params,
use_cache=True,
cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
)
self.assertIsNotNone(outputs)
self.assertIsNotNone(outputs.last_hidden_state)
self.assertEqual(
outputs.last_hidden_state.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
)
@require_torch
class MambaIntegrationTests(unittest.TestCase):