mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix run_glue evaluation when model has a label correspondence (#10401)
This commit is contained in:
parent
26f8b2cb10
commit
17b6e0d474
@ -324,7 +324,7 @@ def main():
|
||||
# Some have all caps in their config, some don't.
|
||||
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
|
||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
||||
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
|
||||
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
|
||||
else:
|
||||
logger.warn(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
@ -350,7 +350,7 @@ def main():
|
||||
|
||||
# Map labels to IDs (not necessary for GLUE tasks)
|
||||
if label_to_id is not None and "label" in examples:
|
||||
result["label"] = [label_to_id[l] for l in examples["label"]]
|
||||
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
|
||||
return result
|
||||
|
||||
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
||||
|
@ -289,8 +289,9 @@ class PretrainedConfig(object):
|
||||
|
||||
@num_labels.setter
|
||||
def num_labels(self, num_labels: int):
|
||||
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
if self.id2label is None or len(self.id2label) != num_labels:
|
||||
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user