mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Appending label2id and id2label to models to ensure inference works properly (#12102)
This commit is contained in:
parent
4cda08decb
commit
bebbdd0fc9
@ -370,6 +370,10 @@ def main():
|
|||||||
elif data_args.task_name is None and not is_regression:
|
elif data_args.task_name is None and not is_regression:
|
||||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||||
|
|
||||||
|
if label_to_id is not None:
|
||||||
|
model.config.label2id = label_to_id
|
||||||
|
model.config.id2label = {id: label for label, id in config.label2id.items()}
|
||||||
|
|
||||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
|
@ -282,6 +282,10 @@ def main():
|
|||||||
elif args.task_name is None:
|
elif args.task_name is None:
|
||||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||||
|
|
||||||
|
if label_to_id is not None:
|
||||||
|
model.config.label2id = label_to_id
|
||||||
|
model.config.id2label = {id: label for label, id in config.label2id.items()}
|
||||||
|
|
||||||
padding = "max_length" if args.pad_to_max_length else False
|
padding = "max_length" if args.pad_to_max_length else False
|
||||||
|
|
||||||
def preprocess_function(examples):
|
def preprocess_function(examples):
|
||||||
|
Loading…
Reference in New Issue
Block a user