Avoid check expected exception when it is on CUDA (#34408)

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-10-25 17:14:07 +02:00 committed by GitHub
parent e447185b1f
commit f73f5e62e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 10 deletions

View File

@ -85,8 +85,9 @@ class SummarizationPipelineTests(unittest.TestCase):
and len(summarizer.model.trainable_weights) > 0 and len(summarizer.model.trainable_weights) > 0
and "GPU" in summarizer.model.trainable_weights[0].device and "GPU" in summarizer.model.trainable_weights[0].device
): ):
with self.assertRaises(Exception): if str(summarizer.device) == "cpu":
outputs = summarizer("This " * 1000) with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST) outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)
@require_torch @require_torch

View File

@ -493,17 +493,19 @@ class TextGenerationPipelineTests(unittest.TestCase):
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
): ):
# Handling of large generations # Handling of large generations
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)): if str(text_generator.device) == "cpu":
text_generator("This is a test" * 500, max_new_tokens=20) with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20) outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
# Hole strategy cannot work # Hole strategy cannot work
with self.assertRaises(ValueError): if str(text_generator.device) == "cpu":
text_generator( with self.assertRaises(ValueError):
"This is a test" * 500, text_generator(
handle_long_generation="hole", "This is a test" * 500,
max_new_tokens=tokenizer.model_max_length + 10, handle_long_generation="hole",
) max_new_tokens=tokenizer.model_max_length + 10,
)
@require_torch @require_torch
@require_accelerate @require_accelerate