good quality generation example for GPT, GPT-2, Transfo-XL, XLNet

This commit is contained in:
thomwolf 2019-07-13 15:25:03 +02:00
parent 7322c314a6
commit 7d4b200e40
10 changed files with 252 additions and 46 deletions

198
examples/run_generation.py Normal file
View File

@ -0,0 +1,198 @@
#!/usr/bin/env python3
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Generation with GPT/GPT-2/Transformer-XL/XLNet models
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import logging
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
with torch.no_grad():
for _ in trange(length):
inputs = {'input_ids': generated}
if is_xlnet:
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits = outputs[0][0, -1, :] / temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
return generated
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default=None, required=True,
help="GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
set_seed(args)
args.model_type = ""
for key in MODEL_CLASSES:
if key in args.model_name.lower():
args.model_type = key # take the first match in model types
break
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name)
model = model_class.from_pretrained(args.model_name)
model.to(args.device)
model.eval()
if args.length < 0 and model.config.max_position_embeddings > 0:
args.length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < args.length:
args.length = model.config.max_position_embeddings # No generation bigger than model size
elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop
print(args)
while True:
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs.
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
context_tokens = tokenizer.encode(raw_text)
out = sample_sequence(
model=model,
context=context_tokens,
length=args.length,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
device=args.device,
is_xlnet=bool(args.model_type == "xlnet"),
)
out = out[0, len(context_tokens):].tolist()
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
print(text)
if args.prompt:
break
return text
if __name__ == '__main__':
main()

View File

@ -29,6 +29,7 @@ except ImportError:
import run_glue
import run_squad
import run_generation
logging.basicConfig(level=logging.DEBUG)
@ -91,5 +92,18 @@ class ExamplesTests(unittest.TestCase):
self.assertGreaterEqual(result['exact'], 30)
def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
testargs = ["run_generation.py",
"--prompt=Hello",
"--seed=42"]
model_name = "--model_name=openai-gpt"
with patch.object(sys, 'argv', testargs + [model_name]):
result = run_generation.main()
self.assertGreaterEqual(result['f1'], 30)
self.assertGreaterEqual(result['exact'], 30)
if __name__ == "__main__":
unittest.main()

View File

@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model
@ -195,6 +195,10 @@ class GPT2Config(PretrainedConfig):
"or the path to a pretrained model config file (str)"
)
@property
def max_position_embeddings(self):
return self.n_positions
@property
def hidden_size(self):
return self.n_embd

View File

@ -214,6 +214,10 @@ class OpenAIGPTConfig(PretrainedConfig):
"or the path to a pretrained model config file (str)"
)
@property
def max_position_embeddings(self):
return self.n_positions
@property
def hidden_size(self):
return self.n_embd

View File

@ -287,6 +287,10 @@ class TransfoXLConfig(PretrainedConfig):
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@property
def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len
@property
def vocab_size(self):
return self.n_token

View File

@ -211,9 +211,6 @@ class XLNetConfig(PretrainedConfig):
layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
@ -247,7 +244,6 @@ class XLNetConfig(PretrainedConfig):
untie_r=True,
attn_type="bi",
max_position_embeddings=512,
initializer_range=0.02,
layer_norm_eps=1e-12,
@ -289,7 +285,6 @@ class XLNetConfig(PretrainedConfig):
self.untie_r = untie_r
self.attn_type = attn_type
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
@ -312,6 +307,10 @@ class XLNetConfig(PretrainedConfig):
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@property
def max_position_embeddings(self):
return -1
@property
def vocab_size(self):
return self.n_token
@ -765,7 +764,7 @@ class XLNetModel(XLNetPreTrainedModel):
return pos_emb
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
mems=None, perm_mask=None, target_mapping=None, head_mask=None):
"""
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
@ -790,10 +789,6 @@ class XLNetModel(XLNetPreTrainedModel):
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
head_mask: TODO Lysandre didn't fill
@ -823,7 +818,6 @@ class XLNetModel(XLNetPreTrainedModel):
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0
@ -878,12 +872,11 @@ class XLNetModel(XLNetPreTrainedModel):
##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k)
if inp_q is not None:
if target_mapping is not None:
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
else:
inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
if target_mapping is not None:
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q)
else:
output_g = None
@ -994,7 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding)
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None):
"""
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
@ -1020,11 +1013,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Returns:
A ``tuple(encoded_layers, pooled_output)``, with
@ -1054,7 +1042,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
"""
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask)
mems, perm_mask, target_mapping, head_mask)
logits = self.lm_loss(transformer_outputs[0])
@ -1103,7 +1091,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None):
"""
Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**
@ -1129,10 +1117,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
on the j-th token.
Only used during pre-training for partial prediction.
Set to None during fine-tuning.
inp_q: float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pre-training for two-stream attention.
Set to None during fine-tuning.
labels: TODO Lysandre didn't fill
head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
@ -1161,7 +1145,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
"""
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask)
mems, perm_mask, target_mapping, head_mask)
output = transformer_outputs[0]
output = self.sequence_summary(output)
@ -1215,7 +1199,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
mems=None, perm_mask=None, target_mapping=None,
start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
head_mask=None):
@ -1266,7 +1250,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask)
"""
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask)
mems, perm_mask, target_mapping, head_mask)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask)

View File

@ -97,7 +97,6 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
target_mapping[:, 0, -1] = 1.0 # predict last token
inp_q = target_mapping[:, 0, :].clone() # predict last token
sequence_labels = None
lm_labels = None
@ -124,14 +123,14 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
num_labels=self.type_sequence_label_size)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
def set_seed(self):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetModel(config)
model.eval()
@ -153,7 +152,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetLMHeadModel(config)
model.eval()
@ -161,7 +160,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q)
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
result = {
"loss_1": loss_1,
@ -193,7 +192,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetForQuestionAnswering(config)
model.eval()
@ -243,7 +242,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetForSequenceClassification(config)
model.eval()
@ -269,7 +268,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, inp_q, segment_ids, lm_labels,
target_mapping, segment_ids, lm_labels,
sequence_labels, is_impossible_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids_1}
return config, inputs_dict

View File

@ -25,7 +25,6 @@ import os
import sys
from collections import Counter, OrderedDict
from io import open
import unicodedata
import torch
import numpy as np

View File

@ -343,7 +343,7 @@ class PreTrainedTokenizer(object):
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
for sub_text in split_text), [])[:-1]
added_tokens = list(self.added_tokens_encoder.keys())
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
tokenized_text = split_on_tokens(added_tokens, text)
return tokenized_text
@ -466,7 +466,7 @@ class PreTrainedTokenizer(object):
def clean_up_tokenization(out_string):
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string

View File

@ -172,7 +172,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
out_string = ''.join(tokens_ids)
out_string = ''.join(tokens_ids).replace(SPIECE_UNDERLINE, ' ')
return out_string
def save_vocabulary(self, save_directory):