mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update tests
This commit is contained in:
parent
c1e9a76641
commit
506633f5ed
@ -19,7 +19,6 @@ from transformers import AutoTokenizer, ModernBertDecoderConfig, is_torch_availa
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
@ -131,9 +130,7 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
|
||||
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
|
||||
model = ModernBertDecoderForCausalLM.from_pretrained(
|
||||
"blab-jhu/test-32m-dec", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
|
||||
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||
|
||||
# Create a longer input to test sliding window attention
|
||||
@ -153,9 +150,7 @@ class ModernBertDecoderIntegrationTest(unittest.TestCase):
|
||||
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
|
||||
model = ModernBertDecoderForSequenceClassification.from_pretrained(
|
||||
"blab-jhu/test-32m-dec", num_labels=2, device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
model = ModernBertDecoderForSequenceClassification.from_pretrained("blab-jhu/test-32m-dec", num_labels=2)
|
||||
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||
|
||||
# Test with sample input
|
||||
|
Loading…
Reference in New Issue
Block a user