# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional from unittest import mock from transformers import is_tf_available, is_torch_available, pipeline from transformers.file_utils import to_py_obj from transformers.pipelines import Pipeline from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow VALID_INPUTS = ["A simple string", ["list of strings"]] @is_pipeline_test class CustomInputPipelineCommonMixin: pipeline_task = None pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with small_models = [] # Models tested without the @slow decorator large_models = [] # Models tested with the @slow decorator valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers def setUp(self) -> None: if not is_tf_available() and not is_torch_available(): return # Currently no JAX pipelines # Download needed checkpoints models = self.small_models if _run_slow_tests: models = models + self.large_models for model_name in models: if is_torch_available(): pipeline( self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt", **self.pipeline_loading_kwargs, ) if is_tf_available(): pipeline( self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf", **self.pipeline_loading_kwargs, ) @require_torch @slow def test_pt_defaults(self): pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs) @require_tf @slow def test_tf_defaults(self): pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs) @require_torch def test_torch_small(self): for model_name in self.small_models: pipe_small = pipeline( task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt", **self.pipeline_loading_kwargs, ) self._test_pipeline(pipe_small) @require_tf def test_tf_small(self): for model_name in self.small_models: pipe_small = pipeline( task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf", **self.pipeline_loading_kwargs, ) self._test_pipeline(pipe_small) @require_torch @slow def test_torch_large(self): for model_name in self.large_models: pipe_large = pipeline( task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt", **self.pipeline_loading_kwargs, ) self._test_pipeline(pipe_large) @require_tf @slow def test_tf_large(self): for model_name in self.large_models: pipe_large = pipeline( task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf", **self.pipeline_loading_kwargs, ) self._test_pipeline(pipe_large) def _test_pipeline(self, pipe: Pipeline): raise NotImplementedError @require_torch def test_compare_slow_fast_torch(self): for model_name in self.small_models: pipe_slow = pipeline( task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt", use_fast=False, **self.pipeline_loading_kwargs, ) pipe_fast = pipeline( task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt", use_fast=True, **self.pipeline_loading_kwargs, ) self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="forward") @require_tf def test_compare_slow_fast_tf(self): for model_name in self.small_models: pipe_slow = pipeline( task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf", use_fast=False, **self.pipeline_loading_kwargs, ) pipe_fast = pipeline( task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf", use_fast=True, **self.pipeline_loading_kwargs, ) self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="call") def _compare_slow_fast_pipelines(self, pipe_slow: Pipeline, pipe_fast: Pipeline, method: str): """We check that the inputs to the models forward passes are identical for slow and fast tokenizers. """ with mock.patch.object( pipe_slow.model, method, wraps=getattr(pipe_slow.model, method) ) as mock_slow, mock.patch.object( pipe_fast.model, method, wraps=getattr(pipe_fast.model, method) ) as mock_fast: for inputs in self.valid_inputs: if isinstance(inputs, dict): inputs.update(self.pipeline_running_kwargs) _ = pipe_slow(**inputs) _ = pipe_fast(**inputs) else: _ = pipe_slow(inputs, **self.pipeline_running_kwargs) _ = pipe_fast(inputs, **self.pipeline_running_kwargs) mock_slow.assert_called() mock_fast.assert_called() self.assertEqual(len(mock_slow.call_args_list), len(mock_fast.call_args_list)) for mock_slow_call_args, mock_fast_call_args in zip( mock_slow.call_args_list, mock_slow.call_args_list ): slow_call_args, slow_call_kwargs = mock_slow_call_args fast_call_args, fast_call_kwargs = mock_fast_call_args slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs) fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs) self.assertEqual(slow_call_args, fast_call_args) self.assertDictEqual(slow_call_kwargs, fast_call_kwargs) @is_pipeline_test class MonoInputPipelineCommonMixin(CustomInputPipelineCommonMixin): """A version of the CustomInputPipelineCommonMixin with a predefined `_test_pipeline` method. """ mandatory_keys = {} # Keys which should be in the output invalid_inputs = [None] # inputs which are not allowed expected_multi_result: Optional[List] = None expected_check_keys: Optional[List[str]] = None def _test_pipeline(self, pipe: Pipeline): self.assertIsNotNone(pipe) mono_result = pipe(self.valid_inputs[0], **self.pipeline_running_kwargs) self.assertIsInstance(mono_result, list) self.assertIsInstance(mono_result[0], (dict, list)) if isinstance(mono_result[0], list): mono_result = mono_result[0] for key in self.mandatory_keys: self.assertIn(key, mono_result[0]) multi_result = [pipe(input, **self.pipeline_running_kwargs) for input in self.valid_inputs] self.assertIsInstance(multi_result, list) self.assertIsInstance(multi_result[0], (dict, list)) if self.expected_multi_result is not None: for result, expect in zip(multi_result, self.expected_multi_result): for key in self.expected_check_keys or []: self.assertEqual( set([o[key] for o in result]), set([o[key] for o in expect]), ) if isinstance(multi_result[0], list): multi_result = multi_result[0] for result in multi_result: for key in self.mandatory_keys: self.assertIn(key, result) self.assertRaises(Exception, pipe, self.invalid_inputs)