Mamba slow_forward gradient fix (#29563)

* FIX: Cached slow forward in mamba
- additionally added mamba cached test
- added unused test (mamba causal lm forward and backward)
- fixed typo: "causl" --> "causal"

* formatting

* fix: use real `slow_forward` call instead of torch module's

* add shape assertion for mixer block test

* adjust shape assertion
This commit is contained in:
Anton Vlasjuk 2024-03-27 04:52:12 +01:00 committed by GitHub
parent 1c39974a4c
commit cefb819f7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 4 deletions

View File

@ -230,7 +230,7 @@ class MambaMixer(nn.Module):
# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx]
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
if cache_params.seqlen_offset > 0:
conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)

View File

@ -170,7 +170,7 @@ class MambaModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1)
def create_and_check_causl_lm(self, config, input_ids, *args):
def create_and_check_causal_lm(self, config, input_ids, *args):
model = MambaForCausalLM(config)
model.to(torch_device)
model.eval()
@ -197,7 +197,30 @@ class MambaModelTester:
self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
# TODO the orignal mamba does not support decoding more than 1 token neither do we
def create_and_check_forward_and_backwards(self, config, input_ids, *args, gradient_checkpointing=False):
def create_and_check_mamba_cached_slow_forward_and_backwards(
self, config, input_ids, *args, gradient_checkpointing=False
):
model = MambaModel(config)
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
# create cache
cache = model(input_ids, use_cache=True).cache_params
cache.seqlen_offset = 0
# use cache
token_emb = model.embeddings(input_ids)
outputs = model.layers[0].mixer.slow_forward(token_emb, cache)
loss = torch.log(1 + torch.abs(outputs.sum()))
self.parent.assertEqual(loss.shape, ())
self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size))
loss.backward()
def create_and_check_mamba_lm_head_forward_and_backwards(
self, config, input_ids, *args, gradient_checkpointing=False
):
model = MambaForCausalLM(config)
model.to(torch_device)
if gradient_checkpointing:
@ -304,12 +327,20 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_mamba_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causl_lm(*config_and_inputs)
self.model_tester.create_and_check_causal_lm(*config_and_inputs)
def test_state_equivalency(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_state_equivalency(*config_and_inputs)
def test_mamba_cached_slow_forward_and_backwards(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba_cached_slow_forward_and_backwards(*config_and_inputs)
def test_mamba_lm_head_forward_and_backwards(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba_lm_head_forward_and_backwards(*config_and_inputs)
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()