mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update summarization run_pipeline_test
(#20623)
* update summarization run_pipeline_test * update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
3e4c9e5c64
commit
cec5f7abd1
@ -17,11 +17,7 @@ import unittest
|
||||
from transformers import (
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
LEDConfig,
|
||||
LongT5Config,
|
||||
SummarizationPipeline,
|
||||
SwitchTransformersConfig,
|
||||
T5Config,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
|
||||
@ -55,7 +51,17 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
|
||||
)
|
||||
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
|
||||
|
||||
if not isinstance(model.config, (SwitchTransformersConfig, T5Config, LongT5Config, LEDConfig)):
|
||||
model_can_handle_longer_seq = [
|
||||
"SwitchTransformersConfig",
|
||||
"T5Config",
|
||||
"LongT5Config",
|
||||
"LEDConfig",
|
||||
"PegasusXConfig",
|
||||
"FSMTConfig",
|
||||
"M2M100Config",
|
||||
"ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values)
|
||||
]
|
||||
if model.config.__class__.__name__ not in model_can_handle_longer_seq:
|
||||
# Switch Transformers, LED, T5, LongT5 can handle it.
|
||||
# Too long.
|
||||
with self.assertRaises(Exception):
|
||||
|
Loading…
Reference in New Issue
Block a user