Avoid check expected exception when it is on CUDA (#34408)

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-10-25 17:14:07 +02:00 committed by GitHub
parent e447185b1f
commit f73f5e62e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 10 deletions

View File

@ -85,8 +85,9 @@ class SummarizationPipelineTests(unittest.TestCase):
and len(summarizer.model.trainable_weights) > 0
and "GPU" in summarizer.model.trainable_weights[0].device
):
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
if str(summarizer.device) == "cpu":
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)
@require_torch

View File

@ -493,17 +493,19 @@ class TextGenerationPipelineTests(unittest.TestCase):
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
):
# Handling of large generations
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
if str(text_generator.device) == "cpu":
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
# Hole strategy cannot work
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)
if str(text_generator.device) == "cpu":
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)
@require_torch
@require_accelerate