mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
AutoConfig + other Auto classes honor model_type
This commit is contained in:
parent
2f32dfd33b
commit
4d1c98c012
@ -16,6 +16,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
@ -27,6 +28,7 @@ from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, Open
|
||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
||||
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
||||
@ -56,17 +58,38 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
)
|
||||
|
||||
|
||||
class AutoConfig(object):
|
||||
CONFIG_MAPPING = OrderedDict(
|
||||
[
|
||||
("t5", T5Config,),
|
||||
("distilbert", DistilBertConfig,),
|
||||
("albert", AlbertConfig,),
|
||||
("camembert", CamembertConfig,),
|
||||
("xlm-roberta", XLMRobertaConfig,),
|
||||
("roberta", RobertaConfig,),
|
||||
("bert", BertConfig,),
|
||||
("openai-gpt", OpenAIGPTConfig,),
|
||||
("gpt2", GPT2Config,),
|
||||
("transfo-xl", TransfoXLConfig,),
|
||||
("xlnet", XLNetConfig,),
|
||||
("xlm", XLMConfig,),
|
||||
("ctrl", CTRLConfig,),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AutoConfig:
|
||||
r""":class:`~transformers.AutoConfig` is a generic configuration class
|
||||
that will be instantiated as one of the configuration classes of the library
|
||||
when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method take care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
When using string matching, the configuration class is matched on
|
||||
the `pretrained_model_name_or_path` string in the following order:
|
||||
- contains `t5`: T5Config (T5 model)
|
||||
- contains `distilbert`: DistilBertConfig (DistilBERT model)
|
||||
- contains `albert`: AlbertConfig (ALBERT model)
|
||||
- contains `camembert`: CamembertConfig (CamemBERT model)
|
||||
@ -90,41 +113,23 @@ class AutoConfig(object):
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_type, *args, **kwargs):
|
||||
if "distilbert" in model_type:
|
||||
return DistilBertConfig(*args, **kwargs)
|
||||
elif "roberta" in model_type:
|
||||
return RobertaConfig(*args, **kwargs)
|
||||
elif "bert" in model_type:
|
||||
return BertConfig(*args, **kwargs)
|
||||
elif "openai-gpt" in model_type:
|
||||
return OpenAIGPTConfig(*args, **kwargs)
|
||||
elif "gpt2" in model_type:
|
||||
return GPT2Config(*args, **kwargs)
|
||||
elif "transfo-xl" in model_type:
|
||||
return TransfoXLConfig(*args, **kwargs)
|
||||
elif "xlnet" in model_type:
|
||||
return XLNetConfig(*args, **kwargs)
|
||||
elif "xlm" in model_type:
|
||||
return XLMConfig(*args, **kwargs)
|
||||
elif "ctrl" in model_type:
|
||||
return CTRLConfig(*args, **kwargs)
|
||||
elif "albert" in model_type:
|
||||
return AlbertConfig(*args, **kwargs)
|
||||
elif "camembert" in model_type:
|
||||
return CamembertConfig(*args, **kwargs)
|
||||
for pattern, config_class in CONFIG_MAPPING.items():
|
||||
if pattern in model_type:
|
||||
return config_class(*args, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)
|
||||
"Unrecognized model identifier in {}. Should contain one of {}".format(
|
||||
model_type, ", ".join(CONFIG_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
r""" Instantiate a one of the configuration classes of the library
|
||||
r""" Instantiate one of the configuration classes of the library
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The configuration class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
The configuration class to instantiate is selected
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
- contains `t5`: T5Config (T5 model)
|
||||
- contains `distilbert`: DistilBertConfig (DistilBERT model)
|
||||
- contains `albert`: AlbertConfig (ALBERT model)
|
||||
@ -183,36 +188,21 @@ class AutoConfig(object):
|
||||
assert unused_kwargs == {'foo': False}
|
||||
|
||||
"""
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config_dict, _ = PretrainedConfig.resolved_config_dict(
|
||||
pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
|
||||
)
|
||||
|
||||
if "model_type" in config_dict:
|
||||
config_class = CONFIG_MAPPING[config_dict["model_type"]]
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
else:
|
||||
# Fallback: use pattern matching on the string.
|
||||
for pattern, config_class in CONFIG_MAPPING.items():
|
||||
if pattern in pretrained_model_name_or_path:
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
|
||||
pretrained_model_name_or_path
|
||||
"Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format(
|
||||
pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
|
||||
@ -36,7 +37,7 @@ class PretrainedConfig(object):
|
||||
It only affects the model's configuration.
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
|
||||
- ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
|
||||
|
||||
Parameters:
|
||||
``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
|
||||
@ -154,14 +155,32 @@ class PretrainedConfig(object):
|
||||
assert unused_kwargs == {'foo': False}
|
||||
|
||||
"""
|
||||
config_dict, kwargs = cls.resolved_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def resolved_config_dict(
|
||||
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
|
||||
for instantiating a Config using `from_dict`.
|
||||
|
||||
Parameters:
|
||||
pretrained_config_archive_map: (`optional`) Dict:
|
||||
A map of `shortcut names` to `url`.
|
||||
By default, will use the current class attribute.
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||
if pretrained_config_archive_map is None:
|
||||
pretrained_config_archive_map = cls.pretrained_config_archive_map
|
||||
|
||||
if pretrained_model_name_or_path in pretrained_config_archive_map:
|
||||
config_file = pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
@ -178,23 +197,20 @@ class PretrainedConfig(object):
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
)
|
||||
# Load config
|
||||
config = cls.from_json_file(resolved_config_file)
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
if pretrained_model_name_or_path in pretrained_config_archive_map:
|
||||
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
||||
config_file
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url to a configuration file named {} or "
|
||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(cls.pretrained_config_archive_map.keys()),
|
||||
config_file,
|
||||
CONFIG_NAME,
|
||||
pretrained_model_name_or_path, config_file, CONFIG_NAME,
|
||||
)
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
@ -212,6 +228,15 @@ class PretrainedConfig(object):
|
||||
else:
|
||||
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
|
||||
|
||||
return config_dict, kwargs
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict, **kwargs):
|
||||
"""Constructs a `Config` from a Python dictionary of parameters."""
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
|
||||
config = cls(**config_dict)
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
||||
|
||||
@ -231,17 +256,16 @@ class PretrainedConfig(object):
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `Config` from a Python dictionary of parameters."""
|
||||
return cls(**json_object)
|
||||
def from_json_file(cls, json_file: str):
|
||||
"""Constructs a `Config` from a json file of parameters."""
|
||||
config_dict = cls._dict_from_json_file(json_file)
|
||||
return cls(**config_dict)
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `Config` from a json file of parameters."""
|
||||
def _dict_from_json_file(cls, json_file: str):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
dict_obj = json.loads(text)
|
||||
return cls(**dict_obj)
|
||||
return json.loads(text)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
|
@ -19,6 +19,7 @@ import logging
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
AutoConfig,
|
||||
BertConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
@ -26,11 +27,13 @@ from .configuration_auto import (
|
||||
GPT2Config,
|
||||
OpenAIGPTConfig,
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
TransfoXLConfig,
|
||||
XLMConfig,
|
||||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .modeling_albert import (
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForMaskedLM,
|
||||
@ -129,7 +132,8 @@ class AutoModel(object):
|
||||
or the `AutoModel.from_config(config)` class methods.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -286,32 +290,36 @@ class AutoModel(object):
|
||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
@ -329,7 +337,8 @@ class AutoModelWithLMHead(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -407,7 +416,8 @@ class AutoModelWithLMHead(object):
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -484,32 +494,56 @@ class AutoModelWithLMHead(object):
|
||||
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return T5WithLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return OpenAIGPTLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TransfoXLLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMWithLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
@ -527,7 +561,8 @@ class AutoModelForSequenceClassification(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -592,7 +627,8 @@ class AutoModelForSequenceClassification(object):
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -665,32 +701,42 @@ class AutoModelForSequenceClassification(object):
|
||||
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if "distilbert" in pretrained_model_name_or_path:
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return DistilBertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
@ -708,7 +754,8 @@ class AutoModelForQuestionAnswering(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -763,7 +810,8 @@ class AutoModelForQuestionAnswering(object):
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -830,16 +878,30 @@ class AutoModelForQuestionAnswering(object):
|
||||
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return DistilBertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
@ -893,7 +955,8 @@ class AutoModelForTokenClassification:
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -959,24 +1022,34 @@ class AutoModelForTokenClassification:
|
||||
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if "camembert" in pretrained_model_name_or_path:
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, CamembertConfig):
|
||||
return CamembertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
|
@ -18,16 +18,20 @@
|
||||
import logging
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
AutoConfig,
|
||||
BertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
OpenAIGPTConfig,
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
TransfoXLConfig,
|
||||
XLMConfig,
|
||||
XLNetConfig,
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .modeling_tf_albert import (
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFAlbertForMaskedLM,
|
||||
@ -113,7 +117,8 @@ class TFAutoModel(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -257,28 +262,38 @@ class TFAutoModel(object):
|
||||
model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||
|
||||
"""
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
return TFOpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
return TFTransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return TFOpenAIGPTModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TFTransfoXLModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
@ -295,7 +310,8 @@ class TFAutoModelWithLMHead(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -368,7 +384,8 @@ class TFAutoModelWithLMHead(object):
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -443,28 +460,54 @@ class TFAutoModelWithLMHead(object):
|
||||
model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||
|
||||
"""
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return TFAlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return TFBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
return TFOpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
return TFGPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
return TFTransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return TFXLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return TFXLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
return TFCTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return TFT5WithLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return TFAlbertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForMaskedLM.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return TFOpenAIGPTLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, GPT2Config):
|
||||
return TFGPT2LMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TFTransfoXLLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMWithLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return TFCTRLLMHeadModel.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
@ -481,7 +524,8 @@ class TFAutoModelForSequenceClassification(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -537,7 +581,8 @@ class TFAutoModelForSequenceClassification(object):
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -610,28 +655,34 @@ class TFAutoModelForSequenceClassification(object):
|
||||
model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||
|
||||
"""
|
||||
if "distilbert" in pretrained_model_name_or_path:
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return TFAlbertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
return TFXLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
@ -647,7 +698,8 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -699,7 +751,8 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -771,19 +824,25 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||
|
||||
"""
|
||||
if "distilbert" in pretrained_model_name_or_path:
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
return TFBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, BertConfig):
|
||||
return TFBertForQuestionAnswering.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetForQuestionAnsweringSimple.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLMConfig):
|
||||
return TFXLMForQuestionAnsweringSimple.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
@ -833,7 +892,8 @@ class TFAutoModelForTokenClassification:
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -898,17 +958,25 @@ class TFAutoModelForTokenClassification:
|
||||
model = TFAutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
if "bert" in pretrained_model_name_or_path:
|
||||
return TFBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
return TFXLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return TFDistilBertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if isinstance(config, BertConfig):
|
||||
return TFBertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return TFXLNetForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return TFDistilBertForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return TFRobertaForTokenClassification.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
|
@ -17,6 +17,23 @@
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
AutoConfig,
|
||||
BertConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
OpenAIGPTConfig,
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
TransfoXLConfig,
|
||||
XLMConfig,
|
||||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bert import BertTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
@ -43,7 +60,8 @@ class AutoTokenizer(object):
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method take care of returning the correct tokenizer class instance
|
||||
using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
@ -72,7 +90,7 @@ class AutoTokenizer(object):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
r""" Instantiate a one of the tokenizer classes of the library
|
||||
r""" Instantiate one of the tokenizer classes of the library
|
||||
from a pre-trained model vocabulary.
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
@ -129,33 +147,38 @@ class AutoTokenizer(object):
|
||||
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
|
||||
|
||||
"""
|
||||
if "t5" in pretrained_model_name_or_path:
|
||||
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "distilbert" in pretrained_model_name_or_path:
|
||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "albert" in pretrained_model_name_or_path:
|
||||
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "camembert" in pretrained_model_name_or_path:
|
||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "xlm-roberta" in pretrained_model_name_or_path:
|
||||
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "roberta" in pretrained_model_name_or_path:
|
||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "bert-base-japanese" in pretrained_model_name_or_path:
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if "bert-base-japanese" in pretrained_model_name_or_path:
|
||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "bert" in pretrained_model_name_or_path:
|
||||
|
||||
if isinstance(config, T5Config):
|
||||
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, DistilBertConfig):
|
||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, AlbertConfig):
|
||||
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, CamembertConfig):
|
||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, XLMRobertaConfig):
|
||||
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, RobertaConfig):
|
||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif isinstance(config, BertConfig):
|
||||
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "openai-gpt" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, OpenAIGPTConfig):
|
||||
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "gpt2" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, GPT2Config):
|
||||
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "transfo-xl" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, TransfoXLConfig):
|
||||
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "xlnet" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLNetConfig):
|
||||
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "xlm" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, XLMConfig):
|
||||
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif "ctrl" in pretrained_model_name_or_path:
|
||||
elif isinstance(config, CTRLConfig):
|
||||
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contains one of "
|
||||
|
3
tests/fixtures/dummy-config.json
vendored
Normal file
3
tests/fixtures/dummy-config.json
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"model_type": "roberta"
|
||||
}
|
38
tests/test_configuration_auto.py
Normal file
38
tests/test_configuration_auto.py
Normal file
@ -0,0 +1,38 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.configuration_auto import AutoConfig
|
||||
from transformers.configuration_bert import BertConfig
|
||||
from transformers.configuration_roberta import RobertaConfig
|
||||
|
||||
|
||||
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
||||
|
||||
|
||||
class AutoConfigTest(unittest.TestCase):
|
||||
def test_config_from_model_shortcut(self):
|
||||
config = AutoConfig.from_pretrained("bert-base-uncased")
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
def test_config_from_model_type(self):
|
||||
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
|
||||
def test_config_for_model_str(self):
|
||||
config = AutoConfig.for_model("roberta")
|
||||
self.assertIsInstance(config, RobertaConfig)
|
@ -83,7 +83,7 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForSequenceClassification)
|
||||
|
||||
@slow
|
||||
# @slow
|
||||
def test_question_answering_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
@ -29,7 +29,7 @@ from .utils import SMALL_MODEL_IDENTIFIER, slow
|
||||
|
||||
|
||||
class AutoTokenizerTest(unittest.TestCase):
|
||||
@slow
|
||||
# @slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user