mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Moving zero-shot-classification
pipeline to new testing. (#13299)
* Moving `zero-shot-classification` pipeline to new testing. * Cleaning up old mixins. * Fixing tests `sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english` is corrupted in PT. * Adding warning.
This commit is contained in:
parent
cc27ac1a87
commit
b89a964d3f
@ -2,12 +2,15 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..file_utils import add_end_docstrings
|
||||
from ..file_utils import add_end_docstrings, is_torch_available
|
||||
from ..tokenization_utils import TruncationStrategy
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -85,23 +88,84 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
hypothesis_template,
|
||||
padding=True,
|
||||
add_special_tokens=True,
|
||||
truncation=TruncationStrategy.ONLY_FIRST,
|
||||
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
|
||||
"""
|
||||
sequence_pairs = self._args_parser(sequences, candidate_labels, hypothesis_template)
|
||||
inputs = self.tokenizer(
|
||||
sequence_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=self.framework,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
)
|
||||
return_tensors = self.framework
|
||||
if getattr(self.tokenizer, "pad_token", None) is None:
|
||||
# XXX some tokenizers do not have a padding token, we use simple lists
|
||||
# and no padding then
|
||||
logger.warning("The tokenizer {self.tokenizer} does not have a pad token, we're not running it as a batch")
|
||||
padding = False
|
||||
inputs = []
|
||||
for sequence_pair in sequence_pairs:
|
||||
model_input = self.tokenizer(
|
||||
text=sequence_pair[0],
|
||||
text_pair=sequence_pair[1],
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
)
|
||||
inputs.append(model_input)
|
||||
else:
|
||||
inputs = self.tokenizer(
|
||||
sequence_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
def _forward(self, inputs, return_tensors=False):
|
||||
"""
|
||||
Internal framework specific forward dispatching
|
||||
|
||||
Args:
|
||||
inputs: dict holding all the keyword arguments for required by the model forward method.
|
||||
return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array
|
||||
|
||||
Returns:
|
||||
Numpy array
|
||||
"""
|
||||
# Encode for forward
|
||||
with self.device_placement():
|
||||
if self.framework == "tf":
|
||||
if isinstance(inputs, list):
|
||||
predictions = []
|
||||
for input_ in inputs:
|
||||
prediction = self.model(input_.data, training=False)[0]
|
||||
predictions.append(prediction)
|
||||
else:
|
||||
predictions = self.model(inputs.data, training=False)[0]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if isinstance(inputs, list):
|
||||
predictions = []
|
||||
for input_ in inputs:
|
||||
model_input = self.ensure_tensor_on_device(**input_)
|
||||
prediction = self.model(**model_input)[0].cpu()
|
||||
predictions.append(prediction)
|
||||
|
||||
else:
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
predictions = self.model(**inputs)[0].cpu()
|
||||
|
||||
if return_tensors:
|
||||
return predictions
|
||||
else:
|
||||
if isinstance(predictions, list):
|
||||
predictions = np.array([p.numpy() for p in predictions])
|
||||
else:
|
||||
predictions = predictions.numpy()
|
||||
return predictions
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sequences: Union[str, List[str]],
|
||||
@ -151,6 +215,12 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
sequences = [sequences]
|
||||
|
||||
outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
|
||||
if isinstance(outputs, list):
|
||||
# XXX: Some tokenizers cannot handle batching because they don't
|
||||
# have pad_token, so outputs will be a list, however, because outputs
|
||||
# is only n logits and sequence_length is not present anymore, we
|
||||
# can recreate a tensor out of outputs.
|
||||
outputs = np.array(outputs)
|
||||
num_sequences = len(sequences)
|
||||
candidate_labels = self._args_parser._parse_labels(candidate_labels)
|
||||
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
|
||||
|
@ -17,21 +17,9 @@ import logging
|
||||
import string
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
from unittest import mock, skipIf
|
||||
from unittest import skipIf
|
||||
|
||||
from transformers import (
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.file_utils import to_py_obj
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
|
||||
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -189,228 +177,3 @@ class PipelineTestCaseMeta(type):
|
||||
dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner)
|
||||
|
||||
return type.__new__(mcs, name, bases, dct)
|
||||
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class CustomInputPipelineCommonMixin:
|
||||
pipeline_task = None
|
||||
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
|
||||
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
|
||||
small_models = [] # Models tested without the @slow decorator
|
||||
large_models = [] # Models tested with the @slow decorator
|
||||
valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers
|
||||
|
||||
def setUp(self) -> None:
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
return # Currently no JAX pipelines
|
||||
|
||||
# Download needed checkpoints
|
||||
models = self.small_models
|
||||
if _run_slow_tests:
|
||||
models = models + self.large_models
|
||||
|
||||
for model_name in models:
|
||||
if is_torch_available():
|
||||
pipeline(
|
||||
self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
if is_tf_available():
|
||||
pipeline(
|
||||
self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_pt_defaults(self):
|
||||
pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs)
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_tf_defaults(self):
|
||||
pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs)
|
||||
|
||||
@require_torch
|
||||
def test_torch_small(self):
|
||||
for model_name in self.small_models:
|
||||
pipe_small = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._test_pipeline(pipe_small)
|
||||
|
||||
@require_tf
|
||||
def test_tf_small(self):
|
||||
for model_name in self.small_models:
|
||||
pipe_small = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._test_pipeline(pipe_small)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_large(self):
|
||||
for model_name in self.large_models:
|
||||
pipe_large = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._test_pipeline(pipe_large)
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_tf_large(self):
|
||||
for model_name in self.large_models:
|
||||
pipe_large = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._test_pipeline(pipe_large)
|
||||
|
||||
def _test_pipeline(self, pipe: Pipeline):
|
||||
raise NotImplementedError
|
||||
|
||||
@require_torch
|
||||
def test_compare_slow_fast_torch(self):
|
||||
for model_name in self.small_models:
|
||||
pipe_slow = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
use_fast=False,
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
pipe_fast = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
use_fast=True,
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="forward")
|
||||
|
||||
@require_tf
|
||||
def test_compare_slow_fast_tf(self):
|
||||
for model_name in self.small_models:
|
||||
pipe_slow = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
use_fast=False,
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
pipe_fast = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
use_fast=True,
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="call")
|
||||
|
||||
def _compare_slow_fast_pipelines(self, pipe_slow: Pipeline, pipe_fast: Pipeline, method: str):
|
||||
"""We check that the inputs to the models forward passes are identical for
|
||||
slow and fast tokenizers.
|
||||
"""
|
||||
with mock.patch.object(
|
||||
pipe_slow.model, method, wraps=getattr(pipe_slow.model, method)
|
||||
) as mock_slow, mock.patch.object(
|
||||
pipe_fast.model, method, wraps=getattr(pipe_fast.model, method)
|
||||
) as mock_fast:
|
||||
for inputs in self.valid_inputs:
|
||||
if isinstance(inputs, dict):
|
||||
inputs.update(self.pipeline_running_kwargs)
|
||||
_ = pipe_slow(**inputs)
|
||||
_ = pipe_fast(**inputs)
|
||||
else:
|
||||
_ = pipe_slow(inputs, **self.pipeline_running_kwargs)
|
||||
_ = pipe_fast(inputs, **self.pipeline_running_kwargs)
|
||||
|
||||
mock_slow.assert_called()
|
||||
mock_fast.assert_called()
|
||||
|
||||
self.assertEqual(len(mock_slow.call_args_list), len(mock_fast.call_args_list))
|
||||
for mock_slow_call_args, mock_fast_call_args in zip(
|
||||
mock_slow.call_args_list, mock_slow.call_args_list
|
||||
):
|
||||
slow_call_args, slow_call_kwargs = mock_slow_call_args
|
||||
fast_call_args, fast_call_kwargs = mock_fast_call_args
|
||||
|
||||
slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs)
|
||||
fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs)
|
||||
|
||||
self.assertEqual(slow_call_args, fast_call_args)
|
||||
self.assertDictEqual(slow_call_kwargs, fast_call_kwargs)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class MonoInputPipelineCommonMixin(CustomInputPipelineCommonMixin):
|
||||
"""A version of the CustomInputPipelineCommonMixin
|
||||
with a predefined `_test_pipeline` method.
|
||||
"""
|
||||
|
||||
mandatory_keys = {} # Keys which should be in the output
|
||||
invalid_inputs = [None] # inputs which are not allowed
|
||||
expected_multi_result: Optional[List] = None
|
||||
expected_check_keys: Optional[List[str]] = None
|
||||
|
||||
def _test_pipeline(self, pipe: Pipeline):
|
||||
self.assertIsNotNone(pipe)
|
||||
|
||||
mono_result = pipe(self.valid_inputs[0], **self.pipeline_running_kwargs)
|
||||
self.assertIsInstance(mono_result, list)
|
||||
self.assertIsInstance(mono_result[0], (dict, list))
|
||||
|
||||
if isinstance(mono_result[0], list):
|
||||
mono_result = mono_result[0]
|
||||
|
||||
for key in self.mandatory_keys:
|
||||
self.assertIn(key, mono_result[0])
|
||||
|
||||
multi_result = [pipe(input, **self.pipeline_running_kwargs) for input in self.valid_inputs]
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
|
||||
if self.expected_multi_result is not None:
|
||||
for result, expect in zip(multi_result, self.expected_multi_result):
|
||||
for key in self.expected_check_keys or []:
|
||||
self.assertEqual(
|
||||
set([o[key] for o in result]),
|
||||
set([o[key] for o in expect]),
|
||||
)
|
||||
|
||||
if isinstance(multi_result[0], list):
|
||||
multi_result = multi_result[0]
|
||||
|
||||
for result in multi_result:
|
||||
for key in self.mandatory_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
self.assertRaises(Exception, pipe, self.invalid_inputs)
|
||||
|
@ -13,39 +13,82 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers import (
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
Pipeline,
|
||||
ZeroShotClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "zero-shot-classification"
|
||||
small_models = ["sgugger/tiny-distilbert-classification"] # Models tested without the @slow decorator
|
||||
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
|
||||
valid_inputs = [
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics"]},
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics, public health"},
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics", "public health"]},
|
||||
{"sequences": ["Who are you voting for in 2020?"], "candidate_labels": "politics"},
|
||||
{
|
||||
"sequences": "Who are you voting for in 2020?",
|
||||
"candidate_labels": "politics",
|
||||
"hypothesis_template": "This text is about {}",
|
||||
},
|
||||
]
|
||||
@is_pipeline_test
|
||||
class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
|
||||
def _test_scores_sum_to_one(self, result):
|
||||
sum = 0.0
|
||||
for score in result["scores"]:
|
||||
sum += score
|
||||
self.assertAlmostEqual(sum, 1.0, places=5)
|
||||
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||
classifier = ZeroShotClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
def _test_entailment_id(self, zero_shot_classifier: Pipeline):
|
||||
outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
|
||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||
|
||||
outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"])
|
||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||
|
||||
outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics, public health")
|
||||
self.assertEqual(
|
||||
outputs, {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
|
||||
)
|
||||
self.assertAlmostEqual(sum(nested_simplify(outputs["scores"])), 1.0)
|
||||
|
||||
outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics", "public health"])
|
||||
self.assertEqual(
|
||||
outputs, {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
|
||||
)
|
||||
self.assertAlmostEqual(sum(nested_simplify(outputs["scores"])), 1.0)
|
||||
|
||||
outputs = classifier(
|
||||
"Who are you voting for in 2020?", candidate_labels="politics", hypothesis_template="This text is about {}"
|
||||
)
|
||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
classifier("", candidate_labels="politics")
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
classifier(None, candidate_labels="politics")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
classifier("Who are you voting for in 2020?", candidate_labels="")
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
classifier("Who are you voting for in 2020?", candidate_labels=None)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
classifier(
|
||||
"Who are you voting for in 2020?",
|
||||
candidate_labels="politics",
|
||||
hypothesis_template="Not formatting template",
|
||||
)
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
classifier(
|
||||
"Who are you voting for in 2020?",
|
||||
candidate_labels="politics",
|
||||
hypothesis_template=None,
|
||||
)
|
||||
|
||||
self.run_entailment_id(classifier)
|
||||
|
||||
def run_entailment_id(self, zero_shot_classifier: Pipeline):
|
||||
config = zero_shot_classifier.model.config
|
||||
original_config = deepcopy(config)
|
||||
original_label2id = config.label2id
|
||||
original_entailment = zero_shot_classifier.entailment_id
|
||||
|
||||
config.label2id = {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2}
|
||||
self.assertEqual(zero_shot_classifier.entailment_id, -1)
|
||||
@ -59,107 +102,105 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
|
||||
config.label2id = {"ENTAIL": 2, "NEUTRAL": 1, "CONTR": 0}
|
||||
self.assertEqual(zero_shot_classifier.entailment_id, 2)
|
||||
|
||||
zero_shot_classifier.model.config = original_config
|
||||
zero_shot_classifier.model.config.label2id = original_label2id
|
||||
self.assertEqual(original_entailment, zero_shot_classifier.entailment_id)
|
||||
|
||||
def _test_pipeline(self, zero_shot_classifier: Pipeline):
|
||||
output_keys = {"sequence", "labels", "scores"}
|
||||
valid_mono_inputs = [
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics"]},
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics, public health"},
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics", "public health"]},
|
||||
{"sequences": ["Who are you voting for in 2020?"], "candidate_labels": "politics"},
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
zero_shot_classifier = pipeline(
|
||||
"zero-shot-classification",
|
||||
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
|
||||
framework="pt",
|
||||
)
|
||||
outputs = zero_shot_classifier(
|
||||
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
{
|
||||
"sequences": "Who are you voting for in 2020?",
|
||||
"candidate_labels": "politics",
|
||||
"hypothesis_template": "This text is about {}",
|
||||
"sequence": "Who are you voting for in 2020?",
|
||||
"labels": ["science", "public health", "politics"],
|
||||
"scores": [0.333, 0.333, 0.333],
|
||||
},
|
||||
]
|
||||
valid_multi_input = {
|
||||
"sequences": ["Who are you voting for in 2020?", "What is the capital of Spain?"],
|
||||
"candidate_labels": "politics",
|
||||
}
|
||||
invalid_inputs = [
|
||||
{"sequences": None, "candidate_labels": "politics"},
|
||||
{"sequences": "", "candidate_labels": "politics"},
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": None},
|
||||
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ""},
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
zero_shot_classifier = pipeline(
|
||||
"zero-shot-classification",
|
||||
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
|
||||
framework="tf",
|
||||
)
|
||||
outputs = zero_shot_classifier(
|
||||
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
{
|
||||
"sequences": "Who are you voting for in 2020?",
|
||||
"candidate_labels": "politics",
|
||||
"hypothesis_template": None,
|
||||
"sequence": "Who are you voting for in 2020?",
|
||||
"labels": ["science", "public health", "politics"],
|
||||
"scores": [0.333, 0.333, 0.333],
|
||||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
zero_shot_classifier = pipeline("zero-shot-classification", model="roberta-large-mnli", framework="pt")
|
||||
outputs = zero_shot_classifier(
|
||||
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
{
|
||||
"sequences": "Who are you voting for in 2020?",
|
||||
"candidate_labels": "politics",
|
||||
"hypothesis_template": "",
|
||||
"sequence": "Who are you voting for in 2020?",
|
||||
"labels": ["politics", "public health", "science"],
|
||||
"scores": [0.976, 0.015, 0.009],
|
||||
},
|
||||
)
|
||||
outputs = zero_shot_classifier(
|
||||
"The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
|
||||
candidate_labels=["machine learning", "statistics", "translation", "vision"],
|
||||
multi_label=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
{
|
||||
"sequences": "Who are you voting for in 2020?",
|
||||
"candidate_labels": "politics",
|
||||
"hypothesis_template": "Template without formatting syntax.",
|
||||
"sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
|
||||
"labels": ["translation", "machine learning", "vision", "statistics"],
|
||||
"scores": [0.817, 0.713, 0.018, 0.018],
|
||||
},
|
||||
]
|
||||
self.assertIsNotNone(zero_shot_classifier)
|
||||
)
|
||||
|
||||
self._test_entailment_id(zero_shot_classifier)
|
||||
@slow
|
||||
@require_tf
|
||||
def test_large_model_tf(self):
|
||||
zero_shot_classifier = pipeline("zero-shot-classification", model="roberta-large-mnli", framework="tf")
|
||||
outputs = zero_shot_classifier(
|
||||
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
|
||||
)
|
||||
|
||||
for mono_input in valid_mono_inputs:
|
||||
mono_result = zero_shot_classifier(**mono_input)
|
||||
self.assertIsInstance(mono_result, dict)
|
||||
if len(mono_result["labels"]) > 1:
|
||||
self._test_scores_sum_to_one(mono_result)
|
||||
|
||||
for key in output_keys:
|
||||
self.assertIn(key, mono_result)
|
||||
|
||||
multi_result = zero_shot_classifier(**valid_multi_input)
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], dict)
|
||||
self.assertEqual(len(multi_result), len(valid_multi_input["sequences"]))
|
||||
|
||||
for result in multi_result:
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
if len(result["labels"]) > 1:
|
||||
self._test_scores_sum_to_one(result)
|
||||
|
||||
for bad_input in invalid_inputs:
|
||||
self.assertRaises(Exception, zero_shot_classifier, **bad_input)
|
||||
|
||||
if zero_shot_classifier.model.name_or_path in self.large_models:
|
||||
# We also check the outputs for the large models
|
||||
inputs = [
|
||||
{
|
||||
"sequences": "Who are you voting for in 2020?",
|
||||
"candidate_labels": ["politics", "public health", "science"],
|
||||
},
|
||||
{
|
||||
"sequences": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
|
||||
"candidate_labels": ["machine learning", "statistics", "translation", "vision"],
|
||||
"multi_label": True,
|
||||
},
|
||||
]
|
||||
|
||||
expected_outputs = [
|
||||
{
|
||||
"sequence": "Who are you voting for in 2020?",
|
||||
"labels": ["politics", "public health", "science"],
|
||||
"scores": [0.975, 0.015, 0.008],
|
||||
},
|
||||
{
|
||||
"sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
|
||||
"labels": ["translation", "machine learning", "vision", "statistics"],
|
||||
"scores": [0.817, 0.712, 0.018, 0.017],
|
||||
},
|
||||
]
|
||||
|
||||
for input, expected_output in zip(inputs, expected_outputs):
|
||||
output = zero_shot_classifier(**input)
|
||||
for key in output:
|
||||
if key == "scores":
|
||||
for output_score, expected_score in zip(output[key], expected_output[key]):
|
||||
self.assertAlmostEqual(output_score, expected_score, places=2)
|
||||
else:
|
||||
self.assertEqual(output[key], expected_output[key])
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
{
|
||||
"sequence": "Who are you voting for in 2020?",
|
||||
"labels": ["politics", "public health", "science"],
|
||||
"scores": [0.976, 0.015, 0.009],
|
||||
},
|
||||
)
|
||||
outputs = zero_shot_classifier(
|
||||
"The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
|
||||
candidate_labels=["machine learning", "statistics", "translation", "vision"],
|
||||
multi_label=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
{
|
||||
"sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
|
||||
"labels": ["translation", "machine learning", "vision", "statistics"],
|
||||
"scores": [0.817, 0.713, 0.018, 0.018],
|
||||
},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user