mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
CTRL to tf automodels
This commit is contained in:
parent
d844db4005
commit
a701c9b321
@ -26,6 +26,7 @@ from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSeque
|
||||
from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple
|
||||
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification
|
||||
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification
|
||||
from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
@ -52,6 +53,7 @@ class TFAutoModel(object):
|
||||
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
|
||||
- contains `xlnet`: TFXLNetModel (XLNet model)
|
||||
- contains `xlm`: TFXLMModel (XLM model)
|
||||
- contains `ctrl`: TFCTRLModel (CTRL model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
@ -73,7 +75,7 @@ class TFAutoModel(object):
|
||||
- contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
|
||||
- contains `xlnet`: TFXLNetModel (XLNet model)
|
||||
- contains `xlm`: TFXLMModel (XLM model)
|
||||
- contains `ctrl`: TFCTRLModel (CTRL model)
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
@ -147,10 +149,12 @@ class TFAutoModel(object):
|
||||
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'ctrl' in pretrained_model_name_or_path:
|
||||
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
|
||||
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path))
|
||||
|
||||
|
||||
class TFAutoModelWithLMHead(object):
|
||||
@ -173,6 +177,7 @@ class TFAutoModelWithLMHead(object):
|
||||
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
|
||||
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
|
||||
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
|
||||
- contains `ctrl`: TFCTRLLMHeadModel (CTRL model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
@ -198,6 +203,7 @@ class TFAutoModelWithLMHead(object):
|
||||
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
|
||||
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
|
||||
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
|
||||
- contains `ctrl`: TFCTRLLMHeadModel (CTRL model)
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
@ -271,10 +277,12 @@ class TFAutoModelWithLMHead(object):
|
||||
return TFXLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'xlm' in pretrained_model_name_or_path:
|
||||
return TFXLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif 'ctrl' in pretrained_model_name_or_path:
|
||||
return TFCTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
|
||||
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path))
|
||||
|
||||
|
||||
class TFAutoModelForSequenceClassification(object):
|
||||
|
Loading…
Reference in New Issue
Block a user