mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Mamba slow_forward
gradient fix (#29563)
* FIX: Cached slow forward in mamba - additionally added mamba cached test - added unused test (mamba causal lm forward and backward) - fixed typo: "causl" --> "causal" * formatting * fix: use real `slow_forward` call instead of torch module's * add shape assertion for mixer block test * adjust shape assertion
This commit is contained in:
parent
1c39974a4c
commit
cefb819f7a
@ -230,7 +230,7 @@ class MambaMixer(nn.Module):
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
if cache_params is not None:
|
||||
ssm_state = cache_params.ssm_states[self.layer_idx]
|
||||
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
||||
if cache_params.seqlen_offset > 0:
|
||||
conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
|
||||
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
|
||||
|
@ -170,7 +170,7 @@ class MambaModelTester:
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1)
|
||||
|
||||
def create_and_check_causl_lm(self, config, input_ids, *args):
|
||||
def create_and_check_causal_lm(self, config, input_ids, *args):
|
||||
model = MambaForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -197,7 +197,30 @@ class MambaModelTester:
|
||||
self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
|
||||
# TODO the orignal mamba does not support decoding more than 1 token neither do we
|
||||
|
||||
def create_and_check_forward_and_backwards(self, config, input_ids, *args, gradient_checkpointing=False):
|
||||
def create_and_check_mamba_cached_slow_forward_and_backwards(
|
||||
self, config, input_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = MambaModel(config)
|
||||
model.to(torch_device)
|
||||
if gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# create cache
|
||||
cache = model(input_ids, use_cache=True).cache_params
|
||||
cache.seqlen_offset = 0
|
||||
|
||||
# use cache
|
||||
token_emb = model.embeddings(input_ids)
|
||||
outputs = model.layers[0].mixer.slow_forward(token_emb, cache)
|
||||
|
||||
loss = torch.log(1 + torch.abs(outputs.sum()))
|
||||
self.parent.assertEqual(loss.shape, ())
|
||||
self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
loss.backward()
|
||||
|
||||
def create_and_check_mamba_lm_head_forward_and_backwards(
|
||||
self, config, input_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = MambaForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
if gradient_checkpointing:
|
||||
@ -304,12 +327,20 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
def test_mamba_lm_head_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_causl_lm(*config_and_inputs)
|
||||
self.model_tester.create_and_check_causal_lm(*config_and_inputs)
|
||||
|
||||
def test_state_equivalency(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_state_equivalency(*config_and_inputs)
|
||||
|
||||
def test_mamba_cached_slow_forward_and_backwards(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mamba_cached_slow_forward_and_backwards(*config_and_inputs)
|
||||
|
||||
def test_mamba_lm_head_forward_and_backwards(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mamba_lm_head_forward_and_backwards(*config_and_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user