mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
adding albert to TF auto models cc @LysandreJik
This commit is contained in:
parent
a4d07b983a
commit
db0a9ee6e0
@ -27,6 +27,7 @@ from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceC
|
||||
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel, TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_tf_albert import TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel, TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
@ -46,7 +47,6 @@ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items())
|
||||
@ -162,6 +162,8 @@ class TFAutoModel(object):
|
||||
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
@ -298,6 +300,8 @@ class TFAutoModelWithLMHead(object):
|
||||
return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'distilbert' in pretrained_model_name_or_path:
|
||||
return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
return TFAlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
@ -425,6 +429,8 @@ class TFAutoModelForSequenceClassification(object):
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
return TFDistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
return TFAlbertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
return TFRobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
|
Loading…
Reference in New Issue
Block a user