AutoConfig + other Auto classes honor model_type

This commit is contained in:
Julien Chaumond 2020-01-11 02:46:17 +00:00
parent 2f32dfd33b
commit 4d1c98c012
9 changed files with 503 additions and 284 deletions

View File

@ -16,6 +16,7 @@
import logging
from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
@ -27,6 +28,7 @@ from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, Open
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .configuration_utils import PretrainedConfig
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
@ -56,17 +58,38 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
)
class AutoConfig(object):
CONFIG_MAPPING = OrderedDict(
[
("t5", T5Config,),
("distilbert", DistilBertConfig,),
("albert", AlbertConfig,),
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("roberta", RobertaConfig,),
("bert", BertConfig,),
("openai-gpt", OpenAIGPTConfig,),
("gpt2", GPT2Config,),
("transfo-xl", TransfoXLConfig,),
("xlnet", XLNetConfig,),
("xlm", XLMConfig,),
("ctrl", CTRLConfig,),
]
)
class AutoConfig:
r""":class:`~transformers.AutoConfig` is a generic configuration class
that will be instantiated as one of the configuration classes of the library
when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method take care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
When using string matching, the configuration class is matched on
the `pretrained_model_name_or_path` string in the following order:
- contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model)
@ -90,41 +113,23 @@ class AutoConfig(object):
@classmethod
def for_model(cls, model_type, *args, **kwargs):
if "distilbert" in model_type:
return DistilBertConfig(*args, **kwargs)
elif "roberta" in model_type:
return RobertaConfig(*args, **kwargs)
elif "bert" in model_type:
return BertConfig(*args, **kwargs)
elif "openai-gpt" in model_type:
return OpenAIGPTConfig(*args, **kwargs)
elif "gpt2" in model_type:
return GPT2Config(*args, **kwargs)
elif "transfo-xl" in model_type:
return TransfoXLConfig(*args, **kwargs)
elif "xlnet" in model_type:
return XLNetConfig(*args, **kwargs)
elif "xlm" in model_type:
return XLMConfig(*args, **kwargs)
elif "ctrl" in model_type:
return CTRLConfig(*args, **kwargs)
elif "albert" in model_type:
return AlbertConfig(*args, **kwargs)
elif "camembert" in model_type:
return CamembertConfig(*args, **kwargs)
for pattern, config_class in CONFIG_MAPPING.items():
if pattern in model_type:
return config_class(*args, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)
"Unrecognized model identifier in {}. Should contain one of {}".format(
model_type, ", ".join(CONFIG_MAPPING.keys())
)
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a one of the configuration classes of the library
r""" Instantiate one of the configuration classes of the library
from a pre-trained model configuration.
The configuration class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
The configuration class to instantiate is selected
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
- contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model)
@ -183,36 +188,21 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
"""
if "t5" in pretrained_model_name_or_path:
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "albert" in pretrained_model_name_or_path:
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "camembert" in pretrained_model_name_or_path:
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "roberta" in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config_dict, _ = PretrainedConfig.resolved_config_dict(
pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
)
if "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]]
return config_class.from_dict(config_dict, **kwargs)
else:
# Fallback: use pattern matching on the string.
for pattern, config_class in CONFIG_MAPPING.items():
if pattern in pretrained_model_name_or_path:
return config_class.from_dict(config_dict, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
"Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format(
pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys())
)
)

View File

@ -20,6 +20,7 @@ import copy
import json
import logging
import os
from typing import Dict, Optional, Tuple
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
@ -36,7 +37,7 @@ class PretrainedConfig(object):
It only affects the model's configuration.
Class attributes (overridden by derived classes):
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
- ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
Parameters:
``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
@ -154,14 +155,32 @@ class PretrainedConfig(object):
assert unused_kwargs == {'foo': False}
"""
config_dict, kwargs = cls.resolved_config_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)
@classmethod
def resolved_config_dict(
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
) -> Tuple[Dict, Dict]:
"""
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
for instantiating a Config using `from_dict`.
Parameters:
pretrained_config_archive_map: (`optional`) Dict:
A map of `shortcut names` to `url`.
By default, will use the current class attribute.
"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
if pretrained_config_archive_map is None:
pretrained_config_archive_map = cls.pretrained_config_archive_map
if pretrained_model_name_or_path in pretrained_config_archive_map:
config_file = pretrained_config_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
@ -178,23 +197,20 @@ class PretrainedConfig(object):
proxies=proxies,
resume_download=resume_download,
)
# Load config
config = cls.from_json_file(resolved_config_file)
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
if pretrained_model_name_or_path in pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file
)
else:
msg = (
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
", ".join(cls.pretrained_config_archive_map.keys()),
config_file,
CONFIG_NAME,
pretrained_model_name_or_path, config_file, CONFIG_NAME,
)
)
raise EnvironmentError(msg)
@ -212,6 +228,15 @@ class PretrainedConfig(object):
else:
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
return config_dict, kwargs
@classmethod
def from_dict(cls, config_dict: Dict, **kwargs):
"""Constructs a `Config` from a Python dictionary of parameters."""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
config = cls(**config_dict)
if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
@ -231,17 +256,16 @@ class PretrainedConfig(object):
return config
@classmethod
def from_dict(cls, json_object):
"""Constructs a `Config` from a Python dictionary of parameters."""
return cls(**json_object)
def from_json_file(cls, json_file: str):
"""Constructs a `Config` from a json file of parameters."""
config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict)
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `Config` from a json file of parameters."""
def _dict_from_json_file(cls, json_file: str):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
dict_obj = json.loads(text)
return cls(**dict_obj)
return json.loads(text)
def __eq__(self, other):
return self.__dict__ == other.__dict__

View File

@ -19,6 +19,7 @@ import logging
from .configuration_auto import (
AlbertConfig,
AutoConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
@ -26,11 +27,13 @@ from .configuration_auto import (
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .configuration_utils import PretrainedConfig
from .modeling_albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM,
@ -129,7 +132,8 @@ class AutoModel(object):
or the `AutoModel.from_config(config)` class methods.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -286,32 +290,36 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "t5" in pretrained_model_name_or_path:
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path:
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "camembert" in pretrained_model_name_or_path:
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "roberta" in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, T5Config):
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, DistilBertConfig):
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, BertConfig):
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, GPT2Config):
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLNetConfig):
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLMConfig):
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, CTRLConfig):
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
@ -329,7 +337,8 @@ class AutoModelWithLMHead(object):
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -407,7 +416,8 @@ class AutoModelWithLMHead(object):
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -484,32 +494,56 @@ class AutoModelWithLMHead(object):
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "t5" in pretrained_model_name_or_path:
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path:
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "camembert" in pretrained_model_name_or_path:
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "roberta" in pretrained_model_name_or_path:
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, T5Config):
return T5WithLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, DistilBertConfig):
return DistilBertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, AlbertConfig):
return AlbertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CamembertConfig):
return CamembertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return RobertaForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, GPT2Config):
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMWithLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CTRLConfig):
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
@ -527,7 +561,8 @@ class AutoModelForSequenceClassification(object):
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -592,7 +627,8 @@ class AutoModelForSequenceClassification(object):
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -665,32 +701,42 @@ class AutoModelForSequenceClassification(object):
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "distilbert" in pretrained_model_name_or_path:
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig):
return DistilBertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "albert" in pretrained_model_name_or_path:
elif isinstance(config, AlbertConfig):
return AlbertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "camembert" in pretrained_model_name_or_path:
elif isinstance(config, CamembertConfig):
return CamembertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "xlm-roberta" in pretrained_model_name_or_path:
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "roberta" in pretrained_model_name_or_path:
elif isinstance(config, RobertaConfig):
return RobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "bert" in pretrained_model_name_or_path:
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
@ -708,7 +754,8 @@ class AutoModelForQuestionAnswering(object):
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -763,7 +810,8 @@ class AutoModelForQuestionAnswering(object):
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -830,16 +878,30 @@ class AutoModelForQuestionAnswering(object):
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "distilbert" in pretrained_model_name_or_path:
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path:
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig):
return DistilBertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, AlbertConfig):
return AlbertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
@ -893,7 +955,8 @@ class AutoModelForTokenClassification:
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -959,24 +1022,34 @@ class AutoModelForTokenClassification:
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "camembert" in pretrained_model_name_or_path:
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, CamembertConfig):
return CamembertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "distilbert" in pretrained_model_name_or_path:
elif isinstance(config, DistilBertConfig):
return DistilBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "xlm-roberta" in pretrained_model_name_or_path:
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return RobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "roberta" in pretrained_model_name_or_path:
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "

View File

@ -18,16 +18,20 @@
import logging
from .configuration_auto import (
AlbertConfig,
AutoConfig,
BertConfig,
CTRLConfig,
DistilBertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
XLMConfig,
XLNetConfig,
)
from .configuration_utils import PretrainedConfig
from .modeling_tf_albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TFAlbertForMaskedLM,
@ -113,7 +117,8 @@ class TFAutoModel(object):
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -257,28 +262,38 @@ class TFAutoModel(object):
model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
"""
if "t5" in pretrained_model_name_or_path:
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "distilbert" in pretrained_model_name_or_path:
return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path:
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "roberta" in pretrained_model_name_or_path:
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
return TFOpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
return TFTransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, T5Config):
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, DistilBertConfig):
return TFDistilBertModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, AlbertConfig):
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, RobertaConfig):
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, BertConfig):
return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return TFOpenAIGPTModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, GPT2Config):
return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TFTransfoXLModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLMConfig):
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, CTRLConfig):
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
@ -295,7 +310,8 @@ class TFAutoModelWithLMHead(object):
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -368,7 +384,8 @@ class TFAutoModelWithLMHead(object):
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -443,28 +460,54 @@ class TFAutoModelWithLMHead(object):
model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
"""
if "t5" in pretrained_model_name_or_path:
return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "distilbert" in pretrained_model_name_or_path:
return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path:
return TFAlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "roberta" in pretrained_model_name_or_path:
return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return TFBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
return TFOpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
return TFGPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
return TFTransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return TFXLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return TFXLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
return TFCTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, T5Config):
return TFT5WithLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, DistilBertConfig):
return TFDistilBertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, AlbertConfig):
return TFAlbertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return TFRobertaForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return TFBertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, OpenAIGPTConfig):
return TFOpenAIGPTLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, GPT2Config):
return TFGPT2LMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, TransfoXLConfig):
return TFTransfoXLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return TFXLMWithLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CTRLConfig):
return TFCTRLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
@ -481,7 +524,8 @@ class TFAutoModelForSequenceClassification(object):
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -537,7 +581,8 @@ class TFAutoModelForSequenceClassification(object):
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -610,28 +655,34 @@ class TFAutoModelForSequenceClassification(object):
model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
"""
if "distilbert" in pretrained_model_name_or_path:
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig):
return TFDistilBertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "albert" in pretrained_model_name_or_path:
elif isinstance(config, AlbertConfig):
return TFAlbertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "roberta" in pretrained_model_name_or_path:
elif isinstance(config, RobertaConfig):
return TFRobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "bert" in pretrained_model_name_or_path:
elif isinstance(config, BertConfig):
return TFBertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "xlnet" in pretrained_model_name_or_path:
elif isinstance(config, XLNetConfig):
return TFXLNetForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return TFXLMForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "xlm" in pretrained_model_name_or_path:
return TFXLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
@ -647,7 +698,8 @@ class TFAutoModelForQuestionAnswering(object):
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -699,7 +751,8 @@ class TFAutoModelForQuestionAnswering(object):
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -771,19 +824,25 @@ class TFAutoModelForQuestionAnswering(object):
model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
"""
if "distilbert" in pretrained_model_name_or_path:
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig):
return TFDistilBertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "bert" in pretrained_model_name_or_path:
return TFBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
elif isinstance(config, BertConfig):
return TFBertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetForQuestionAnsweringSimple.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "xlm" in pretrained_model_name_or_path:
elif isinstance(config, XLMConfig):
return TFXLMForQuestionAnsweringSimple.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(
@ -833,7 +892,8 @@ class TFAutoModelForTokenClassification:
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -898,17 +958,25 @@ class TFAutoModelForTokenClassification:
model = TFAutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "bert" in pretrained_model_name_or_path:
return TFBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return TFXLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "distilbert" in pretrained_model_name_or_path:
return TFDistilBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, BertConfig):
return TFBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif "roberta" in pretrained_model_name_or_path:
elif isinstance(config, XLNetConfig):
return TFXLNetForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, DistilBertConfig):
return TFDistilBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return TFRobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(

View File

@ -17,6 +17,23 @@
import logging
from .configuration_auto import (
AlbertConfig,
AutoConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer
from .tokenization_bert import BertTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer
@ -43,7 +60,8 @@ class AutoTokenizer(object):
class method.
The `from_pretrained()` method take care of returning the correct tokenizer class instance
using pattern matching on the `pretrained_model_name_or_path` string.
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
@ -72,7 +90,7 @@ class AutoTokenizer(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r""" Instantiate a one of the tokenizer classes of the library
r""" Instantiate one of the tokenizer classes of the library
from a pre-trained model vocabulary.
The tokenizer class to instantiate is selected as the first pattern matching
@ -129,33 +147,38 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
"""
if "t5" in pretrained_model_name_or_path:
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "albert" in pretrained_model_name_or_path:
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "camembert" in pretrained_model_name_or_path:
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "roberta" in pretrained_model_name_or_path:
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "bert-base-japanese" in pretrained_model_name_or_path:
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "bert" in pretrained_model_name_or_path:
if isinstance(config, T5Config):
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, DistilBertConfig):
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, BertConfig):
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
elif isinstance(config, GPT2Config):
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
elif isinstance(config, TransfoXLConfig):
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
elif isinstance(config, XLNetConfig):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
elif isinstance(config, XLMConfig):
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
elif isinstance(config, CTRLConfig):
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "

3
tests/fixtures/dummy-config.json vendored Normal file
View File

@ -0,0 +1,3 @@
{
"model_type": "roberta"
}

View File

@ -0,0 +1,38 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from transformers.configuration_auto import AutoConfig
from transformers.configuration_bert import BertConfig
from transformers.configuration_roberta import RobertaConfig
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
class AutoConfigTest(unittest.TestCase):
def test_config_from_model_shortcut(self):
config = AutoConfig.from_pretrained("bert-base-uncased")
self.assertIsInstance(config, BertConfig)
def test_config_from_model_type(self):
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
self.assertIsInstance(config, RobertaConfig)
def test_config_for_model_str(self):
config = AutoConfig.for_model("roberta")
self.assertIsInstance(config, RobertaConfig)

View File

@ -83,7 +83,7 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model)
self.assertIsInstance(model, BertForSequenceClassification)
@slow
# @slow
def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:

View File

@ -29,7 +29,7 @@ from .utils import SMALL_MODEL_IDENTIFIER, slow
class AutoTokenizerTest(unittest.TestCase):
@slow
# @slow
def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: