Update modeling_mamba2.py, fix pad size (#32599)

* Update modeling_mamba2.py

Fix pad_size calculation to ensure it's less than self.chunk_size

* [run_slow] mamba2

* [run-slow] mamba2

* [run-slow] Add @require_read_token decorator to failing tests for token propagation

* [run_slow] mamba2
This commit is contained in:
Lake Lee 2024-09-20 19:40:57 +09:00 committed by GitHub
parent 8bd1f2f338
commit ec1424c6a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 1 deletions

View File

@ -510,7 +510,7 @@ class Mamba2Mixer(nn.Module):
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)

View File

@ -291,6 +291,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
self.prompt = ("[INST]Write a hello world program in C++.",)
@require_read_token
@parameterized.expand(
[
(torch_device,),
@ -319,6 +320,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
self.assertEqual(output_sentence, ground_truth_sentence)
@require_read_token
@slow
@require_torch_gpu
def test_batched_equivalence_with_cache(self):
@ -349,6 +351,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
@require_read_token
@slow
@require_torch_gpu
def test_batched_equivalence_without_cache(self):