mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 13:08:21 +06:00
unified tokenizer api and serialization + tests
This commit is contained in:
parent
3d5f291386
commit
b19786985d
@ -32,9 +32,11 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_transformers.modeling_bert import BertForSequenceClassification
|
||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||
from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification,
|
||||
XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
|
||||
XLMTokenizer)
|
||||
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
|
||||
|
||||
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
||||
@ -42,6 +44,21 @@ from utils_glue import processors, output_modes, convert_examples_to_features, c
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': BertForSequenceClassification,
|
||||
'xlnet': XLNetForSequenceClassification,
|
||||
'xlm': XLMForSequenceClassification,
|
||||
}
|
||||
|
||||
TOKENIZER_CLASSES = {
|
||||
'bert': BertTokenizer,
|
||||
'xlnet': XLNetTokenizer,
|
||||
'xlm': XLMTokenizer,
|
||||
}
|
||||
|
||||
def train(args, train_features, model):
|
||||
""" Train the model """
|
||||
@ -156,7 +173,7 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model):
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Num examples = %d", len(eval_features))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
@ -208,7 +225,7 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
|
||||
examples = processor.get_dev_examples(args.data_dir)
|
||||
cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format(
|
||||
'dev' if eval else 'train',
|
||||
list(filter(None, args.bert_model.split('/'))).pop(),
|
||||
list(filter(None, args.model_name.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task)))
|
||||
|
||||
@ -217,6 +234,11 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=bool(args.model_type not in ['bert', 'xlm']),
|
||||
cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||
pad_on_left=True, pad_token_segment_id=4)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
@ -230,12 +252,10 @@ def main():
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||
help="The name of the task to train.")
|
||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
@ -243,9 +263,8 @@ def main():
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||
"than this will be padded.")
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
@ -263,8 +282,7 @@ def main():
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
@ -331,8 +349,11 @@ def main():
|
||||
# Make sure only the first process in distributed training will download model & vocab
|
||||
torch.distributed.barrier()
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||
args.model_type = args.model_name.lower().split('-')[0]
|
||||
args.tokenizer_class = TOKENIZER_CLASSES[args.model_type]
|
||||
args.model_class = MODEL_CLASSES[args.model_type]
|
||||
tokenizer = args.tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
|
||||
model = args.model_class.from_pretrained(args.model_name, num_labels=num_labels)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
@ -359,26 +380,15 @@ def main():
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
model.save_pretrained(args.output_dir)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForSequenceClassification.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
|
||||
torch.save(args, output_args_file)
|
||||
else:
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model)
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = args.model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = args.tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
|
@ -211,8 +211,8 @@ def main():
|
||||
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
|
||||
cls_token_at_end=True, cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||
pad_on_left=True, pad_token_segment_id=4)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||
@ -369,8 +369,8 @@ def main():
|
||||
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
|
||||
cls_token_at_end=True, cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||
pad_on_left=True, pad_token_segment_id=4)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
||||
|
@ -396,7 +396,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
mask_padding_with_zero=True):
|
||||
""" Loads a data file into a list of `InputBatch`s
|
||||
`cls_token_at_end` define the location of the CLS token:
|
||||
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
||||
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
||||
"""
|
||||
@ -489,8 +489,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
[str(x) for x in tokens]))
|
||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info(
|
||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
logger.info("label: %s (id = %d)" % (example.label, label_id))
|
||||
|
||||
features.append(
|
||||
|
@ -11,22 +11,28 @@ from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification, BertForMultipleChoice,
|
||||
BertForTokenClassification, BertForQuestionAnswering,
|
||||
load_tf_weights_in_bert)
|
||||
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||
load_tf_weights_in_openai_gpt)
|
||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
|
||||
load_tf_weights_in_transfo_xl)
|
||||
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||
load_tf_weights_in_gpt2)
|
||||
load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlnet import (XLNetConfig,
|
||||
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
||||
XLNetForSequenceClassification, XLNetForQuestionAnswering,
|
||||
load_tf_weights_in_xlnet)
|
||||
load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlm import (XLMConfig, XLMModel,
|
||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||
XLMForQuestionAnswering)
|
||||
XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
||||
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
||||
|
||||
|
@ -29,8 +29,7 @@ from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME,
|
||||
TransfoXLConfig,
|
||||
TransfoXLLMHeadModel,
|
||||
load_tf_weights_in_transfo_xl)
|
||||
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME,
|
||||
VOCAB_NAME)
|
||||
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
@ -53,7 +52,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
with open(transfo_xl_dataset_file, "rb") as fp:
|
||||
corpus = pickle.load(fp, encoding="latin1")
|
||||
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
|
||||
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
|
||||
corpus_vocab_dict = corpus.vocab.__dict__
|
||||
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
|
||||
|
@ -24,7 +24,7 @@ import torch
|
||||
import numpy
|
||||
|
||||
from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
|
||||
from pytorch_transformers.tokenization_xlm import MERGES_NAME, VOCAB_NAME
|
||||
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
||||
@ -42,7 +42,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
|
||||
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model, pytorch_weights_dump_path)
|
||||
|
@ -33,7 +33,7 @@ from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrai
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
||||
@ -49,7 +49,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||
}
|
||||
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
||||
@ -152,7 +152,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
class BertConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=30522,
|
||||
@ -543,7 +543,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = BertConfig
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
||||
|
||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
@ -103,7 +103,7 @@ def gelu(x):
|
||||
class GPT2Config(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `GPT2Model`.
|
||||
"""
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -358,7 +358,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = GPT2Config
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_gpt2
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
@ -37,8 +37,8 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
||||
@ -130,7 +130,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
||||
class OpenAIGPTConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
|
||||
"""
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -384,7 +384,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = OpenAIGPTConfig
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_openai_gpt
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
@ -41,10 +41,10 @@ from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrai
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
|
||||
}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||
}
|
||||
|
||||
@ -179,7 +179,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
class TransfoXLConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
||||
"""
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=267735,
|
||||
@ -838,7 +838,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = TransfoXLConfig
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_transfo_xl
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module):
|
||||
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_to_prune._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model with its configuration file to a directory, so that it
|
||||
can be re-loaded using the `from_pretrained(save_directory)` class method.
|
||||
"""
|
||||
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||
|
||||
# Only save the model it-self if we are using distributed training
|
||||
model_to_save = self.module if hasattr(self, 'module') else self
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
"""
|
||||
|
@ -40,10 +40,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
|
||||
}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
||||
}
|
||||
|
||||
@ -51,7 +51,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
class XLMConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `XLMModel`.
|
||||
"""
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=30145,
|
||||
@ -357,7 +357,7 @@ class XLMPreTrainedModel(PreTrainedModel):
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = XLMConfig
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
@ -38,10 +38,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
|
||||
}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
||||
}
|
||||
|
||||
@ -195,7 +195,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
class XLNetConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `XLNetModel`.
|
||||
"""
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=32000,
|
||||
@ -593,7 +593,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = XLNetConfig
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_xlnet
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||
BertForNextSentencePrediction, BertForPreTraining,
|
||||
BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification, BertForMultipleChoice)
|
||||
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||
|
||||
@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
@ -413,7 +413,7 @@ class GPTModelTester(object):
|
||||
|
||||
def create_and_check_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
|
||||
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.parent.assertIsNotNone(model)
|
||||
|
@ -26,7 +26,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
|
||||
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
||||
|
||||
@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
@ -20,12 +20,12 @@ import unittest
|
||||
import logging
|
||||
|
||||
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def test_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, PretrainedConfig)
|
||||
|
@ -21,7 +21,7 @@ import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
|
||||
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||
|
||||
@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
@ -26,7 +26,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
||||
|
||||
@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
||||
BertTokenizer,
|
||||
WordpieceTokenizer,
|
||||
_is_control, _is_punctuation,
|
||||
_is_whitespace)
|
||||
_is_whitespace, VOCAB_FILES_NAMES)
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase):
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
"##ing", ",", "low", "lowest",
|
||||
]
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_directory = "/tmp/"
|
||||
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory)
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
tokenizer = WordpieceTokenizer(vocab=vocab)
|
||||
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
|
||||
|
||||
self.assertListEqual(tokenizer.tokenize(""), [])
|
||||
|
||||
|
@ -17,8 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@ -28,32 +29,32 @@ class GPT2TokenizationTest(unittest.TestCase):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"lo", "low", "er",
|
||||
"low", "lowest", "newer", "wider"]
|
||||
"low", "lowest", "newer", "wider", "<unk>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [13, 12, 16]
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [13, 12, 17]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import shutil
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer
|
||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@ -32,21 +31,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"w</w>", "r</w>", "t</w>",
|
||||
"lo", "low", "er</w>",
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
import tempfile
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
unicode = str
|
||||
@ -28,22 +29,19 @@ else:
|
||||
|
||||
|
||||
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
|
||||
vocab_path="/tmp/"
|
||||
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_path)
|
||||
|
||||
for f in output_files:
|
||||
os.remove(f)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
tester.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
text = u"Munich and Berlin are nice cities"
|
||||
filename = u"/tmp/tokenizer.bin"
|
||||
@ -58,8 +56,54 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs
|
||||
tester.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
|
||||
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_size = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size, 0)
|
||||
tester.assertEqual(vocab_size, all_size)
|
||||
|
||||
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
vocab_size_2 = tokenizer.vocab_size
|
||||
all_size_2 = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size_2, 0)
|
||||
tester.assertEqual(vocab_size, vocab_size_2)
|
||||
tester.assertEqual(added_toks, len(new_toks))
|
||||
tester.assertEqual(all_size_2, all_size + len(new_toks))
|
||||
|
||||
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
|
||||
tester.assertGreaterEqual(len(tokens), 4)
|
||||
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
|
||||
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
||||
'pad_token': "<<<<<|||>|>>>>|>"}
|
||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||
vocab_size_3 = tokenizer.vocab_size
|
||||
all_size_3 = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size_3, 0)
|
||||
tester.assertEqual(vocab_size, vocab_size_3)
|
||||
tester.assertEqual(added_toks_2, len(new_toks_2))
|
||||
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||
|
||||
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
||||
|
||||
tester.assertGreaterEqual(len(tokens), 6)
|
||||
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[0], tokens[1])
|
||||
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[-2], tokens[-3])
|
||||
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
|
||||
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||
|
||||
|
||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
text = u"He is very happy, UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(text)
|
||||
@ -75,5 +119,6 @@ def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs
|
||||
|
||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
|
@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer
|
||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@ -28,16 +27,17 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
|
||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
|
||||
"running", ",", "low", "l",
|
||||
]
|
||||
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True)
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
@ -17,6 +17,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import six
|
||||
|
||||
from pytorch_transformers import PreTrainedTokenizer
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||
@ -27,8 +28,17 @@ class TokenizerUtilsTest(unittest.TestCase):
|
||||
for model_name in s3_models[:1]:
|
||||
tokenizer = tokenizer_class.from_pretrained(model_name)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
self.assertIsInstance(tokenizer, tokenizer_class)
|
||||
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
for special_tok in tokenizer.all_special_tokens:
|
||||
if six.PY2:
|
||||
self.assertIsInstance(special_tok, unicode)
|
||||
else:
|
||||
self.assertIsInstance(special_tok, str)
|
||||
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
|
||||
self.assertIsInstance(special_tok_id, int)
|
||||
|
||||
def test_pretrained_tokenizers(self):
|
||||
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
|
||||
|
||||
|
@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import shutil
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer
|
||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@ -31,21 +30,21 @@ class XLMTokenizationTest(unittest.TestCase):
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"w</w>", "r</w>", "t</w>",
|
||||
"lo", "low", "er</w>",
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
|
@ -16,10 +16,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES)
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@ -29,10 +28,13 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
class XLNetTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
|
||||
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
|
@ -22,7 +22,6 @@ import os
|
||||
import unicodedata
|
||||
from io import open
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -45,7 +44,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'bert-base-uncased': 512,
|
||||
@ -93,8 +93,9 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
|
||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||
mask_token="[MASK]", **kwargs):
|
||||
"""Constructs a BertTokenizer.
|
||||
|
||||
Args:
|
||||
@ -102,17 +103,18 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
do_lower_case: Whether to lower case the input
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
||||
max_len: An artificial maximum length to truncate tokenized sequences to;
|
||||
Effective maximum length is always the minimum of this
|
||||
value (if specified) and the underlying BERT model's
|
||||
sequence length.
|
||||
never_split: List of tokens which will never be split during tokenization.
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
"""
|
||||
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, **kwargs)
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
||||
if never_split is None:
|
||||
never_split = self.all_special_tokens
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict(
|
||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||
@ -120,90 +122,34 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
if do_basic_tokenize:
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
||||
never_split=never_split)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||
|
||||
@property
|
||||
def UNK_TOKEN(self):
|
||||
return "[UNK]"
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
@property
|
||||
def SEP_TOKEN(self):
|
||||
return "[SEP]"
|
||||
|
||||
@property
|
||||
def PAD_TOKEN(self):
|
||||
return "[PAD]"
|
||||
|
||||
@property
|
||||
def CLS_TOKEN(self):
|
||||
return "[CLS]"
|
||||
|
||||
@property
|
||||
def MASK_TOKEN(self):
|
||||
return "[MASK]"
|
||||
|
||||
@property
|
||||
def UNK_ID(self):
|
||||
return self.vocab["[UNK]"]
|
||||
|
||||
@property
|
||||
def SEP_ID(self):
|
||||
return self.vocab["[SEP]"]
|
||||
|
||||
@property
|
||||
def PAD_ID(self):
|
||||
return self.vocab["[PAD]"]
|
||||
|
||||
@property
|
||||
def CLS_ID(self):
|
||||
return self.vocab["[CLS]"]
|
||||
|
||||
@property
|
||||
def MASK_ID(self):
|
||||
return self.vocab["[MASK]"]
|
||||
|
||||
def tokenize(self, text):
|
||||
def _tokenize(self, text):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self.vocab[token])
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this BERT model ({} > {}). Running this"
|
||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, token_ids, clean_up_tokenization_spaces=True):
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(token_ids)
|
||||
tokens = self.convert_ids_to_tokens(tokens_ids)
|
||||
out_string = ''.join(tokens).replace(' ##', '').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
for special_tok in (self.UNK_TOKEN, self.SEP_TOKEN, self.PAD_TOKEN, self.CLS_TOKEN, self.MASK_TOKEN):
|
||||
out_string = out_string.replace(special_tok, '')
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
@ -245,17 +191,20 @@ class BasicTokenizer(object):
|
||||
|
||||
def __init__(self,
|
||||
do_lower_case=True,
|
||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||
never_split=None):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
if never_split is None:
|
||||
never_split = []
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = never_split
|
||||
|
||||
def tokenize(self, text):
|
||||
def tokenize(self, text, never_split=None):
|
||||
"""Tokenizes a piece of text."""
|
||||
never_split = self.never_split + (never_split if never_split is not None else [])
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
@ -267,7 +216,7 @@ class BasicTokenizer(object):
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case and token not in self.never_split:
|
||||
if self.do_lower_case and token not in never_split:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
@ -286,9 +235,9 @@ class BasicTokenizer(object):
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
def _run_split_on_punc(self, text, never_split=None):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if text in self.never_split:
|
||||
if never_split is not None and text in never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
@ -360,7 +309,7 @@ class BasicTokenizer(object):
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
@ -38,7 +38,6 @@ logger = logging.getLogger(__name__)
|
||||
VOCAB_FILES_NAMES = {
|
||||
'vocab_file': 'vocab.json',
|
||||
'merges_file': 'merges.txt',
|
||||
'special_tokens_file': 'special_tokens.txt'
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
@ -52,11 +51,6 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
||||
},
|
||||
'special_tokens_file':
|
||||
{
|
||||
'gpt2': None,
|
||||
'gpt2-medium': None,
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
@ -108,8 +102,10 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, errors='replace', max_len=None):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
def __init__(self, vocab_file, merges_file, errors='replace',
|
||||
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
|
||||
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs)
|
||||
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
@ -123,32 +119,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
all_special_tokens = []
|
||||
if special_tokens_file is not None:
|
||||
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
all_special_tokens.extend(special_tokens_to_add)
|
||||
if special_tokens is not None and special_tokens:
|
||||
all_special_tokens.extend(special_tokens)
|
||||
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(all_special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||
logger.info("Special tokens {}".format(self.special_tokens))
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
@ -191,7 +164,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def tokenize(self, text):
|
||||
def _tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
@ -202,57 +175,27 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens, 0)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
text = ''.join(tokens_ids)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = text.replace('<unk>', '')
|
||||
text = clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
def save_vocabulary(self, save_directory):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
|
||||
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
|
||||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
@ -268,14 +211,4 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
index = len(self.encoder)
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
return vocab_file, merge_file
|
||||
|
@ -20,13 +20,9 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -34,7 +30,6 @@ logger = logging.getLogger(__name__)
|
||||
VOCAB_FILES_NAMES = {
|
||||
'vocab_file': 'vocab.json',
|
||||
'merges_file': 'merges.txt',
|
||||
'special_tokens_file': 'special_tokens.txt'
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
@ -46,10 +41,6 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
{
|
||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
|
||||
},
|
||||
'special_tokens_file':
|
||||
{
|
||||
'openai-gpt': None,
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
@ -88,14 +79,14 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
BPE tokenizer. Peculiarities:
|
||||
- lower case all inputs
|
||||
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
|
||||
- argument special_tokens and function set_special_tokens:
|
||||
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
|
||||
"""
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None):
|
||||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
||||
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
import spacy
|
||||
@ -103,11 +94,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
self.fix_text = ftfy.fix_text
|
||||
except ImportError:
|
||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||
self.nlp = BasicTokenizer(do_lower_case=True,
|
||||
never_split=special_tokens if special_tokens is not None else [])
|
||||
self.nlp = BasicTokenizer(do_lower_case=True)
|
||||
self.fix_text = None
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
@ -115,35 +104,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
|
||||
all_special_tokens = []
|
||||
if special_tokens_file is not None:
|
||||
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
all_special_tokens.extend(special_tokens_to_add)
|
||||
if special_tokens is not None and special_tokens:
|
||||
all_special_tokens.extend(special_tokens)
|
||||
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(all_special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||
if self.fix_text is None:
|
||||
# Using BERT's BasicTokenizer: we can update the tokenizer
|
||||
self.nlp.never_split = special_tokens
|
||||
logger.info("Special tokens {}".format(self.special_tokens))
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
||||
def bpe(self, token):
|
||||
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
||||
@ -188,7 +151,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def tokenize(self, text):
|
||||
def _tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
split_tokens = []
|
||||
if self.fix_text is None:
|
||||
@ -203,58 +166,26 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens, 0)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an id in a token (BPE) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.replace('<unk>', '')
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
def save_vocabulary(self, save_directory):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
|
||||
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
|
||||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
@ -270,14 +201,4 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
index = len(self.encoder)
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
return vocab_file, merge_file
|
||||
|
@ -41,7 +41,7 @@ else:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin'}
|
||||
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'pretrained_vocab_file':
|
||||
@ -67,9 +67,17 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
|
||||
def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,
|
||||
delimiter=None, vocab_file=None, pretrained_vocab_file=None,
|
||||
never_split=("<unk>", "<eos>", "<formula>")):
|
||||
never_split=None, unk_token="<unk>", eos_token="<eos>",
|
||||
additional_special_tokens=["<formula>"], **kwargs):
|
||||
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs)
|
||||
if never_split is None:
|
||||
never_split = self.all_special_tokens
|
||||
if special is None:
|
||||
special = []
|
||||
self.counter = Counter()
|
||||
self.special = special
|
||||
self.min_freq = min_freq
|
||||
@ -200,11 +208,13 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
self.idx2sym.append(sym)
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
|
||||
def get_sym(self, idx):
|
||||
def _convert_id_to_token(self, idx):
|
||||
"""Converts an id in a token (BPE) using the vocab."""
|
||||
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
|
||||
return self.idx2sym[idx]
|
||||
|
||||
def get_idx(self, sym):
|
||||
def _convert_token_to_id(self, sym):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
if sym in self.sym2idx:
|
||||
return self.sym2idx[sym]
|
||||
else:
|
||||
@ -220,36 +230,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
else:
|
||||
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
|
||||
|
||||
def convert_ids_to_tokens(self, indices):
|
||||
"""Converts a sequence of indices in symbols using the vocab."""
|
||||
return [self.get_sym(idx) for idx in indices]
|
||||
|
||||
def convert_tokens_to_ids(self, symbols):
|
||||
"""Converts a sequence of symbols into ids using the vocab."""
|
||||
return [self.get_idx(sym) for sym in symbols]
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ' '.join(tokens_ids).strip()
|
||||
return out_string
|
||||
|
||||
def convert_to_tensor(self, symbols):
|
||||
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, indices, exclude=None, clean_up_tokenization_spaces=True):
|
||||
"""Converts a sequence of indices in a string."""
|
||||
if exclude is None:
|
||||
out_string = ' '.join([self.get_sym(idx) for idx in indices])
|
||||
else:
|
||||
out_string = ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
|
||||
return out_string
|
||||
|
||||
def __len__(self):
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.idx2sym)
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
def _tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = line.strip()
|
||||
# convert to lower case
|
||||
if self.lower_case:
|
||||
|
@ -16,37 +16,145 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import regex as re
|
||||
import json
|
||||
import six
|
||||
from io import open
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
# Just a dummy decorator to get the checks to run on python2
|
||||
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
|
||||
def lru_cache():
|
||||
return lambda func: func
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
|
||||
ADDED_TOKENS_FILE = 'added_tokens.json'
|
||||
|
||||
class PreTrainedTokenizer(object):
|
||||
""" An abstract class to handle dowloading and loading pretrained tokenizers.
|
||||
""" An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary.
|
||||
|
||||
Derived class can set up a few special tokens to be used in common scripts and internals:
|
||||
bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token
|
||||
additional_special_tokens = []
|
||||
|
||||
We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the
|
||||
specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...).
|
||||
"""
|
||||
vocab_files_names = {}
|
||||
pretrained_vocab_files_map = {}
|
||||
max_model_input_sizes = {}
|
||||
|
||||
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
|
||||
"pad_token", "cls_token", "mask_token",
|
||||
"additional_special_tokens"]
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
if self._bos_token is None:
|
||||
logger.error("Using bos_token, but it is not set yet.")
|
||||
return self._bos_token
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
if self._eos_token is None:
|
||||
logger.error("Using eos_token, but it is not set yet.")
|
||||
return self._eos_token
|
||||
|
||||
@property
|
||||
def unk_token(self):
|
||||
if self._unk_token is None:
|
||||
logger.error("Using unk_token, but it is not set yet.")
|
||||
return self._unk_token
|
||||
|
||||
@property
|
||||
def sep_token(self):
|
||||
if self._sep_token is None:
|
||||
logger.error("Using sep_token, but it is not set yet.")
|
||||
return self._sep_token
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
if self._pad_token is None:
|
||||
logger.error("Using pad_token, but it is not set yet.")
|
||||
return self._pad_token
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
if self._cls_token is None:
|
||||
logger.error("Using cls_token, but it is not set yet.")
|
||||
return self._cls_token
|
||||
|
||||
@property
|
||||
def mask_token(self):
|
||||
if self._mask_token is None:
|
||||
logger.error("Using mask_token, but it is not set yet.")
|
||||
return self._mask_token
|
||||
|
||||
@property
|
||||
def additional_special_tokens(self):
|
||||
if self._additional_special_tokens is None:
|
||||
logger.error("Using additional_special_tokens, but it is not set yet.")
|
||||
return self._additional_special_tokens
|
||||
|
||||
@bos_token.setter
|
||||
def bos_token(self, value):
|
||||
self._bos_token = value
|
||||
|
||||
@eos_token.setter
|
||||
def eos_token(self, value):
|
||||
self._eos_token = value
|
||||
|
||||
@unk_token.setter
|
||||
def unk_token(self, value):
|
||||
self._unk_token = value
|
||||
|
||||
@sep_token.setter
|
||||
def sep_token(self, value):
|
||||
self._sep_token = value
|
||||
|
||||
@pad_token.setter
|
||||
def pad_token(self, value):
|
||||
self._pad_token = value
|
||||
|
||||
@cls_token.setter
|
||||
def cls_token(self, value):
|
||||
self._cls_token = value
|
||||
|
||||
@mask_token.setter
|
||||
def mask_token(self, value):
|
||||
self._mask_token = value
|
||||
|
||||
@additional_special_tokens.setter
|
||||
def additional_special_tokens(self, value):
|
||||
self._additional_special_tokens = value
|
||||
|
||||
def __init__(self, max_len=None, **kwargs):
|
||||
self._bos_token = None
|
||||
self._eos_token = None
|
||||
self._unk_token = None
|
||||
self._sep_token = None
|
||||
self._pad_token = None
|
||||
self._cls_token = None
|
||||
self._mask_token = None
|
||||
self._additional_special_tokens = []
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.added_tokens_encoder = {}
|
||||
self.added_tokens_decoder = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
raise ValueError(
|
||||
"PreTrainedTokenizer.__init__() argument {} should be in {}".format(
|
||||
key, ', '.join(self.SPECIAL_TOKENS_ATTRIBUTES)))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *inputs, **kwargs):
|
||||
return cls._from_pretrained(*inputs, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
@ -59,16 +167,20 @@ class PreTrainedTokenizer(object):
|
||||
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
||||
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
||||
else:
|
||||
for file_id, file_name in cls.vocab_files_names.items():
|
||||
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
|
||||
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
|
||||
all_vocab_files_names.update(cls.vocab_files_names)
|
||||
for file_id, file_name in all_vocab_files_names.items():
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||
else:
|
||||
full_file_name = pretrained_model_name_or_path
|
||||
if not os.path.exists(full_file_name):
|
||||
logger.info("Didn't find file {}. We don't load it.".format(full_file_name))
|
||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||
full_file_name = None
|
||||
vocab_files[file_id] = full_file_name
|
||||
# redirect to the cache, if necessary
|
||||
|
||||
# Get files from url, cache, or disk depending on the case
|
||||
try:
|
||||
resolved_vocab_files = {}
|
||||
for file_id, file_path in vocab_files.items():
|
||||
@ -95,6 +207,7 @@ class PreTrainedTokenizer(object):
|
||||
logger.info("loading file {} from cache at {}".format(
|
||||
file_path, resolved_vocab_files[file_id]))
|
||||
|
||||
# Set max length if needed
|
||||
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
||||
# if we're using a pretrained model, ensure the tokenizer
|
||||
# wont index sequences longer than the number of positional embeddings
|
||||
@ -102,31 +215,255 @@ class PreTrainedTokenizer(object):
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
|
||||
# Merge resolved_vocab_files arguments in kwargs.
|
||||
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
||||
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
|
||||
for args_name, file_path in resolved_vocab_files.items():
|
||||
if args_name not in kwargs:
|
||||
kwargs[args_name] = file_path
|
||||
if special_tokens_map_file is not None:
|
||||
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
||||
for key, value in special_tokens_map.items():
|
||||
if key not in kwargs:
|
||||
kwargs[key] = value
|
||||
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(*inputs, **kwargs)
|
||||
|
||||
# Add supplementary tokens.
|
||||
if added_tokens_file is not None:
|
||||
added_tokens = json.load(open(added_tokens_file, encoding="utf-8"))
|
||||
added_tok_encoder = dict((tok, len(tokenizer) + i) for i, tok in enumerate(added_tokens))
|
||||
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
||||
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text):
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save the tokenizer vocabulary files (with added tokens) and the
|
||||
special-tokens-to-class-attributes-mapping to a directory, so that it
|
||||
can be re-loaded using the `from_pretrained(save_directory)` class method.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Saving directory ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
|
||||
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
|
||||
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
|
||||
|
||||
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
||||
|
||||
with open(added_tokens_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.added_tokens_decoder, ensure_ascii=False))
|
||||
|
||||
vocab_files = self.save_vocabulary(save_directory)
|
||||
|
||||
return vocab_files + (special_tokens_map_file, added_tokens_file)
|
||||
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the tokenizer vocabulary to a directory. This method doesn't save added tokens
|
||||
and special token mappings.
|
||||
|
||||
Please use `save_pretrained()` to save the full Tokenizer state so that it can be
|
||||
reloaded using the `from_pretrained(save_directory)` class method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vocab_size(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.vocab_size + len(self.added_tokens_encoder)
|
||||
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to the added_tokens_encoder with indices starting from
|
||||
the last index of the current vocabulary.
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary which can be used to correspondingly
|
||||
increase the size of the associated model embedding matrices.
|
||||
"""
|
||||
if not new_tokens:
|
||||
return 0
|
||||
|
||||
to_add_tokens = []
|
||||
for token in new_tokens:
|
||||
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
|
||||
to_add_tokens.append(token)
|
||||
logger.info("Adding %s to the vocabulary", token)
|
||||
|
||||
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
|
||||
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||
self.added_tokens_encoder.update(added_tok_encoder)
|
||||
self.added_tokens_decoder.update(added_tok_decoder)
|
||||
|
||||
return len(to_add_tokens)
|
||||
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
""" Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them
|
||||
to class attributes. If the special tokens are not in the vocabulary, they are added
|
||||
to it and indexed starting from the last index of the current vocabulary.
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary which can be used to correspondingly
|
||||
increase the size of the associated model embedding matrices.
|
||||
"""
|
||||
if not special_tokens_dict:
|
||||
return 0
|
||||
|
||||
added_special_tokens = self.add_tokens(special_tokens_dict.values())
|
||||
for key, value in special_tokens_dict.items():
|
||||
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
||||
setattr(self, key, value)
|
||||
|
||||
return added_special_tokens
|
||||
|
||||
|
||||
def tokenize(self, text, **kwargs):
|
||||
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||
Split in words for word-based vocabulary or sub-words for sub-word-based
|
||||
vocabularies (BPE/SentencePieces/WordPieces).
|
||||
|
||||
Take care of added tokens.
|
||||
"""
|
||||
def split_on_tokens(tok_list, text):
|
||||
if not text:
|
||||
return []
|
||||
if not tok_list:
|
||||
return self._tokenize(text, **kwargs)
|
||||
tok = tok_list[0]
|
||||
split_text = text.split(tok)
|
||||
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
|
||||
for sub_text in split_text), [])[:-1]
|
||||
|
||||
added_tokens = list(self.added_tokens_encoder.keys())
|
||||
tokenized_text = split_on_tokens(added_tokens, text)
|
||||
return tokenized_text
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||
Split in words for word-based vocabulary or sub-words for sub-word-based
|
||||
vocabularies (BPE/SentencePieces/WordPieces).
|
||||
|
||||
Don't take care of added tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a single token or a sequence of tokens (str/unicode) in a integer id
|
||||
(resp.) a sequence of ids, using the vocabulary.
|
||||
"""
|
||||
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
||||
return self.convert_token_to_id_with_added_voc(tokens)
|
||||
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self.convert_token_to_id_with_added_voc(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
|
||||
def convert_token_to_id_with_added_voc(self, token):
|
||||
if token in self.added_tokens_encoder:
|
||||
return self.added_tokens_encoder[token]
|
||||
return self._convert_token_to_id(token)
|
||||
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
||||
|
||||
Args:
|
||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||
"""
|
||||
if isinstance(ids, int):
|
||||
return self.convert_id_to_token(ids)
|
||||
tokens = []
|
||||
for index in ids:
|
||||
if index in self.all_special_ids and skip_special_tokens:
|
||||
continue
|
||||
if index in self.added_tokens_decoder:
|
||||
tokens.append(self.added_tokens_decoder[index])
|
||||
else:
|
||||
tokens.append(self._convert_id_to_token(index))
|
||||
return tokens
|
||||
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
raise NotImplementedError
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
same as self.convert_tokens_to_ids(self.tokenize(text)).
|
||||
"""
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, token_ids, *input, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
raise NotImplementedError
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||
with options to remove special tokens and clean up tokenization spaces.
|
||||
"""
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
text = self._convert_ids_to_string(filtered_tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
|
||||
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
|
||||
"""
|
||||
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
|
||||
|
||||
@property
|
||||
def special_tokens_map(self):
|
||||
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
|
||||
values ('<unk>', '<cls>'...)
|
||||
"""
|
||||
set_attr = {}
|
||||
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
attr_value = getattr(self, "_" + attr)
|
||||
if attr_value:
|
||||
set_attr[attr] = attr_value
|
||||
return set_attr
|
||||
|
||||
@property
|
||||
def all_special_tokens(self):
|
||||
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
|
||||
(cls_token, unk_token...).
|
||||
"""
|
||||
all_toks = []
|
||||
set_attr = self.special_tokens_map
|
||||
for attr_value in set_attr.values():
|
||||
all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value])
|
||||
all_toks = list(set(all_toks))
|
||||
return all_toks
|
||||
|
||||
@property
|
||||
def all_special_ids(self):
|
||||
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
|
||||
class attributes (cls_token, unk_token...).
|
||||
"""
|
||||
all_toks = self.all_special_tokens
|
||||
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
|
||||
return all_ids
|
||||
|
||||
|
||||
|
||||
def clean_up_tokenization(out_string):
|
||||
|
@ -34,7 +34,6 @@ logger = logging.getLogger(__name__)
|
||||
VOCAB_FILES_NAMES = {
|
||||
'vocab_file': 'vocab.json',
|
||||
'merges_file': 'merges.txt',
|
||||
'special_tokens_file': 'special_tokens.txt'
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
@ -46,24 +45,12 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
{
|
||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
||||
},
|
||||
'special_tokens_file':
|
||||
{
|
||||
'xlm-mlm-en-2048': None,
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'xlm-mlm-en-2048': 512,
|
||||
}
|
||||
|
||||
INDEX = {
|
||||
"bos_index": 0,
|
||||
"eos_index": 1,
|
||||
"pad_index": 2,
|
||||
"unk_index": 3,
|
||||
"mask_index": 5
|
||||
}
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
@ -103,7 +90,16 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None):
|
||||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
|
||||
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
|
||||
mask_token="<special1>", additional_special_tokens=["<special0>",
|
||||
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
|
||||
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs):
|
||||
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
|
||||
sep_token=sep_token, pad_token=pad_token,
|
||||
cls_token=cls_token, mask_token=mask_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs)
|
||||
try:
|
||||
import ftfy
|
||||
import spacy
|
||||
@ -111,11 +107,9 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
self.fix_text = ftfy.fix_text
|
||||
except ImportError:
|
||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||
self.nlp = BasicTokenizer(do_lower_case=True,
|
||||
never_split=special_tokens if special_tokens is not None else [])
|
||||
self.nlp = BasicTokenizer(do_lower_case=True)
|
||||
self.fix_text = None
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
@ -123,35 +117,9 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
|
||||
all_special_tokens = []
|
||||
if special_tokens_file is not None:
|
||||
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
all_special_tokens.extend(special_tokens_to_add)
|
||||
if special_tokens is not None and special_tokens:
|
||||
all_special_tokens.extend(special_tokens)
|
||||
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(all_special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||
if self.fix_text is None:
|
||||
# Using BERT's BasicTokenizer: we can update the tokenizer
|
||||
self.nlp.never_split = special_tokens
|
||||
logger.info("Special tokens {}".format(self.special_tokens))
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
||||
def bpe(self, token):
|
||||
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
||||
@ -196,7 +164,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def tokenize(self, text):
|
||||
def _tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
split_tokens = []
|
||||
if self.fix_text is None:
|
||||
@ -211,58 +179,26 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens, 0)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.replace('<unk>', '')
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
def save_vocabulary(self, save_directory):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
|
||||
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
|
||||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
@ -277,14 +213,4 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
index = len(self.encoder)
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
return vocab_file, merge_file
|
||||
|
@ -16,17 +16,13 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from shutil import copyfile
|
||||
from io import open
|
||||
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -44,8 +40,6 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'xlnet-large-cased': 512,
|
||||
}
|
||||
|
||||
VOCAB_NAME = 'spiece.model'
|
||||
|
||||
SPIECE_UNDERLINE = u'▁'
|
||||
|
||||
# Segments (not really needed)
|
||||
@ -60,31 +54,26 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
SentencePiece based tokenizer. Peculiarities:
|
||||
- requires SentencePiece: https://github.com/google/sentencepiece
|
||||
"""
|
||||
# Tokens
|
||||
special_symbols = {
|
||||
"<unk>" : 0,
|
||||
"<s>" : 1,
|
||||
"</s>" : 2,
|
||||
"<cls>" : 3,
|
||||
"<sep>" : 4,
|
||||
"<pad>" : 5,
|
||||
"<mask>" : 6,
|
||||
"<eod>" : 7,
|
||||
"<eop>" : 8,
|
||||
}
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, max_len=None,
|
||||
do_lower_case=False, remove_space=True, keep_accents=False):
|
||||
do_lower_case=False, remove_space=True, keep_accents=False,
|
||||
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
|
||||
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
|
||||
additional_special_tokens=["<eop>", "<eod>"], **kwargs):
|
||||
super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token,
|
||||
unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, additional_special_tokens=
|
||||
additional_special_tokens, **kwargs)
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
except ImportError:
|
||||
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
||||
"pip install sentencepiece")
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.remove_space = remove_space
|
||||
self.keep_accents = keep_accents
|
||||
@ -94,46 +83,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
self.sp_model.Load(vocab_file)
|
||||
|
||||
@property
|
||||
def UNK_TOKEN(self):
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def SEP_TOKEN(self):
|
||||
return "<sep>"
|
||||
|
||||
@property
|
||||
def PAD_TOKEN(self):
|
||||
return "<pad>"
|
||||
|
||||
@property
|
||||
def CLS_TOKEN(self):
|
||||
return "<cls>"
|
||||
|
||||
@property
|
||||
def MASK_TOKEN(self):
|
||||
return "<mask>"
|
||||
|
||||
@property
|
||||
def UNK_ID(self):
|
||||
return self.special_symbols["<unk>"]
|
||||
|
||||
@property
|
||||
def SEP_ID(self):
|
||||
return self.special_symbols["<sep>"]
|
||||
|
||||
@property
|
||||
def PAD_ID(self):
|
||||
return self.special_symbols["<pad>"]
|
||||
|
||||
@property
|
||||
def CLS_ID(self):
|
||||
return self.special_symbols["<cls>"]
|
||||
|
||||
@property
|
||||
def MASK_ID(self):
|
||||
return self.special_symbols["<mask>"]
|
||||
|
||||
def __len__(self):
|
||||
def vocab_size(self):
|
||||
return len(self.sp_model)
|
||||
|
||||
def __getstate__(self):
|
||||
@ -169,7 +119,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
|
||||
return outputs
|
||||
|
||||
def tokenize(self, text, return_unicode=True, sample=False):
|
||||
def _tokenize(self, text, return_unicode=True, sample=False):
|
||||
""" Tokenize a string.
|
||||
return_unicode is used only for py2
|
||||
"""
|
||||
@ -208,56 +158,30 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
|
||||
return new_pieces
|
||||
|
||||
def convert_tokens_to_ids(self, tokens, sample=False):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
return self.sp_model.PieceToId(tokens)
|
||||
for token in tokens:
|
||||
ids.append(self.sp_model.PieceToId(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this XLNet model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.sp_model.PieceToId(token)
|
||||
|
||||
def convert_ids_to_tokens(self, ids, return_unicode=True):
|
||||
"""Converts a sequence of ids in tokens."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.sp_model.IdToPiece(i))
|
||||
def _convert_id_to_token(self, index, return_unicode=True):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
if six.PY2 and return_unicode and isinstance(token, str):
|
||||
token = token.decode('utf-8')
|
||||
return token
|
||||
|
||||
if six.PY2 and return_unicode:
|
||||
ret_pieces = []
|
||||
for piece in tokens:
|
||||
if isinstance(piece, str):
|
||||
piece = piece.decode('utf-8')
|
||||
ret_pieces.append(piece)
|
||||
tokens = ret_pieces
|
||||
return tokens
|
||||
|
||||
def encode(self, text, sample=False):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text, sample=sample))
|
||||
|
||||
def decode(self, ids, clean_up_tokenization_spaces=True):
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids)
|
||||
out_string = ''.join(tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.strip().replace('<unk>', '')
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
out_string = ''.join(tokens_ids)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
"""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
out_vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user