AutoConfig + other Auto classes honor model_type

This commit is contained in:
Julien Chaumond 2020-01-11 02:46:17 +00:00
parent 2f32dfd33b
commit 4d1c98c012
9 changed files with 503 additions and 284 deletions

View File

@ -16,6 +16,7 @@
import logging import logging
from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
@ -27,6 +28,7 @@ from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, Open
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .configuration_utils import PretrainedConfig
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
@ -56,17 +58,38 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
) )
class AutoConfig(object): CONFIG_MAPPING = OrderedDict(
[
("t5", T5Config,),
("distilbert", DistilBertConfig,),
("albert", AlbertConfig,),
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("roberta", RobertaConfig,),
("bert", BertConfig,),
("openai-gpt", OpenAIGPTConfig,),
("gpt2", GPT2Config,),
("transfo-xl", TransfoXLConfig,),
("xlnet", XLNetConfig,),
("xlm", XLMConfig,),
("ctrl", CTRLConfig,),
]
)
class AutoConfig:
r""":class:`~transformers.AutoConfig` is a generic configuration class r""":class:`~transformers.AutoConfig` is a generic configuration class
that will be instantiated as one of the configuration classes of the library that will be instantiated as one of the configuration classes of the library
when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
The `from_pretrained()` method take care of returning the correct model class instance The `from_pretrained()` method take care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching When using string matching, the configuration class is matched on
in the `pretrained_model_name_or_path` string (in the following order): the `pretrained_model_name_or_path` string in the following order:
- contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model) - contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model) - contains `camembert`: CamembertConfig (CamemBERT model)
@ -90,41 +113,23 @@ class AutoConfig(object):
@classmethod @classmethod
def for_model(cls, model_type, *args, **kwargs): def for_model(cls, model_type, *args, **kwargs):
if "distilbert" in model_type: for pattern, config_class in CONFIG_MAPPING.items():
return DistilBertConfig(*args, **kwargs) if pattern in model_type:
elif "roberta" in model_type: return config_class(*args, **kwargs)
return RobertaConfig(*args, **kwargs)
elif "bert" in model_type:
return BertConfig(*args, **kwargs)
elif "openai-gpt" in model_type:
return OpenAIGPTConfig(*args, **kwargs)
elif "gpt2" in model_type:
return GPT2Config(*args, **kwargs)
elif "transfo-xl" in model_type:
return TransfoXLConfig(*args, **kwargs)
elif "xlnet" in model_type:
return XLNetConfig(*args, **kwargs)
elif "xlm" in model_type:
return XLMConfig(*args, **kwargs)
elif "ctrl" in model_type:
return CTRLConfig(*args, **kwargs)
elif "albert" in model_type:
return AlbertConfig(*args, **kwargs)
elif "camembert" in model_type:
return CamembertConfig(*args, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contain one of {}".format(
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " model_type, ", ".join(CONFIG_MAPPING.keys())
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type) )
) )
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a one of the configuration classes of the library r""" Instantiate one of the configuration classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The configuration class to instantiate is selected as the first pattern matching The configuration class to instantiate is selected
in the `pretrained_model_name_or_path` string (in the following order): based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
- contains `t5`: T5Config (T5 model) - contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model) - contains `albert`: AlbertConfig (ALBERT model)
@ -183,36 +188,21 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
if "t5" in pretrained_model_name_or_path: config_dict, _ = PretrainedConfig.resolved_config_dict(
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs) pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
elif "distilbert" in pretrained_model_name_or_path: )
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "albert" in pretrained_model_name_or_path: if "model_type" in config_dict:
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config_class = CONFIG_MAPPING[config_dict["model_type"]]
elif "camembert" in pretrained_model_name_or_path: return config_class.from_dict(config_dict, **kwargs)
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) else:
elif "xlm-roberta" in pretrained_model_name_or_path: # Fallback: use pattern matching on the string.
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) for pattern, config_class in CONFIG_MAPPING.items():
elif "roberta" in pretrained_model_name_or_path: if pattern in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return config_class.from_dict(config_dict, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "gpt2" in pretrained_model_name_or_path:
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "ctrl" in pretrained_model_name_or_path:
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format(
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys())
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
) )
) )

View File

@ -20,6 +20,7 @@ import copy
import json import json
import logging import logging
import os import os
from typing import Dict, Optional, Tuple
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
@ -36,7 +37,7 @@ class PretrainedConfig(object):
It only affects the model's configuration. It only affects the model's configuration.
Class attributes (overridden by derived classes): Class attributes (overridden by derived classes):
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. - ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
Parameters: Parameters:
``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
@ -154,14 +155,32 @@ class PretrainedConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
config_dict, kwargs = cls.resolved_config_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)
@classmethod
def resolved_config_dict(
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
) -> Tuple[Dict, Dict]:
"""
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
for instantiating a Config using `from_dict`.
Parameters:
pretrained_config_archive_map: (`optional`) Dict:
A map of `shortcut names` to `url`.
By default, will use the current class attribute.
"""
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_config_archive_map is None:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] pretrained_config_archive_map = cls.pretrained_config_archive_map
if pretrained_model_name_or_path in pretrained_config_archive_map:
config_file = pretrained_config_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path): elif os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
@ -178,23 +197,20 @@ class PretrainedConfig(object):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
) )
# Load config # Load config dict
config = cls.from_json_file(resolved_config_file) config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file config_file
) )
else: else:
msg = ( msg = (
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url to a configuration file named {} or " "We assumed '{}' was a path or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format( "a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path, config_file, CONFIG_NAME,
", ".join(cls.pretrained_config_archive_map.keys()),
config_file,
CONFIG_NAME,
) )
) )
raise EnvironmentError(msg) raise EnvironmentError(msg)
@ -212,6 +228,15 @@ class PretrainedConfig(object):
else: else:
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
return config_dict, kwargs
@classmethod
def from_dict(cls, config_dict: Dict, **kwargs):
"""Constructs a `Config` from a Python dictionary of parameters."""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
config = cls(**config_dict)
if hasattr(config, "pruned_heads"): if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
@ -231,17 +256,16 @@ class PretrainedConfig(object):
return config return config
@classmethod @classmethod
def from_dict(cls, json_object): def from_json_file(cls, json_file: str):
"""Constructs a `Config` from a Python dictionary of parameters.""" """Constructs a `Config` from a json file of parameters."""
return cls(**json_object) config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict)
@classmethod @classmethod
def from_json_file(cls, json_file): def _dict_from_json_file(cls, json_file: str):
"""Constructs a `Config` from a json file of parameters."""
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
dict_obj = json.loads(text) return json.loads(text)
return cls(**dict_obj)
def __eq__(self, other): def __eq__(self, other):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__

View File

@ -19,6 +19,7 @@ import logging
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
AutoConfig,
BertConfig, BertConfig,
CamembertConfig, CamembertConfig,
CTRLConfig, CTRLConfig,
@ -26,11 +27,13 @@ from .configuration_auto import (
GPT2Config, GPT2Config,
OpenAIGPTConfig, OpenAIGPTConfig,
RobertaConfig, RobertaConfig,
T5Config,
TransfoXLConfig, TransfoXLConfig,
XLMConfig, XLMConfig,
XLMRobertaConfig, XLMRobertaConfig,
XLNetConfig, XLNetConfig,
) )
from .configuration_utils import PretrainedConfig
from .modeling_albert import ( from .modeling_albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, AlbertForMaskedLM,
@ -129,7 +132,8 @@ class AutoModel(object):
or the `AutoModel.from_config(config)` class methods. or the `AutoModel.from_config(config)` class methods.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -286,32 +290,36 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "t5" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) if not isinstance(config, PretrainedConfig):
elif "distilbert" in pretrained_model_name_or_path: config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path: if isinstance(config, T5Config):
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "camembert" in pretrained_model_name_or_path: elif isinstance(config, DistilBertConfig):
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "xlm-roberta" in pretrained_model_name_or_path: elif isinstance(config, AlbertConfig):
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "roberta" in pretrained_model_name_or_path: elif isinstance(config, CamembertConfig):
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "bert" in pretrained_model_name_or_path: elif isinstance(config, XLMRobertaConfig):
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path: elif isinstance(config, RobertaConfig):
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "gpt2" in pretrained_model_name_or_path: elif isinstance(config, BertConfig):
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path: elif isinstance(config, OpenAIGPTConfig):
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "xlnet" in pretrained_model_name_or_path: elif isinstance(config, GPT2Config):
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "xlm" in pretrained_model_name_or_path: elif isinstance(config, TransfoXLConfig):
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "ctrl" in pretrained_model_name_or_path: elif isinstance(config, XLNetConfig):
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLMConfig):
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, CTRLConfig):
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
@ -329,7 +337,8 @@ class AutoModelWithLMHead(object):
class method. class method.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -407,7 +416,8 @@ class AutoModelWithLMHead(object):
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -484,32 +494,56 @@ class AutoModelWithLMHead(object):
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "t5" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) if not isinstance(config, PretrainedConfig):
elif "distilbert" in pretrained_model_name_or_path: config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path: if isinstance(config, T5Config):
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return T5WithLMHeadModel.from_pretrained(
elif "camembert" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "xlm-roberta" in pretrained_model_name_or_path: elif isinstance(config, DistilBertConfig):
return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForMaskedLM.from_pretrained(
elif "roberta" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "bert" in pretrained_model_name_or_path: elif isinstance(config, AlbertConfig):
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return AlbertForMaskedLM.from_pretrained(
elif "openai-gpt" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "gpt2" in pretrained_model_name_or_path: elif isinstance(config, CamembertConfig):
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertForMaskedLM.from_pretrained(
elif "transfo-xl" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "xlnet" in pretrained_model_name_or_path: elif isinstance(config, XLMRobertaConfig):
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLMRobertaForMaskedLM.from_pretrained(
elif "xlm" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "ctrl" in pretrained_model_name_or_path: elif isinstance(config, RobertaConfig):
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, GPT2Config):
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMWithLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CTRLConfig):
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
@ -527,7 +561,8 @@ class AutoModelForSequenceClassification(object):
class method. class method.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -592,7 +627,8 @@ class AutoModelForSequenceClassification(object):
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -665,32 +701,42 @@ class AutoModelForSequenceClassification(object):
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "distilbert" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig):
return DistilBertForSequenceClassification.from_pretrained( return DistilBertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "albert" in pretrained_model_name_or_path: elif isinstance(config, AlbertConfig):
return AlbertForSequenceClassification.from_pretrained( return AlbertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "camembert" in pretrained_model_name_or_path: elif isinstance(config, CamembertConfig):
return CamembertForSequenceClassification.from_pretrained( return CamembertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "xlm-roberta" in pretrained_model_name_or_path: elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForSequenceClassification.from_pretrained( return XLMRobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "roberta" in pretrained_model_name_or_path: elif isinstance(config, RobertaConfig):
return RobertaForSequenceClassification.from_pretrained( return RobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "bert" in pretrained_model_name_or_path:
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlm" in pretrained_model_name_or_path:
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "
@ -708,7 +754,8 @@ class AutoModelForQuestionAnswering(object):
class method. class method.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -763,7 +810,8 @@ class AutoModelForQuestionAnswering(object):
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -830,16 +878,30 @@ class AutoModelForQuestionAnswering(object):
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "distilbert" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) if not isinstance(config, PretrainedConfig):
elif "albert" in pretrained_model_name_or_path: config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "bert" in pretrained_model_name_or_path: if isinstance(config, DistilBertConfig):
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForQuestionAnswering.from_pretrained(
elif "xlnet" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "xlm" in pretrained_model_name_or_path: elif isinstance(config, AlbertConfig):
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return AlbertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return XLMForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "
@ -893,7 +955,8 @@ class AutoModelForTokenClassification:
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -959,24 +1022,34 @@ class AutoModelForTokenClassification:
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "camembert" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, CamembertConfig):
return CamembertForTokenClassification.from_pretrained( return CamembertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "distilbert" in pretrained_model_name_or_path: elif isinstance(config, DistilBertConfig):
return DistilBertForTokenClassification.from_pretrained( return DistilBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "xlm-roberta" in pretrained_model_name_or_path: elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForTokenClassification.from_pretrained( return XLMRobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return RobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, BertConfig):
return BertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return XLNetForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "roberta" in pretrained_model_name_or_path:
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "bert" in pretrained_model_name_or_path:
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "xlnet" in pretrained_model_name_or_path:
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "

View File

@ -18,16 +18,20 @@
import logging import logging
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig,
AutoConfig,
BertConfig, BertConfig,
CTRLConfig, CTRLConfig,
DistilBertConfig, DistilBertConfig,
GPT2Config, GPT2Config,
OpenAIGPTConfig, OpenAIGPTConfig,
RobertaConfig, RobertaConfig,
T5Config,
TransfoXLConfig, TransfoXLConfig,
XLMConfig, XLMConfig,
XLNetConfig, XLNetConfig,
) )
from .configuration_utils import PretrainedConfig
from .modeling_tf_albert import ( from .modeling_tf_albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TFAlbertForMaskedLM, TFAlbertForMaskedLM,
@ -113,7 +117,8 @@ class TFAutoModel(object):
class method. class method.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -257,28 +262,38 @@ class TFAutoModel(object):
model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
""" """
if "t5" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) if not isinstance(config, PretrainedConfig):
elif "distilbert" in pretrained_model_name_or_path: config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path: if isinstance(config, T5Config):
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "roberta" in pretrained_model_name_or_path: elif isinstance(config, DistilBertConfig):
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFDistilBertModel.from_pretrained(
elif "bert" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "openai-gpt" in pretrained_model_name_or_path: elif isinstance(config, AlbertConfig):
return TFOpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "gpt2" in pretrained_model_name_or_path: elif isinstance(config, RobertaConfig):
return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path: elif isinstance(config, BertConfig):
return TFTransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif "xlnet" in pretrained_model_name_or_path: elif isinstance(config, OpenAIGPTConfig):
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFOpenAIGPTModel.from_pretrained(
elif "xlm" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "ctrl" in pretrained_model_name_or_path: elif isinstance(config, GPT2Config):
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TFTransfoXLModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLMConfig):
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, CTRLConfig):
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "
@ -295,7 +310,8 @@ class TFAutoModelWithLMHead(object):
class method. class method.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -368,7 +384,8 @@ class TFAutoModelWithLMHead(object):
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -443,28 +460,54 @@ class TFAutoModelWithLMHead(object):
model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
""" """
if "t5" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) if not isinstance(config, PretrainedConfig):
elif "distilbert" in pretrained_model_name_or_path: config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "albert" in pretrained_model_name_or_path: if isinstance(config, T5Config):
return TFAlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFT5WithLMHeadModel.from_pretrained(
elif "roberta" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "bert" in pretrained_model_name_or_path: elif isinstance(config, DistilBertConfig):
return TFBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFDistilBertForMaskedLM.from_pretrained(
elif "openai-gpt" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return TFOpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "gpt2" in pretrained_model_name_or_path: elif isinstance(config, AlbertConfig):
return TFGPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFAlbertForMaskedLM.from_pretrained(
elif "transfo-xl" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return TFTransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "xlnet" in pretrained_model_name_or_path: elif isinstance(config, RobertaConfig):
return TFXLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFRobertaForMaskedLM.from_pretrained(
elif "xlm" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
return TFXLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) )
elif "ctrl" in pretrained_model_name_or_path: elif isinstance(config, BertConfig):
return TFCTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFBertForMaskedLM.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, OpenAIGPTConfig):
return TFOpenAIGPTLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, GPT2Config):
return TFGPT2LMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, TransfoXLConfig):
return TFTransfoXLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return TFXLMWithLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, CTRLConfig):
return TFCTRLLMHeadModel.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "
@ -481,7 +524,8 @@ class TFAutoModelForSequenceClassification(object):
class method. class method.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -537,7 +581,8 @@ class TFAutoModelForSequenceClassification(object):
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -610,28 +655,34 @@ class TFAutoModelForSequenceClassification(object):
model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
""" """
if "distilbert" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig):
return TFDistilBertForSequenceClassification.from_pretrained( return TFDistilBertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "albert" in pretrained_model_name_or_path: elif isinstance(config, AlbertConfig):
return TFAlbertForSequenceClassification.from_pretrained( return TFAlbertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "roberta" in pretrained_model_name_or_path: elif isinstance(config, RobertaConfig):
return TFRobertaForSequenceClassification.from_pretrained( return TFRobertaForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "bert" in pretrained_model_name_or_path: elif isinstance(config, BertConfig):
return TFBertForSequenceClassification.from_pretrained( return TFBertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "xlnet" in pretrained_model_name_or_path: elif isinstance(config, XLNetConfig):
return TFXLNetForSequenceClassification.from_pretrained( return TFXLNetForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLMConfig):
return TFXLMForSequenceClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "xlm" in pretrained_model_name_or_path:
return TFXLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "
@ -647,7 +698,8 @@ class TFAutoModelForQuestionAnswering(object):
class method. class method.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -699,7 +751,8 @@ class TFAutoModelForQuestionAnswering(object):
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -771,19 +824,25 @@ class TFAutoModelForQuestionAnswering(object):
model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
""" """
if "distilbert" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if isinstance(config, DistilBertConfig):
return TFDistilBertForQuestionAnswering.from_pretrained( return TFDistilBertForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "bert" in pretrained_model_name_or_path: elif isinstance(config, BertConfig):
return TFBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFBertForQuestionAnswering.from_pretrained(
elif "xlnet" in pretrained_model_name_or_path: pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, XLNetConfig):
return TFXLNetForQuestionAnsweringSimple.from_pretrained( return TFXLNetForQuestionAnsweringSimple.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "xlm" in pretrained_model_name_or_path: elif isinstance(config, XLMConfig):
return TFXLMForQuestionAnsweringSimple.from_pretrained( return TFXLMForQuestionAnsweringSimple.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
raise ValueError( raise ValueError(
@ -833,7 +892,8 @@ class TFAutoModelForTokenClassification:
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -898,17 +958,25 @@ class TFAutoModelForTokenClassification:
model = TFAutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = TFAutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "bert" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
return TFBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) if not isinstance(config, PretrainedConfig):
elif "xlnet" in pretrained_model_name_or_path: config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
return TFXLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif "distilbert" in pretrained_model_name_or_path: if isinstance(config, BertConfig):
return TFDistilBertForTokenClassification.from_pretrained( return TFBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
elif "roberta" in pretrained_model_name_or_path: elif isinstance(config, XLNetConfig):
return TFXLNetForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, DistilBertConfig):
return TFDistilBertForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
elif isinstance(config, RobertaConfig):
return TFRobertaForTokenClassification.from_pretrained( return TFRobertaForTokenClassification.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, *model_args, config=config, **kwargs
) )
raise ValueError( raise ValueError(

View File

@ -17,6 +17,23 @@
import logging import logging
from .configuration_auto import (
AlbertConfig,
AutoConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer from .tokenization_albert import AlbertTokenizer
from .tokenization_bert import BertTokenizer from .tokenization_bert import BertTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer
@ -43,7 +60,8 @@ class AutoTokenizer(object):
class method. class method.
The `from_pretrained()` method take care of returning the correct tokenizer class instance The `from_pretrained()` method take care of returning the correct tokenizer class instance
using pattern matching on the `pretrained_model_name_or_path` string. based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
@ -72,7 +90,7 @@ class AutoTokenizer(object):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r""" Instantiate a one of the tokenizer classes of the library r""" Instantiate one of the tokenizer classes of the library
from a pre-trained model vocabulary. from a pre-trained model vocabulary.
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected as the first pattern matching
@ -129,33 +147,38 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
""" """
if "t5" in pretrained_model_name_or_path: config = kwargs.pop("config", None)
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) if not isinstance(config, PretrainedConfig):
elif "distilbert" in pretrained_model_name_or_path: config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "albert" in pretrained_model_name_or_path: if "bert-base-japanese" in pretrained_model_name_or_path:
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "camembert" in pretrained_model_name_or_path:
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "roberta" in pretrained_model_name_or_path:
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "bert" in pretrained_model_name_or_path:
if isinstance(config, T5Config):
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, DistilBertConfig):
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif isinstance(config, BertConfig):
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "openai-gpt" in pretrained_model_name_or_path: elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "gpt2" in pretrained_model_name_or_path: elif isinstance(config, GPT2Config):
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "transfo-xl" in pretrained_model_name_or_path: elif isinstance(config, TransfoXLConfig):
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlnet" in pretrained_model_name_or_path: elif isinstance(config, XLNetConfig):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "xlm" in pretrained_model_name_or_path: elif isinstance(config, XLMConfig):
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif "ctrl" in pretrained_model_name_or_path: elif isinstance(config, CTRLConfig):
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized model identifier in {}. Should contains one of " "Unrecognized model identifier in {}. Should contains one of "

3
tests/fixtures/dummy-config.json vendored Normal file
View File

@ -0,0 +1,3 @@
{
"model_type": "roberta"
}

View File

@ -0,0 +1,38 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from transformers.configuration_auto import AutoConfig
from transformers.configuration_bert import BertConfig
from transformers.configuration_roberta import RobertaConfig
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
class AutoConfigTest(unittest.TestCase):
def test_config_from_model_shortcut(self):
config = AutoConfig.from_pretrained("bert-base-uncased")
self.assertIsInstance(config, BertConfig)
def test_config_from_model_type(self):
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
self.assertIsInstance(config, RobertaConfig)
def test_config_for_model_str(self):
config = AutoConfig.for_model("roberta")
self.assertIsInstance(config, RobertaConfig)

View File

@ -83,7 +83,7 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForSequenceClassification) self.assertIsInstance(model, BertForSequenceClassification)
@slow # @slow
def test_question_answering_model_from_pretrained(self): def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:

View File

@ -29,7 +29,7 @@ from .utils import SMALL_MODEL_IDENTIFIER, slow
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
@slow # @slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: