mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Moving token-classification
pipeline to new testing. (#13286)
* Moving `token-classification` pipeline to new testing. * Fix tests.
This commit is contained in:
parent
a6e36558ef
commit
45a8eb66bb
@ -1329,7 +1329,7 @@ def nested_simplify(obj, decimals=3):
|
||||
return nested_simplify(obj.numpy().tolist())
|
||||
elif isinstance(obj, float):
|
||||
return round(obj, decimals)
|
||||
elif isinstance(obj, np.float32):
|
||||
elif isinstance(obj, (np.int32, np.float32)):
|
||||
return nested_simplify(obj.item(), decimals)
|
||||
else:
|
||||
raise Exception(f"Not supported: {type(obj)}")
|
||||
|
@ -127,7 +127,11 @@ class PipelineTestCaseMeta(type):
|
||||
if tokenizer_class is not None:
|
||||
try:
|
||||
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
||||
if hasattr(model.config, "max_position_embeddings"):
|
||||
# XLNet actually defines it as -1.
|
||||
if (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings > 0
|
||||
):
|
||||
tokenizer.model_max_length = model.config.max_position_embeddings
|
||||
# Rust Panic exception are NOT Exception subclass
|
||||
# Some test tokenizer contain broken vocabs or custom PreTokenizer, so we
|
||||
|
@ -16,57 +16,171 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
|
||||
from transformers.pipelines import AggregationStrategy, Pipeline, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow
|
||||
from transformers import (
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
AutoModelForTokenClassification,
|
||||
AutoTokenizer,
|
||||
TokenClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
|
||||
|
||||
|
||||
class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "ner"
|
||||
small_models = [
|
||||
"sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
] # Default model - Models tested without the @slow decorator
|
||||
large_models = [] # Models tested with the @slow decorator
|
||||
@is_pipeline_test
|
||||
class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
|
||||
def _test_pipeline(self, token_classifier: Pipeline):
|
||||
output_keys = {"entity", "word", "score", "start", "end", "index"}
|
||||
if token_classifier.aggregation_strategy != AggregationStrategy.NONE:
|
||||
output_keys = {"entity_group", "word", "score", "start", "end"}
|
||||
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
self.assertIsNotNone(token_classifier)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{
|
||||
"entity": ANY(str),
|
||||
"score": ANY(float),
|
||||
"start": ANY(int),
|
||||
"end": ANY(int),
|
||||
"index": ANY(int),
|
||||
"word": ANY(str),
|
||||
}
|
||||
for i in range(n)
|
||||
],
|
||||
)
|
||||
outputs = token_classifier(["list of strings", "A simple string that is quite a bit longer"])
|
||||
self.assertIsInstance(outputs, list)
|
||||
self.assertEqual(len(outputs), 2)
|
||||
n = len(outputs[0])
|
||||
m = len(outputs[1])
|
||||
|
||||
mono_result = token_classifier(VALID_INPUTS[0])
|
||||
self.assertIsInstance(mono_result, list)
|
||||
self.assertIsInstance(mono_result[0], (dict, list))
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
[
|
||||
{
|
||||
"entity": ANY(str),
|
||||
"score": ANY(float),
|
||||
"start": ANY(int),
|
||||
"end": ANY(int),
|
||||
"index": ANY(int),
|
||||
"word": ANY(str),
|
||||
}
|
||||
for i in range(n)
|
||||
],
|
||||
[
|
||||
{
|
||||
"entity": ANY(str),
|
||||
"score": ANY(float),
|
||||
"start": ANY(int),
|
||||
"end": ANY(int),
|
||||
"index": ANY(int),
|
||||
"word": ANY(str),
|
||||
}
|
||||
for i in range(m)
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
if isinstance(mono_result[0], list):
|
||||
mono_result = mono_result[0]
|
||||
self.run_aggregation_strategy(model, tokenizer)
|
||||
|
||||
for key in output_keys:
|
||||
self.assertIn(key, mono_result[0])
|
||||
def run_aggregation_strategy(self, model, tokenizer):
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple")
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{
|
||||
"entity_group": ANY(str),
|
||||
"score": ANY(float),
|
||||
"start": ANY(int),
|
||||
"end": ANY(int),
|
||||
"word": ANY(str),
|
||||
}
|
||||
for i in range(n)
|
||||
],
|
||||
)
|
||||
|
||||
multi_result = [token_classifier(input) for input in VALID_INPUTS]
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="first")
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{
|
||||
"entity_group": ANY(str),
|
||||
"score": ANY(float),
|
||||
"start": ANY(int),
|
||||
"end": ANY(int),
|
||||
"word": ANY(str),
|
||||
}
|
||||
for i in range(n)
|
||||
],
|
||||
)
|
||||
|
||||
if isinstance(multi_result[0], list):
|
||||
multi_result = multi_result[0]
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="max")
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.MAX)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{
|
||||
"entity_group": ANY(str),
|
||||
"score": ANY(float),
|
||||
"start": ANY(int),
|
||||
"end": ANY(int),
|
||||
"word": ANY(str),
|
||||
}
|
||||
for i in range(n)
|
||||
],
|
||||
)
|
||||
|
||||
for result in multi_result:
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
token_classifier = TokenClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer, aggregation_strategy="average"
|
||||
)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.AVERAGE)
|
||||
outputs = token_classifier("A simple string")
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{
|
||||
"entity_group": ANY(str),
|
||||
"score": ANY(float),
|
||||
"start": ANY(int),
|
||||
"end": ANY(int),
|
||||
"word": ANY(str),
|
||||
}
|
||||
for i in range(n)
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_model_kwargs_passed_to_model_load(self):
|
||||
ner_pipeline = pipeline(task="ner", model=self.small_models[0])
|
||||
self.assertFalse(ner_pipeline.model.config.output_attentions)
|
||||
ner_pipeline = pipeline(task="ner", model=self.small_models[0], model_kwargs={"output_attentions": True})
|
||||
self.assertTrue(ner_pipeline.model.config.output_attentions)
|
||||
with self.assertWarns(UserWarning):
|
||||
token_classifier = pipeline(task="ner", model=model, tokenizer=tokenizer, grouped_entities=True)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
|
||||
with self.assertWarns(UserWarning):
|
||||
token_classifier = pipeline(
|
||||
task="ner", model=model, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
|
||||
)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@ -206,7 +320,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
|
||||
@require_torch
|
||||
def test_aggregation_strategy(self):
|
||||
model_name = self.small_models[0]
|
||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||
# Just to understand scores indexes in this test
|
||||
@ -283,7 +397,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
|
||||
@require_torch
|
||||
def test_aggregation_strategy_example2(self):
|
||||
model_name = self.small_models[0]
|
||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||
# Just to understand scores indexes in this test
|
||||
@ -345,8 +459,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
|
||||
@require_torch
|
||||
def test_gather_pre_entities(self):
|
||||
|
||||
model_name = self.small_models[0]
|
||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||
|
||||
@ -389,42 +502,37 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
model_name = "Narsil/small" # This model only has a TensorFlow version
|
||||
# We test that if we don't specificy framework='tf', it gets detected automatically
|
||||
token_classifier = pipeline(task="ner", model=model_name)
|
||||
self._test_pipeline(token_classifier)
|
||||
self.assertEqual(token_classifier.framework, "tf")
|
||||
|
||||
@require_tf
|
||||
def test_tf_defaults(self):
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="tf")
|
||||
self._test_pipeline(token_classifier)
|
||||
def test_small_model_tf(self):
|
||||
model_name = "Narsil/small2"
|
||||
token_classifier = pipeline(task="token-classification", model=model_name, framework="tf")
|
||||
outputs = token_classifier("This is a test !")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_tf_small_ignore_subwords_available_for_fast_tokenizers(self):
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(
|
||||
task="ner",
|
||||
model=model_name,
|
||||
tokenizer=tokenizer,
|
||||
framework="tf",
|
||||
aggregation_strategy=AggregationStrategy.FIRST,
|
||||
)
|
||||
self._test_pipeline(token_classifier)
|
||||
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(
|
||||
task="ner",
|
||||
model=model_name,
|
||||
tokenizer=tokenizer,
|
||||
framework="tf",
|
||||
aggregation_strategy=AggregationStrategy.SIMPLE,
|
||||
)
|
||||
self._test_pipeline(token_classifier)
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
model_name = "Narsil/small2"
|
||||
token_classifier = pipeline(task="token-classification", model=model_name, framework="pt")
|
||||
outputs = token_classifier("This is a test !")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
|
||||
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
||||
model_name = self.small_models[0]
|
||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
@ -436,31 +544,6 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.MAX)
|
||||
|
||||
@require_torch
|
||||
def test_pt_defaults_slow_tokenizer(self):
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer)
|
||||
self._test_pipeline(token_classifier)
|
||||
|
||||
@require_torch
|
||||
def test_pt_defaults(self):
|
||||
for model_name in self.small_models:
|
||||
token_classifier = pipeline(task="ner", model=model_name)
|
||||
self._test_pipeline(token_classifier)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_warnings(self):
|
||||
with self.assertWarns(UserWarning):
|
||||
token_classifier = pipeline(task="ner", model=self.small_models[0], grouped_entities=True)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
|
||||
with self.assertWarns(UserWarning):
|
||||
token_classifier = pipeline(
|
||||
task="ner", model=self.small_models[0], grouped_entities=True, ignore_subwords=True
|
||||
)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_simple(self):
|
||||
@ -501,23 +584,8 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(
|
||||
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
|
||||
)
|
||||
self._test_pipeline(token_classifier)
|
||||
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(
|
||||
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False
|
||||
)
|
||||
self._test_pipeline(token_classifier)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.args_parser = TokenClassificationArgumentHandler()
|
||||
|
Loading…
Reference in New Issue
Block a user