mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fixing backward compatiblity for non prefixed tokens (B-, I-). (#13493)
This commit is contained in:
parent
e59d4d0147
commit
db514a75d0
@ -411,7 +411,8 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
tag = entity_name[2:]
|
tag = entity_name[2:]
|
||||||
else:
|
else:
|
||||||
# It's not in B-, I- format
|
# It's not in B-, I- format
|
||||||
bi = "B"
|
# Default to I- for continuation.
|
||||||
|
bi = "I"
|
||||||
tag = entity_name
|
tag = entity_name
|
||||||
return bi, tag
|
return bi, tag
|
||||||
|
|
||||||
|
@ -318,6 +318,59 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_aggregation_strategy_no_b_i_prefix(self):
|
||||||
|
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||||
|
# Just to understand scores indexes in this test
|
||||||
|
token_classifier.model.config.id2label = {0: "O", 1: "MISC", 2: "PER", 3: "ORG", 4: "LOC"}
|
||||||
|
example = [
|
||||||
|
{
|
||||||
|
# fmt : off
|
||||||
|
"scores": np.array([0, 0, 0, 0, 0.9968166351318359]),
|
||||||
|
"index": 1,
|
||||||
|
"is_subword": False,
|
||||||
|
"word": "En",
|
||||||
|
"start": 0,
|
||||||
|
"end": 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# fmt : off
|
||||||
|
"scores": np.array([0, 0, 0, 0, 0.9957635998725891]),
|
||||||
|
"index": 2,
|
||||||
|
"is_subword": True,
|
||||||
|
"word": "##zo",
|
||||||
|
"start": 2,
|
||||||
|
"end": 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# fmt: off
|
||||||
|
"scores": np.array([0, 0, 0, 0.9986497163772583, 0]),
|
||||||
|
# fmt: on
|
||||||
|
"index": 7,
|
||||||
|
"word": "UN",
|
||||||
|
"is_subword": False,
|
||||||
|
"start": 11,
|
||||||
|
"end": 13,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)),
|
||||||
|
[
|
||||||
|
{"end": 2, "entity": "LOC", "score": 0.997, "start": 0, "word": "En", "index": 1},
|
||||||
|
{"end": 4, "entity": "LOC", "score": 0.996, "start": 2, "word": "##zo", "index": 2},
|
||||||
|
{"end": 13, "entity": "ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)),
|
||||||
|
[
|
||||||
|
{"entity_group": "LOC", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
|
||||||
|
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_aggregation_strategy(self):
|
def test_aggregation_strategy(self):
|
||||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||||
|
Loading…
Reference in New Issue
Block a user