diff --git a/examples/run_glue.py b/examples/run_glue.py index 53b46fc1025..89fb957b47f 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -128,7 +128,7 @@ def train(args, train_dataset, model, tokenizer): batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], - 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids + 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids 'labels': batch[3]} outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)