Use word_ids to get labels in run_ner (#8962)

* Use word_ids to get labels in run_ner

* Add sanity check
This commit is contained in:
Sylvain Gugger 2020-12-07 14:26:36 -05:00 committed by GitHub
parent de6befd41f
commit 7f9ccffc5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -35,6 +35,7 @@ from transformers import (
AutoTokenizer,
DataCollatorForTokenClassification,
HfArgumentParser,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
set_seed,
@ -250,6 +251,14 @@ def main():
cache_dir=model_args.cache_dir,
)
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
"requirement"
)
# Preprocessing the dataset
# Padding strategy
padding = "max_length" if data_args.pad_to_max_length else False
@ -262,28 +271,25 @@ def main():
truncation=True,
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
is_split_into_words=True,
return_offsets_mapping=True,
)
offset_mappings = tokenized_inputs.pop("offset_mapping")
labels = []
for label, offset_mapping in zip(examples[label_column_name], offset_mappings):
label_index = 0
current_label = -100
for i, label in enumerate(examples[label_column_name]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for offset in offset_mapping:
# We set the label for the first token of each word. Special characters will have an offset of (0, 0)
# so the test ignores them.
if offset[0] == 0 and offset[1] != 0:
current_label = label_to_id[label[label_index]]
label_index += 1
label_ids.append(current_label)
# For special tokens, we set the label to -100 so it's automatically ignored in the loss function.
elif offset[0] == 0 and offset[1] == 0:
for word_idx in word_ids:
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
# ignored in the loss function.
if word_idx is None:
label_ids.append(-100)
# We set the label for the first token of each word.
elif word_idx != previous_word_idx:
label_ids.append(label_to_id[label[word_idx]])
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
else:
label_ids.append(current_label if data_args.label_all_tokens else -100)
label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels