diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index 8fb0a9ba6d5..c5718e82fcf 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -295,12 +295,15 @@ def main(): label_list.sort() return label_list - if isinstance(features[label_column_name].feature, ClassLabel): + # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere. + # Otherwise, we have to get the list of labels manually. + labels_are_int = isinstance(features[label_column_name].feature, ClassLabel) + if labels_are_int: label_list = features[label_column_name].feature.names - label_keys = list(range(len(label_list))) + label_to_id = {i: i for i in range(len(label_list))} else: label_list = get_label_list(raw_datasets["train"][label_column_name]) - label_keys = label_list + label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) @@ -354,21 +357,26 @@ def main(): "requirement" ) + # Model has labels -> use them. if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: - label_name_to_id = {k: v for k, v in model.config.label2id.items()} - if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): - label_to_id = {k: int(label_name_to_id[k]) for k in label_keys} + if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): + # Reorganize `label_list` to match the ordering of the model. + if labels_are_int: + label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} + label_list = [model.config.id2label[i] for i in range(num_labels)] + else: + label_list = [model.config.id2label[i] for i in range(num_labels)] + label_to_id = {l: i for i, l in enumerate(label_list)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", - f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) - else: - label_to_id = {k: i for i, k in enumerate(label_keys)} - model.config.label2id = label_to_id - model.config.id2label = {i: l for l, i in label_to_id.items()} + # Set the correspondences label/ID inside the model config + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {i: l for i, l in enumerate(label_list)} # Map that sends B-Xxx label to its I-Xxx counterpart b_to_i_label = [] diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index af1959aa511..e292331ea44 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -320,12 +320,15 @@ def main(): label_list.sort() return label_list - if isinstance(features[label_column_name].feature, ClassLabel): + # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere. + # Otherwise, we have to get the list of labels manually. + labels_are_int = isinstance(features[label_column_name].feature, ClassLabel) + if labels_are_int: label_list = features[label_column_name].feature.names - label_keys = list(range(len(label_list))) + label_to_id = {i: i for i in range(len(label_list))} else: label_list = get_label_list(raw_datasets["train"][label_column_name]) - label_keys = label_list + label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) @@ -365,21 +368,26 @@ def main(): model.resize_token_embeddings(len(tokenizer)) + # Model has labels -> use them. if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: - label_name_to_id = {k: v for k, v in model.config.label2id.items()} - if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): - label_to_id = {k: int(label_name_to_id[k]) for k in label_keys} + if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): + # Reorganize `label_list` to match the ordering of the model. + if labels_are_int: + label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} + label_list = [model.config.id2label[i] for i in range(num_labels)] + else: + label_list = [model.config.id2label[i] for i in range(num_labels)] + label_to_id = {l: i for i, l in enumerate(label_list)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", - f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) - else: - label_to_id = {k: i for i, k in enumerate(label_keys)} - model.config.label2id = label_to_id - model.config.id2label = {i: l for l, i in label_to_id.items()} + # Set the correspondences label/ID inside the model config + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {i: l for i, l in enumerate(label_list)} # Map that sends B-Xxx label to its I-Xxx counterpart b_to_i_label = []