transformers/tests/test_pipelines_dialog.py
Thomas Wolf 3a40cdf58d
[tests|tokenizers] Refactoring pipelines test backbone - Small tokenizers improvements - General tests speedups (#7970)
* WIP refactoring pipeline tests - switching to fast tokenizers

* fix dialog pipeline and fill-mask

* refactoring pipeline tests backbone

* make large tests slow

* fix tests (tf Bart inactive for now)

* fix doc...

* clean up for merge

* fixing tests - remove bart from summarization until there is TF

* fix quality and RAG

* Add new translation pipeline tests - fix JAX tests

* only slow for dialog

* Fixing the missing TF-BART imports in modeling_tf_auto

* spin out pipeline tests in separate CI job

* adding pipeline test to CI YAML

* add slow pipeline tests

* speed up tf and pt join test to avoid redoing all the standalone pt and tf tests

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

* Update src/transformers/pipelines.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/testing_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add require_torch and require_tf in is_pt_tf_cross_test

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
2020-10-23 15:58:19 +02:00

30 lines
1.2 KiB
Python

import unittest
from transformers.pipelines import Conversation, Pipeline
from .test_pipelines_common import CustomInputPipelineCommonMixin
class DialoguePipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "conversational"
small_models = [] # Default model - Models tested without the @slow decorator
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
def _test_pipeline(self, nlp: Pipeline):
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
invalid_inputs = ["Hi there!", Conversation()]
self.assertIsNotNone(nlp)
mono_result = nlp(valid_inputs[0])
self.assertIsInstance(mono_result, Conversation)
multi_result = nlp(valid_inputs[1])
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], Conversation)
# Inactive conversations passed to the pipeline raise a ValueError
self.assertRaises(ValueError, nlp, valid_inputs[1])
for bad_input in invalid_inputs:
self.assertRaises(Exception, nlp, bad_input)
self.assertRaises(Exception, nlp, invalid_inputs)