mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Map configs to models and tokenizers
This commit is contained in:
parent
1fc855e456
commit
0304628590
@ -202,7 +202,7 @@ class AutoConfig:
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format(
|
||||
pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys())
|
||||
)
|
||||
"Unrecognized model in {}. "
|
||||
"Should have a `model_type` key in its config.json, or contain one of the following strings "
|
||||
"in its name: {}".format(pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()))
|
||||
)
|
||||
|
@ -47,8 +47,8 @@ class PretrainedConfig(object):
|
||||
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
|
||||
``torchscript``: string, default `False`. Is the model used with Torchscript.
|
||||
"""
|
||||
pretrained_config_archive_map = {} # type: Dict[str, str]
|
||||
model_type = "" # type: str
|
||||
pretrained_config_archive_map = {} # type: Dict[str, str]
|
||||
model_type = "" # type: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Attributes with defaults
|
||||
@ -273,7 +273,7 @@ class PretrainedConfig(object):
|
||||
return self.__dict__ == other.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
return "{} {}".format(self.__class__.__name__, self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Type
|
||||
from typing import Dict, Type
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
@ -126,14 +126,14 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||
MODEL_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(T5Config, T5Model),
|
||||
(DistilBertConfig, DistilBertModel),
|
||||
(AlbertConfig, AlbertModel),
|
||||
(CamembertConfig, CamembertModel),
|
||||
(RobertaConfig, XLMRobertaModel),
|
||||
(XLMRobertaConfig, RobertaModel),
|
||||
(RobertaConfig, RobertaModel),
|
||||
(XLMRobertaConfig, XLMRobertaModel),
|
||||
(BertConfig, BertModel),
|
||||
(OpenAIGPTConfig, OpenAIGPTModel),
|
||||
(GPT2Config, GPT2Model),
|
||||
@ -144,12 +144,53 @@ MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = Orde
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||
MODEL_WITH_LM_HEAD_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(T5Config, T5WithLMHeadModel),
|
||||
(DistilBertConfig, DistilBertForMaskedLM),
|
||||
(AlbertConfig, AlbertForMaskedLM),
|
||||
(CamembertConfig, CamembertForMaskedLM),
|
||||
(RobertaConfig, RobertaForMaskedLM),
|
||||
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||
(BertConfig, BertForMaskedLM),
|
||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||
(GPT2Config, GPT2LMHeadModel),
|
||||
(TransfoXLConfig, TransfoXLLMHeadModel),
|
||||
(XLNetConfig, XLNetLMHeadModel),
|
||||
(XLMConfig, XLMWithLMHeadModel),
|
||||
(CTRLConfig, CTRLLMHeadModel),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, DistilBertForSequenceClassification),
|
||||
(AlbertConfig, AlbertForSequenceClassification),
|
||||
(CamembertConfig, CamembertForSequenceClassification),
|
||||
(RobertaConfig, RobertaForSequenceClassification),
|
||||
(XLMRobertaConfig, XLMRobertaForSequenceClassification),
|
||||
(BertConfig, BertForSequenceClassification),
|
||||
(XLNetConfig, XLNetForSequenceClassification),
|
||||
(XLMConfig, XLMForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, DistilBertForQuestionAnswering),
|
||||
(AlbertConfig, AlbertForQuestionAnswering),
|
||||
(BertConfig, BertForQuestionAnswering),
|
||||
(XLNetConfig, XLNetForQuestionAnswering),
|
||||
(XLMConfig, XLMForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, DistilBertForTokenClassification),
|
||||
(CamembertConfig, CamembertForTokenClassification),
|
||||
(RobertaConfig, XLMRobertaForTokenClassification),
|
||||
(XLMRobertaConfig, RobertaForTokenClassification),
|
||||
(RobertaConfig, RobertaForTokenClassification),
|
||||
(XLMRobertaConfig, XLMRobertaForTokenClassification),
|
||||
(BertConfig, BertForTokenClassification),
|
||||
(XLNetConfig, XLNetForTokenClassification),
|
||||
]
|
||||
@ -218,7 +259,12 @@ class AutoModel(object):
|
||||
for config_class, model_class in MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -309,10 +355,9 @@ class AutoModel(object):
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@ -376,27 +421,15 @@ class AutoModelWithLMHead(object):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return DistilBertForMaskedLM(config)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForMaskedLM(config)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForMaskedLM(config)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return OpenAIGPTLMHeadModel(config)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return GPT2LMHeadModel(config)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TransfoXLLMHeadModel(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetLMHeadModel(config)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMWithLMHeadModel(config)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return CTRLLMHeadModel(config)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForMaskedLM(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -486,57 +519,13 @@ class AutoModelWithLMHead(object):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return T5WithLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return OpenAIGPTLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TransfoXLLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMWithLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@ -591,23 +580,17 @@ class AutoModelForSequenceClassification(object):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, AlbertConfig):
|
||||
return AlbertForSequenceClassification(config)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertForSequenceClassification(config)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertForSequenceClassification(config)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForSequenceClassification(config)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForSequenceClassification(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForSequenceClassification(config)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMForSequenceClassification(config)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForSequenceClassification(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -693,43 +676,15 @@ class AutoModelForSequenceClassification(object):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return DistilBertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@ -780,17 +735,18 @@ class AutoModelForQuestionAnswering(object):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, AlbertConfig):
|
||||
return AlbertForQuestionAnswering(config)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertForQuestionAnswering(config)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForQuestionAnswering(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForQuestionAnswering(config)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMForQuestionAnswering(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -870,30 +826,17 @@ class AutoModelForQuestionAnswering(object):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return DistilBertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path)
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -923,19 +866,18 @@ class AutoModelForTokenClassification:
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, CamembertConfig):
|
||||
return CamembertForTokenClassification(config)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertForTokenClassification(config)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForTokenClassification(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForTokenClassification(config)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForTokenClassification(config)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForTokenClassification(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -1014,34 +956,15 @@ class AutoModelForTokenClassification:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, CamembertConfig):
|
||||
return CamembertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(
|
||||
pretrained_model_name_or_path
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Type
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
@ -70,6 +72,7 @@ from .modeling_tf_transfo_xl import (
|
||||
TFTransfoXLLMHeadModel,
|
||||
TFTransfoXLModel,
|
||||
)
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
from .modeling_tf_xlm import (
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
@ -108,6 +111,65 @@ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
TF_MODEL_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, TFDistilBertModel),
|
||||
(AlbertConfig, TFAlbertModel),
|
||||
(RobertaConfig, TFRobertaModel),
|
||||
(BertConfig, TFBertModel),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTModel),
|
||||
(GPT2Config, TFGPT2Model),
|
||||
(TransfoXLConfig, TFTransfoXLModel),
|
||||
(XLNetConfig, TFXLNetModel),
|
||||
(XLMConfig, TFXLMModel),
|
||||
(CTRLConfig, TFCTRLModel),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||
(AlbertConfig, TFAlbertForMaskedLM),
|
||||
(RobertaConfig, TFRobertaForMaskedLM),
|
||||
(BertConfig, TFBertForMaskedLM),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||
(GPT2Config, TFGPT2LMHeadModel),
|
||||
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
||||
(XLNetConfig, TFXLNetLMHeadModel),
|
||||
(XLMConfig, TFXLMWithLMHeadModel),
|
||||
(CTRLConfig, TFCTRLLMHeadModel),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
||||
(AlbertConfig, TFAlbertForSequenceClassification),
|
||||
(RobertaConfig, TFRobertaForSequenceClassification),
|
||||
(BertConfig, TFBertForSequenceClassification),
|
||||
(XLNetConfig, TFXLNetForSequenceClassification),
|
||||
(XLMConfig, TFXLMForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||
(BertConfig, TFBertForQuestionAnswering),
|
||||
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
|
||||
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, TFDistilBertForTokenClassification),
|
||||
(RobertaConfig, TFRobertaForTokenClassification),
|
||||
(BertConfig, TFBertForTokenClassification),
|
||||
(XLNetConfig, TFXLNetForTokenClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TFAutoModel(object):
|
||||
r"""
|
||||
@ -165,25 +227,15 @@ class TFAutoModel(object):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = TFAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertModel(config)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaModel(config)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertModel(config)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return TFOpenAIGPTModel(config)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return TFGPT2Model(config)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TFTransfoXLModel(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetModel(config)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMModel(config)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return TFCTRLModel(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in TF_MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -266,39 +318,14 @@ class TFAutoModel(object):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return TFOpenAIGPTModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TFTransfoXLModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
|
||||
for config_class, model_class in TF_MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path)
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -358,25 +385,15 @@ class TFAutoModelWithLMHead(object):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForMaskedLM(config)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForMaskedLM(config)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForMaskedLM(config)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return TFOpenAIGPTLMHeadModel(config)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return TFGPT2LMHeadModel(config)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TFTransfoXLLMHeadModel(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetLMHeadModel(config)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMWithLMHeadModel(config)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return TFCTRLLMHeadModel(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -464,55 +481,14 @@ class TFAutoModelWithLMHead(object):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return TFT5WithLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return TFAlbertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return TFOpenAIGPTLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return TFGPT2LMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TFTransfoXLLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMWithLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return TFCTRLLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path)
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -563,17 +539,17 @@ class TFAutoModelForSequenceClassification(object):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForSequenceClassification(config)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForSequenceClassification(config)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForSequenceClassification(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetForSequenceClassification(config)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMForSequenceClassification(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -659,34 +635,16 @@ class TFAutoModelForSequenceClassification(object):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return TFAlbertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'distilbert', 'bert', 'xlnet', 'xlm', 'roberta'".format(pretrained_model_name_or_path)
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -735,15 +693,17 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForQuestionAnswering(config)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForQuestionAnswering(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
raise NotImplementedError("TFXLNetForQuestionAnswering isn't implemented")
|
||||
elif isinstance(config, XLMConfig):
|
||||
raise NotImplementedError("TFXLMForQuestionAnswering isn't implemented")
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -828,26 +788,16 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetForQuestionAnsweringSimple.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMForQuestionAnsweringSimple.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'distilbert', 'bert', 'xlnet', 'xlm'".format(pretrained_model_name_or_path)
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -876,15 +826,17 @@ class TFAutoModelForTokenClassification:
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = TFAutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
if isinstance(config, BertConfig):
|
||||
return TFBertForTokenClassification(config)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetForTokenClassification(config)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForTokenClassification(config)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForTokenClassification(config)
|
||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@ -962,24 +914,14 @@ class TFAutoModelForTokenClassification:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, BertConfig):
|
||||
return TFBertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'xlnet', 'distilbert', 'roberta'".format(pretrained_model_name_or_path)
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Type
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_t5 import T5Tokenizer
|
||||
from .tokenization_transfo_xl import TransfoXLTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from .tokenization_xlnet import XLNetTokenizer
|
||||
@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TOKENIZER_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedTokenizer]] = OrderedDict(
|
||||
[
|
||||
(T5Config, T5Tokenizer),
|
||||
(DistilBertConfig, DistilBertTokenizer),
|
||||
(AlbertConfig, AlbertTokenizer),
|
||||
(CamembertConfig, CamembertTokenizer),
|
||||
(RobertaConfig, XLMRobertaTokenizer),
|
||||
(XLMRobertaConfig, RobertaTokenizer),
|
||||
(BertConfig, BertTokenizer),
|
||||
(OpenAIGPTConfig, OpenAIGPTTokenizer),
|
||||
(GPT2Config, GPT2Tokenizer),
|
||||
(TransfoXLConfig, TransfoXLTokenizer),
|
||||
(XLNetConfig, XLNetTokenizer),
|
||||
(XLMConfig, XLMTokenizer),
|
||||
(CTRLConfig, CTRLTokenizer),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AutoTokenizer(object):
|
||||
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
|
||||
that will be instantiated as one of the tokenizer classes of the library
|
||||
@ -154,36 +176,13 @@ class AutoTokenizer(object):
|
||||
if "bert-base-japanese" in pretrained_model_name_or_path:
|
||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
for config_class, tokenizer_class in TOKENIZER_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
"Unrecognized configuration class {} to build an AutoTokenizer.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
|
||||
# no key string should be included in a later key string (typical failure case)
|
||||
keys = list(CONFIG_MAPPING.keys())
|
||||
for i, key in enumerate(keys):
|
||||
self.assertFalse(any(key in later_key for later_key in keys[i+1:]))
|
||||
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
|
||||
|
Loading…
Reference in New Issue
Block a user