mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 13:08:21 +06:00
unified tokenizer api and serialization + tests
This commit is contained in:
parent
3d5f291386
commit
b19786985d
@ -32,9 +32,11 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification,
|
||||||
from pytorch_transformers.modeling_bert import BertForSequenceClassification
|
XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
|
||||||
|
XLMTokenizer)
|
||||||
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
|
||||||
|
|
||||||
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
||||||
@ -42,6 +44,21 @@ from utils_glue import processors, output_modes, convert_examples_to_features, c
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
'bert': BertForSequenceClassification,
|
||||||
|
'xlnet': XLNetForSequenceClassification,
|
||||||
|
'xlm': XLMForSequenceClassification,
|
||||||
|
}
|
||||||
|
|
||||||
|
TOKENIZER_CLASSES = {
|
||||||
|
'bert': BertTokenizer,
|
||||||
|
'xlnet': XLNetTokenizer,
|
||||||
|
'xlm': XLMTokenizer,
|
||||||
|
}
|
||||||
|
|
||||||
def train(args, train_features, model):
|
def train(args, train_features, model):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
@ -156,7 +173,7 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model):
|
|||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation *****")
|
logger.info("***** Running evaluation *****")
|
||||||
logger.info(" Num examples = %d", len(eval_examples))
|
logger.info(" Num examples = %d", len(eval_features))
|
||||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
model.eval()
|
model.eval()
|
||||||
eval_loss = 0
|
eval_loss = 0
|
||||||
@ -208,7 +225,7 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
|
|||||||
examples = processor.get_dev_examples(args.data_dir)
|
examples = processor.get_dev_examples(args.data_dir)
|
||||||
cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format(
|
cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format(
|
||||||
'dev' if eval else 'train',
|
'dev' if eval else 'train',
|
||||||
list(filter(None, args.bert_model.split('/'))).pop(),
|
list(filter(None, args.model_name.split('/'))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task)))
|
str(task)))
|
||||||
|
|
||||||
@ -217,6 +234,11 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
|
|||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||||
|
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||||
|
cls_token_at_end=bool(args.model_type not in ['bert', 'xlm']),
|
||||||
|
cls_token=tokenizer.cls_token,
|
||||||
|
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||||
|
pad_on_left=True, pad_token_segment_id=4)
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save(features, cached_features_file)
|
||||||
@ -230,12 +252,10 @@ def main():
|
|||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
parser.add_argument("--model_name", default=None, type=str, required=True,
|
||||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
|
||||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
|
||||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||||
help="The name of the task to train.")
|
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
help="The output directory where the model predictions and checkpoints will be written.")
|
||||||
|
|
||||||
@ -243,9 +263,8 @@ def main():
|
|||||||
parser.add_argument("--cache_dir", default="", type=str,
|
parser.add_argument("--cache_dir", default="", type=str,
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
"than this will be truncated, sequences shorter will be padded.")
|
||||||
"than this will be padded.")
|
|
||||||
parser.add_argument("--do_train", action='store_true',
|
parser.add_argument("--do_train", action='store_true',
|
||||||
help="Whether to run training.")
|
help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action='store_true',
|
||||||
@ -263,8 +282,7 @@ def main():
|
|||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||||
help="Proportion of training to perform linear learning rate warmup for. "
|
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
||||||
"E.g., 0.1 = 10%% of training.")
|
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
parser.add_argument("--no_cuda", action='store_true',
|
||||||
help="Avoid using CUDA when available")
|
help="Avoid using CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||||
@ -331,8 +349,11 @@ def main():
|
|||||||
# Make sure only the first process in distributed training will download model & vocab
|
# Make sure only the first process in distributed training will download model & vocab
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
args.model_type = args.model_name.lower().split('-')[0]
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
args.tokenizer_class = TOKENIZER_CLASSES[args.model_type]
|
||||||
|
args.model_class = MODEL_CLASSES[args.model_type]
|
||||||
|
tokenizer = args.tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
|
||||||
|
model = args.model_class.from_pretrained(args.model_name, num_labels=num_labels)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
@ -359,27 +380,16 @@ def main():
|
|||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Save a trained model, configuration and tokenizer
|
# Save a trained model, configuration and tokenizer
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
model.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
|
||||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
|
||||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
|
||||||
|
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
|
||||||
model_to_save.config.to_json_file(output_config_file)
|
|
||||||
tokenizer.save_vocabulary(args.output_dir)
|
tokenizer.save_vocabulary(args.output_dir)
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
|
||||||
model = BertForSequenceClassification.from_pretrained(args.output_dir)
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
|
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||||
torch.save(args, output_args_file)
|
|
||||||
else:
|
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model)
|
|
||||||
|
|
||||||
model.to(args.device)
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
|
model = args.model_class.from_pretrained(args.output_dir)
|
||||||
|
tokenizer = args.tokenizer_class.from_pretrained(args.output_dir)
|
||||||
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
|
@ -211,8 +211,8 @@ def main():
|
|||||||
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
|
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
|
||||||
train_features = convert_examples_to_features(
|
train_features = convert_examples_to_features(
|
||||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
cls_token_at_end=True, cls_token=tokenizer.cls_token,
|
||||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
|
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||||
pad_on_left=True, pad_token_segment_id=4)
|
pad_on_left=True, pad_token_segment_id=4)
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||||
@ -369,8 +369,8 @@ def main():
|
|||||||
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
|
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
cls_token_at_end=True, cls_token=tokenizer.cls_token,
|
||||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
|
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
|
||||||
pad_on_left=True, pad_token_segment_id=4)
|
pad_on_left=True, pad_token_segment_id=4)
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
||||||
|
@ -396,7 +396,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
|||||||
mask_padding_with_zero=True):
|
mask_padding_with_zero=True):
|
||||||
""" Loads a data file into a list of `InputBatch`s
|
""" Loads a data file into a list of `InputBatch`s
|
||||||
`cls_token_at_end` define the location of the CLS token:
|
`cls_token_at_end` define the location of the CLS token:
|
||||||
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
|
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||||
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
||||||
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
||||||
"""
|
"""
|
||||||
@ -489,8 +489,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
|||||||
[str(x) for x in tokens]))
|
[str(x) for x in tokens]))
|
||||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||||
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||||
logger.info(
|
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
|
||||||
logger.info("label: %s (id = %d)" % (example.label, label_id))
|
logger.info("label: %s (id = %d)" % (example.label, label_id))
|
||||||
|
|
||||||
features.append(
|
features.append(
|
||||||
|
@ -11,22 +11,28 @@ from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
|
|||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForMultipleChoice,
|
BertForSequenceClassification, BertForMultipleChoice,
|
||||||
BertForTokenClassification, BertForQuestionAnswering,
|
BertForTokenClassification, BertForQuestionAnswering,
|
||||||
load_tf_weights_in_bert)
|
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||||
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||||
load_tf_weights_in_openai_gpt)
|
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
|
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
|
||||||
load_tf_weights_in_transfo_xl)
|
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
||||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||||
load_tf_weights_in_gpt2)
|
load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_xlnet import (XLNetConfig,
|
from .modeling_xlnet import (XLNetConfig,
|
||||||
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
||||||
XLNetForSequenceClassification, XLNetForQuestionAnswering,
|
XLNetForSequenceClassification, XLNetForQuestionAnswering,
|
||||||
load_tf_weights_in_xlnet)
|
load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_xlm import (XLMConfig, XLMModel,
|
from .modeling_xlm import (XLMConfig, XLMModel,
|
||||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||||
XLMForQuestionAnswering)
|
XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
||||||
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
||||||
|
|
||||||
|
@ -29,8 +29,7 @@ from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME,
|
|||||||
TransfoXLConfig,
|
TransfoXLConfig,
|
||||||
TransfoXLLMHeadModel,
|
TransfoXLLMHeadModel,
|
||||||
load_tf_weights_in_transfo_xl)
|
load_tf_weights_in_transfo_xl)
|
||||||
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME,
|
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
|
||||||
VOCAB_NAME)
|
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
@ -53,7 +52,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
with open(transfo_xl_dataset_file, "rb") as fp:
|
with open(transfo_xl_dataset_file, "rb") as fp:
|
||||||
corpus = pickle.load(fp, encoding="latin1")
|
corpus = pickle.load(fp, encoding="latin1")
|
||||||
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
||||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
|
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
|
||||||
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
|
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
|
||||||
corpus_vocab_dict = corpus.vocab.__dict__
|
corpus_vocab_dict = corpus.vocab.__dict__
|
||||||
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
|
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
|
||||||
|
@ -24,7 +24,7 @@ import torch
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
|
from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
|
||||||
from pytorch_transformers.tokenization_xlm import MERGES_NAME, VOCAB_NAME
|
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
|
||||||
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
||||||
@ -42,7 +42,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
|
|||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
|
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
|
||||||
|
|
||||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||||
torch.save(model, pytorch_weights_dump_path)
|
torch.save(model, pytorch_weights_dump_path)
|
||||||
|
@ -33,7 +33,7 @@ from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrai
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
||||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
||||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
||||||
@ -49,7 +49,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
||||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
||||||
@ -152,7 +152,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
|||||||
class BertConfig(PretrainedConfig):
|
class BertConfig(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `BertModel`.
|
"""Configuration class to store the configuration of a `BertModel`.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vocab_size_or_config_json_file=30522,
|
vocab_size_or_config_json_file=30522,
|
||||||
@ -543,7 +543,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
|||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_bert
|
load_tf_weights = load_tf_weights_in_bert
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
|
|
||||||
|
@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
||||||
|
|
||||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||||
@ -103,7 +103,7 @@ def gelu(x):
|
|||||||
class GPT2Config(PretrainedConfig):
|
class GPT2Config(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `GPT2Model`.
|
"""Configuration class to store the configuration of a `GPT2Model`.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -358,7 +358,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
config_class = GPT2Config
|
config_class = GPT2Config
|
||||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_gpt2
|
load_tf_weights = load_tf_weights_in_gpt2
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
@ -37,8 +37,8 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
||||||
@ -130,7 +130,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
|||||||
class OpenAIGPTConfig(PretrainedConfig):
|
class OpenAIGPTConfig(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
|
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -384,7 +384,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
|||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
config_class = OpenAIGPTConfig
|
config_class = OpenAIGPTConfig
|
||||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_openai_gpt
|
load_tf_weights = load_tf_weights_in_openai_gpt
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
@ -41,10 +41,10 @@ from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrai
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
|||||||
class TransfoXLConfig(PretrainedConfig):
|
class TransfoXLConfig(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vocab_size_or_config_json_file=267735,
|
vocab_size_or_config_json_file=267735,
|
||||||
@ -838,7 +838,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
|||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
config_class = TransfoXLConfig
|
config_class = TransfoXLConfig
|
||||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_transfo_xl
|
load_tf_weights = load_tf_weights_in_transfo_xl
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module):
|
|||||||
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||||
model_to_prune._prune_heads(heads_to_prune)
|
model_to_prune._prune_heads(heads_to_prune)
|
||||||
|
|
||||||
|
def save_pretrained(self, save_directory):
|
||||||
|
""" Save a model with its configuration file to a directory, so that it
|
||||||
|
can be re-loaded using the `from_pretrained(save_directory)` class method.
|
||||||
|
"""
|
||||||
|
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||||
|
|
||||||
|
# Only save the model it-self if we are using distributed training
|
||||||
|
model_to_save = self.module if hasattr(self, 'module') else self
|
||||||
|
|
||||||
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
|
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||||
|
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||||
|
|
||||||
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -40,10 +40,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
|
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
class XLMConfig(PretrainedConfig):
|
class XLMConfig(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `XLMModel`.
|
"""Configuration class to store the configuration of a `XLMModel`.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vocab_size_or_config_json_file=30145,
|
vocab_size_or_config_json_file=30145,
|
||||||
@ -357,7 +357,7 @@ class XLMPreTrainedModel(PreTrainedModel):
|
|||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
config_class = XLMConfig
|
config_class = XLMConfig
|
||||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = None
|
load_tf_weights = None
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
@ -38,10 +38,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
|
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
|||||||
class XLNetConfig(PretrainedConfig):
|
class XLNetConfig(PretrainedConfig):
|
||||||
"""Configuration class to store the configuration of a `XLNetModel`.
|
"""Configuration class to store the configuration of a `XLNetModel`.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vocab_size_or_config_json_file=32000,
|
vocab_size_or_config_json_file=32000,
|
||||||
@ -593,7 +593,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
|
|||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
config_class = XLNetConfig
|
config_class = XLNetConfig
|
||||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_xlnet
|
load_tf_weights = load_tf_weights_in_xlnet
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
|
|||||||
BertForNextSentencePrediction, BertForPreTraining,
|
BertForNextSentencePrediction, BertForPreTraining,
|
||||||
BertForQuestionAnswering, BertForSequenceClassification,
|
BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
BertForTokenClassification, BertForMultipleChoice)
|
BertForTokenClassification, BertForMultipleChoice)
|
||||||
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||||
|
|
||||||
@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||||
shutil.rmtree(cache_dir)
|
shutil.rmtree(cache_dir)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
@ -413,7 +413,7 @@ class GPTModelTester(object):
|
|||||||
|
|
||||||
def create_and_check_model_from_pretrained(self):
|
def create_and_check_model_from_pretrained(self):
|
||||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||||
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
|
||||||
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
|
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
|
||||||
shutil.rmtree(cache_dir)
|
shutil.rmtree(cache_dir)
|
||||||
self.parent.assertIsNotNone(model)
|
self.parent.assertIsNotNone(model)
|
||||||
|
@ -26,7 +26,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
|
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
|
||||||
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase):
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||||
shutil.rmtree(cache_dir)
|
shutil.rmtree(cache_dir)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
@ -20,12 +20,12 @@ import unittest
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
||||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
class ModelUtilsTest(unittest.TestCase):
|
class ModelUtilsTest(unittest.TestCase):
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
config = BertConfig.from_pretrained(model_name)
|
config = BertConfig.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(config)
|
self.assertIsNotNone(config)
|
||||||
self.assertIsInstance(config, PretrainedConfig)
|
self.assertIsInstance(config, PretrainedConfig)
|
||||||
|
@ -21,7 +21,7 @@ import shutil
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
|
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
|
||||||
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase):
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||||
shutil.rmtree(cache_dir)
|
shutil.rmtree(cache_dir)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
@ -26,7 +26,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||||
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
||||||
|
|
||||||
@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase):
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||||
shutil.rmtree(cache_dir)
|
shutil.rmtree(cache_dir)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from io import open
|
from io import open
|
||||||
import shutil
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
WordpieceTokenizer,
|
WordpieceTokenizer,
|
||||||
_is_control, _is_punctuation,
|
_is_control, _is_punctuation,
|
||||||
_is_whitespace)
|
_is_whitespace, VOCAB_FILES_NAMES)
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||||
|
|
||||||
@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
vocab_tokens = [
|
vocab_tokens = [
|
||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
"##ing", ","
|
"##ing", ",", "low", "lowest",
|
||||||
]
|
]
|
||||||
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
vocab_directory = "/tmp/"
|
||||||
|
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
vocab_file = vocab_writer.name
|
vocab_file = vocab_writer.name
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
|
create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory)
|
||||||
|
|
||||||
tokenizer = BertTokenizer(vocab_file)
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
vocab = {}
|
vocab = {}
|
||||||
for (i, token) in enumerate(vocab_tokens):
|
for (i, token) in enumerate(vocab_tokens):
|
||||||
vocab[token] = i
|
vocab[token] = i
|
||||||
tokenizer = WordpieceTokenizer(vocab=vocab)
|
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
|
||||||
|
|
||||||
self.assertListEqual(tokenizer.tokenize(""), [])
|
self.assertListEqual(tokenizer.tokenize(""), [])
|
||||||
|
|
||||||
|
@ -17,8 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
|
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||||
|
|
||||||
@ -28,31 +29,31 @@ class GPT2TokenizationTest(unittest.TestCase):
|
|||||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
"lo", "low", "er",
|
"lo", "low", "er",
|
||||||
"low", "lowest", "newer", "wider"]
|
"low", "lowest", "newer", "wider", "<unk>"]
|
||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
special_tokens_map = {"unk_token": "<unk>"}
|
||||||
fp.write(json.dumps(vocab_tokens))
|
|
||||||
vocab_file = fp.name
|
|
||||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
|
||||||
fp.write("\n".join(merges))
|
|
||||||
merges_file = fp.name
|
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
|
with open(vocab_file, "w") as fp:
|
||||||
|
fp.write(json.dumps(vocab_tokens))
|
||||||
|
with open(merges_file, "w") as fp:
|
||||||
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||||
text = "lower"
|
|
||||||
bpe_tokens = ["low", "er"]
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
self.assertListEqual(tokens, bpe_tokens)
|
|
||||||
|
|
||||||
input_tokens = tokens + ["<unk>"]
|
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
||||||
input_bpe_tokens = [13, 12, 16]
|
text = "lower"
|
||||||
self.assertListEqual(
|
bpe_tokens = ["low", "er"]
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
tokens = tokenizer.tokenize(text)
|
||||||
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
os.remove(vocab_file)
|
input_tokens = tokens + [tokenizer.unk_token]
|
||||||
os.remove(merges_file)
|
input_bpe_tokens = [13, 12, 17]
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import shutil
|
import tempfile
|
||||||
import pytest
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer
|
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||||
|
|
||||||
@ -32,31 +31,31 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
|||||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
"w</w>", "r</w>", "t</w>",
|
"w</w>", "r</w>", "t</w>",
|
||||||
"lo", "low", "er</w>",
|
"lo", "low", "er</w>",
|
||||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
|
||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
||||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
|
||||||
fp.write(json.dumps(vocab_tokens))
|
|
||||||
vocab_file = fp.name
|
|
||||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
|
||||||
fp.write("\n".join(merges))
|
|
||||||
merges_file = fp.name
|
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
|
with open(vocab_file, "w") as fp:
|
||||||
|
fp.write(json.dumps(vocab_tokens))
|
||||||
|
with open(merges_file, "w") as fp:
|
||||||
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
|
||||||
os.remove(vocab_file)
|
|
||||||
os.remove(merges_file)
|
|
||||||
|
|
||||||
text = "lower"
|
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
||||||
bpe_tokens = ["low", "er</w>"]
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
self.assertListEqual(tokens, bpe_tokens)
|
|
||||||
|
|
||||||
input_tokens = tokens + ["<unk>"]
|
text = "lower"
|
||||||
input_bpe_tokens = [14, 15, 20]
|
bpe_tokens = ["low", "er</w>"]
|
||||||
self.assertListEqual(
|
tokens = tokenizer.tokenize(text)
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
|
input_tokens = tokens + ["<unk>"]
|
||||||
|
input_bpe_tokens = [14, 15, 20]
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
|
import tempfile
|
||||||
|
|
||||||
if sys.version_info[0] == 3:
|
if sys.version_info[0] == 3:
|
||||||
unicode = str
|
unicode = str
|
||||||
@ -28,22 +29,19 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||||
|
|
||||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||||
|
|
||||||
vocab_path="/tmp/"
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path)
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
tokenizer = tokenizer.from_pretrained(vocab_path)
|
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
for f in output_files:
|
|
||||||
os.remove(f)
|
|
||||||
|
|
||||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||||
tester.assertListEqual(before_tokens, after_tokens)
|
tester.assertListEqual(before_tokens, after_tokens)
|
||||||
|
|
||||||
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||||
|
|
||||||
text = u"Munich and Berlin are nice cities"
|
text = u"Munich and Berlin are nice cities"
|
||||||
filename = u"/tmp/tokenizer.bin"
|
filename = u"/tmp/tokenizer.bin"
|
||||||
@ -58,8 +56,54 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs
|
|||||||
tester.assertListEqual(subwords, subwords_loaded)
|
tester.assertListEqual(subwords, subwords_loaded)
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||||
|
|
||||||
|
vocab_size = tokenizer.vocab_size
|
||||||
|
all_size = len(tokenizer)
|
||||||
|
|
||||||
|
tester.assertNotEqual(vocab_size, 0)
|
||||||
|
tester.assertEqual(vocab_size, all_size)
|
||||||
|
|
||||||
|
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
|
||||||
|
added_toks = tokenizer.add_tokens(new_toks)
|
||||||
|
vocab_size_2 = tokenizer.vocab_size
|
||||||
|
all_size_2 = len(tokenizer)
|
||||||
|
|
||||||
|
tester.assertNotEqual(vocab_size_2, 0)
|
||||||
|
tester.assertEqual(vocab_size, vocab_size_2)
|
||||||
|
tester.assertEqual(added_toks, len(new_toks))
|
||||||
|
tester.assertEqual(all_size_2, all_size + len(new_toks))
|
||||||
|
|
||||||
|
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
|
||||||
|
tester.assertGreaterEqual(len(tokens), 4)
|
||||||
|
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||||
|
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||||
|
|
||||||
|
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
||||||
|
'pad_token': "<<<<<|||>|>>>>|>"}
|
||||||
|
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||||
|
vocab_size_3 = tokenizer.vocab_size
|
||||||
|
all_size_3 = len(tokenizer)
|
||||||
|
|
||||||
|
tester.assertNotEqual(vocab_size_3, 0)
|
||||||
|
tester.assertEqual(vocab_size, vocab_size_3)
|
||||||
|
tester.assertEqual(added_toks_2, len(new_toks_2))
|
||||||
|
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||||
|
|
||||||
|
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
||||||
|
|
||||||
|
tester.assertGreaterEqual(len(tokens), 6)
|
||||||
|
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||||
|
tester.assertGreater(tokens[0], tokens[1])
|
||||||
|
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||||
|
tester.assertGreater(tokens[-2], tokens[-3])
|
||||||
|
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
|
||||||
|
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||||
|
|
||||||
text = u"He is very happy, UNwant\u00E9d,running"
|
text = u"He is very happy, UNwant\u00E9d,running"
|
||||||
tokens = tokenizer.tokenize(text)
|
tokens = tokenizer.tokenize(text)
|
||||||
@ -75,5 +119,6 @@ def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs
|
|||||||
|
|
||||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
||||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||||
|
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||||
|
@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from io import open
|
from io import open
|
||||||
import shutil
|
import tempfile
|
||||||
import pytest
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer
|
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||||
|
|
||||||
@ -28,22 +27,23 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
vocab_tokens = [
|
vocab_tokens = [
|
||||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
|
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
|
||||||
|
"running", ",", "low", "l",
|
||||||
]
|
]
|
||||||
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
vocab_file = vocab_writer.name
|
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True)
|
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||||
|
|
||||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||||
os.remove(vocab_file)
|
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||||
|
|
||||||
def test_full_tokenizer_lower(self):
|
def test_full_tokenizer_lower(self):
|
||||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||||
|
@ -17,6 +17,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
import six
|
||||||
|
|
||||||
from pytorch_transformers import PreTrainedTokenizer
|
from pytorch_transformers import PreTrainedTokenizer
|
||||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
|
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||||
@ -27,8 +28,17 @@ class TokenizerUtilsTest(unittest.TestCase):
|
|||||||
for model_name in s3_models[:1]:
|
for model_name in s3_models[:1]:
|
||||||
tokenizer = tokenizer_class.from_pretrained(model_name)
|
tokenizer = tokenizer_class.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(tokenizer)
|
self.assertIsNotNone(tokenizer)
|
||||||
|
self.assertIsInstance(tokenizer, tokenizer_class)
|
||||||
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
|
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
|
||||||
|
|
||||||
|
for special_tok in tokenizer.all_special_tokens:
|
||||||
|
if six.PY2:
|
||||||
|
self.assertIsInstance(special_tok, unicode)
|
||||||
|
else:
|
||||||
|
self.assertIsInstance(special_tok, str)
|
||||||
|
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
|
||||||
|
self.assertIsInstance(special_tok_id, int)
|
||||||
|
|
||||||
def test_pretrained_tokenizers(self):
|
def test_pretrained_tokenizers(self):
|
||||||
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
|
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
|
||||||
|
|
||||||
|
@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import shutil
|
import tempfile
|
||||||
import pytest
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer
|
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||||
|
|
||||||
@ -31,31 +30,31 @@ class XLMTokenizationTest(unittest.TestCase):
|
|||||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
"w</w>", "r</w>", "t</w>",
|
"w</w>", "r</w>", "t</w>",
|
||||||
"lo", "low", "er</w>",
|
"lo", "low", "er</w>",
|
||||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
|
||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
|
||||||
fp.write(json.dumps(vocab_tokens))
|
|
||||||
vocab_file = fp.name
|
|
||||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
|
||||||
fp.write("\n".join(merges))
|
|
||||||
merges_file = fp.name
|
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
|
with open(vocab_file, "w") as fp:
|
||||||
|
fp.write(json.dumps(vocab_tokens))
|
||||||
|
with open(merges_file, "w") as fp:
|
||||||
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
|
||||||
os.remove(vocab_file)
|
|
||||||
os.remove(merges_file)
|
|
||||||
|
|
||||||
text = "lower"
|
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
||||||
bpe_tokens = ["low", "er</w>"]
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
self.assertListEqual(tokens, bpe_tokens)
|
|
||||||
|
|
||||||
input_tokens = tokens + ["<unk>"]
|
text = "lower"
|
||||||
input_bpe_tokens = [14, 15, 20]
|
bpe_tokens = ["low", "er</w>"]
|
||||||
self.assertListEqual(
|
tokens = tokenizer.tokenize(text)
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
|
input_tokens = tokens + ["<unk>"]
|
||||||
|
input_bpe_tokens = [14, 15, 20]
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -16,10 +16,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import shutil
|
import tempfile
|
||||||
import pytest
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES)
|
||||||
|
|
||||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||||
|
|
||||||
@ -29,34 +28,37 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
|||||||
class XLNetTokenizationTest(unittest.TestCase):
|
class XLNetTokenizationTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
|
|
||||||
|
|
||||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u'This is a test')
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.assertListEqual(
|
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
||||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
tokens = tokenizer.tokenize(u'This is a test')
|
||||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
|
||||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
|
||||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
|
|
||||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
|
||||||
self.assertListEqual(
|
|
||||||
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
|
||||||
602, 347, 347, 347, 3, 12, 66,
|
|
||||||
46, 72, 80, 6, 0, 4])
|
|
||||||
|
|
||||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
self.assertListEqual(
|
||||||
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||||
u'or', u'n', SPIECE_UNDERLINE + u'in',
|
|
||||||
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
|
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||||
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||||
u'<unk>', u'.'])
|
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||||
|
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
|
||||||
|
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
self.assertListEqual(
|
||||||
|
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
||||||
|
602, 347, 347, 347, 3, 12, 66,
|
||||||
|
46, 72, 80, 6, 0, 4])
|
||||||
|
|
||||||
|
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||||
|
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||||
|
u'or', u'n', SPIECE_UNDERLINE + u'in',
|
||||||
|
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
|
||||||
|
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||||
|
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||||
|
u'<unk>', u'.'])
|
||||||
|
|
||||||
def test_tokenizer_lower(self):
|
def test_tokenizer_lower(self):
|
||||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
|
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
|
||||||
|
@ -22,7 +22,6 @@ import os
|
|||||||
import unicodedata
|
import unicodedata
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .file_utils import cached_path
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -32,20 +31,21 @@ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
|
|||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
'vocab_file':
|
'vocab_file':
|
||||||
{
|
{
|
||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
||||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
||||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
||||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
||||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
|
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
|
||||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
|
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
|
||||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
|
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
|
||||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
|
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
|
||||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
|
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
|
||||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
|
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
|
||||||
}}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
'bert-base-uncased': 512,
|
'bert-base-uncased': 512,
|
||||||
@ -93,8 +93,9 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
|
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||||
|
mask_token="[MASK]", **kwargs):
|
||||||
"""Constructs a BertTokenizer.
|
"""Constructs a BertTokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -102,17 +103,18 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
do_lower_case: Whether to lower case the input
|
do_lower_case: Whether to lower case the input
|
||||||
Only has an effect when do_wordpiece_only=False
|
Only has an effect when do_wordpiece_only=False
|
||||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
||||||
max_len: An artificial maximum length to truncate tokenized sequences to;
|
|
||||||
Effective maximum length is always the minimum of this
|
|
||||||
value (if specified) and the underlying BERT model's
|
|
||||||
sequence length.
|
|
||||||
never_split: List of tokens which will never be split during tokenization.
|
never_split: List of tokens which will never be split during tokenization.
|
||||||
Only has an effect when do_wordpiece_only=False
|
Only has an effect when do_wordpiece_only=False
|
||||||
"""
|
"""
|
||||||
|
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||||
|
pad_token=pad_token, cls_token=cls_token,
|
||||||
|
mask_token=mask_token, **kwargs)
|
||||||
if not os.path.isfile(vocab_file):
|
if not os.path.isfile(vocab_file):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
||||||
|
if never_split is None:
|
||||||
|
never_split = self.all_special_tokens
|
||||||
self.vocab = load_vocab(vocab_file)
|
self.vocab = load_vocab(vocab_file)
|
||||||
self.ids_to_tokens = collections.OrderedDict(
|
self.ids_to_tokens = collections.OrderedDict(
|
||||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
@ -120,90 +122,34 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
if do_basic_tokenize:
|
if do_basic_tokenize:
|
||||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
||||||
never_split=never_split)
|
never_split=never_split)
|
||||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def UNK_TOKEN(self):
|
def vocab_size(self):
|
||||||
return "[UNK]"
|
return len(self.vocab)
|
||||||
|
|
||||||
@property
|
def _tokenize(self, text):
|
||||||
def SEP_TOKEN(self):
|
|
||||||
return "[SEP]"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def PAD_TOKEN(self):
|
|
||||||
return "[PAD]"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def CLS_TOKEN(self):
|
|
||||||
return "[CLS]"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def MASK_TOKEN(self):
|
|
||||||
return "[MASK]"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def UNK_ID(self):
|
|
||||||
return self.vocab["[UNK]"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def SEP_ID(self):
|
|
||||||
return self.vocab["[SEP]"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def PAD_ID(self):
|
|
||||||
return self.vocab["[PAD]"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def CLS_ID(self):
|
|
||||||
return self.vocab["[CLS]"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def MASK_ID(self):
|
|
||||||
return self.vocab["[MASK]"]
|
|
||||||
|
|
||||||
def tokenize(self, text):
|
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text):
|
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||||
split_tokens.append(sub_token)
|
split_tokens.append(sub_token)
|
||||||
else:
|
else:
|
||||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def _convert_token_to_id(self, token):
|
||||||
"""Converts a sequence of tokens into ids using the vocab."""
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
ids = []
|
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||||
for token in tokens:
|
|
||||||
ids.append(self.vocab[token])
|
|
||||||
if len(ids) > self.max_len:
|
|
||||||
logger.warning(
|
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
|
||||||
" sequence length for this BERT model ({} > {}). Running this"
|
|
||||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
|
||||||
)
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
tokens = []
|
return self.ids_to_tokens.get(index, self.unk_token)
|
||||||
for i in ids:
|
|
||||||
tokens.append(self.ids_to_tokens[i])
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def encode(self, text):
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
|
||||||
|
|
||||||
def decode(self, token_ids, clean_up_tokenization_spaces=True):
|
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of ids in a string."""
|
||||||
tokens = self.convert_ids_to_tokens(token_ids)
|
tokens = self.convert_ids_to_tokens(tokens_ids)
|
||||||
out_string = ''.join(tokens).replace(' ##', '').strip()
|
out_string = ''.join(tokens).replace(' ##', '').strip()
|
||||||
if clean_up_tokenization_spaces:
|
|
||||||
for special_tok in (self.UNK_TOKEN, self.SEP_TOKEN, self.PAD_TOKEN, self.CLS_TOKEN, self.MASK_TOKEN):
|
|
||||||
out_string = out_string.replace(special_tok, '')
|
|
||||||
out_string = clean_up_tokenization(out_string)
|
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, vocab_path):
|
||||||
@ -245,17 +191,20 @@ class BasicTokenizer(object):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
do_lower_case=True,
|
do_lower_case=True,
|
||||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
never_split=None):
|
||||||
"""Constructs a BasicTokenizer.
|
"""Constructs a BasicTokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
do_lower_case: Whether to lower case the input.
|
do_lower_case: Whether to lower case the input.
|
||||||
"""
|
"""
|
||||||
|
if never_split is None:
|
||||||
|
never_split = []
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
self.never_split = never_split
|
self.never_split = never_split
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text, never_split=None):
|
||||||
"""Tokenizes a piece of text."""
|
"""Tokenizes a piece of text."""
|
||||||
|
never_split = self.never_split + (never_split if never_split is not None else [])
|
||||||
text = self._clean_text(text)
|
text = self._clean_text(text)
|
||||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||||
# models. This is also applied to the English models now, but it doesn't
|
# models. This is also applied to the English models now, but it doesn't
|
||||||
@ -267,7 +216,7 @@ class BasicTokenizer(object):
|
|||||||
orig_tokens = whitespace_tokenize(text)
|
orig_tokens = whitespace_tokenize(text)
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
for token in orig_tokens:
|
for token in orig_tokens:
|
||||||
if self.do_lower_case and token not in self.never_split:
|
if self.do_lower_case and token not in never_split:
|
||||||
token = token.lower()
|
token = token.lower()
|
||||||
token = self._run_strip_accents(token)
|
token = self._run_strip_accents(token)
|
||||||
split_tokens.extend(self._run_split_on_punc(token))
|
split_tokens.extend(self._run_split_on_punc(token))
|
||||||
@ -286,9 +235,9 @@ class BasicTokenizer(object):
|
|||||||
output.append(char)
|
output.append(char)
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
def _run_split_on_punc(self, text):
|
def _run_split_on_punc(self, text, never_split=None):
|
||||||
"""Splits punctuation on a piece of text."""
|
"""Splits punctuation on a piece of text."""
|
||||||
if text in self.never_split:
|
if never_split is not None and text in never_split:
|
||||||
return [text]
|
return [text]
|
||||||
chars = list(text)
|
chars = list(text)
|
||||||
i = 0
|
i = 0
|
||||||
@ -360,7 +309,7 @@ class BasicTokenizer(object):
|
|||||||
class WordpieceTokenizer(object):
|
class WordpieceTokenizer(object):
|
||||||
"""Runs WordPiece tokenization."""
|
"""Runs WordPiece tokenization."""
|
||||||
|
|
||||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.unk_token = unk_token
|
self.unk_token = unk_token
|
||||||
self.max_input_chars_per_word = max_input_chars_per_word
|
self.max_input_chars_per_word = max_input_chars_per_word
|
||||||
|
@ -38,7 +38,6 @@ logger = logging.getLogger(__name__)
|
|||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
'vocab_file': 'vocab.json',
|
'vocab_file': 'vocab.json',
|
||||||
'merges_file': 'merges.txt',
|
'merges_file': 'merges.txt',
|
||||||
'special_tokens_file': 'special_tokens.txt'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
@ -52,11 +51,6 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
||||||
},
|
},
|
||||||
'special_tokens_file':
|
|
||||||
{
|
|
||||||
'gpt2': None,
|
|
||||||
'gpt2-medium': None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
@ -108,8 +102,10 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, errors='replace', max_len=None):
|
def __init__(self, vocab_file, merges_file, errors='replace',
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
|
||||||
|
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs)
|
||||||
|
|
||||||
self.encoder = json.load(open(vocab_file))
|
self.encoder = json.load(open(vocab_file))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
self.errors = errors # how to handle errors in decoding
|
self.errors = errors # how to handle errors in decoding
|
||||||
@ -123,32 +119,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||||
|
|
||||||
all_special_tokens = []
|
@property
|
||||||
if special_tokens_file is not None:
|
def vocab_size(self):
|
||||||
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
return len(self.encoder)
|
||||||
all_special_tokens.extend(special_tokens_to_add)
|
|
||||||
if special_tokens is not None and special_tokens:
|
|
||||||
all_special_tokens.extend(special_tokens)
|
|
||||||
|
|
||||||
self.special_tokens = {}
|
|
||||||
self.special_tokens_decoder = {}
|
|
||||||
self.set_special_tokens(all_special_tokens)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.encoder) + len(self.special_tokens)
|
|
||||||
|
|
||||||
def set_special_tokens(self, special_tokens):
|
|
||||||
""" Add a list of additional tokens to the encoder.
|
|
||||||
The additional tokens are indexed starting from the last index of the
|
|
||||||
current vocabulary in the order of the `special_tokens` list.
|
|
||||||
"""
|
|
||||||
if not special_tokens:
|
|
||||||
self.special_tokens = {}
|
|
||||||
self.special_tokens_decoder = {}
|
|
||||||
return
|
|
||||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
|
||||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
|
||||||
logger.info("Special tokens {}".format(self.special_tokens))
|
|
||||||
|
|
||||||
def bpe(self, token):
|
def bpe(self, token):
|
||||||
if token in self.cache:
|
if token in self.cache:
|
||||||
@ -191,7 +164,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
def tokenize(self, text):
|
def _tokenize(self, text):
|
||||||
""" Tokenize a string. """
|
""" Tokenize a string. """
|
||||||
bpe_tokens = []
|
bpe_tokens = []
|
||||||
for token in re.findall(self.pat, text):
|
for token in re.findall(self.pat, text):
|
||||||
@ -202,57 +175,27 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||||
return bpe_tokens
|
return bpe_tokens
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a sequence of tokens into ids using the vocab. """
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
ids = []
|
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
|
||||||
if tokens in self.special_tokens:
|
|
||||||
return self.special_tokens[tokens]
|
|
||||||
else:
|
|
||||||
return self.encoder.get(tokens, 0)
|
|
||||||
for token in tokens:
|
|
||||||
if token in self.special_tokens:
|
|
||||||
ids.append(self.special_tokens[token])
|
|
||||||
else:
|
|
||||||
ids.append(self.encoder.get(token, 0))
|
|
||||||
if len(ids) > self.max_len:
|
|
||||||
logger.warning(
|
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
|
||||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
|
||||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
|
||||||
)
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
tokens = []
|
return self.decoder.get(index, self.unk_token)
|
||||||
for i in ids:
|
|
||||||
if i in self.special_tokens_decoder:
|
|
||||||
if not skip_special_tokens:
|
|
||||||
tokens.append(self.special_tokens_decoder[i])
|
|
||||||
else:
|
|
||||||
tokens.append(self.decoder[i])
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def encode(self, text):
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
"""Converts a sequence of ids in a string."""
|
||||||
|
text = ''.join(tokens_ids)
|
||||||
def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
|
||||||
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
|
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
if clean_up_tokenization_spaces:
|
|
||||||
text = text.replace('<unk>', '')
|
|
||||||
text = clean_up_tokenization(text)
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, save_directory):
|
||||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||||
if not os.path.isdir(vocab_path):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
return
|
return
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
|
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
||||||
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
|
|
||||||
|
|
||||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||||
@ -268,14 +211,4 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
index = len(self.encoder)
|
return vocab_file, merge_file
|
||||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
|
||||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
|
||||||
if index != token_index:
|
|
||||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
|
||||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
|
||||||
index = token_index
|
|
||||||
writer.write(token + u'\n')
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
return vocab_file, merge_file, special_tokens_file
|
|
||||||
|
@ -20,13 +20,9 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from tqdm import tqdm
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from .file_utils import cached_path
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
|
||||||
from .tokenization_bert import BasicTokenizer
|
from .tokenization_bert import BasicTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -34,7 +30,6 @@ logger = logging.getLogger(__name__)
|
|||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
'vocab_file': 'vocab.json',
|
'vocab_file': 'vocab.json',
|
||||||
'merges_file': 'merges.txt',
|
'merges_file': 'merges.txt',
|
||||||
'special_tokens_file': 'special_tokens.txt'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
@ -46,10 +41,6 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
{
|
{
|
||||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
|
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
|
||||||
},
|
},
|
||||||
'special_tokens_file':
|
|
||||||
{
|
|
||||||
'openai-gpt': None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
@ -88,14 +79,14 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
|||||||
BPE tokenizer. Peculiarities:
|
BPE tokenizer. Peculiarities:
|
||||||
- lower case all inputs
|
- lower case all inputs
|
||||||
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
|
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
|
||||||
- argument special_tokens and function set_special_tokens:
|
|
||||||
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
|
|
||||||
"""
|
"""
|
||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None):
|
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
||||||
|
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ftfy
|
import ftfy
|
||||||
import spacy
|
import spacy
|
||||||
@ -103,11 +94,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
|||||||
self.fix_text = ftfy.fix_text
|
self.fix_text = ftfy.fix_text
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||||
self.nlp = BasicTokenizer(do_lower_case=True,
|
self.nlp = BasicTokenizer(do_lower_case=True)
|
||||||
never_split=special_tokens if special_tokens is not None else [])
|
|
||||||
self.fix_text = None
|
self.fix_text = None
|
||||||
|
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
|
||||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||||
@ -115,35 +104,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
|||||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
all_special_tokens = []
|
@property
|
||||||
if special_tokens_file is not None:
|
def vocab_size(self):
|
||||||
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
return len(self.encoder)
|
||||||
all_special_tokens.extend(special_tokens_to_add)
|
|
||||||
if special_tokens is not None and special_tokens:
|
|
||||||
all_special_tokens.extend(special_tokens)
|
|
||||||
|
|
||||||
self.special_tokens = {}
|
|
||||||
self.special_tokens_decoder = {}
|
|
||||||
self.set_special_tokens(all_special_tokens)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.encoder) + len(self.special_tokens)
|
|
||||||
|
|
||||||
def set_special_tokens(self, special_tokens):
|
|
||||||
""" Add a list of additional tokens to the encoder.
|
|
||||||
The additional tokens are indexed starting from the last index of the
|
|
||||||
current vocabulary in the order of the `special_tokens` list.
|
|
||||||
"""
|
|
||||||
if not special_tokens:
|
|
||||||
self.special_tokens = {}
|
|
||||||
self.special_tokens_decoder = {}
|
|
||||||
return
|
|
||||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
|
||||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
|
||||||
if self.fix_text is None:
|
|
||||||
# Using BERT's BasicTokenizer: we can update the tokenizer
|
|
||||||
self.nlp.never_split = special_tokens
|
|
||||||
logger.info("Special tokens {}".format(self.special_tokens))
|
|
||||||
|
|
||||||
def bpe(self, token):
|
def bpe(self, token):
|
||||||
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
||||||
@ -188,7 +151,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
def tokenize(self, text):
|
def _tokenize(self, text):
|
||||||
""" Tokenize a string. """
|
""" Tokenize a string. """
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.fix_text is None:
|
if self.fix_text is None:
|
||||||
@ -203,58 +166,26 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
|||||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a sequence of tokens into ids using the vocab. """
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
ids = []
|
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
|
||||||
if tokens in self.special_tokens:
|
|
||||||
return self.special_tokens[tokens]
|
|
||||||
else:
|
|
||||||
return self.encoder.get(tokens, 0)
|
|
||||||
for token in tokens:
|
|
||||||
if token in self.special_tokens:
|
|
||||||
ids.append(self.special_tokens[token])
|
|
||||||
else:
|
|
||||||
ids.append(self.encoder.get(token, 0))
|
|
||||||
if len(ids) > self.max_len:
|
|
||||||
logger.warning(
|
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
|
||||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
|
||||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
|
||||||
)
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
"""Converts an id in a token (BPE) using the vocab."""
|
||||||
tokens = []
|
return self.decoder.get(index, self.unk_token)
|
||||||
for i in ids:
|
|
||||||
if i in self.special_tokens_decoder:
|
|
||||||
if not skip_special_tokens:
|
|
||||||
tokens.append(self.special_tokens_decoder[i])
|
|
||||||
else:
|
|
||||||
tokens.append(self.decoder[i])
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def encode(self, text):
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
|
||||||
|
|
||||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of ids in a string."""
|
||||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
||||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
|
||||||
if clean_up_tokenization_spaces:
|
|
||||||
out_string = out_string.replace('<unk>', '')
|
|
||||||
out_string = clean_up_tokenization(out_string)
|
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, save_directory):
|
||||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||||
if not os.path.isdir(vocab_path):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
return
|
return
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
|
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
||||||
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
|
|
||||||
|
|
||||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||||
@ -270,14 +201,4 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
|||||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
index = len(self.encoder)
|
return vocab_file, merge_file
|
||||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
|
||||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
|
||||||
if index != token_index:
|
|
||||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
|
||||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
|
||||||
index = token_index
|
|
||||||
writer.write(token + u'\n')
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
return vocab_file, merge_file, special_tokens_file
|
|
||||||
|
@ -41,7 +41,7 @@ else:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin'}
|
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'}
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
'pretrained_vocab_file':
|
'pretrained_vocab_file':
|
||||||
@ -67,9 +67,17 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
|
def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,
|
||||||
delimiter=None, vocab_file=None, pretrained_vocab_file=None,
|
delimiter=None, vocab_file=None, pretrained_vocab_file=None,
|
||||||
never_split=("<unk>", "<eos>", "<formula>")):
|
never_split=None, unk_token="<unk>", eos_token="<eos>",
|
||||||
|
additional_special_tokens=["<formula>"], **kwargs):
|
||||||
|
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
|
||||||
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
**kwargs)
|
||||||
|
if never_split is None:
|
||||||
|
never_split = self.all_special_tokens
|
||||||
|
if special is None:
|
||||||
|
special = []
|
||||||
self.counter = Counter()
|
self.counter = Counter()
|
||||||
self.special = special
|
self.special = special
|
||||||
self.min_freq = min_freq
|
self.min_freq = min_freq
|
||||||
@ -200,11 +208,13 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
self.idx2sym.append(sym)
|
self.idx2sym.append(sym)
|
||||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||||
|
|
||||||
def get_sym(self, idx):
|
def _convert_id_to_token(self, idx):
|
||||||
|
"""Converts an id in a token (BPE) using the vocab."""
|
||||||
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
|
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
|
||||||
return self.idx2sym[idx]
|
return self.idx2sym[idx]
|
||||||
|
|
||||||
def get_idx(self, sym):
|
def _convert_token_to_id(self, sym):
|
||||||
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
if sym in self.sym2idx:
|
if sym in self.sym2idx:
|
||||||
return self.sym2idx[sym]
|
return self.sym2idx[sym]
|
||||||
else:
|
else:
|
||||||
@ -220,36 +230,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
|
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, indices):
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
"""Converts a sequence of indices in symbols using the vocab."""
|
"""Converts a sequence of ids in a string."""
|
||||||
return [self.get_sym(idx) for idx in indices]
|
out_string = ' '.join(tokens_ids).strip()
|
||||||
|
return out_string
|
||||||
def convert_tokens_to_ids(self, symbols):
|
|
||||||
"""Converts a sequence of symbols into ids using the vocab."""
|
|
||||||
return [self.get_idx(sym) for sym in symbols]
|
|
||||||
|
|
||||||
def convert_to_tensor(self, symbols):
|
def convert_to_tensor(self, symbols):
|
||||||
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
||||||
|
|
||||||
def encode(self, text):
|
@property
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
def vocab_size(self):
|
||||||
|
|
||||||
def decode(self, indices, exclude=None, clean_up_tokenization_spaces=True):
|
|
||||||
"""Converts a sequence of indices in a string."""
|
|
||||||
if exclude is None:
|
|
||||||
out_string = ' '.join([self.get_sym(idx) for idx in indices])
|
|
||||||
else:
|
|
||||||
out_string = ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
|
|
||||||
|
|
||||||
if clean_up_tokenization_spaces:
|
|
||||||
out_string = clean_up_tokenization(out_string)
|
|
||||||
|
|
||||||
return out_string
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.idx2sym)
|
return len(self.idx2sym)
|
||||||
|
|
||||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
def _tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
# convert to lower case
|
# convert to lower case
|
||||||
if self.lower_case:
|
if self.lower_case:
|
||||||
|
@ -16,37 +16,145 @@
|
|||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import (absolute_import, division, print_function,
|
||||||
unicode_literals)
|
unicode_literals)
|
||||||
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import regex as re
|
import json
|
||||||
|
import six
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
try:
|
|
||||||
from functools import lru_cache
|
|
||||||
except ImportError:
|
|
||||||
# Just a dummy decorator to get the checks to run on python2
|
|
||||||
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
|
|
||||||
def lru_cache():
|
|
||||||
return lambda func: func
|
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
|
||||||
|
ADDED_TOKENS_FILE = 'added_tokens.json'
|
||||||
|
|
||||||
class PreTrainedTokenizer(object):
|
class PreTrainedTokenizer(object):
|
||||||
""" An abstract class to handle dowloading and loading pretrained tokenizers.
|
""" An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary.
|
||||||
|
|
||||||
|
Derived class can set up a few special tokens to be used in common scripts and internals:
|
||||||
|
bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token
|
||||||
|
additional_special_tokens = []
|
||||||
|
|
||||||
|
We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the
|
||||||
|
specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...).
|
||||||
"""
|
"""
|
||||||
vocab_files_names = {}
|
vocab_files_names = {}
|
||||||
pretrained_vocab_files_map = {}
|
pretrained_vocab_files_map = {}
|
||||||
max_model_input_sizes = {}
|
max_model_input_sizes = {}
|
||||||
|
|
||||||
|
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
|
||||||
|
"pad_token", "cls_token", "mask_token",
|
||||||
|
"additional_special_tokens"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_token(self):
|
||||||
|
if self._bos_token is None:
|
||||||
|
logger.error("Using bos_token, but it is not set yet.")
|
||||||
|
return self._bos_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token(self):
|
||||||
|
if self._eos_token is None:
|
||||||
|
logger.error("Using eos_token, but it is not set yet.")
|
||||||
|
return self._eos_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unk_token(self):
|
||||||
|
if self._unk_token is None:
|
||||||
|
logger.error("Using unk_token, but it is not set yet.")
|
||||||
|
return self._unk_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sep_token(self):
|
||||||
|
if self._sep_token is None:
|
||||||
|
logger.error("Using sep_token, but it is not set yet.")
|
||||||
|
return self._sep_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_token(self):
|
||||||
|
if self._pad_token is None:
|
||||||
|
logger.error("Using pad_token, but it is not set yet.")
|
||||||
|
return self._pad_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cls_token(self):
|
||||||
|
if self._cls_token is None:
|
||||||
|
logger.error("Using cls_token, but it is not set yet.")
|
||||||
|
return self._cls_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask_token(self):
|
||||||
|
if self._mask_token is None:
|
||||||
|
logger.error("Using mask_token, but it is not set yet.")
|
||||||
|
return self._mask_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def additional_special_tokens(self):
|
||||||
|
if self._additional_special_tokens is None:
|
||||||
|
logger.error("Using additional_special_tokens, but it is not set yet.")
|
||||||
|
return self._additional_special_tokens
|
||||||
|
|
||||||
|
@bos_token.setter
|
||||||
|
def bos_token(self, value):
|
||||||
|
self._bos_token = value
|
||||||
|
|
||||||
|
@eos_token.setter
|
||||||
|
def eos_token(self, value):
|
||||||
|
self._eos_token = value
|
||||||
|
|
||||||
|
@unk_token.setter
|
||||||
|
def unk_token(self, value):
|
||||||
|
self._unk_token = value
|
||||||
|
|
||||||
|
@sep_token.setter
|
||||||
|
def sep_token(self, value):
|
||||||
|
self._sep_token = value
|
||||||
|
|
||||||
|
@pad_token.setter
|
||||||
|
def pad_token(self, value):
|
||||||
|
self._pad_token = value
|
||||||
|
|
||||||
|
@cls_token.setter
|
||||||
|
def cls_token(self, value):
|
||||||
|
self._cls_token = value
|
||||||
|
|
||||||
|
@mask_token.setter
|
||||||
|
def mask_token(self, value):
|
||||||
|
self._mask_token = value
|
||||||
|
|
||||||
|
@additional_special_tokens.setter
|
||||||
|
def additional_special_tokens(self, value):
|
||||||
|
self._additional_special_tokens = value
|
||||||
|
|
||||||
|
def __init__(self, max_len=None, **kwargs):
|
||||||
|
self._bos_token = None
|
||||||
|
self._eos_token = None
|
||||||
|
self._unk_token = None
|
||||||
|
self._sep_token = None
|
||||||
|
self._pad_token = None
|
||||||
|
self._cls_token = None
|
||||||
|
self._mask_token = None
|
||||||
|
self._additional_special_tokens = []
|
||||||
|
|
||||||
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
self.added_tokens_encoder = {}
|
||||||
|
self.added_tokens_decoder = {}
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key not in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||||
|
raise ValueError(
|
||||||
|
"PreTrainedTokenizer.__init__() argument {} should be in {}".format(
|
||||||
|
key, ', '.join(self.SPECIAL_TOKENS_ATTRIBUTES)))
|
||||||
|
else:
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, *inputs, **kwargs):
|
def from_pretrained(cls, *inputs, **kwargs):
|
||||||
return cls._from_pretrained(*inputs, **kwargs)
|
return cls._from_pretrained(*inputs, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -59,16 +167,20 @@ class PreTrainedTokenizer(object):
|
|||||||
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
||||||
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
for file_id, file_name in cls.vocab_files_names.items():
|
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
|
||||||
|
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
|
||||||
|
all_vocab_files_names.update(cls.vocab_files_names)
|
||||||
|
for file_id, file_name in all_vocab_files_names.items():
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||||
else:
|
else:
|
||||||
full_file_name = pretrained_model_name_or_path
|
full_file_name = pretrained_model_name_or_path
|
||||||
if not os.path.exists(full_file_name):
|
if not os.path.exists(full_file_name):
|
||||||
logger.info("Didn't find file {}. We don't load it.".format(full_file_name))
|
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||||
full_file_name = None
|
full_file_name = None
|
||||||
vocab_files[file_id] = full_file_name
|
vocab_files[file_id] = full_file_name
|
||||||
# redirect to the cache, if necessary
|
|
||||||
|
# Get files from url, cache, or disk depending on the case
|
||||||
try:
|
try:
|
||||||
resolved_vocab_files = {}
|
resolved_vocab_files = {}
|
||||||
for file_id, file_path in vocab_files.items():
|
for file_id, file_path in vocab_files.items():
|
||||||
@ -95,6 +207,7 @@ class PreTrainedTokenizer(object):
|
|||||||
logger.info("loading file {} from cache at {}".format(
|
logger.info("loading file {} from cache at {}".format(
|
||||||
file_path, resolved_vocab_files[file_id]))
|
file_path, resolved_vocab_files[file_id]))
|
||||||
|
|
||||||
|
# Set max length if needed
|
||||||
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
||||||
# if we're using a pretrained model, ensure the tokenizer
|
# if we're using a pretrained model, ensure the tokenizer
|
||||||
# wont index sequences longer than the number of positional embeddings
|
# wont index sequences longer than the number of positional embeddings
|
||||||
@ -102,31 +215,255 @@ class PreTrainedTokenizer(object):
|
|||||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||||
|
|
||||||
# Merge resolved_vocab_files arguments in kwargs.
|
# Merge resolved_vocab_files arguments in kwargs.
|
||||||
|
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
||||||
|
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
|
||||||
for args_name, file_path in resolved_vocab_files.items():
|
for args_name, file_path in resolved_vocab_files.items():
|
||||||
kwargs[args_name] = file_path
|
if args_name not in kwargs:
|
||||||
|
kwargs[args_name] = file_path
|
||||||
|
if special_tokens_map_file is not None:
|
||||||
|
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
||||||
|
for key, value in special_tokens_map.items():
|
||||||
|
if key not in kwargs:
|
||||||
|
kwargs[key] = value
|
||||||
|
|
||||||
# Instantiate tokenizer.
|
# Instantiate tokenizer.
|
||||||
tokenizer = cls(*inputs, **kwargs)
|
tokenizer = cls(*inputs, **kwargs)
|
||||||
|
|
||||||
|
# Add supplementary tokens.
|
||||||
|
if added_tokens_file is not None:
|
||||||
|
added_tokens = json.load(open(added_tokens_file, encoding="utf-8"))
|
||||||
|
added_tok_encoder = dict((tok, len(tokenizer) + i) for i, tok in enumerate(added_tokens))
|
||||||
|
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||||
|
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
||||||
|
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def tokenize(self, text):
|
|
||||||
|
def save_pretrained(self, save_directory):
|
||||||
|
""" Save the tokenizer vocabulary files (with added tokens) and the
|
||||||
|
special-tokens-to-class-attributes-mapping to a directory, so that it
|
||||||
|
can be re-loaded using the `from_pretrained(save_directory)` class method.
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
logger.error("Saving directory ({}) should be a directory".format(save_directory))
|
||||||
|
return
|
||||||
|
|
||||||
|
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
|
||||||
|
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
|
||||||
|
|
||||||
|
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
||||||
|
|
||||||
|
with open(added_tokens_file, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(self.added_tokens_decoder, ensure_ascii=False))
|
||||||
|
|
||||||
|
vocab_files = self.save_vocabulary(save_directory)
|
||||||
|
|
||||||
|
return vocab_files + (special_tokens_map_file, added_tokens_file)
|
||||||
|
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory):
|
||||||
|
""" Save the tokenizer vocabulary to a directory. This method doesn't save added tokens
|
||||||
|
and special token mappings.
|
||||||
|
|
||||||
|
Please use `save_pretrained()` to save the full Tokenizer state so that it can be
|
||||||
|
reloaded using the `from_pretrained(save_directory)` class method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def vocab_size(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.vocab_size + len(self.added_tokens_encoder)
|
||||||
|
|
||||||
|
|
||||||
|
def add_tokens(self, new_tokens):
|
||||||
|
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||||
|
vocabulary, they are added to the added_tokens_encoder with indices starting from
|
||||||
|
the last index of the current vocabulary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens added to the vocabulary which can be used to correspondingly
|
||||||
|
increase the size of the associated model embedding matrices.
|
||||||
|
"""
|
||||||
|
if not new_tokens:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
to_add_tokens = []
|
||||||
|
for token in new_tokens:
|
||||||
|
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
|
||||||
|
to_add_tokens.append(token)
|
||||||
|
logger.info("Adding %s to the vocabulary", token)
|
||||||
|
|
||||||
|
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
|
||||||
|
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||||
|
self.added_tokens_encoder.update(added_tok_encoder)
|
||||||
|
self.added_tokens_decoder.update(added_tok_decoder)
|
||||||
|
|
||||||
|
return len(to_add_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def add_special_tokens(self, special_tokens_dict):
|
||||||
|
""" Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them
|
||||||
|
to class attributes. If the special tokens are not in the vocabulary, they are added
|
||||||
|
to it and indexed starting from the last index of the current vocabulary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens added to the vocabulary which can be used to correspondingly
|
||||||
|
increase the size of the associated model embedding matrices.
|
||||||
|
"""
|
||||||
|
if not special_tokens_dict:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
added_special_tokens = self.add_tokens(special_tokens_dict.values())
|
||||||
|
for key, value in special_tokens_dict.items():
|
||||||
|
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
return added_special_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(self, text, **kwargs):
|
||||||
|
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||||
|
Split in words for word-based vocabulary or sub-words for sub-word-based
|
||||||
|
vocabularies (BPE/SentencePieces/WordPieces).
|
||||||
|
|
||||||
|
Take care of added tokens.
|
||||||
|
"""
|
||||||
|
def split_on_tokens(tok_list, text):
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
if not tok_list:
|
||||||
|
return self._tokenize(text, **kwargs)
|
||||||
|
tok = tok_list[0]
|
||||||
|
split_text = text.split(tok)
|
||||||
|
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
|
||||||
|
for sub_text in split_text), [])[:-1]
|
||||||
|
|
||||||
|
added_tokens = list(self.added_tokens_encoder.keys())
|
||||||
|
tokenized_text = split_on_tokens(added_tokens, text)
|
||||||
|
return tokenized_text
|
||||||
|
|
||||||
|
def _tokenize(self, text, **kwargs):
|
||||||
|
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||||
|
Split in words for word-based vocabulary or sub-words for sub-word-based
|
||||||
|
vocabularies (BPE/SentencePieces/WordPieces).
|
||||||
|
|
||||||
|
Don't take care of added tokens.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
|
""" Converts a single token or a sequence of tokens (str/unicode) in a integer id
|
||||||
|
(resp.) a sequence of ids, using the vocabulary.
|
||||||
|
"""
|
||||||
|
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
||||||
|
return self.convert_token_to_id_with_added_voc(tokens)
|
||||||
|
|
||||||
|
ids = []
|
||||||
|
for token in tokens:
|
||||||
|
ids.append(self.convert_token_to_id_with_added_voc(token))
|
||||||
|
if len(ids) > self.max_len:
|
||||||
|
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||||
|
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||||
|
"indexing errors".format(len(ids), self.max_len))
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def convert_token_to_id_with_added_voc(self, token):
|
||||||
|
if token in self.added_tokens_encoder:
|
||||||
|
return self.added_tokens_encoder[token]
|
||||||
|
return self._convert_token_to_id(token)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids):
|
|
||||||
|
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||||
|
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||||
|
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||||
|
"""
|
||||||
|
if isinstance(ids, int):
|
||||||
|
return self.convert_id_to_token(ids)
|
||||||
|
tokens = []
|
||||||
|
for index in ids:
|
||||||
|
if index in self.all_special_ids and skip_special_tokens:
|
||||||
|
continue
|
||||||
|
if index in self.added_tokens_decoder:
|
||||||
|
tokens.append(self.added_tokens_decoder[index])
|
||||||
|
else:
|
||||||
|
tokens.append(self._convert_id_to_token(index))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
raise NotImplementedError
|
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||||
|
same as self.convert_tokens_to_ids(self.tokenize(text)).
|
||||||
|
"""
|
||||||
|
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||||
|
|
||||||
def decode(self, token_ids, *input, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||||
raise NotImplementedError
|
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||||
|
with options to remove special tokens and clean up tokenization spaces.
|
||||||
|
"""
|
||||||
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
text = self._convert_ids_to_string(filtered_tokens)
|
||||||
|
if clean_up_tokenization_spaces:
|
||||||
|
text = clean_up_tokenization(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
|
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
|
||||||
|
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
|
||||||
|
"""
|
||||||
|
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def special_tokens_map(self):
|
||||||
|
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
|
||||||
|
values ('<unk>', '<cls>'...)
|
||||||
|
"""
|
||||||
|
set_attr = {}
|
||||||
|
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||||
|
attr_value = getattr(self, "_" + attr)
|
||||||
|
if attr_value:
|
||||||
|
set_attr[attr] = attr_value
|
||||||
|
return set_attr
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_special_tokens(self):
|
||||||
|
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
|
||||||
|
(cls_token, unk_token...).
|
||||||
|
"""
|
||||||
|
all_toks = []
|
||||||
|
set_attr = self.special_tokens_map
|
||||||
|
for attr_value in set_attr.values():
|
||||||
|
all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value])
|
||||||
|
all_toks = list(set(all_toks))
|
||||||
|
return all_toks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_special_ids(self):
|
||||||
|
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
|
||||||
|
class attributes (cls_token, unk_token...).
|
||||||
|
"""
|
||||||
|
all_toks = self.all_special_tokens
|
||||||
|
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
|
||||||
|
return all_ids
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def clean_up_tokenization(out_string):
|
def clean_up_tokenization(out_string):
|
||||||
|
@ -34,7 +34,6 @@ logger = logging.getLogger(__name__)
|
|||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
'vocab_file': 'vocab.json',
|
'vocab_file': 'vocab.json',
|
||||||
'merges_file': 'merges.txt',
|
'merges_file': 'merges.txt',
|
||||||
'special_tokens_file': 'special_tokens.txt'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
@ -46,24 +45,12 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
{
|
{
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
||||||
},
|
},
|
||||||
'special_tokens_file':
|
|
||||||
{
|
|
||||||
'xlm-mlm-en-2048': None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
'xlm-mlm-en-2048': 512,
|
'xlm-mlm-en-2048': 512,
|
||||||
}
|
}
|
||||||
|
|
||||||
INDEX = {
|
|
||||||
"bos_index": 0,
|
|
||||||
"eos_index": 1,
|
|
||||||
"pad_index": 2,
|
|
||||||
"unk_index": 3,
|
|
||||||
"mask_index": 5
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_pairs(word):
|
def get_pairs(word):
|
||||||
"""
|
"""
|
||||||
Return set of symbol pairs in a word.
|
Return set of symbol pairs in a word.
|
||||||
@ -103,7 +90,16 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None):
|
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
|
||||||
|
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
|
||||||
|
mask_token="<special1>", additional_special_tokens=["<special0>",
|
||||||
|
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
|
||||||
|
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs):
|
||||||
|
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
|
||||||
|
sep_token=sep_token, pad_token=pad_token,
|
||||||
|
cls_token=cls_token, mask_token=mask_token,
|
||||||
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
**kwargs)
|
||||||
try:
|
try:
|
||||||
import ftfy
|
import ftfy
|
||||||
import spacy
|
import spacy
|
||||||
@ -111,11 +107,9 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
self.fix_text = ftfy.fix_text
|
self.fix_text = ftfy.fix_text
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||||
self.nlp = BasicTokenizer(do_lower_case=True,
|
self.nlp = BasicTokenizer(do_lower_case=True)
|
||||||
never_split=special_tokens if special_tokens is not None else [])
|
|
||||||
self.fix_text = None
|
self.fix_text = None
|
||||||
|
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
|
||||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1]
|
merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1]
|
||||||
@ -123,35 +117,9 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
all_special_tokens = []
|
@property
|
||||||
if special_tokens_file is not None:
|
def vocab_size(self):
|
||||||
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
return len(self.encoder)
|
||||||
all_special_tokens.extend(special_tokens_to_add)
|
|
||||||
if special_tokens is not None and special_tokens:
|
|
||||||
all_special_tokens.extend(special_tokens)
|
|
||||||
|
|
||||||
self.special_tokens = {}
|
|
||||||
self.special_tokens_decoder = {}
|
|
||||||
self.set_special_tokens(all_special_tokens)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.encoder) + len(self.special_tokens)
|
|
||||||
|
|
||||||
def set_special_tokens(self, special_tokens):
|
|
||||||
""" Add a list of additional tokens to the encoder.
|
|
||||||
The additional tokens are indexed starting from the last index of the
|
|
||||||
current vocabulary in the order of the `special_tokens` list.
|
|
||||||
"""
|
|
||||||
if not special_tokens:
|
|
||||||
self.special_tokens = {}
|
|
||||||
self.special_tokens_decoder = {}
|
|
||||||
return
|
|
||||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
|
||||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
|
||||||
if self.fix_text is None:
|
|
||||||
# Using BERT's BasicTokenizer: we can update the tokenizer
|
|
||||||
self.nlp.never_split = special_tokens
|
|
||||||
logger.info("Special tokens {}".format(self.special_tokens))
|
|
||||||
|
|
||||||
def bpe(self, token):
|
def bpe(self, token):
|
||||||
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
||||||
@ -196,7 +164,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
def tokenize(self, text):
|
def _tokenize(self, text):
|
||||||
""" Tokenize a string. """
|
""" Tokenize a string. """
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.fix_text is None:
|
if self.fix_text is None:
|
||||||
@ -211,58 +179,26 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a sequence of tokens into ids using the vocab. """
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
ids = []
|
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
|
||||||
if tokens in self.special_tokens:
|
|
||||||
return self.special_tokens[tokens]
|
|
||||||
else:
|
|
||||||
return self.encoder.get(tokens, 0)
|
|
||||||
for token in tokens:
|
|
||||||
if token in self.special_tokens:
|
|
||||||
ids.append(self.special_tokens[token])
|
|
||||||
else:
|
|
||||||
ids.append(self.encoder.get(token, 0))
|
|
||||||
if len(ids) > self.max_len:
|
|
||||||
logger.warning(
|
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
|
||||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
|
||||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
|
||||||
)
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
tokens = []
|
return self.decoder.get(index, self.unk_token)
|
||||||
for i in ids:
|
|
||||||
if i in self.special_tokens_decoder:
|
|
||||||
if not skip_special_tokens:
|
|
||||||
tokens.append(self.special_tokens_decoder[i])
|
|
||||||
else:
|
|
||||||
tokens.append(self.decoder[i])
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def encode(self, text):
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
|
||||||
|
|
||||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of ids in a string."""
|
||||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
||||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
|
||||||
if clean_up_tokenization_spaces:
|
|
||||||
out_string = out_string.replace('<unk>', '')
|
|
||||||
out_string = clean_up_tokenization(out_string)
|
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, save_directory):
|
||||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||||
if not os.path.isdir(vocab_path):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
return
|
return
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
|
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
||||||
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
|
|
||||||
|
|
||||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||||
@ -277,14 +213,4 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
index = len(self.encoder)
|
return vocab_file, merge_file
|
||||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
|
||||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
|
||||||
if index != token_index:
|
|
||||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
|
||||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
|
||||||
index = token_index
|
|
||||||
writer.write(token + u'\n')
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
return vocab_file, merge_file, special_tokens_file
|
|
||||||
|
@ -16,17 +16,13 @@
|
|||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import (absolute_import, division, print_function,
|
||||||
unicode_literals)
|
unicode_literals)
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from io import open
|
|
||||||
|
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from .file_utils import cached_path
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -44,8 +40,6 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||||||
'xlnet-large-cased': 512,
|
'xlnet-large-cased': 512,
|
||||||
}
|
}
|
||||||
|
|
||||||
VOCAB_NAME = 'spiece.model'
|
|
||||||
|
|
||||||
SPIECE_UNDERLINE = u'▁'
|
SPIECE_UNDERLINE = u'▁'
|
||||||
|
|
||||||
# Segments (not really needed)
|
# Segments (not really needed)
|
||||||
@ -60,31 +54,26 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
SentencePiece based tokenizer. Peculiarities:
|
SentencePiece based tokenizer. Peculiarities:
|
||||||
- requires SentencePiece: https://github.com/google/sentencepiece
|
- requires SentencePiece: https://github.com/google/sentencepiece
|
||||||
"""
|
"""
|
||||||
# Tokens
|
|
||||||
special_symbols = {
|
|
||||||
"<unk>" : 0,
|
|
||||||
"<s>" : 1,
|
|
||||||
"</s>" : 2,
|
|
||||||
"<cls>" : 3,
|
|
||||||
"<sep>" : 4,
|
|
||||||
"<pad>" : 5,
|
|
||||||
"<mask>" : 6,
|
|
||||||
"<eod>" : 7,
|
|
||||||
"<eop>" : 8,
|
|
||||||
}
|
|
||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, max_len=None,
|
def __init__(self, vocab_file, max_len=None,
|
||||||
do_lower_case=False, remove_space=True, keep_accents=False):
|
do_lower_case=False, remove_space=True, keep_accents=False,
|
||||||
|
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
|
||||||
|
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
|
||||||
|
additional_special_tokens=["<eop>", "<eod>"], **kwargs):
|
||||||
|
super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token,
|
||||||
|
unk_token=unk_token, sep_token=sep_token,
|
||||||
|
pad_token=pad_token, cls_token=cls_token,
|
||||||
|
mask_token=mask_token, additional_special_tokens=
|
||||||
|
additional_special_tokens, **kwargs)
|
||||||
try:
|
try:
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
||||||
"pip install sentencepiece")
|
"pip install sentencepiece")
|
||||||
|
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
self.remove_space = remove_space
|
self.remove_space = remove_space
|
||||||
self.keep_accents = keep_accents
|
self.keep_accents = keep_accents
|
||||||
@ -94,46 +83,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
self.sp_model.Load(vocab_file)
|
self.sp_model.Load(vocab_file)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def UNK_TOKEN(self):
|
def vocab_size(self):
|
||||||
return "<unk>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def SEP_TOKEN(self):
|
|
||||||
return "<sep>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def PAD_TOKEN(self):
|
|
||||||
return "<pad>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def CLS_TOKEN(self):
|
|
||||||
return "<cls>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def MASK_TOKEN(self):
|
|
||||||
return "<mask>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def UNK_ID(self):
|
|
||||||
return self.special_symbols["<unk>"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def SEP_ID(self):
|
|
||||||
return self.special_symbols["<sep>"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def PAD_ID(self):
|
|
||||||
return self.special_symbols["<pad>"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def CLS_ID(self):
|
|
||||||
return self.special_symbols["<cls>"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def MASK_ID(self):
|
|
||||||
return self.special_symbols["<mask>"]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.sp_model)
|
return len(self.sp_model)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
@ -169,7 +119,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def tokenize(self, text, return_unicode=True, sample=False):
|
def _tokenize(self, text, return_unicode=True, sample=False):
|
||||||
""" Tokenize a string.
|
""" Tokenize a string.
|
||||||
return_unicode is used only for py2
|
return_unicode is used only for py2
|
||||||
"""
|
"""
|
||||||
@ -208,56 +158,30 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
return new_pieces
|
return new_pieces
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens, sample=False):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a sequence of tokens into ids using the vocab. """
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
ids = []
|
return self.sp_model.PieceToId(token)
|
||||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
|
||||||
return self.sp_model.PieceToId(tokens)
|
|
||||||
for token in tokens:
|
|
||||||
ids.append(self.sp_model.PieceToId(token))
|
|
||||||
if len(ids) > self.max_len:
|
|
||||||
logger.warning(
|
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
|
||||||
" sequence length for this XLNet model ({} > {}). Running this"
|
|
||||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
|
||||||
)
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids, return_unicode=True):
|
def _convert_id_to_token(self, index, return_unicode=True):
|
||||||
"""Converts a sequence of ids in tokens."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
tokens = []
|
token = self.sp_model.IdToPiece(index)
|
||||||
for i in ids:
|
if six.PY2 and return_unicode and isinstance(token, str):
|
||||||
tokens.append(self.sp_model.IdToPiece(i))
|
token = token.decode('utf-8')
|
||||||
|
return token
|
||||||
|
|
||||||
if six.PY2 and return_unicode:
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
ret_pieces = []
|
|
||||||
for piece in tokens:
|
|
||||||
if isinstance(piece, str):
|
|
||||||
piece = piece.decode('utf-8')
|
|
||||||
ret_pieces.append(piece)
|
|
||||||
tokens = ret_pieces
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def encode(self, text, sample=False):
|
|
||||||
return self.convert_tokens_to_ids(self.tokenize(text, sample=sample))
|
|
||||||
|
|
||||||
def decode(self, ids, clean_up_tokenization_spaces=True):
|
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of ids in a string."""
|
||||||
tokens = self.convert_ids_to_tokens(ids)
|
out_string = ''.join(tokens_ids)
|
||||||
out_string = ''.join(tokens)
|
|
||||||
if clean_up_tokenization_spaces:
|
|
||||||
out_string = out_string.strip().replace('<unk>', '')
|
|
||||||
out_string = clean_up_tokenization(out_string)
|
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, save_directory):
|
||||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||||
to a directory.
|
to a directory.
|
||||||
"""
|
"""
|
||||||
if not os.path.isdir(vocab_path):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
return
|
return
|
||||||
out_vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
|
||||||
copyfile(self.vocab_file, out_vocab_file)
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user