update tests

This commit is contained in:
oweller2 2025-06-21 23:48:58 -04:00 committed by Orion Weller
parent c1e9a76641
commit 506633f5ed

View File

@ -19,7 +19,6 @@ from transformers import AutoTokenizer, ModernBertDecoderConfig, is_torch_availa
from transformers.testing_utils import (
require_torch,
slow,
torch_device,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
@ -131,9 +130,7 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = ModernBertDecoderForCausalLM.from_pretrained(
"blab-jhu/test-32m-dec", device_map=torch_device, torch_dtype=torch.float16
)
model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
# Create a longer input to test sliding window attention
@ -153,9 +150,7 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = ModernBertDecoderForSequenceClassification.from_pretrained(
"blab-jhu/test-32m-dec", num_labels=2, device_map=torch_device, torch_dtype=torch.float16
)
model = ModernBertDecoderForSequenceClassification.from_pretrained("blab-jhu/test-32m-dec", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
# Test with sample input