mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Handling longformer model_type (#7990)
Updating the run_squad training script to handle the "longformer" `model_type`. The longformer is trained in the same was as RoBERTa, so I've added the "longformer" `model_type` (that's the right hugginface name for the LongFormer model, right?) everywhere there was a "roberta" `model_type` reference. The longformer (like RoBERTa) doesn't use `token_type_ids` (as I understand from looking at the [longformer notebook](https://github.com/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb), which is what gets updated after this change. This fix might be related to [this issue](https://github.com/huggingface/transformers/issues/7249) with SQuAD training when using run_squad.py
This commit is contained in:
parent
5e323017a4
commit
d39da5a2ab
@ -187,7 +187,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
"end_positions": batch[4],
|
||||
}
|
||||
|
||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]:
|
||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
@ -300,7 +300,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
"token_type_ids": batch[2],
|
||||
}
|
||||
|
||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]:
|
||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
feature_indices = batch[3]
|
||||
|
Loading…
Reference in New Issue
Block a user