mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Change a logic in pipeline test regarding TF (#20710)
* Fix the pipeline test regarding TF * Fix the pipeline test regarding TF * update comment Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
1af4bee896
commit
a12c5cbcd8
@ -18,9 +18,10 @@ from transformers import (
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
SummarizationPipeline,
|
||||
TFPreTrainedModel,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import get_gpu_count, require_tf, require_torch, slow, torch_device
|
||||
from transformers.tokenization_utils import TruncationStrategy
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
@ -51,6 +52,7 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
|
||||
)
|
||||
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
|
||||
|
||||
# Some models (Switch Transformers, LED, T5, LongT5, etc) can handle long sequences.
|
||||
model_can_handle_longer_seq = [
|
||||
"SwitchTransformersConfig",
|
||||
"T5Config",
|
||||
@ -62,10 +64,16 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
|
||||
"ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values)
|
||||
]
|
||||
if model.config.__class__.__name__ not in model_can_handle_longer_seq:
|
||||
# Switch Transformers, LED, T5, LongT5 can handle it.
|
||||
# Too long.
|
||||
with self.assertRaises(Exception):
|
||||
outputs = summarizer("This " * 1000)
|
||||
# Too long and exception is expected.
|
||||
# For TF models, if the weights are initialized in GPU context, we won't get expected index error from
|
||||
# the embedding layer.
|
||||
if not (
|
||||
isinstance(model, TFPreTrainedModel)
|
||||
and get_gpu_count() > 0
|
||||
and len(summarizer.model.trainable_weights) > 0
|
||||
):
|
||||
with self.assertRaises(Exception):
|
||||
outputs = summarizer("This " * 1000)
|
||||
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)
|
||||
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user