mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixing backward compatiblity for non prefixed tokens (B-, I-). (#13493)
This commit is contained in:
parent
e59d4d0147
commit
db514a75d0
@ -411,7 +411,8 @@ class TokenClassificationPipeline(Pipeline):
|
||||
tag = entity_name[2:]
|
||||
else:
|
||||
# It's not in B-, I- format
|
||||
bi = "B"
|
||||
# Default to I- for continuation.
|
||||
bi = "I"
|
||||
tag = entity_name
|
||||
return bi, tag
|
||||
|
||||
|
@ -318,6 +318,59 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_aggregation_strategy_no_b_i_prefix(self):
|
||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||
# Just to understand scores indexes in this test
|
||||
token_classifier.model.config.id2label = {0: "O", 1: "MISC", 2: "PER", 3: "ORG", 4: "LOC"}
|
||||
example = [
|
||||
{
|
||||
# fmt : off
|
||||
"scores": np.array([0, 0, 0, 0, 0.9968166351318359]),
|
||||
"index": 1,
|
||||
"is_subword": False,
|
||||
"word": "En",
|
||||
"start": 0,
|
||||
"end": 2,
|
||||
},
|
||||
{
|
||||
# fmt : off
|
||||
"scores": np.array([0, 0, 0, 0, 0.9957635998725891]),
|
||||
"index": 2,
|
||||
"is_subword": True,
|
||||
"word": "##zo",
|
||||
"start": 2,
|
||||
"end": 4,
|
||||
},
|
||||
{
|
||||
# fmt: off
|
||||
"scores": np.array([0, 0, 0, 0.9986497163772583, 0]),
|
||||
# fmt: on
|
||||
"index": 7,
|
||||
"word": "UN",
|
||||
"is_subword": False,
|
||||
"start": 11,
|
||||
"end": 13,
|
||||
},
|
||||
]
|
||||
self.assertEqual(
|
||||
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)),
|
||||
[
|
||||
{"end": 2, "entity": "LOC", "score": 0.997, "start": 0, "word": "En", "index": 1},
|
||||
{"end": 4, "entity": "LOC", "score": 0.996, "start": 2, "word": "##zo", "index": 2},
|
||||
{"end": 13, "entity": "ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7},
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)),
|
||||
[
|
||||
{"entity_group": "LOC", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_aggregation_strategy(self):
|
||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||
|
Loading…
Reference in New Issue
Block a user