mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
587197dcd2
commit
900daec24e
@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
|
||||
Handles arguments for token classification.
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
||||
|
||||
if args is not None and len(args) > 0:
|
||||
inputs = list(args)
|
||||
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
|
||||
inputs = list(inputs)
|
||||
batch_size = len(inputs)
|
||||
elif isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
batch_size = 1
|
||||
else:
|
||||
raise ValueError("At least one input is required.")
|
||||
|
||||
@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline):
|
||||
Only exists if the offsets are available within the tokenizer
|
||||
"""
|
||||
|
||||
inputs, offset_mappings = self._args_parser(inputs, **kwargs)
|
||||
_inputs, offset_mappings = self._args_parser(inputs, **kwargs)
|
||||
|
||||
answers = []
|
||||
|
||||
for i, sentence in enumerate(inputs):
|
||||
for i, sentence in enumerate(_inputs):
|
||||
|
||||
# Manage correct placement of the tensors
|
||||
with self.device_placement():
|
||||
|
@ -14,14 +14,17 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, pipeline
|
||||
from transformers import AutoTokenizer, is_torch_available, pipeline
|
||||
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import require_tf, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
|
||||
|
||||
|
||||
class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
@ -334,17 +337,26 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
def test_simple(self):
|
||||
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
|
||||
output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York")
|
||||
sentence = "Hello Sarah Jessica Parker who Jessica lives in New York"
|
||||
sentence2 = "This is a simple test"
|
||||
output = nlp(sentence)
|
||||
|
||||
def simplify(output):
|
||||
for i in range(len(output)):
|
||||
output[i]["score"] = round(output[i]["score"], 3)
|
||||
return output
|
||||
if isinstance(output, (list, tuple)):
|
||||
return [simplify(item) for item in output]
|
||||
elif isinstance(output, dict):
|
||||
return {simplify(k): simplify(v) for k, v in output.items()}
|
||||
elif isinstance(output, (str, int, np.int64)):
|
||||
return output
|
||||
elif isinstance(output, float):
|
||||
return round(output, 3)
|
||||
else:
|
||||
raise Exception(f"Cannot handle {type(output)}")
|
||||
|
||||
output = simplify(output)
|
||||
output_ = simplify(output)
|
||||
|
||||
self.assertEqual(
|
||||
output,
|
||||
output_,
|
||||
[
|
||||
{
|
||||
"entity_group": "PER",
|
||||
@ -358,6 +370,21 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
output = nlp([sentence, sentence2])
|
||||
output_ = simplify(output)
|
||||
|
||||
self.assertEqual(
|
||||
output_,
|
||||
[
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.996, "word": "Sarah Jessica Parker", "start": 6, "end": 26},
|
||||
{"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38},
|
||||
{"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56},
|
||||
],
|
||||
[],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
|
||||
for model_name in self.small_models:
|
||||
@ -386,7 +413,7 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
|
||||
self.assertEqual(inputs, [string])
|
||||
self.assertEqual(offset_mapping, None)
|
||||
|
||||
inputs, offset_mapping = self.args_parser(string, string)
|
||||
inputs, offset_mapping = self.args_parser([string, string])
|
||||
self.assertEqual(inputs, [string, string])
|
||||
self.assertEqual(offset_mapping, None)
|
||||
|
||||
@ -394,25 +421,35 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
|
||||
self.assertEqual(inputs, [string])
|
||||
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
|
||||
|
||||
inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
||||
inputs, offset_mapping = self.args_parser(
|
||||
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
|
||||
)
|
||||
self.assertEqual(inputs, [string, string])
|
||||
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
||||
|
||||
def test_errors(self):
|
||||
string = "This is a simple input"
|
||||
|
||||
# 2 sentences, 1 offset_mapping
|
||||
with self.assertRaises(ValueError):
|
||||
# 2 sentences, 1 offset_mapping, args
|
||||
with self.assertRaises(TypeError):
|
||||
self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]])
|
||||
|
||||
# 2 sentences, 1 offset_mapping
|
||||
with self.assertRaises(ValueError):
|
||||
# 2 sentences, 1 offset_mapping, args
|
||||
with self.assertRaises(TypeError):
|
||||
self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)])
|
||||
|
||||
# 2 sentences, 1 offset_mapping, input_list
|
||||
with self.assertRaises(ValueError):
|
||||
self.args_parser([string, string], offset_mapping=[[(0, 1), (1, 2)]])
|
||||
|
||||
# 2 sentences, 1 offset_mapping, input_list
|
||||
with self.assertRaises(ValueError):
|
||||
self.args_parser([string, string], offset_mapping=[(0, 1), (1, 2)])
|
||||
|
||||
# 1 sentences, 2 offset_mapping
|
||||
with self.assertRaises(ValueError):
|
||||
self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
||||
|
||||
# 0 sentences, 1 offset_mapping
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(TypeError):
|
||||
self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])
|
||||
|
Loading…
Reference in New Issue
Block a user