doc on base classes

This commit is contained in:
thomwolf 2019-07-15 12:08:06 +02:00
parent 44c985facd
commit e28d8bde0d

View File

@ -54,7 +54,8 @@ else:
class PretrainedConfig(object):
""" An abstract class to handle dowloading a model pretrained config.
""" Base class for all configuration classes.
Handle a few common parameters and methods for loading/downloading/saving configurations.
"""
pretrained_config_archive_map = {}
@ -66,7 +67,7 @@ class PretrainedConfig(object):
self.torchscript = kwargs.pop('torchscript', False)
def save_pretrained(self, save_directory):
""" Save a configuration file to a directory, so that it
""" Save a configuration object to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
@ -78,16 +79,30 @@ class PretrainedConfig(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
"""
Instantiate a PretrainedConfig from a pre-trained model configuration.
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a directory containing a configuration file `config.json` for the model,
- a path or url to a configuration file for the model.
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
**pretrained_model_name_or_path**: either:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
- a path to a `directory` containing a configuration file saved
using the `save_pretrained(save_directory)` method.
- a path or url to a saved configuration `file`.
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters.
Examples::
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True)
>>> assert config.output_attention == True
"""
cache_dir = kwargs.pop('cache_dir', None)
@ -172,7 +187,7 @@ class PretrainedConfig(object):
class PreTrainedModel(nn.Module):
""" An abstract class to handle storing model config and
""" Base class for all models. Handle loading/storing model config and
a simple interface for dowloading and loading pretrained models.
"""
config_class = PretrainedConfig
@ -199,11 +214,12 @@ class PreTrainedModel(nn.Module):
Reducing the size will remove vectors from the end
Args:
new_num_tokens: (Optional) New number of tokens in the embedding matrix.
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return:
Return: ``torch.nn.Embeddings``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
if new_num_tokens is None:
@ -236,13 +252,15 @@ class PreTrainedModel(nn.Module):
def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Args:
new_num_tokens: (Optional) New number of tokens in the embedding matrix.
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: does nothing.
Return:
Return: ``torch.nn.Embeddings``
Pointer to the input tokens Embedding Module of the model
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
@ -262,7 +280,8 @@ class PreTrainedModel(nn.Module):
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
Args:
heads_to_prune: dict of {layer_num (int): list of heads to prune in this layer (list of int)}
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
base_model._prune_heads(heads_to_prune)
@ -286,26 +305,47 @@ class PreTrainedModel(nn.Module):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load, or
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a tensorflow pretrained model checkpoint containing:
. `config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
config: an optional configuration for the model
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
**pretrained_model_name_or_path**: either:
- a string with the `shortcut name` of a pre-trained model to load from cache
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
- a path to a `directory` containing a configuration file saved
using the `save_pretrained(save_directory)` method.
- a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`).
In this case, ``from_tf`` should be set to True and a configuration object should be
provided as `config` argument. This loading option is slower than converting the TensorFlow
checkpoint in a PyTorch model using the provided conversion scripts and loading
the PyTorch model afterwards.
**config**: an optional configuration for the model to use instead of an automatically loaded configuation.
Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
- the model was saved using the `save_pretrained(save_directory)` (loaded by suppling the save directory).
**state_dict**: an optional state dictionnary for the model to use instead of a state dictionary loaded
from saved weights file.
This option can be used if you want to create a model from a pretrained configuraton but load your own weights.
In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not
a simpler option.
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
**output_loading_info**: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
**kwargs**: (`optional`) dict:
Dictionnary of key, values to update the configuration object after loading.
Can be used to override selected configuration parameters. E.g. ``output_attention=True``
Examples::
>>> model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
>>> model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
>>> model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
>>> assert model.config.output_attention == True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
>>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
config = kwargs.pop('config', None)
state_dict = kwargs.pop('state_dict', None)
@ -428,7 +468,7 @@ class PreTrainedModel(nn.Module):
class Conv1D(nn.Module):
def __init__(self, nf, nx):
""" Conv1D layer as defined by Alec for GPT (and also used in GPT-2)
""" Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
Basically works like a Linear layer but the weights are transposed
"""
super(Conv1D, self).__init__()
@ -612,7 +652,7 @@ class SQuADHead(nn.Module):
class SequenceSummary(nn.Module):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
- 'last' => [default] take the last token hidden state (like XLNet)