mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
add tokenizer and tests
This commit is contained in:
parent
45709d7532
commit
32da75486b
@ -3,6 +3,7 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
from .tokenization_xlnet import XLNetTokenizer
|
||||
|
||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
@ -24,4 +25,5 @@ from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig,
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path,
|
||||
WEIGHTS_NAME, CONFIG_NAME)
|
||||
|
@ -1034,14 +1034,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
Only used during pretraining for two-stream attention.
|
||||
Set to None during finetuning.
|
||||
|
||||
mem_len: int, the number of tokens to cache.
|
||||
reuse_len: int, the number of tokens in the currect batch to be cached
|
||||
and reused in the future.
|
||||
bi_data: bool, whether to use bidirectional input pipeline.
|
||||
Usually set to True during pretraining and False during finetuning.
|
||||
clamp_len: int, clamp all relative distances larger than clamp_len.
|
||||
-1 means no clamping.
|
||||
same_length: bool, whether to use the same attention length for each token.
|
||||
summary_type: str, "last", "first", "mean", or "attn". The method
|
||||
to pool the input to get a vector representation.
|
||||
"""
|
||||
@ -1068,4 +1060,4 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
# encoded_layers = encoded_layers[-1]
|
||||
# if self.output_attentions:
|
||||
# return all_attentions, encoded_layers, pooled_output
|
||||
return output, new_mems
|
||||
return logits, new_mems
|
||||
|
111
pytorch_pretrained_bert/modeling_xlnet_utilities.py
Normal file
111
pytorch_pretrained_bert/modeling_xlnet_utilities.py
Normal file
@ -0,0 +1,111 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Utilities for PyTorch XLNet model.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
special_symbols = {
|
||||
"<unk>" : 0,
|
||||
"<s>" : 1,
|
||||
"</s>" : 2,
|
||||
"<cls>" : 3,
|
||||
"<sep>" : 4,
|
||||
"<pad>" : 5,
|
||||
"<mask>" : 6,
|
||||
"<eod>" : 7,
|
||||
"<eop>" : 8,
|
||||
}
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
UNK_ID = special_symbols["<unk>"]
|
||||
CLS_ID = special_symbols["<cls>"]
|
||||
SEP_ID = special_symbols["<sep>"]
|
||||
MASK_ID = special_symbols["<mask>"]
|
||||
EOD_ID = special_symbols["<eod>"]
|
||||
|
||||
|
||||
def permutation_mask(inputs, targets, is_masked, perm_size, seq_len):
|
||||
"""
|
||||
Sample a permutation of the factorization order, and create an
|
||||
attention mask accordingly.
|
||||
Args:
|
||||
inputs: int64 Tensor in shape [seq_len], input ids.
|
||||
targets: int64 Tensor in shape [seq_len], target ids.
|
||||
is_masked: bool Tensor in shape [seq_len]. True means being selected
|
||||
for partial prediction.
|
||||
perm_size: the length of longest permutation. Could be set to be reuse_len.
|
||||
Should not be larger than reuse_len or there will be data leaks.
|
||||
seq_len: int, sequence length.
|
||||
"""
|
||||
|
||||
# Generate permutation indices
|
||||
index = np.arange(10)
|
||||
index = np.transpose(np.reshape(index, [-1, perm_size]))
|
||||
index = np.random.shuffle(index)
|
||||
index = np.reshape(np.transpose(index), [-1])
|
||||
|
||||
# `perm_mask` and `target_mask`
|
||||
# non-functional tokens
|
||||
non_func_tokens = tf.logical_not(tf.logical_or(
|
||||
tf.equal(inputs, SEP_ID),
|
||||
tf.equal(inputs, CLS_ID)))
|
||||
|
||||
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
|
||||
masked_or_func_tokens = tf.logical_not(non_mask_tokens)
|
||||
|
||||
# Set the permutation indices of non-masked (& non-funcional) tokens to the
|
||||
# smallest index (-1):
|
||||
# (1) they can be seen by all other positions
|
||||
# (2) they cannot see masked positions, so there won"t be information leak
|
||||
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
|
||||
rev_index = tf.where(non_mask_tokens, smallest_index, index)
|
||||
|
||||
# Create `target_mask`: non-funcional and maksed tokens
|
||||
# 1: use mask as input and have loss
|
||||
# 0: use token (or [SEP], [CLS]) as input and do not have loss
|
||||
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
|
||||
target_mask = tf.cast(target_tokens, tf.float32)
|
||||
|
||||
# Create `perm_mask`
|
||||
# `target_tokens` cannot see themselves
|
||||
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
|
||||
|
||||
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
|
||||
# 0: can attend if i > j or j is non-masked
|
||||
perm_mask = tf.logical_and(
|
||||
self_rev_index[:, None] <= rev_index[None, :],
|
||||
masked_or_func_tokens)
|
||||
perm_mask = tf.cast(perm_mask, tf.float32)
|
||||
|
||||
# new target: [next token] for LM and [curr token] (self) for PLM
|
||||
new_targets = tf.concat([inputs[0: 1], targets[: -1]],
|
||||
axis=0)
|
||||
|
||||
# construct inputs_k
|
||||
inputs_k = inputs
|
||||
|
||||
# construct inputs_q
|
||||
inputs_q = target_mask
|
||||
|
||||
return perm_mask, new_targets, target_mask, inputs_k, inputs_q
|
||||
|
@ -0,0 +1,254 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Tokenization classes for XLNet model."""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from shutil import copyfile
|
||||
from io import open
|
||||
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
|
||||
}
|
||||
VOCAB_NAME = 'spiece.model'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
SPIECE_UNDERLINE = '▁'
|
||||
|
||||
class XLNetTokenizer(object):
|
||||
"""
|
||||
SentencePiece based tokenizer. Peculiarities:
|
||||
- requires SentencePiece: https://github.com/google/sentencepiece
|
||||
"""
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
special_tokens_file = None
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
logger.error(
|
||||
"Couldn't reach server at '{}' to download vocabulary.".format(
|
||||
vocab_file))
|
||||
else:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {}"
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path,
|
||||
vocab_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
# Instantiate tokenizer.
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, special_tokens=None, max_len=None,
|
||||
do_lower_case=False, remove_space=True, keep_accents=False):
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
except ImportError:
|
||||
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
||||
"pip install sentencepiece")
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.remove_space = remove_space
|
||||
self.keep_accents = keep_accents
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(vocab_file)
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.sp_model) + i) for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||
logger.info("Special tokens: %s", str(self.special_tokens))
|
||||
|
||||
def preprocess_text(self, inputs):
|
||||
if self.remove_space:
|
||||
outputs = ' '.join(inputs.strip().split())
|
||||
else:
|
||||
outputs = inputs
|
||||
outputs = outputs.replace("``", '"').replace("''", '"')
|
||||
|
||||
if six.PY2 and isinstance(outputs, str):
|
||||
outputs = outputs.decode('utf-8')
|
||||
|
||||
if not self.keep_accents:
|
||||
outputs = unicodedata.normalize('NFKD', outputs)
|
||||
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
|
||||
if self.do_lower_case:
|
||||
outputs = outputs.lower()
|
||||
|
||||
return outputs
|
||||
|
||||
def tokenize(self, text, return_unicode=True, sample=False):
|
||||
""" Tokenize a string.
|
||||
return_unicode is used only for py2
|
||||
"""
|
||||
text = self.preprocess_text(text)
|
||||
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
|
||||
if six.PY2 and isinstance(text, unicode):
|
||||
text = text.encode('utf-8')
|
||||
|
||||
if not sample:
|
||||
pieces = self.sp_model.EncodeAsPieces(text)
|
||||
else:
|
||||
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
|
||||
new_pieces = []
|
||||
for piece in pieces:
|
||||
if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
|
||||
cur_pieces = self.sp_model.EncodeAsPieces(
|
||||
piece[:-1].replace(SPIECE_UNDERLINE, ''))
|
||||
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
||||
if len(cur_pieces[0]) == 1:
|
||||
cur_pieces = cur_pieces[1:]
|
||||
else:
|
||||
cur_pieces[0] = cur_pieces[0][1:]
|
||||
cur_pieces.append(piece[-1])
|
||||
new_pieces.extend(cur_pieces)
|
||||
else:
|
||||
new_pieces.append(piece)
|
||||
|
||||
# note(zhiliny): convert back to unicode for py2
|
||||
if six.PY2 and return_unicode:
|
||||
ret_pieces = []
|
||||
for piece in new_pieces:
|
||||
if isinstance(piece, str):
|
||||
piece = piece.decode('utf-8')
|
||||
ret_pieces.append(piece)
|
||||
new_pieces = ret_pieces
|
||||
|
||||
return new_pieces
|
||||
|
||||
def convert_tokens_to_ids(self, tokens, sample=False):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.sp_model.PieceToId(tokens)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.sp_model.PieceToId(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this XLNet model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in tokens."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.sp_model.IdToPiece(i))
|
||||
return tokens
|
||||
|
||||
def encode(self, text, sample=False):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text, sample=sample))
|
||||
|
||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||
out_string = ''.join(tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.strip().replace('<unk>', '')
|
||||
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
||||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
||||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
"""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
out_vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
index = len(self.sp_model)
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
|
||||
return out_vocab_file, special_tokens_file
|
@ -7,4 +7,6 @@ boto3
|
||||
# Used for downloading models over HTTP
|
||||
requests
|
||||
# For OpenAI GPT
|
||||
regex
|
||||
regex
|
||||
# For XLNet
|
||||
sentencepiece
|
BIN
samples/test_sentencepiece.model
Normal file
BIN
samples/test_sentencepiece.model
Normal file
Binary file not shown.
3
setup.py
3
setup.py
@ -54,7 +54,8 @@ setup(
|
||||
'boto3',
|
||||
'requests',
|
||||
'tqdm',
|
||||
'regex'],
|
||||
'regex',
|
||||
'sentencepiece'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",
|
||||
|
@ -35,8 +35,8 @@ class XLNetModelTest(unittest.TestCase):
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
mem_len=30,
|
||||
clamp_len=15,
|
||||
mem_len=10,
|
||||
clamp_len=-1,
|
||||
reuse_len=15,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
@ -78,6 +78,27 @@ class XLNetModelTest(unittest.TestCase):
|
||||
input_ids_2 = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
|
||||
segment_ids = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.type_vocab_size)
|
||||
|
||||
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
|
||||
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
|
||||
# input_mask: float32 Tensor in shape [len, bsz], the input mask.
|
||||
# 0 for real tokens and 1 for padding.
|
||||
# mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
# from previous batches. The length of the list equals n_layer.
|
||||
# If None, no memory is used.
|
||||
# perm_mask: float32 Tensor in shape [len, len, bsz].
|
||||
# If perm_mask[i, j, k] = 0, i attend to j in batch k;
|
||||
# if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
|
||||
# If None, each position attends to all the others.
|
||||
# target_mapping: float32 Tensor in shape [num_predict, len, bsz].
|
||||
# If target_mapping[i, j, k] = 1, the i-th predict in batch k is
|
||||
# on the j-th token.
|
||||
# Only used during pretraining for partial prediction.
|
||||
# Set to None during finetuning.
|
||||
# inp_q: float32 Tensor in shape [len, bsz].
|
||||
# 1 for tokens with losses and 0 for tokens without losses.
|
||||
# Only used during pretraining for two-stream attention.
|
||||
# Set to None during finetuning.
|
||||
|
||||
lm_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
|
||||
@ -106,44 +127,15 @@ class XLNetModelTest(unittest.TestCase):
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
|
||||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
hidden_states_1, mems_1 = model(input_ids_1, seg_id=segment_ids)
|
||||
hidden_states_2, mems_2 = model(input_ids_2, seg_id=segment_ids, mems=mems_1)
|
||||
outputs = {
|
||||
"hidden_states_1": hidden_states_1,
|
||||
"mems_1": mems_1,
|
||||
"hidden_states_2": hidden_states_2,
|
||||
"mems_2": mems_2,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_transfo_xl_model_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].size()),
|
||||
[self.seq_length, self.batch_size, self.d_model])
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()),
|
||||
[self.seq_length, self.batch_size, self.d_model])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
|
||||
|
||||
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
|
||||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
loss_1, mems_1a = model(input_ids_1, target=lm_labels)
|
||||
lm_logits_1, mems_1b = model(input_ids_1)
|
||||
loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels)
|
||||
lm_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids)
|
||||
|
||||
loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a)
|
||||
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
|
||||
loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a)
|
||||
lm_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b)
|
||||
|
||||
outputs = {
|
||||
"loss_1": loss_1,
|
||||
@ -160,23 +152,23 @@ class XLNetModelTest(unittest.TestCase):
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_1"].size()),
|
||||
[self.seq_length, self.batch_size])
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()),
|
||||
[self.seq_length, self.batch_size, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1a"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1b"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
|
||||
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_2"].size()),
|
||||
[self.seq_length, self.batch_size])
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()),
|
||||
[self.seq_length, self.batch_size, self.vocab_size])
|
||||
@ -218,10 +210,6 @@ class XLNetModelTest(unittest.TestCase):
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
|
||||
tester.set_seed()
|
||||
output_result = tester.create_transfo_xl_model(*config_and_inputs)
|
||||
tester.check_transfo_xl_model_output(output_result)
|
||||
|
||||
tester.set_seed()
|
||||
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||
tester.check_transfo_xl_lm_head_output(output_result)
|
||||
@ -242,6 +230,22 @@ class XLNetModelTest(unittest.TestCase):
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
|
||||
|
||||
@classmethod
|
||||
def mask_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||
"""Creates a tensor with padding on the right (0.0 for )."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -49,7 +49,7 @@ class TokenizationTest(unittest.TestCase):
|
||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained(vocab_file)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
|
@ -44,7 +44,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained(vocab_file)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
|
88
tests/tokenization_xlnet_test.py
Normal file
88
tests/tokenization_xlnet_test.py
Normal file
@ -0,0 +1,88 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP)
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(__file__))),
|
||||
'samples/test_sentencepiece.model')
|
||||
|
||||
class XLNetTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB)
|
||||
|
||||
tokens = tokenizer.tokenize('This is a test')
|
||||
self.assertListEqual(tokens, ['▁This', '▁is', '▁a', '▁t', 'est'])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
|
||||
vocab_path = "/tmp/"
|
||||
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_path,
|
||||
keep_accents=True)
|
||||
os.remove(vocab_file)
|
||||
os.remove(special_tokens_file)
|
||||
|
||||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in', '▁',
|
||||
'9', '2', '0', '0', '0', ',', '▁and', '▁this',
|
||||
'▁is', '▁f', 'al', 's', 'é', '.'])
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
||||
602, 347, 347, 347, 3, 12, 66,
|
||||
46, 72, 80, 6, 0, 4])
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(back_tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in',
|
||||
'▁', '<unk>', '2', '0', '0', '0', ',',
|
||||
'▁and', '▁this', '▁is', '▁f', 'al', 's',
|
||||
'<unk>', '.'])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = XLNetTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
def test_tokenizer_lower(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, ['▁', 'i', '▁was', '▁b', 'or', 'n', '▁in', '▁',
|
||||
'9', '2', '0', '0', '0', ',', '▁and', '▁this',
|
||||
'▁is', '▁f', 'al', 'se', '.'])
|
||||
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["▁he", "ll", "o"])
|
||||
|
||||
def test_tokenizer_no_lower(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, ['▁I', '▁was', '▁b', 'or', 'n', '▁in', '▁',
|
||||
'9', '2', '0', '0', '0', ',', '▁and', '▁this',
|
||||
'▁is', '▁f', 'al', 'se', '.'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user