Added multiple AutoModel classes: AutoModelWithLMHead, AutoModelForQuestionAnswering and AutoModelForSequenceClassification

This commit is contained in:
LysandreJik 2019-08-26 15:44:30 -04:00
parent df9d6effae
commit cb60ce59dd
2 changed files with 344 additions and 18 deletions

View File

@ -10,7 +10,8 @@ from .tokenization_roberta import RobertaTokenizer
from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_utils import (PreTrainedTokenizer)
from .modeling_auto import (AutoConfig, AutoModel) from .modeling_auto import (AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
AutoModelWithLMHead)
from .modeling_bert import (BertConfig, BertPreTrainedModel, BertModel, BertForPreTraining, from .modeling_bert import (BertConfig, BertPreTrainedModel, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,

View File

@ -23,13 +23,13 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .modeling_bert import BertConfig, BertModel from .modeling_bert import BertConfig, BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTLMHeadModel
from .modeling_gpt2 import GPT2Config, GPT2Model from .modeling_gpt2 import GPT2Config, GPT2Model, GPT2LMHeadModel
from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel
from .modeling_xlnet import XLNetConfig, XLNetModel from .modeling_xlnet import XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering
from .modeling_xlm import XLMConfig, XLMModel from .modeling_xlm import XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
from .modeling_roberta import RobertaConfig, RobertaModel from .modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
@ -137,20 +137,20 @@ class AutoModel(object):
when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
The `from_pretrained()` method take care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string. using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model) - contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model) - contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model) - contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model) - contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model) - contains `xlm`: XLMModel (XLM model)
- contains `roberta`: RobertaModel (RoBERTa model)
This class cannot be instantiated using `__init__()` (throw an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoModel is designed to be instantiated " raise EnvironmentError("AutoModel is designed to be instantiated "
@ -158,18 +158,18 @@ class AutoModel(object):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiate a one of the base model classes of the library r""" Instantiates one of the base model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The base model class to instantiate is selected as the first pattern matching The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model) - contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model) - contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model) - contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model) - contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model) - contains `xlm`: XLMModel (XLM model)
- contains `roberta`: RobertaModel (RoBERTa model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()` To train the model, you should first set it back in training mode with `model.train()`
@ -186,12 +186,12 @@ class AutoModel(object):
checkpoint in a PyTorch model using the provided conversion scripts and loading checkpoint in a PyTorch model using the provided conversion scripts and loading
the PyTorch model afterwards. the PyTorch model afterwards.
**model_args**: (`optional`) Sequence: **model_args**: (`optional`) Sequence:
All remaning positional arguments will be passed to the underlying model's __init__ function All remaining positional arguments will be passed to the underlying model's __init__ function
**config**: an optional configuration for the model to use instead of an automatically loaded configuation. **config**: an optional configuration for the model to use instead of an automatically loaded configuration.
Configuration can be automatically loaded when: Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
- the model was saved using the `save_pretrained(save_directory)` (loaded by suppling the save directory). - the model was saved using the `save_pretrained(save_directory)` (loaded by supplying the save directory).
**state_dict**: an optional state dictionnary for the model to use instead of a state dictionary loaded **state_dict**: an optional state dictionary for the model to use instead of a state dictionary loaded
from saved weights file. from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not
@ -200,7 +200,7 @@ class AutoModel(object):
Path to a directory in which a downloaded pre-trained model Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used. configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean: **output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
**kwargs**: (`optional`) dict: **kwargs**: (`optional`) dict:
Dictionary of key, values to update the configuration object after loading. Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``. Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
@ -243,3 +243,328 @@ class AutoModel(object):
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path)) "'xlm', 'roberta'".format(pretrained_model_name_or_path))
class AutoModelWithLMHead(object):
r"""
:class:`~pytorch_transformers.AutoModelWithLMHead` is a generic model class
that will be instantiated as one of the language modeling model classes of the library
when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError("AutoModelWithLMHead is designed to be instantiated "
"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
Params:
**pretrained_model_name_or_path**: either:
- a string with the `shortcut name` of a pre-trained model to load from cache
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
- a path to a `directory` containing a configuration file saved
using the `save_pretrained(save_directory)` method.
- a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`).
In this case, ``from_tf`` should be set to True and a configuration object should be
provided as `config` argument. This loading option is slower than converting the TensorFlow
checkpoint in a PyTorch model using the provided conversion scripts and loading
the PyTorch model afterwards.
**model_args**: (`optional`) Sequence:
All remaining positional arguments will be passed to the underlying model's __init__ function
**config**: an optional configuration for the model to use instead of an automatically loaded configuration.
Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
- the model was saved using the `save_pretrained(save_directory)` (loaded by supplying the save directory).
**state_dict**: an optional state dictionary for the model to use instead of a state dictionary loaded
from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not
a simpler option.
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
**kwargs**: (`optional`) dict:
Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
- If a configuration is provided with `config`, **kwargs will be directly passed
to the underlying model's __init__ method.
- If a configuration is not provided, **kwargs will be first passed to the pretrained
model configuration class loading function (`PretrainedConfig.from_pretrained`).
Each key of **kwargs that corresponds to a configuration attribute
will be used to override said attribute with the supplied **kwargs value.
Remaining keys that do not correspond to any configuration attribute will
be passed to the underlying model's __init__ function.
Examples::
model = AutoModelWithLMHead.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModelWithLMHead.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = AutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'roberta' in pretrained_model_name_or_path:
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path:
return OpenAIGPTLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path:
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path:
return TransfoXLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
class AutoModelForSequenceClassification(object):
r"""
:class:`~pytorch_transformers.AutoModelForSequenceClassification` is a generic model class
that will be instantiated as one of the sequence classification model classes of the library
when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `xlnet`: XLNetForSequenceClassification (XLNet model)
- contains `xlm`: XLMForSequenceClassification (XLM model)
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError("AutoModelWithLMHead is designed to be instantiated "
"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the sequence classification model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `xlnet`: XLNetForSequenceClassification (XLNet model)
- contains `xlm`: XLMForSequenceClassification (XLM model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
Params:
**pretrained_model_name_or_path**: either:
- a string with the `shortcut name` of a pre-trained model to load from cache
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
- a path to a `directory` containing a configuration file saved
using the `save_pretrained(save_directory)` method.
- a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`).
In this case, ``from_tf`` should be set to True and a configuration object should be
provided as `config` argument. This loading option is slower than converting the TensorFlow
checkpoint in a PyTorch model using the provided conversion scripts and loading
the PyTorch model afterwards.
**model_args**: (`optional`) Sequence:
All remaining positional arguments will be passed to the underlying model's __init__ function
**config**: an optional configuration for the model to use instead of an automatically loaded configuration.
Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
- the model was saved using the `save_pretrained(save_directory)` (loaded by supplying the save directory).
**state_dict**: an optional state dictionary for the model to use instead of a state dictionary loaded
from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not
a simpler option.
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
**kwargs**: (`optional`) dict:
Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
- If a configuration is provided with `config`, **kwargs will be directly passed
to the underlying model's __init__ method.
- If a configuration is not provided, **kwargs will be first passed to the pretrained
model configuration class loading function (`PretrainedConfig.from_pretrained`).
Each key of **kwargs that corresponds to a configuration attribute
will be used to override said attribute with the supplied **kwargs value.
Remaining keys that do not correspond to any configuration attribute will
be passed to the underlying model's __init__ function.
Examples::
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'roberta' in pretrained_model_name_or_path:
return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
return XLNetForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm', 'roberta'".format(pretrained_model_name_or_path))
class AutoModelForQuestionAnswering(object):
r"""
:class:`~pytorch_transformers.AutoModelForQuestionAnswering` is a generic model class
that will be instantiated as one of the question answering model classes of the library
when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError("AutoModelWithLMHead is designed to be instantiated "
"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
Params:
**pretrained_model_name_or_path**: either:
- a string with the `shortcut name` of a pre-trained model to load from cache
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
- a path to a `directory` containing a configuration file saved
using the `save_pretrained(save_directory)` method.
- a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`).
In this case, ``from_tf`` should be set to True and a configuration object should be
provided as `config` argument. This loading option is slower than converting the TensorFlow
checkpoint in a PyTorch model using the provided conversion scripts and loading
the PyTorch model afterwards.
**model_args**: (`optional`) Sequence:
All remaining positional arguments will be passed to the underlying model's __init__ function
**config**: an optional configuration for the model to use instead of an automatically loaded configuration.
Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
- the model was saved using the `save_pretrained(save_directory)` (loaded by supplying the save directory).
**state_dict**: an optional state dictionary for the model to use instead of a state dictionary loaded
from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not
a simpler option.
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
**kwargs**: (`optional`) dict:
Dictionary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
- If a configuration is provided with `config`, **kwargs will be directly passed
to the underlying model's __init__ method.
- If a configuration is not provided, **kwargs will be first passed to the pretrained
model configuration class loading function (`PretrainedConfig.from_pretrained`).
Each key of **kwargs that corresponds to a configuration attribute
will be used to override said attribute with the supplied **kwargs value.
Remaining keys that do not correspond to any configuration attribute will
be passed to the underlying model's __init__ function.
Examples::
model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'bert' in pretrained_model_name_or_path:
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
return XLNetForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm'".format(pretrained_model_name_or_path))