mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
good quality generation example for GPT, GPT-2, Transfo-XL, XLNet
This commit is contained in:
parent
7322c314a6
commit
7d4b200e40
198
examples/run_generation.py
Normal file
198
examples/run_generation.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Generation with GPT/GPT-2/Transformer-XL/XLNet models
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig
|
||||||
|
|
||||||
|
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
|
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
||||||
|
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
|
||||||
|
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
level = logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
||||||
|
|
||||||
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
|
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||||
|
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
|
||||||
|
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||||
|
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||||
|
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||||
|
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||||
|
(except for Alexei and Maria) are discovered.
|
||||||
|
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||||
|
remainder of the story. 1883 Western Siberia,
|
||||||
|
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
||||||
|
Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
||||||
|
father initially slaps him for making such an accusation, Rasputin watches as the
|
||||||
|
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
||||||
|
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
||||||
|
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(args):
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
if args.n_gpu > 0:
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||||
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
|
Args:
|
||||||
|
logits: logits distribution shape (vocabulary size)
|
||||||
|
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||||
|
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||||
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||||
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||||
|
"""
|
||||||
|
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
||||||
|
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||||
|
if top_k > 0:
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
if top_p > 0.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
|
||||||
|
context = torch.tensor(context, dtype=torch.long, device=device)
|
||||||
|
context = context.unsqueeze(0).repeat(num_samples, 1)
|
||||||
|
generated = context
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in trange(length):
|
||||||
|
|
||||||
|
inputs = {'input_ids': generated}
|
||||||
|
if is_xlnet:
|
||||||
|
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
|
||||||
|
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
|
||||||
|
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
|
||||||
|
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
|
||||||
|
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||||
|
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
|
||||||
|
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||||
|
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||||
|
|
||||||
|
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||||
|
next_token_logits = outputs[0][0, -1, :] / temperature
|
||||||
|
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||||
|
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||||
|
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
||||||
|
return generated
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--model_name', type=str, default=None, required=True,
|
||||||
|
help="GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||||
|
parser.add_argument("--prompt", type=str, default="")
|
||||||
|
parser.add_argument("--padding_text", type=str, default="")
|
||||||
|
parser.add_argument("--length", type=int, default=20)
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
|
parser.add_argument("--top_k", type=int, default=0)
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
|
parser.add_argument("--no_cuda", action='store_true',
|
||||||
|
help="Avoid using CUDA when available")
|
||||||
|
parser.add_argument('--seed', type=int, default=42,
|
||||||
|
help="random seed for initialization")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
|
args.n_gpu = torch.cuda.device_count()
|
||||||
|
|
||||||
|
set_seed(args)
|
||||||
|
|
||||||
|
args.model_type = ""
|
||||||
|
for key in MODEL_CLASSES:
|
||||||
|
if key in args.model_name.lower():
|
||||||
|
args.model_type = key # take the first match in model types
|
||||||
|
break
|
||||||
|
|
||||||
|
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(args.model_name)
|
||||||
|
model = model_class.from_pretrained(args.model_name)
|
||||||
|
model.to(args.device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if args.length < 0 and model.config.max_position_embeddings > 0:
|
||||||
|
args.length = model.config.max_position_embeddings
|
||||||
|
elif 0 < model.config.max_position_embeddings < args.length:
|
||||||
|
args.length = model.config.max_position_embeddings # No generation bigger than model size
|
||||||
|
elif args.length < 0:
|
||||||
|
args.length = MAX_LENGTH # avoid infinite loop
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
while True:
|
||||||
|
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||||
|
if args.model_type in ["transfo-xl", "xlnet"]:
|
||||||
|
# Models with memory likes to have a long prompt for short inputs.
|
||||||
|
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
|
||||||
|
context_tokens = tokenizer.encode(raw_text)
|
||||||
|
out = sample_sequence(
|
||||||
|
model=model,
|
||||||
|
context=context_tokens,
|
||||||
|
length=args.length,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_k=args.top_k,
|
||||||
|
top_p=args.top_p,
|
||||||
|
device=args.device,
|
||||||
|
is_xlnet=bool(args.model_type == "xlnet"),
|
||||||
|
)
|
||||||
|
out = out[0, len(context_tokens):].tolist()
|
||||||
|
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
|
||||||
|
print(text)
|
||||||
|
if args.prompt:
|
||||||
|
break
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -29,6 +29,7 @@ except ImportError:
|
|||||||
|
|
||||||
import run_glue
|
import run_glue
|
||||||
import run_squad
|
import run_squad
|
||||||
|
import run_generation
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
@ -91,5 +92,18 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
self.assertGreaterEqual(result['exact'], 30)
|
self.assertGreaterEqual(result['exact'], 30)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generation(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
testargs = ["run_generation.py",
|
||||||
|
"--prompt=Hello",
|
||||||
|
"--seed=42"]
|
||||||
|
model_name = "--model_name=openai-gpt"
|
||||||
|
with patch.object(sys, 'argv', testargs + [model_name]):
|
||||||
|
result = run_generation.main()
|
||||||
|
self.assertGreaterEqual(result['f1'], 30)
|
||||||
|
self.assertGreaterEqual(result['exact'], 30)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
||||||
|
|
||||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||||
""" Load tf checkpoints in a pytorch model
|
""" Load tf checkpoints in a pytorch model
|
||||||
@ -195,6 +195,10 @@ class GPT2Config(PretrainedConfig):
|
|||||||
"or the path to a pretrained model config file (str)"
|
"or the path to a pretrained model config file (str)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_position_embeddings(self):
|
||||||
|
return self.n_positions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hidden_size(self):
|
def hidden_size(self):
|
||||||
return self.n_embd
|
return self.n_embd
|
||||||
|
@ -214,6 +214,10 @@ class OpenAIGPTConfig(PretrainedConfig):
|
|||||||
"or the path to a pretrained model config file (str)"
|
"or the path to a pretrained model config file (str)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_position_embeddings(self):
|
||||||
|
return self.n_positions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hidden_size(self):
|
def hidden_size(self):
|
||||||
return self.n_embd
|
return self.n_embd
|
||||||
|
@ -287,6 +287,10 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||||
"or the path to a pretrained model config file (str)")
|
"or the path to a pretrained model config file (str)")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_position_embeddings(self):
|
||||||
|
return self.tgt_len + self.ext_len + self.mem_len
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return self.n_token
|
return self.n_token
|
||||||
|
@ -211,9 +211,6 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
layers in the embeddings, encoder, and pooler.
|
layers in the embeddings, encoder, and pooler.
|
||||||
dropatt: The dropout ratio for the attention
|
dropatt: The dropout ratio for the attention
|
||||||
probabilities.
|
probabilities.
|
||||||
max_position_embeddings: The maximum sequence length that this model might
|
|
||||||
ever be used with. Typically set this to something large just in case
|
|
||||||
(e.g., 512 or 1024 or 2048).
|
|
||||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
layer_norm_eps: The epsilon used by LayerNorm.
|
layer_norm_eps: The epsilon used by LayerNorm.
|
||||||
@ -247,7 +244,6 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
untie_r=True,
|
untie_r=True,
|
||||||
attn_type="bi",
|
attn_type="bi",
|
||||||
|
|
||||||
max_position_embeddings=512,
|
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
|
|
||||||
@ -289,7 +285,6 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
self.untie_r = untie_r
|
self.untie_r = untie_r
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
|
||||||
@ -312,6 +307,10 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||||
"or the path to a pretrained model config file (str)")
|
"or the path to a pretrained model config file (str)")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_position_embeddings(self):
|
||||||
|
return -1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return self.n_token
|
return self.n_token
|
||||||
@ -765,7 +764,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
|
mems=None, perm_mask=None, target_mapping=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||||
|
|
||||||
@ -790,10 +789,6 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
on the j-th token.
|
on the j-th token.
|
||||||
Only used during pretraining for partial prediction.
|
Only used during pretraining for partial prediction.
|
||||||
Set to None during finetuning.
|
Set to None during finetuning.
|
||||||
inp_q: [optional] float32 Tensor in shape [bsz, len].
|
|
||||||
1 for tokens with losses and 0 for tokens without losses.
|
|
||||||
Only used during pretraining for two-stream attention.
|
|
||||||
Set to None during finetuning.
|
|
||||||
head_mask: TODO Lysandre didn't fill
|
head_mask: TODO Lysandre didn't fill
|
||||||
|
|
||||||
|
|
||||||
@ -823,7 +818,6 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
|
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
|
||||||
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
|
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
|
||||||
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
|
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
|
||||||
inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None
|
|
||||||
|
|
||||||
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
|
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
|
||||||
mlen = mems[0].shape[0] if mems is not None else 0
|
mlen = mems[0].shape[0] if mems is not None else 0
|
||||||
@ -878,12 +872,11 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
##### Word embeddings and prepare h & g hidden states
|
##### Word embeddings and prepare h & g hidden states
|
||||||
word_emb_k = self.word_embedding(input_ids)
|
word_emb_k = self.word_embedding(input_ids)
|
||||||
output_h = self.dropout(word_emb_k)
|
output_h = self.dropout(word_emb_k)
|
||||||
if inp_q is not None:
|
if target_mapping is not None:
|
||||||
if target_mapping is not None:
|
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
|
||||||
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
|
# else: # We removed the inp_q input which was same as target mapping
|
||||||
else:
|
# inp_q_ext = inp_q[:, :, None]
|
||||||
inp_q_ext = inp_q[:, :, None]
|
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
|
||||||
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
|
|
||||||
output_g = self.dropout(word_emb_q)
|
output_g = self.dropout(word_emb_q)
|
||||||
else:
|
else:
|
||||||
output_g = None
|
output_g = None
|
||||||
@ -994,7 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
|
self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
labels=None, head_mask=None):
|
labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
@ -1020,11 +1013,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
on the j-th token.
|
on the j-th token.
|
||||||
Only used during pretraining for partial prediction.
|
Only used during pretraining for partial prediction.
|
||||||
Set to None during finetuning.
|
Set to None during finetuning.
|
||||||
inp_q: [optional] float32 Tensor in shape [bsz, len].
|
|
||||||
1 for tokens with losses and 0 for tokens without losses.
|
|
||||||
Only used during pretraining for two-stream attention.
|
|
||||||
Set to None during finetuning.
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ``tuple(encoded_layers, pooled_output)``, with
|
A ``tuple(encoded_layers, pooled_output)``, with
|
||||||
@ -1054,7 +1042,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
||||||
mems, perm_mask, target_mapping, inp_q, head_mask)
|
mems, perm_mask, target_mapping, head_mask)
|
||||||
|
|
||||||
logits = self.lm_loss(transformer_outputs[0])
|
logits = self.lm_loss(transformer_outputs[0])
|
||||||
|
|
||||||
@ -1103,7 +1091,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
labels=None, head_mask=None):
|
labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
|
||||||
@ -1129,10 +1117,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
on the j-th token.
|
on the j-th token.
|
||||||
Only used during pre-training for partial prediction.
|
Only used during pre-training for partial prediction.
|
||||||
Set to None during fine-tuning.
|
Set to None during fine-tuning.
|
||||||
inp_q: float32 Tensor in shape [bsz, len].
|
|
||||||
1 for tokens with losses and 0 for tokens without losses.
|
|
||||||
Only used during pre-training for two-stream attention.
|
|
||||||
Set to None during fine-tuning.
|
|
||||||
labels: TODO Lysandre didn't fill
|
labels: TODO Lysandre didn't fill
|
||||||
head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||||
@ -1161,7 +1145,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
||||||
mems, perm_mask, target_mapping, inp_q, head_mask)
|
mems, perm_mask, target_mapping, head_mask)
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
|
|
||||||
output = self.sequence_summary(output)
|
output = self.sequence_summary(output)
|
||||||
@ -1215,7 +1199,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
|
start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
|
||||||
head_mask=None):
|
head_mask=None):
|
||||||
|
|
||||||
@ -1266,7 +1250,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask)
|
start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask)
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
|
||||||
mems, perm_mask, target_mapping, inp_q, head_mask)
|
mems, perm_mask, target_mapping, head_mask)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
start_logits = self.start_logits(hidden_states, p_mask)
|
start_logits = self.start_logits(hidden_states, p_mask)
|
||||||
|
|
||||||
|
@ -97,7 +97,6 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||||
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
|
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
|
||||||
target_mapping[:, 0, -1] = 1.0 # predict last token
|
target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||||
inp_q = target_mapping[:, 0, :].clone() # predict last token
|
|
||||||
|
|
||||||
sequence_labels = None
|
sequence_labels = None
|
||||||
lm_labels = None
|
lm_labels = None
|
||||||
@ -124,14 +123,14 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
num_labels=self.type_sequence_label_size)
|
num_labels=self.type_sequence_label_size)
|
||||||
|
|
||||||
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
|
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
|
||||||
|
|
||||||
def set_seed(self):
|
def set_seed(self):
|
||||||
random.seed(self.seed)
|
random.seed(self.seed)
|
||||||
torch.manual_seed(self.seed)
|
torch.manual_seed(self.seed)
|
||||||
|
|
||||||
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||||
model = XLNetModel(config)
|
model = XLNetModel(config)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -153,7 +152,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||||
|
|
||||||
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||||
model = XLNetLMHeadModel(config)
|
model = XLNetLMHeadModel(config)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -161,7 +160,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
|
|
||||||
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
|
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
|
||||||
|
|
||||||
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q)
|
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"loss_1": loss_1,
|
"loss_1": loss_1,
|
||||||
@ -193,7 +192,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||||
|
|
||||||
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||||
model = XLNetForQuestionAnswering(config)
|
model = XLNetForQuestionAnswering(config)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -243,7 +242,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||||
|
|
||||||
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||||
model = XLNetForSequenceClassification(config)
|
model = XLNetForSequenceClassification(config)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -269,7 +268,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
target_mapping, inp_q, segment_ids, lm_labels,
|
target_mapping, segment_ids, lm_labels,
|
||||||
sequence_labels, is_impossible_labels) = config_and_inputs
|
sequence_labels, is_impossible_labels) = config_and_inputs
|
||||||
inputs_dict = {'input_ids': input_ids_1}
|
inputs_dict = {'input_ids': input_ids_1}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
@ -25,7 +25,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from collections import Counter, OrderedDict
|
from collections import Counter, OrderedDict
|
||||||
from io import open
|
from io import open
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -343,7 +343,7 @@ class PreTrainedTokenizer(object):
|
|||||||
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
|
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
|
||||||
for sub_text in split_text), [])[:-1]
|
for sub_text in split_text), [])[:-1]
|
||||||
|
|
||||||
added_tokens = list(self.added_tokens_encoder.keys())
|
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
|
||||||
tokenized_text = split_on_tokens(added_tokens, text)
|
tokenized_text = split_on_tokens(added_tokens, text)
|
||||||
return tokenized_text
|
return tokenized_text
|
||||||
|
|
||||||
@ -466,7 +466,7 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
|
|
||||||
def clean_up_tokenization(out_string):
|
def clean_up_tokenization(out_string):
|
||||||
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||||
return out_string
|
return out_string
|
||||||
|
@ -172,7 +172,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of ids in a string."""
|
||||||
out_string = ''.join(tokens_ids)
|
out_string = ''.join(tokens_ids).replace(SPIECE_UNDERLINE, ' ')
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory):
|
def save_vocabulary(self, save_directory):
|
||||||
|
Loading…
Reference in New Issue
Block a user