transformers/tests/test_pipelines_common.py
Nicolas Patry 84caa23301
Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)
* Some work to fix the behaviour of DefaultArgumentHandler by removing it.

* Fixing specific pipelines argument checking.
2020-11-02 12:33:50 +01:00

275 lines
9.5 KiB
Python

from typing import List, Optional
from transformers import is_tf_available, is_torch_available, pipeline
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
VALID_INPUTS = ["A simple string", ["list of strings"]]
@is_pipeline_test
class CustomInputPipelineCommonMixin:
pipeline_task = None
pipeline_loading_kwargs = {}
small_models = None # Models tested without the @slow decorator
large_models = None # Models tested with the @slow decorator
def setUp(self) -> None:
if not is_tf_available() and not is_torch_available():
return # Currently no JAX pipelines
# Download needed checkpoints
models = self.small_models
if _run_slow_tests:
models = models + self.large_models
for model_name in models:
if is_torch_available():
pipeline(
self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
if is_tf_available():
pipeline(
self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
@require_torch
@slow
def test_pt_defaults(self):
pipeline(self.pipeline_task, framework="pt")
@require_tf
@slow
def test_tf_defaults(self):
pipeline(self.pipeline_task, framework="tf")
@require_torch
def test_torch_small(self):
for model_name in self.small_models:
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt")
self._test_pipeline(nlp)
@require_tf
def test_tf_small(self):
for model_name in self.small_models:
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf")
self._test_pipeline(nlp)
@require_torch
@slow
def test_torch_large(self):
for model_name in self.large_models:
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt")
self._test_pipeline(nlp)
@require_tf
@slow
def test_tf_large(self):
for model_name in self.large_models:
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf")
self._test_pipeline(nlp)
def _test_pipeline(self, nlp: Pipeline):
raise NotImplementedError
@is_pipeline_test
class MonoInputPipelineCommonMixin:
pipeline_task = None
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
small_models = [] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator
mandatory_keys = {} # Keys which should be in the output
valid_inputs = VALID_INPUTS # inputs which are valid
invalid_inputs = [None] # inputs which are not allowed
expected_multi_result: Optional[List] = None
expected_check_keys: Optional[List[str]] = None
def setUp(self) -> None:
if not is_tf_available() and not is_torch_available():
return # Currently no JAX pipelines
for model_name in self.small_models:
pipeline(self.pipeline_task, model=model_name, tokenizer=model_name, **self.pipeline_loading_kwargs)
for model_name in self.large_models:
pipeline(self.pipeline_task, model=model_name, tokenizer=model_name, **self.pipeline_loading_kwargs)
@require_torch
@slow
def test_pt_defaults_loads(self):
pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs)
@require_tf
@slow
def test_tf_defaults_loads(self):
pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs)
@require_torch
def test_torch_small(self):
for model_name in self.small_models:
nlp = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(nlp)
@require_tf
def test_tf_small(self):
for model_name in self.small_models:
nlp = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(nlp)
@require_torch
@slow
def test_torch_large(self):
for model_name in self.large_models:
nlp = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(nlp)
@require_tf
@slow
def test_tf_large(self):
for model_name in self.large_models:
nlp = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(nlp)
def _test_pipeline(self, nlp: Pipeline):
self.assertIsNotNone(nlp)
mono_result = nlp(self.valid_inputs[0], **self.pipeline_running_kwargs)
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], (dict, list))
if isinstance(mono_result[0], list):
mono_result = mono_result[0]
for key in self.mandatory_keys:
self.assertIn(key, mono_result[0])
multi_result = [nlp(input, **self.pipeline_running_kwargs) for input in self.valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
if self.expected_multi_result is not None:
for result, expect in zip(multi_result, self.expected_multi_result):
for key in self.expected_check_keys or []:
self.assertEqual(
set([o[key] for o in result]),
set([o[key] for o in expect]),
)
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in self.mandatory_keys:
self.assertIn(key, result)
self.assertRaises(Exception, nlp, self.invalid_inputs)
# @is_pipeline_test
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
# def setUp(self) -> None:
# self.handler = DefaultArgumentHandler()
#
# def test_kwargs_x(self):
# mono_data = {"X": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_kwargs_data(self):
# mono_data = {"data": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_multi_kwargs(self):
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 2)
#
# multi_data = {
# "data": ["This is a sample input", "This is a second sample input"],
# "test": ["This is a sample input 2", "This is a second sample input 2"],
# }
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 4)
#
# def test_args(self):
# mono_data = "This is a sample input"
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# mono_data = ["This is a sample input"]
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(*multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)