mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
doc on base classes
This commit is contained in:
parent
44c985facd
commit
e28d8bde0d
@ -54,7 +54,8 @@ else:
|
||||
|
||||
|
||||
class PretrainedConfig(object):
|
||||
""" An abstract class to handle dowloading a model pretrained config.
|
||||
""" Base class for all configuration classes.
|
||||
Handle a few common parameters and methods for loading/downloading/saving configurations.
|
||||
"""
|
||||
pretrained_config_archive_map = {}
|
||||
|
||||
@ -66,7 +67,7 @@ class PretrainedConfig(object):
|
||||
self.torchscript = kwargs.pop('torchscript', False)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a configuration file to a directory, so that it
|
||||
""" Save a configuration object to a directory, so that it
|
||||
can be re-loaded using the `from_pretrained(save_directory)` class method.
|
||||
"""
|
||||
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||
@ -78,16 +79,30 @@ class PretrainedConfig(object):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
|
||||
"""
|
||||
Instantiate a PretrainedConfig from a pre-trained model configuration.
|
||||
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `xlnet-large-cased`
|
||||
- a path or url to a directory containing a configuration file `config.json` for the model,
|
||||
- a path or url to a configuration file for the model.
|
||||
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
|
||||
**pretrained_model_name_or_path**: either:
|
||||
- a string with the `shortcut name` of a pre-trained model configuration to load from cache
|
||||
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
|
||||
- a path to a `directory` containing a configuration file saved
|
||||
using the `save_pretrained(save_directory)` method.
|
||||
- a path or url to a saved configuration `file`.
|
||||
**cache_dir**: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
**kwargs**: (`optional`) dict:
|
||||
Dictionnary of key, values to update the configuration object after loading.
|
||||
Can be used to override selected configuration parameters.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
|
||||
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
|
||||
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True)
|
||||
>>> assert config.output_attention == True
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
|
||||
@ -172,7 +187,7 @@ class PretrainedConfig(object):
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle storing model config and
|
||||
""" Base class for all models. Handle loading/storing model config and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = PretrainedConfig
|
||||
@ -199,11 +214,12 @@ class PreTrainedModel(nn.Module):
|
||||
Reducing the size will remove vectors from the end
|
||||
|
||||
Args:
|
||||
new_num_tokens: (Optional) New number of tokens in the embedding matrix.
|
||||
new_num_tokens: (`optional`) int
|
||||
New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: return the provided token Embedding Module.
|
||||
Return:
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||
"""
|
||||
if new_num_tokens is None:
|
||||
@ -236,13 +252,15 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
||||
Args:
|
||||
new_num_tokens: (Optional) New number of tokens in the embedding matrix.
|
||||
new_num_tokens: (`optional`) int
|
||||
New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: does nothing.
|
||||
Return:
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embedding Module of the model
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
@ -262,7 +280,8 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
Args:
|
||||
heads_to_prune: dict of {layer_num (int): list of heads to prune in this layer (list of int)}
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
@ -286,26 +305,47 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load, or
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
|
||||
- a path or url to a tensorflow pretrained model checkpoint containing:
|
||||
. `config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
config: an optional configuration for the model
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
|
||||
instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific XLNet class
|
||||
(ex: num_labels for XLNetForSequenceClassification)
|
||||
**pretrained_model_name_or_path**: either:
|
||||
- a string with the `shortcut name` of a pre-trained model to load from cache
|
||||
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
|
||||
- a path to a `directory` containing a configuration file saved
|
||||
using the `save_pretrained(save_directory)` method.
|
||||
- a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`).
|
||||
In this case, ``from_tf`` should be set to True and a configuration object should be
|
||||
provided as `config` argument. This loading option is slower than converting the TensorFlow
|
||||
checkpoint in a PyTorch model using the provided conversion scripts and loading
|
||||
the PyTorch model afterwards.
|
||||
**config**: an optional configuration for the model to use instead of an automatically loaded configuation.
|
||||
Configuration can be automatically loaded when:
|
||||
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
|
||||
- the model was saved using the `save_pretrained(save_directory)` (loaded by suppling the save directory).
|
||||
**state_dict**: an optional state dictionnary for the model to use instead of a state dictionary loaded
|
||||
from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuraton but load your own weights.
|
||||
In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not
|
||||
a simpler option.
|
||||
**cache_dir**: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
**output_loading_info**: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
**kwargs**: (`optional`) dict:
|
||||
Dictionnary of key, values to update the configuration object after loading.
|
||||
Can be used to override selected configuration parameters. E.g. ``output_attention=True``
|
||||
|
||||
Examples::
|
||||
|
||||
>>> model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
||||
>>> model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
>>> model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
||||
>>> assert model.config.output_attention == True
|
||||
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
>>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
||||
>>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
|
||||
"""
|
||||
config = kwargs.pop('config', None)
|
||||
state_dict = kwargs.pop('state_dict', None)
|
||||
@ -428,7 +468,7 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, nx):
|
||||
""" Conv1D layer as defined by Alec for GPT (and also used in GPT-2)
|
||||
""" Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
|
||||
Basically works like a Linear layer but the weights are transposed
|
||||
"""
|
||||
super(Conv1D, self).__init__()
|
||||
@ -612,7 +652,7 @@ class SQuADHead(nn.Module):
|
||||
|
||||
|
||||
class SequenceSummary(nn.Module):
|
||||
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
|
||||
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
|
||||
Args of the config class:
|
||||
summary_type:
|
||||
- 'last' => [default] take the last token hidden state (like XLNet)
|
||||
|
Loading…
Reference in New Issue
Block a user