Mamba2 remove unecessary test parameterization (#38227)

This commit is contained in:
ivarflakstad 2025-05-20 15:54:04 +02:00 committed by GitHub
parent 9cde2f5d42
commit 3f0b7d0fac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,8 +15,6 @@
import unittest
from parameterized import parameterized
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
from transformers.testing_utils import (
Expectations,
@ -362,14 +360,9 @@ class Mamba2IntegrationTest(unittest.TestCase):
self.prompt = ("[INST]Write a hello world program in C++.",)
@require_read_token
@parameterized.expand(
[
(torch_device,),
]
)
@slow
@require_torch
def test_simple_generate(self, device):
def test_simple_generate(self):
"""
Simple generate test to avoid regressions.
Note: state-spaces (cuda) implementation and pure torch implementation
@ -380,9 +373,9 @@ class Mamba2IntegrationTest(unittest.TestCase):
tokenizer.pad_token_id = tokenizer.eos_token_id
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
model.to(device)
model.to(torch_device)
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
device
torch_device
)
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)