mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
[tests] add require_torch_sdpa
for test that needs sdpa support (#30408)
* add cuda flag * check for sdpa * add bitsandbytes
This commit is contained in:
parent
04ac3245e4
commit
2d61823fa2
@ -374,6 +374,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
@slow
|
@slow
|
||||||
class CohereIntegrationTest(unittest.TestCase):
|
class CohereIntegrationTest(unittest.TestCase):
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
|
@require_bitsandbytes
|
||||||
def test_batched_4bit(self):
|
def test_batched_4bit(self):
|
||||||
model_id = "CohereForAI/c4ai-command-r-v01-4bit"
|
model_id = "CohereForAI/c4ai-command-r-v01-4bit"
|
||||||
|
|
||||||
@ -393,6 +394,7 @@ class CohereIntegrationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
||||||
self.assertEqual(tokenizer.batch_decode(output, skip_special_tokens=True), EXPECTED_TEXT)
|
self.assertEqual(tokenizer.batch_decode(output, skip_special_tokens=True), EXPECTED_TEXT)
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
def test_batched_small_model_logits(self):
|
def test_batched_small_model_logits(self):
|
||||||
# Since the model is very large, we created a random cohere model so that we can do a simple
|
# Since the model is very large, we created a random cohere model so that we can do a simple
|
||||||
# logits check on it.
|
# logits check on it.
|
||||||
|
Loading…
Reference in New Issue
Block a user