mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[AutoModel] Split AutoModelWithLMHead into clm, mlm, encoder-decoder (#4933)
* first commit * add new auto models * better naming * fix bert automodel * fix automodel for pretraining * add models to init * fix name typo * fix typo * better naming * future warning instead of depreciation warning
This commit is contained in:
parent
5620033115
commit
86578bb04c
@ -166,11 +166,17 @@ if is_torch_available():
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelWithLMHead,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForMultipleChoice,
|
||||
MODEL_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
@ -182,6 +188,7 @@ if is_torch_available():
|
||||
BertModel,
|
||||
BertForPreTraining,
|
||||
BertForMaskedLM,
|
||||
BertLMHeadModel,
|
||||
BertForNextSentencePrediction,
|
||||
BertForSequenceClassification,
|
||||
BertForMultipleChoice,
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import (
|
||||
@ -58,6 +59,7 @@ from .modeling_bert import (
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
BertForTokenClassification,
|
||||
BertLMHeadModel,
|
||||
BertModel,
|
||||
)
|
||||
from .modeling_camembert import (
|
||||
@ -210,6 +212,46 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
(BertConfig, BertLMHeadModel),
|
||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||
(GPT2Config, GPT2LMHeadModel),
|
||||
(TransfoXLConfig, TransfoXLLMHeadModel),
|
||||
(XLNetConfig, XLNetLMHeadModel),
|
||||
(
|
||||
XLMConfig,
|
||||
XLMWithLMHeadModel,
|
||||
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
|
||||
(CTRLConfig, CTRLLMHeadModel),
|
||||
(ReformerConfig, ReformerModelWithLMHead),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, DistilBertForMaskedLM),
|
||||
(AlbertConfig, AlbertForMaskedLM),
|
||||
(CamembertConfig, CamembertForMaskedLM),
|
||||
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||
(LongformerConfig, LongformerForMaskedLM),
|
||||
(RobertaConfig, RobertaForMaskedLM),
|
||||
(BertConfig, BertForMaskedLM),
|
||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||
(XLMConfig, XLMWithLMHeadModel),
|
||||
(ElectraConfig, ElectraForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
(MarianConfig, MarianMTModel),
|
||||
(BartConfig, BartForConditionalGeneration),
|
||||
(EncoderDecoderConfig, EncoderDecoderModel),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, DistilBertForSequenceClassification),
|
||||
@ -620,6 +662,10 @@ class AutoModelWithLMHead:
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
warnings.warn(
|
||||
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.",
|
||||
FutureWarning,
|
||||
)
|
||||
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
@ -638,7 +684,7 @@ class AutoModelWithLMHead:
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
|
||||
- `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
|
||||
- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
||||
- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
||||
@ -704,6 +750,10 @@ class AutoModelWithLMHead:
|
||||
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.",
|
||||
FutureWarning,
|
||||
)
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
@ -719,6 +769,412 @@ class AutoModelWithLMHead:
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForCausalLM:
|
||||
r"""
|
||||
:class:`~transformers.AutoModelForCausalLM` is a generic model class
|
||||
that will be instantiated as one of the language modeling model classes of the library
|
||||
when created with the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method.
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoModelForCausalLM is designed to be instantiated "
|
||||
"using the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForCausalLM.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
r""" Instantiates one of the base model classes of the library
|
||||
from a configuration.
|
||||
|
||||
Note:
|
||||
Loading a model from its configuration file does **not** load the model weights.
|
||||
It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
|
||||
the model weights
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
|
||||
- isInstance of `bert` configuration class: :class:`~transformers.BertLMHeadModel` (Bert model)
|
||||
- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||
- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||
- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||
- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
|
||||
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||
- isInstance of `reformer` configuration class: :class:`~transformers.ReformerModelWithLMHead` (Reformer model)
|
||||
|
||||
Examples::
|
||||
|
||||
config = GPT2Config.from_pretrained('gpt2') # Download configuration from S3 and cache.
|
||||
model = AutoModelForCausalLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r""" Instantiates one of the language modeling model classes of the library
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `bert`: :class:`~transformers.BertLMHeadModel` (Bert model)
|
||||
- `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||
- `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
|
||||
- `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||
- `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||
- `reformer`: :class:`~transformers.ReformerModelWithLMHead` (Google Reformer model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path:
|
||||
Either:
|
||||
|
||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||
|
||||
state_dict: (`optional`) dict:
|
||||
an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
These arguments will be passed to the configuration and the model.
|
||||
|
||||
Examples::
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
|
||||
model = AutoModelForCausalLM.from_pretrained('./test/gpt2_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
assert model.config.output_attention == True
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json')
|
||||
model = AutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForMaskedLM:
|
||||
r"""
|
||||
:class:`~transformers.AutoModelForMaskedLM` is a generic model class
|
||||
that will be instantiated as one of the language modeling model classes of the library
|
||||
when created with the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method.
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoModelForMaskedLM is designed to be instantiated "
|
||||
"using the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForMaskedLM.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
r""" Instantiates one of the base model classes of the library
|
||||
from a configuration.
|
||||
|
||||
Note:
|
||||
Loading a model from its configuration file does **not** load the model weights.
|
||||
It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
|
||||
the model weights
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||
- isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
|
||||
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
|
||||
- isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
|
||||
- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- isInstance of `xlm-roberta` configuration class: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-Roberta model)
|
||||
- isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
|
||||
- isInstance of `camembert` configuration class: :class:`~transformers.CamembertForMaskedLM` (Camembert model)
|
||||
- isInstance of `albert` configuration class: :class:`~transformers.AlbertForMaskedLM` (Albert model)
|
||||
|
||||
|
||||
Examples::
|
||||
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelForMaskedLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r""" Instantiates one of the language modeling model classes of the library
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
||||
- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
||||
- `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
|
||||
- `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
|
||||
- `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
|
||||
- `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
|
||||
- `bert`: :class:`~transformers.BertLMHeadModel` (Bert model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path:
|
||||
Either:
|
||||
|
||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||
|
||||
state_dict: (`optional`) dict:
|
||||
an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
These arguments will be passed to the configuration and the model.
|
||||
|
||||
Examples::
|
||||
|
||||
model = AutoModelForMaskedLM.from_pretrained('bert') # Download model and configuration from S3 and cache.
|
||||
model = AutoModelForMaskedLM.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
assert model.config.output_attention == True
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
||||
model = AutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForSeq2SeqLM:
|
||||
r"""
|
||||
:class:`~transformers.AutoModelForSeq2SeqLM` is a generic model class
|
||||
that will be instantiated as one of the language modeling model classes of the library
|
||||
when created with the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method.
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoModelForSeq2SeqLM is designed to be instantiated "
|
||||
"using the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForSeq2SeqLM.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
r""" Instantiates one of the base model classes of the library
|
||||
from a configuration.
|
||||
|
||||
Note:
|
||||
Loading a model from its configuration file does **not** load the model weights.
|
||||
It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
|
||||
the model weights
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
|
||||
- isInstance of `t5` configuration class: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
|
||||
- isInstance of `bart` configuration class: :class:`~transformers.BartForConditionalGeneration` (Bart model)
|
||||
- isInstance of `marian` configuration class: :class:`~transformers.MarianMTModel` (Marian model)
|
||||
- isInstance of `encoder-decoder` configuration class: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model)
|
||||
|
||||
Examples::
|
||||
|
||||
config = T5Config.from_pretrained('t5')
|
||||
model = AutoModelForSeq2SeqLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r""" Instantiates one of the language modeling model classes of the library
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
|
||||
- `bart`: :class:`~transformers.BartForConditionalGeneration` (Bert model)
|
||||
- `marian`: :class:`~transformers.MarianMTModel` (Marian model)
|
||||
- `encoder-decoder`: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path:
|
||||
Either:
|
||||
|
||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||
|
||||
state_dict: (`optional`) dict:
|
||||
an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
These arguments will be passed to the configuration and the model.
|
||||
|
||||
Examples::
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained('t5-base') # Download model and configuration from S3 and cache.
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
assert model.config.output_attention == True
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json')
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForSequenceClassification:
|
||||
r"""
|
||||
:class:`~transformers.AutoModelForSequenceClassification` is a generic model class
|
||||
|
@ -987,6 +987,9 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
assert (
|
||||
not config.is_decoder
|
||||
), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertOnlyMLMHead(config)
|
||||
|
@ -32,7 +32,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
instantiated as a transformer architecture with one of the base model
|
||||
classes of the library as encoder and another one as
|
||||
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
|
||||
class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
|
||||
"""
|
||||
config_class = EncoderDecoderConfig
|
||||
base_model_prefix = "encoder_decoder"
|
||||
@ -61,9 +61,9 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
encoder = AutoModel.from_config(config.encoder)
|
||||
|
||||
if decoder is None:
|
||||
from transformers import AutoModelWithLMHead
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
decoder = AutoModelWithLMHead.from_config(config.decoder)
|
||||
decoder = AutoModelForCausalLM.from_config(config.decoder)
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
@ -157,7 +157,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
assert (
|
||||
decoder_pretrained_model_name_or_path is not None
|
||||
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||
from .modeling_auto import AutoModelWithLMHead
|
||||
from .modeling_auto import AutoModelForCausalLM
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
from transformers import AutoConfig
|
||||
@ -176,7 +176,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
return cls(encoder=encoder, decoder=decoder)
|
||||
|
||||
|
@ -26,13 +26,20 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
BertConfig,
|
||||
GPT2Config,
|
||||
T5Config,
|
||||
AutoModel,
|
||||
BertModel,
|
||||
AutoModelForPreTraining,
|
||||
BertForPreTraining,
|
||||
AutoModelForCausalLM,
|
||||
GPT2LMHeadModel,
|
||||
AutoModelWithLMHead,
|
||||
AutoModelForMaskedLM,
|
||||
BertForMaskedLM,
|
||||
RobertaForMaskedLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
T5ForConditionalGeneration,
|
||||
AutoModelForSequenceClassification,
|
||||
BertForSequenceClassification,
|
||||
AutoModelForQuestionAnswering,
|
||||
@ -41,6 +48,8 @@ if is_torch_available():
|
||||
BertForTokenClassification,
|
||||
)
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_auto import (
|
||||
MODEL_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
@ -48,6 +57,9 @@ if is_torch_available():
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
)
|
||||
|
||||
|
||||
@ -97,6 +109,45 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
|
||||
@slow
|
||||
def test_model_for_causal_lm(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, GPT2Config)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
model, loading_info = AutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, GPT2LMHeadModel)
|
||||
|
||||
@slow
|
||||
def test_model_for_masked_lm(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||
model, loading_info = AutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
|
||||
@slow
|
||||
def test_model_for_encoder_decoder_lm(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, T5Config)
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||||
model, loading_info = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, T5ForConditionalGeneration)
|
||||
|
||||
@slow
|
||||
def test_sequence_classification_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -163,6 +214,9 @@ class AutoModelTest(unittest.TestCase):
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
)
|
||||
|
||||
for mapping in mappings:
|
||||
|
@ -27,6 +27,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertModel,
|
||||
BertLMHeadModel,
|
||||
BertForMaskedLM,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
@ -35,7 +36,7 @@ if is_torch_available():
|
||||
BertForTokenClassification,
|
||||
BertForMultipleChoice,
|
||||
)
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class BertModelTester:
|
||||
|
Loading…
Reference in New Issue
Block a user