mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Update modeling_mamba2.py, fix pad size (#32599)
* Update modeling_mamba2.py Fix pad_size calculation to ensure it's less than self.chunk_size * [run_slow] mamba2 * [run-slow] mamba2 * [run-slow] Add @require_read_token decorator to failing tests for token propagation * [run_slow] mamba2
This commit is contained in:
parent
8bd1f2f338
commit
ec1424c6a3
@ -510,7 +510,7 @@ class Mamba2Mixer(nn.Module):
|
|||||||
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
||||||
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
|
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
|
||||||
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
|
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
|
||||||
pad_size = self.chunk_size - (seq_len % self.chunk_size)
|
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
|
||||||
|
|
||||||
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
|
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
|
||||||
|
|
||||||
|
@ -291,6 +291,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
||||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
[
|
[
|
||||||
(torch_device,),
|
(torch_device,),
|
||||||
@ -319,6 +320,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
|
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
|
||||||
self.assertEqual(output_sentence, ground_truth_sentence)
|
self.assertEqual(output_sentence, ground_truth_sentence)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_batched_equivalence_with_cache(self):
|
def test_batched_equivalence_with_cache(self):
|
||||||
@ -349,6 +351,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
||||||
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_batched_equivalence_without_cache(self):
|
def test_batched_equivalence_without_cache(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user