From ec1424c6a3cc91a3fa2570bbeb9c4431072b873b Mon Sep 17 00:00:00 2001 From: Lake Lee <101966044+klae01@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:40:57 +0900 Subject: [PATCH] Update modeling_mamba2.py, fix pad size (#32599) * Update modeling_mamba2.py Fix pad_size calculation to ensure it's less than self.chunk_size * [run_slow] mamba2 * [run-slow] mamba2 * [run-slow] Add @require_read_token decorator to failing tests for token propagation * [run_slow] mamba2 --- src/transformers/models/mamba2/modeling_mamba2.py | 2 +- tests/models/mamba2/test_modeling_mamba2.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 7b414ff9570..19d53437130 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -510,7 +510,7 @@ class Mamba2Mixer(nn.Module): C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) - pad_size = self.chunk_size - (seq_len % self.chunk_size) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 55c18abe6b9..f19358a22f4 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -291,6 +291,7 @@ class Mamba2IntegrationTest(unittest.TestCase): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False) self.prompt = ("[INST]Write a hello world program in C++.",) + @require_read_token @parameterized.expand( [ (torch_device,), @@ -319,6 +320,7 @@ class Mamba2IntegrationTest(unittest.TestCase): ground_truth_sentence = """[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include \n\n""" self.assertEqual(output_sentence, ground_truth_sentence) + @require_read_token @slow @require_torch_gpu def test_batched_equivalence_with_cache(self): @@ -349,6 +351,7 @@ class Mamba2IntegrationTest(unittest.TestCase): individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] self.assertEqual(individual_output[:100], batched_output[index_gen][:100]) + @require_read_token @slow @require_torch_gpu def test_batched_equivalence_without_cache(self):