mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Mamba slow path bug with dtype mismatch. (#32691)
* Fix Mamba slow path bug with dtype mismatch. * Update test_modeling_mamba.py * Improve style. * Fix issue with cache position of dtype mismatch test. * Change test for slow path. * Revert changes. * Switch to buggy code and add test to catch it. * Fix the dtype mismatch bug and add test code to verify it. * Fix minor bug with test. * Fix incorrect dtype of model output. * Fix incorrect dtype of cache. * Fix incorrect dtype of ssm cache. * Fix incorrect dtype of conv state. * Remove assertion for ssm state. * Add assertion for conv state dtype. * Fix all issues with dtype mismatch test.
This commit is contained in:
parent
570c89625b
commit
c269c5c74d
@ -1797,7 +1797,7 @@ class MambaCache:
|
||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||
|
||||
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
||||
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
|
||||
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.conv_states[layer_idx] += conv_state
|
||||
return self.conv_states[layer_idx]
|
||||
|
@ -421,6 +421,30 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_beam_sample_generate(self):
|
||||
pass
|
||||
|
||||
def test_dtype_mismatch_handled_in_cache(self):
|
||||
config, input_ids, *args = self.model_tester.prepare_config_and_inputs()
|
||||
model = MambaModel(config)
|
||||
model.to(torch_device).to(torch.float16)
|
||||
model.eval()
|
||||
|
||||
# Create cache with float32 dtype
|
||||
cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device)
|
||||
|
||||
# If code is correct, no error occurs and test passes
|
||||
outputs = model(
|
||||
input_ids,
|
||||
cache_params=cache_params,
|
||||
use_cache=True,
|
||||
cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
|
||||
)
|
||||
|
||||
self.assertIsNotNone(outputs)
|
||||
self.assertIsNotNone(outputs.last_hidden_state)
|
||||
self.assertEqual(
|
||||
outputs.last_hidden_state.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class MambaIntegrationTests(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user