Fix label attribution in token classification examples (#14055)

This commit is contained in:
Sylvain Gugger 2021-10-20 07:55:14 -04:00 committed by GitHub
parent 31560f6397
commit f875fb0e5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 2 deletions

View File

@ -303,6 +303,14 @@ def main():
label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
# Load pretrained model and tokenizer
#
# Distributed training:
@ -385,7 +393,10 @@ def main():
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
else:
label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
if data_args.label_all_tokens:
label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels.append(label_ids)

View File

@ -328,6 +328,14 @@ def main():
label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
# Load pretrained model and tokenizer
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
@ -396,7 +404,10 @@ def main():
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
else:
label_ids.append(label_to_id[label[word_idx]] if args.label_all_tokens else -100)
if args.label_all_tokens:
label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels.append(label_ids)