mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
parent
68427c9beb
commit
d29baf69bb
@ -96,7 +96,6 @@ class TokenClassificationPipeline(Pipeline):
|
||||
default_input_names = "sequences"
|
||||
|
||||
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
|
||||
self.ignore_labels = ["O"]
|
||||
super().__init__(*args, **kwargs)
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
@ -216,7 +215,9 @@ class TokenClassificationPipeline(Pipeline):
|
||||
**model_inputs,
|
||||
}
|
||||
|
||||
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
|
||||
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
|
||||
if ignore_labels is None:
|
||||
ignore_labels = ["O"]
|
||||
logits = model_outputs["logits"][0].numpy()
|
||||
sentence = model_outputs["sentence"]
|
||||
input_ids = model_outputs["input_ids"][0]
|
||||
@ -235,8 +236,8 @@ class TokenClassificationPipeline(Pipeline):
|
||||
entities = [
|
||||
entity
|
||||
for entity in grouped_entities
|
||||
if entity.get("entity", None) not in self.ignore_labels
|
||||
and entity.get("entity_group", None) not in self.ignore_labels
|
||||
if entity.get("entity", None) not in ignore_labels
|
||||
and entity.get("entity_group", None) not in ignore_labels
|
||||
]
|
||||
return entities
|
||||
|
||||
|
@ -627,6 +627,15 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline(
|
||||
task="token-classification", model=model_name, framework="pt", ignore_labels=["O", "I-MISC"]
|
||||
)
|
||||
outputs = token_classifier("This is a test !")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
|
Loading…
Reference in New Issue
Block a user