mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix auto models
This commit is contained in:
parent
de203853cc
commit
28a30af6d1
@ -22,7 +22,7 @@ from .modeling_tf_bert import TFBertModel, TFBertForMaskedLM, TFBertForSequenceC
|
||||
from .modeling_tf_openai import TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel
|
||||
from .modeling_tf_gpt2 import TFGPT2Model, TFGPT2LMHeadModel
|
||||
from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel
|
||||
from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering
|
||||
from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, TFXLNetForQuestionAnsweringSimple
|
||||
from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple
|
||||
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification
|
||||
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification
|
||||
@ -493,9 +493,9 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
elif 'bert' in pretrained_model_name_or_path:
|
||||
return TFBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlnet' in pretrained_model_name_or_path:
|
||||
return TFXLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
return TFXLNetForQuestionAnsweringSimple.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
return TFXLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
return TFXLMForQuestionAnsweringSimple.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'xlm'".format(pretrained_model_name_or_path))
|
||||
|
Loading…
Reference in New Issue
Block a user