Fix labels stored in model config for token classification examples (#15482)

* Playing

* Properly set labels in model config for token classification example

* Port to run_ner_no_trainer

* Quality
This commit is contained in:
Sylvain Gugger 2022-02-02 14:23:43 -05:00 committed by GitHub
parent c74f3d4c48
commit 45cac3fade
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 22 deletions

View File

@ -295,12 +295,15 @@ def main():
label_list.sort()
return label_list
if isinstance(features[label_column_name].feature, ClassLabel):
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
# Otherwise, we have to get the list of labels manually.
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
if labels_are_int:
label_list = features[label_column_name].feature.names
label_keys = list(range(len(label_list)))
label_to_id = {i: i for i in range(len(label_list))}
else:
label_list = get_label_list(raw_datasets["train"][label_column_name])
label_keys = label_list
label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)
@ -354,21 +357,26 @@ def main():
"requirement"
)
# Model has labels -> use them.
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
# Reorganize `label_list` to match the ordering of the model.
if labels_are_int:
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
label_list = [model.config.id2label[i] for i in range(num_labels)]
else:
label_list = [model.config.id2label[i] for i in range(num_labels)]
label_to_id = {l: i for i, l in enumerate(label_list)}
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
else:
label_to_id = {k: i for i, k in enumerate(label_keys)}
model.config.label2id = label_to_id
model.config.id2label = {i: l for l, i in label_to_id.items()}
# Set the correspondences label/ID inside the model config
model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {i: l for i, l in enumerate(label_list)}
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []

View File

@ -320,12 +320,15 @@ def main():
label_list.sort()
return label_list
if isinstance(features[label_column_name].feature, ClassLabel):
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
# Otherwise, we have to get the list of labels manually.
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
if labels_are_int:
label_list = features[label_column_name].feature.names
label_keys = list(range(len(label_list)))
label_to_id = {i: i for i in range(len(label_list))}
else:
label_list = get_label_list(raw_datasets["train"][label_column_name])
label_keys = label_list
label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)
@ -365,21 +368,26 @@ def main():
model.resize_token_embeddings(len(tokenizer))
# Model has labels -> use them.
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
# Reorganize `label_list` to match the ordering of the model.
if labels_are_int:
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
label_list = [model.config.id2label[i] for i in range(num_labels)]
else:
label_list = [model.config.id2label[i] for i in range(num_labels)]
label_to_id = {l: i for i, l in enumerate(label_list)}
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
else:
label_to_id = {k: i for i, k in enumerate(label_keys)}
model.config.label2id = label_to_id
model.config.id2label = {i: l for l, i in label_to_id.items()}
# Set the correspondences label/ID inside the model config
model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {i: l for i, l in enumerate(label_list)}
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []