mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Update modeling_mamba2.py, fix pad size (#32599)
* Update modeling_mamba2.py Fix pad_size calculation to ensure it's less than self.chunk_size * [run_slow] mamba2 * [run-slow] mamba2 * [run-slow] Add @require_read_token decorator to failing tests for token propagation * [run_slow] mamba2
This commit is contained in:
parent
8bd1f2f338
commit
ec1424c6a3
@ -510,7 +510,7 @@ class Mamba2Mixer(nn.Module):
|
||||
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
||||
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
|
||||
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
|
||||
pad_size = self.chunk_size - (seq_len % self.chunk_size)
|
||||
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
|
||||
|
||||
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
|
||||
|
||||
|
@ -291,6 +291,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||
|
||||
@require_read_token
|
||||
@parameterized.expand(
|
||||
[
|
||||
(torch_device,),
|
||||
@ -319,6 +320,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
|
||||
self.assertEqual(output_sentence, ground_truth_sentence)
|
||||
|
||||
@require_read_token
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_batched_equivalence_with_cache(self):
|
||||
@ -349,6 +351,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
||||
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||
|
||||
@require_read_token
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_batched_equivalence_without_cache(self):
|
||||
|
Loading…
Reference in New Issue
Block a user