diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 7b414ff9570..19d53437130 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -510,7 +510,7 @@ class Mamba2Mixer(nn.Module): C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) - pad_size = self.chunk_size - (seq_len % self.chunk_size) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 55c18abe6b9..f19358a22f4 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -291,6 +291,7 @@ class Mamba2IntegrationTest(unittest.TestCase): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False) self.prompt = ("[INST]Write a hello world program in C++.",) + @require_read_token @parameterized.expand( [ (torch_device,), @@ -319,6 +320,7 @@ class Mamba2IntegrationTest(unittest.TestCase): ground_truth_sentence = """[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include \n\n""" self.assertEqual(output_sentence, ground_truth_sentence) + @require_read_token @slow @require_torch_gpu def test_batched_equivalence_with_cache(self): @@ -349,6 +351,7 @@ class Mamba2IntegrationTest(unittest.TestCase): individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] self.assertEqual(individual_output[:100], batched_output[index_gen][:100]) + @require_read_token @slow @require_torch_gpu def test_batched_equivalence_without_cache(self):