add GPT2 torchhub compatibility

This commit is contained in:
VictorSanh 2019-06-01 15:27:43 -04:00
parent 2a329c6186
commit a92b6dc3c1
3 changed files with 175 additions and 5 deletions

165
hubconfs/gpt2_hubconf.py Normal file
View File

@ -0,0 +1,165 @@
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer
from pytorch_pretrained_bert.modeling_openai import (
GPT2Model,
GPT2LMHeadModel,
GPT2DoubleHeadsModel
)
# A lot of models share the same param doc. Use a decorator
# to save typing
gpt2_docstring = """
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `gpt2`
- a path or url to a pretrained model archive containing:
. `gpt2_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a GPT2Model instance
- a path or url to a pretrained model archive containing:
. `gpt2_config.json` a configuration file for the model
. a TensorFlow checkpoint with trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT-2 class
"""
def _append_from_pretrained_docstring(docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + docstr
return fn
return docstring_decorator
def gpt2Tokenizer(*args, **kwargs):
"""
Instantiate a GPT-2 BPE tokenizer for OpenAI GPT-2 from a pre-trained/customized vocab file.
Peculiarities:
- Byte-level BPE
Args:
pretrained_model_name_or_path: Path to pretrained model archive
or one of pre-trained vocab configs below.
* openai-gpt
Keyword args:
special_tokens: Special tokens in vocabulary that are not pretrained ([SEP], [CLS]...)
Default: None
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
Default: None
Example:
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Tokenizer', 'gpt2')
>>> text = "Who was Jim Henson ?"
>>> indexed_tokens = tokenizer.encode(tokenized_text)
"""
tokenizer = GPT2Tokenizer.from_pretrained(*args, **kwargs)
return tokenizer
@_append_from_pretrained_docstring(gpt2_docstring)
def gpt2Model(*args, **kwargs):
"""
gpt2Model is the basic OpenAI GPT-2 Transformer model based on
identical stacked masked self-attention blocks and pre-trained
on large scale dataset using language modeling signal.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Tokenizer', 'gpt2')
# Prepare tokenized input
>>> text_1 = "Who was Jim Henson ?"
>>> text_2 = "Jim Henson was a puppeteer"
>>> indexed_tokens_1 = tokenizer.encode(text_1)
>>> indexed_tokens_2 = tokenizer.encode(text_2)
>>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
>>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
# Load gpt2Model
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Model', 'gpt2')
>>> model.eval()
# Predict hidden states features for each layer
# past can be used to reuse precomputed hidden state in a subsequent predictions
>>> with torch.no_grad():
hidden_states_1, past = model(tokens_tensor_1)
hidden_states_2, past = model(tokens_tensor_2, past=past)
"""
model = GPT2Model.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(gpt2_docstring)
def gpt2LMHeadModel(*args, **kwargs):
"""
gpt2LMHeadModel is the OpenAI GPT-2 Transformer model with the
tied (pre-trained) language modeling head on top.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Tokenizer', 'gpt2')
# Prepare tokenized input
>>> text_1 = "Who was Jim Henson ?"
>>> text_2 = "Jim Henson was a puppeteer"
>>> indexed_tokens_1 = tokenizer.encode(text_1)
>>> indexed_tokens_2 = tokenizer.encode(text_2)
>>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
>>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
# Load gpt2LMHeadModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2LMHeadModel', 'gpt2')
>>> model.eval()
# Predict hidden states features for each layer
# past can be used to reuse precomputed hidden state in a subsequent predictions
>>> with torch.no_grad():
predictions_1, past = model(tokens_tensor_1)
predictions_2, past = model(tokens_tensor_2, past=past)
# Get the predicted last token
>>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
>>> predicted_token = tokenizer.decode([predicted_index])
>>> assert predicted_token == ' who'
"""
model = OpenAIGPTLMHeadModel.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(gpt2_docstring)
def gpt2DoubleHeadsModel(*args, **kwargs):
"""
gpt2DoubleHeadsModel is the OpenAI GPT-2 Transformer model with the
tied (pre-trained) language modeling head and a multiple choice
classification head (only initialized, not pre-trained).
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Tokenizer', 'gpt2')
# Prepare tokenized input
>>> text = "Who was Jim Henson ?"
>>> indexed_tokens = tokenizer.encode(tokenized_text)
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> mc_token_ids = torch.LongTensor([ [len(tokenized_text)] ])
# Load openAIGPTDoubleHeadsModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2DoubleHeadsModel', 'gpt2')
>>> model.eval()
# Predict hidden states features for each layer
>>> with torch.no_grad():
lm_logits, multiple_choice_logits, presents = model(tokens_tensor, mc_token_ids)
"""
model = GPT2DoubleHeadsModel.from_pretrained(*args, **kwargs)
return model

View File

@ -362,9 +362,7 @@ class GPT2PreTrainedModel(nn.Module):
module.bias.data.zero_()
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
):
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
@ -382,8 +380,15 @@ class GPT2PreTrainedModel(nn.Module):
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT class
*inputs, **kwargs: additional input for the specific GPT2 class
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]

View File

@ -91,7 +91,7 @@ class GPT2Tokenizer(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Instantiate a GPT2Tokenizer from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: