diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index b5199d324d6..d8b880906f7 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -735,10 +735,8 @@ class Pipeline(_ScikitCompat): supported_models_names.append(model.__name__) supported_models = supported_models_names if self.model.__class__.__name__ not in supported_models: - raise PipelineException( - self.task, - self.model.base_model_prefix, - f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}", + logger.error( + f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}." ) def _parse_and_tokenize( diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 1f98d374795..291b7c56445 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,3 +1,5 @@ +from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING + from ..file_utils import add_end_docstrings from .base import PIPELINE_INIT_ARGS, Pipeline @@ -30,25 +32,12 @@ class TextGenerationPipeline(Pipeline): begging for his blessing. """ - ALLOWED_MODELS = [ - "XLNetLMHeadModel", - "TransfoXLLMHeadModel", - "ReformerModelWithLMHead", - "GPT2LMHeadModel", - "GPTNeoForCausalLM", - "OpenAIGPTLMHeadModel", - "CTRLLMHeadModel", - "TFXLNetLMHeadModel", - "TFTransfoXLLMHeadModel", - "TFGPT2LMHeadModel", - "TFOpenAIGPTLMHeadModel", - "TFCTRLLMHeadModel", - ] - def __init__(self, *args, return_full_text=True, **kwargs): super().__init__(*args, **kwargs) + self.check_model_type( + TF_MODEL_FOR_CAUSAL_LM_MAPPING if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING + ) - self.check_model_type(self.ALLOWED_MODELS) self.return_full_text = return_full_text # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments @@ -124,6 +113,9 @@ class TextGenerationPipeline(Pipeline): prefix_length = prefix_inputs["input_ids"].shape[-1] if generate_kwargs.get("max_length", None) is not None: generate_kwargs["max_length"] += prefix_length + else: + generate_kwargs["max_length"] = self.model.config.max_length + prefix_length + if generate_kwargs.get("min_length", None) is not None: generate_kwargs["min_length"] += prefix_length diff --git a/tests/test_pipelines_text_generation.py b/tests/test_pipelines_text_generation.py index 1a2d77b55e5..22bc8bf42cd 100644 --- a/tests/test_pipelines_text_generation.py +++ b/tests/test_pipelines_text_generation.py @@ -14,49 +14,95 @@ import unittest -from transformers import pipeline -from transformers.testing_utils import require_torch +from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline +from transformers.testing_utils import is_pipeline_test, require_tf, require_torch -from .test_pipelines_common import MonoInputPipelineCommonMixin +from .test_pipelines_common import ANY, PipelineTestCaseMeta -class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): - pipeline_task = "text-generation" - pipeline_running_kwargs = {"prefix": "This is "} - small_models = ["sshleifer/tiny-ctrl"] # Models tested without the @slow decorator - large_models = [] # Models tested with the @slow decorator - - def test_simple_generation(self): - text_generator = pipeline(task="text-generation", model=self.small_models[0]) - # text-generation is non-deterministic by nature, we can't fully test the output - - outputs = text_generator("This is a test") - - self.assertEqual(len(outputs), 1) - self.assertEqual(list(outputs[0].keys()), ["generated_text"]) - self.assertEqual(type(outputs[0]["generated_text"]), str) - - outputs = text_generator(["This is a test", "This is a second test"]) - self.assertEqual(len(outputs[0]), 1) - self.assertEqual(list(outputs[0][0].keys()), ["generated_text"]) - self.assertEqual(type(outputs[0][0]["generated_text"]), str) - self.assertEqual(list(outputs[1][0].keys()), ["generated_text"]) - self.assertEqual(type(outputs[1][0]["generated_text"]), str) +@is_pipeline_test +class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING + tf_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING @require_torch - def test_generation_output_style(self): - text_generator = pipeline(task="text-generation", model=self.small_models[0]) - # text-generation is non-deterministic by nature, we can't fully test the output + def test_small_model_pt(self): + text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="pt") + # Using `do_sample=False` to force deterministic output + outputs = text_generator("This is a test", do_sample=False) + self.assertEqual( + outputs, + [ + { + "generated_text": "This is a test ☃ ☃ segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@" + } + ], + ) + outputs = text_generator(["This is a test", "This is a second test"]) + self.assertEqual( + outputs, + [ + [ + { + "generated_text": "This is a test ☃ ☃ segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@" + } + ], + [ + { + "generated_text": "This is a second test ☃ segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@" + } + ], + ], + ) + + @require_tf + def test_small_model_tf(self): + text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf") + + # Using `do_sample=False` to force deterministic output + outputs = text_generator("This is a test", do_sample=False) + self.assertEqual( + outputs, + [ + { + "generated_text": "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes 閲閲Cannes Cannes Cannes 攵 please," + } + ], + ) + + outputs = text_generator(["This is a test", "This is a second test"], do_sample=False) + self.assertEqual( + outputs, + [ + [ + { + "generated_text": "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes 閲閲Cannes Cannes Cannes 攵 please," + } + ], + [ + { + "generated_text": "This is a second test Chieftain Chieftain prefecture prefecture prefecture Cannes Cannes Cannes 閲閲Cannes Cannes Cannes 攵 please," + } + ], + ], + ) + + def run_pipeline_test(self, model, tokenizer, feature_extractor): + text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer) outputs = text_generator("This is a test") - self.assertIn("This is a test", outputs[0]["generated_text"]) + self.assertEqual(outputs, [{"generated_text": ANY(str)}]) + self.assertTrue(outputs[0]["generated_text"].startswith("This is a test")) outputs = text_generator("This is a test", return_full_text=False) + self.assertEqual(outputs, [{"generated_text": ANY(str)}]) self.assertNotIn("This is a test", outputs[0]["generated_text"]) - text_generator = pipeline(task="text-generation", model=self.small_models[0], return_full_text=False) + text_generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer, return_full_text=False) outputs = text_generator("This is a test") + self.assertEqual(outputs, [{"generated_text": ANY(str)}]) self.assertNotIn("This is a test", outputs[0]["generated_text"]) outputs = text_generator("This is a test", return_full_text=True) - self.assertIn("This is a test", outputs[0]["generated_text"]) + self.assertEqual(outputs, [{"generated_text": ANY(str)}]) + self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))