diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index c4c646cee96..eb688d69c70 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -17,11 +17,7 @@ import unittest from transformers import ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - LEDConfig, - LongT5Config, SummarizationPipeline, - SwitchTransformersConfig, - T5Config, pipeline, ) from transformers.testing_utils import require_tf, require_torch, slow, torch_device @@ -55,7 +51,17 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe ) self.assertEqual(outputs, [{"summary_text": ANY(str)}]) - if not isinstance(model.config, (SwitchTransformersConfig, T5Config, LongT5Config, LEDConfig)): + model_can_handle_longer_seq = [ + "SwitchTransformersConfig", + "T5Config", + "LongT5Config", + "LEDConfig", + "PegasusXConfig", + "FSMTConfig", + "M2M100Config", + "ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values) + ] + if model.config.__class__.__name__ not in model_can_handle_longer_seq: # Switch Transformers, LED, T5, LongT5 can handle it. # Too long. with self.assertRaises(Exception):