Fixing NER pipeline for list inputs. (#10184)

Fixes #10168
This commit is contained in:
Nicolas Patry 2021-02-15 12:22:45 +01:00 committed by GitHub
parent 587197dcd2
commit 900daec24e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 20 deletions

View File

@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
Handles arguments for token classification.
"""
def __call__(self, *args, **kwargs):
def __call__(self, inputs: Union[str, List[str]], **kwargs):
if args is not None and len(args) > 0:
inputs = list(args)
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
inputs = list(inputs)
batch_size = len(inputs)
elif isinstance(inputs, str):
inputs = [inputs]
batch_size = 1
else:
raise ValueError("At least one input is required.")
@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline):
Only exists if the offsets are available within the tokenizer
"""
inputs, offset_mappings = self._args_parser(inputs, **kwargs)
_inputs, offset_mappings = self._args_parser(inputs, **kwargs)
answers = []
for i, sentence in enumerate(inputs):
for i, sentence in enumerate(_inputs):
# Manage correct placement of the tensors
with self.device_placement():

View File

@ -14,14 +14,17 @@
import unittest
from transformers import AutoTokenizer, pipeline
from transformers import AutoTokenizer, is_torch_available, pipeline
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import require_tf, require_torch, slow
from .test_pipelines_common import CustomInputPipelineCommonMixin
VALID_INPUTS = ["A simple string", ["list of strings"]]
if is_torch_available():
import numpy as np
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
@ -334,17 +337,26 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
@require_torch
def test_simple(self):
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York")
sentence = "Hello Sarah Jessica Parker who Jessica lives in New York"
sentence2 = "This is a simple test"
output = nlp(sentence)
def simplify(output):
for i in range(len(output)):
output[i]["score"] = round(output[i]["score"], 3)
return output
if isinstance(output, (list, tuple)):
return [simplify(item) for item in output]
elif isinstance(output, dict):
return {simplify(k): simplify(v) for k, v in output.items()}
elif isinstance(output, (str, int, np.int64)):
return output
elif isinstance(output, float):
return round(output, 3)
else:
raise Exception(f"Cannot handle {type(output)}")
output = simplify(output)
output_ = simplify(output)
self.assertEqual(
output,
output_,
[
{
"entity_group": "PER",
@ -358,6 +370,21 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
],
)
output = nlp([sentence, sentence2])
output_ = simplify(output)
self.assertEqual(
output_,
[
[
{"entity_group": "PER", "score": 0.996, "word": "Sarah Jessica Parker", "start": 6, "end": 26},
{"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38},
{"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56},
],
[],
],
)
@require_torch
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
for model_name in self.small_models:
@ -386,7 +413,7 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
self.assertEqual(inputs, [string])
self.assertEqual(offset_mapping, None)
inputs, offset_mapping = self.args_parser(string, string)
inputs, offset_mapping = self.args_parser([string, string])
self.assertEqual(inputs, [string, string])
self.assertEqual(offset_mapping, None)
@ -394,25 +421,35 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
self.assertEqual(inputs, [string])
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
inputs, offset_mapping = self.args_parser(
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
)
self.assertEqual(inputs, [string, string])
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
def test_errors(self):
string = "This is a simple input"
# 2 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
# 2 sentences, 1 offset_mapping, args
with self.assertRaises(TypeError):
self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]])
# 2 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
# 2 sentences, 1 offset_mapping, args
with self.assertRaises(TypeError):
self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)])
# 2 sentences, 1 offset_mapping, input_list
with self.assertRaises(ValueError):
self.args_parser([string, string], offset_mapping=[[(0, 1), (1, 2)]])
# 2 sentences, 1 offset_mapping, input_list
with self.assertRaises(ValueError):
self.args_parser([string, string], offset_mapping=[(0, 1), (1, 2)])
# 1 sentences, 2 offset_mapping
with self.assertRaises(ValueError):
self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
# 0 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])