Fixing mishandling of ignore_labels. (#14274)

Fixes #14272
This commit is contained in:
Nicolas Patry 2021-11-04 14:47:52 +01:00 committed by GitHub
parent 68427c9beb
commit d29baf69bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 4 deletions

View File

@ -96,7 +96,6 @@ class TokenClassificationPipeline(Pipeline):
default_input_names = "sequences"
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
self.ignore_labels = ["O"]
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
@ -216,7 +215,9 @@ class TokenClassificationPipeline(Pipeline):
**model_inputs,
}
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
if ignore_labels is None:
ignore_labels = ["O"]
logits = model_outputs["logits"][0].numpy()
sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
@ -235,8 +236,8 @@ class TokenClassificationPipeline(Pipeline):
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in self.ignore_labels
and entity.get("entity_group", None) not in self.ignore_labels
if entity.get("entity", None) not in ignore_labels
and entity.get("entity_group", None) not in ignore_labels
]
return entities

View File

@ -627,6 +627,15 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
],
)
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", ignore_labels=["O", "I-MISC"]
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[],
)
@require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"