mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Use word_ids to get labels in run_ner (#8962)
* Use word_ids to get labels in run_ner * Add sanity check
This commit is contained in:
parent
de6befd41f
commit
7f9ccffc5b
@ -35,6 +35,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
DataCollatorForTokenClassification,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerFast,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
@ -250,6 +251,14 @@ def main():
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# Tokenizer check: this script requires a fast tokenizer.
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
raise ValueError(
|
||||
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
|
||||
"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
|
||||
"requirement"
|
||||
)
|
||||
|
||||
# Preprocessing the dataset
|
||||
# Padding strategy
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
@ -262,28 +271,25 @@ def main():
|
||||
truncation=True,
|
||||
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
|
||||
is_split_into_words=True,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
offset_mappings = tokenized_inputs.pop("offset_mapping")
|
||||
labels = []
|
||||
for label, offset_mapping in zip(examples[label_column_name], offset_mappings):
|
||||
label_index = 0
|
||||
current_label = -100
|
||||
for i, label in enumerate(examples[label_column_name]):
|
||||
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
||||
previous_word_idx = None
|
||||
label_ids = []
|
||||
for offset in offset_mapping:
|
||||
# We set the label for the first token of each word. Special characters will have an offset of (0, 0)
|
||||
# so the test ignores them.
|
||||
if offset[0] == 0 and offset[1] != 0:
|
||||
current_label = label_to_id[label[label_index]]
|
||||
label_index += 1
|
||||
label_ids.append(current_label)
|
||||
# For special tokens, we set the label to -100 so it's automatically ignored in the loss function.
|
||||
elif offset[0] == 0 and offset[1] == 0:
|
||||
for word_idx in word_ids:
|
||||
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
|
||||
# ignored in the loss function.
|
||||
if word_idx is None:
|
||||
label_ids.append(-100)
|
||||
# We set the label for the first token of each word.
|
||||
elif word_idx != previous_word_idx:
|
||||
label_ids.append(label_to_id[label[word_idx]])
|
||||
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
||||
# the label_all_tokens flag.
|
||||
else:
|
||||
label_ids.append(current_label if data_args.label_all_tokens else -100)
|
||||
label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
|
||||
previous_word_idx = word_idx
|
||||
|
||||
labels.append(label_ids)
|
||||
tokenized_inputs["labels"] = labels
|
||||
|
Loading…
Reference in New Issue
Block a user