mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Added DistilBERT to run_lm_finetuning
This commit is contained in:
parent
2d8ec5a684
commit
88368c2a16
@ -39,7 +39,8 @@ from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
|||||||
BertConfig, BertForMaskedLM, BertTokenizer,
|
BertConfig, BertForMaskedLM, BertTokenizer,
|
||||||
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
||||||
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
||||||
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
|
||||||
|
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -49,7 +50,8 @@ MODEL_CLASSES = {
|
|||||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||||
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||||
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||||
|
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -380,7 +382,7 @@ def main():
|
|||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model_type in ["bert", "roberta"] and not args.mlm:
|
if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm:
|
||||||
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||||
"flag (masked language modeling).")
|
"flag (masked language modeling).")
|
||||||
if args.eval_data_file is None and args.do_eval:
|
if args.eval_data_file is None and args.do_eval:
|
||||||
|
Loading…
Reference in New Issue
Block a user