From db0a9ee6e0ddcb9d634c3ab0ba3d25501c370d8c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 16 Dec 2019 14:08:08 +0100 Subject: [PATCH] adding albert to TF auto models cc @LysandreJik --- transformers/modeling_tf_auto.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/transformers/modeling_tf_auto.py b/transformers/modeling_tf_auto.py index 9c687d92352..3e9b4d120b2 100644 --- a/transformers/modeling_tf_auto.py +++ b/transformers/modeling_tf_auto.py @@ -27,6 +27,7 @@ from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceC from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel, TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP +from .modeling_tf_albert import TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel, TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP from .file_utils import add_start_docstrings @@ -46,7 +47,6 @@ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value) TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, - TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP, ] for key, value, in pretrained_map.items()) @@ -162,6 +162,8 @@ class TFAutoModel(object): return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'distilbert' in pretrained_model_name_or_path: return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'albert' in pretrained_model_name_or_path: + return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'roberta' in pretrained_model_name_or_path: return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'bert' in pretrained_model_name_or_path: @@ -298,6 +300,8 @@ class TFAutoModelWithLMHead(object): return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'distilbert' in pretrained_model_name_or_path: return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'albert' in pretrained_model_name_or_path: + return TFAlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'roberta' in pretrained_model_name_or_path: return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'bert' in pretrained_model_name_or_path: @@ -425,6 +429,8 @@ class TFAutoModelForSequenceClassification(object): """ if 'distilbert' in pretrained_model_name_or_path: return TFDistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'albert' in pretrained_model_name_or_path: + return TFAlbertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'roberta' in pretrained_model_name_or_path: return TFRobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'bert' in pretrained_model_name_or_path: