mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix labels stored in model config for token classification examples (#15482)
* Playing * Properly set labels in model config for token classification example * Port to run_ner_no_trainer * Quality
This commit is contained in:
parent
c74f3d4c48
commit
45cac3fade
@ -295,12 +295,15 @@ def main():
|
||||
label_list.sort()
|
||||
return label_list
|
||||
|
||||
if isinstance(features[label_column_name].feature, ClassLabel):
|
||||
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
|
||||
# Otherwise, we have to get the list of labels manually.
|
||||
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
|
||||
if labels_are_int:
|
||||
label_list = features[label_column_name].feature.names
|
||||
label_keys = list(range(len(label_list)))
|
||||
label_to_id = {i: i for i in range(len(label_list))}
|
||||
else:
|
||||
label_list = get_label_list(raw_datasets["train"][label_column_name])
|
||||
label_keys = label_list
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
|
||||
num_labels = len(label_list)
|
||||
|
||||
@ -354,21 +357,26 @@ def main():
|
||||
"requirement"
|
||||
)
|
||||
|
||||
# Model has labels -> use them.
|
||||
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
|
||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
||||
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
|
||||
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
|
||||
# Reorganize `label_list` to match the ordering of the model.
|
||||
if labels_are_int:
|
||||
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
|
||||
label_list = [model.config.id2label[i] for i in range(num_labels)]
|
||||
else:
|
||||
label_list = [model.config.id2label[i] for i in range(num_labels)]
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
||||
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
||||
"\nIgnoring the model labels as a result.",
|
||||
)
|
||||
else:
|
||||
label_to_id = {k: i for i, k in enumerate(label_keys)}
|
||||
|
||||
model.config.label2id = label_to_id
|
||||
model.config.id2label = {i: l for l, i in label_to_id.items()}
|
||||
# Set the correspondences label/ID inside the model config
|
||||
model.config.label2id = {l: i for i, l in enumerate(label_list)}
|
||||
model.config.id2label = {i: l for i, l in enumerate(label_list)}
|
||||
|
||||
# Map that sends B-Xxx label to its I-Xxx counterpart
|
||||
b_to_i_label = []
|
||||
|
@ -320,12 +320,15 @@ def main():
|
||||
label_list.sort()
|
||||
return label_list
|
||||
|
||||
if isinstance(features[label_column_name].feature, ClassLabel):
|
||||
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
|
||||
# Otherwise, we have to get the list of labels manually.
|
||||
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
|
||||
if labels_are_int:
|
||||
label_list = features[label_column_name].feature.names
|
||||
label_keys = list(range(len(label_list)))
|
||||
label_to_id = {i: i for i in range(len(label_list))}
|
||||
else:
|
||||
label_list = get_label_list(raw_datasets["train"][label_column_name])
|
||||
label_keys = label_list
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
|
||||
num_labels = len(label_list)
|
||||
|
||||
@ -365,21 +368,26 @@ def main():
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Model has labels -> use them.
|
||||
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
|
||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
||||
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
|
||||
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
|
||||
# Reorganize `label_list` to match the ordering of the model.
|
||||
if labels_are_int:
|
||||
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
|
||||
label_list = [model.config.id2label[i] for i in range(num_labels)]
|
||||
else:
|
||||
label_list = [model.config.id2label[i] for i in range(num_labels)]
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
||||
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
||||
"\nIgnoring the model labels as a result.",
|
||||
)
|
||||
else:
|
||||
label_to_id = {k: i for i, k in enumerate(label_keys)}
|
||||
|
||||
model.config.label2id = label_to_id
|
||||
model.config.id2label = {i: l for l, i in label_to_id.items()}
|
||||
# Set the correspondences label/ID inside the model config
|
||||
model.config.label2id = {l: i for i, l in enumerate(label_list)}
|
||||
model.config.id2label = {i: l for i, l in enumerate(label_list)}
|
||||
|
||||
# Map that sends B-Xxx label to its I-Xxx counterpart
|
||||
b_to_i_label = []
|
||||
|
Loading…
Reference in New Issue
Block a user