mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Mamba2 remove unecessary test parameterization (#38227)
This commit is contained in:
parent
9cde2f5d42
commit
3f0b7d0fac
@ -15,8 +15,6 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
@ -362,14 +360,9 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||
|
||||
@require_read_token
|
||||
@parameterized.expand(
|
||||
[
|
||||
(torch_device,),
|
||||
]
|
||||
)
|
||||
@slow
|
||||
@require_torch
|
||||
def test_simple_generate(self, device):
|
||||
def test_simple_generate(self):
|
||||
"""
|
||||
Simple generate test to avoid regressions.
|
||||
Note: state-spaces (cuda) implementation and pure torch implementation
|
||||
@ -380,9 +373,9 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
||||
model.to(device)
|
||||
model.to(torch_device)
|
||||
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||
device
|
||||
torch_device
|
||||
)
|
||||
|
||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||
|
Loading…
Reference in New Issue
Block a user