Fix roberta model ordering for TFAutoModel (#5414)

This commit is contained in:
Pierric Cistac 2020-07-02 19:23:55 -04:00 committed by GitHub
parent 6b735a7253
commit 8438bab38e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -141,7 +141,6 @@ logger = logging.getLogger(__name__)
TF_MODEL_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertModel),
(BertConfig, TFBertModel),
(CamembertConfig, TFCamembertModel),
(CTRLConfig, TFCTRLModel),
(DistilBertConfig, TFDistilBertModel),
@ -151,6 +150,7 @@ TF_MODEL_MAPPING = OrderedDict(
(MobileBertConfig, TFMobileBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
(RobertaConfig, TFRobertaModel),
(BertConfig, TFBertModel),
(T5Config, TFT5Model),
(TransfoXLConfig, TFTransfoXLModel),
(XLMConfig, TFXLMModel),
@ -162,7 +162,6 @@ TF_MODEL_MAPPING = OrderedDict(
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForPreTraining),
(BertConfig, TFBertForPreTraining),
(CamembertConfig, TFCamembertForMaskedLM),
(CTRLConfig, TFCTRLLMHeadModel),
(DistilBertConfig, TFDistilBertForMaskedLM),
@ -172,6 +171,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(MobileBertConfig, TFMobileBertForPreTraining),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForPreTraining),
(T5Config, TFT5ForConditionalGeneration),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
@ -183,7 +183,6 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(CamembertConfig, TFCamembertForMaskedLM),
(CTRLConfig, TFCTRLLMHeadModel),
(DistilBertConfig, TFDistilBertForMaskedLM),
@ -193,6 +192,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(MobileBertConfig, TFMobileBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(T5Config, TFT5ForConditionalGeneration),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
@ -204,12 +204,12 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForMultipleChoice),
(BertConfig, TFBertForMultipleChoice),
(CamembertConfig, TFCamembertForMultipleChoice),
(DistilBertConfig, TFDistilBertForMultipleChoice),
(FlaubertConfig, TFFlaubertForMultipleChoice),
(MobileBertConfig, TFMobileBertForMultipleChoice),
(RobertaConfig, TFRobertaForMultipleChoice),
(BertConfig, TFBertForMultipleChoice),
(XLMConfig, TFXLMForMultipleChoice),
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
(XLNetConfig, TFXLNetForMultipleChoice),
@ -219,13 +219,13 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(CamembertConfig, TFCamembertForQuestionAnswering),
(DistilBertConfig, TFDistilBertForQuestionAnswering),
(ElectraConfig, TFElectraForQuestionAnswering),
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
(MobileBertConfig, TFMobileBertForQuestionAnswering),
(RobertaConfig, TFRobertaForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(XLMConfig, TFXLMForQuestionAnsweringSimple),
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
@ -235,12 +235,12 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(CamembertConfig, TFCamembertForSequenceClassification),
(DistilBertConfig, TFDistilBertForSequenceClassification),
(FlaubertConfig, TFFlaubertForSequenceClassification),
(MobileBertConfig, TFMobileBertForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(XLMConfig, TFXLMForSequenceClassification),
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
@ -250,13 +250,13 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(CamembertConfig, TFCamembertForTokenClassification),
(DistilBertConfig, TFDistilBertForTokenClassification),
(ElectraConfig, TFElectraForTokenClassification),
(FlaubertConfig, TFFlaubertForTokenClassification),
(MobileBertConfig, TFMobileBertForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(XLMConfig, TFXLMForTokenClassification),
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(XLNetConfig, TFXLNetForTokenClassification),