mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-27 16:22:23 +06:00
695 lines
32 KiB
Python
695 lines
32 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tokenization classes for OpenAI GPT."""
|
|
from __future__ import (absolute_import, division, print_function,
|
|
unicode_literals)
|
|
|
|
import logging
|
|
import os
|
|
import json
|
|
import six
|
|
from io import open
|
|
|
|
from .file_utils import cached_path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
|
|
ADDED_TOKENS_FILE = 'added_tokens.json'
|
|
|
|
class PreTrainedTokenizer(object):
|
|
""" Base class for all tokenizers.
|
|
Handle all the shared methods for tokenization and special tokens as well as methods dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
|
|
|
|
This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
|
|
|
|
Class attributes (overridden by derived classes):
|
|
|
|
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
|
|
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
|
|
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
|
|
|
|
Parameters:
|
|
|
|
- ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token``
|
|
|
|
- ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token``
|
|
|
|
- ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token``
|
|
|
|
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token``
|
|
|
|
- ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token``
|
|
|
|
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token``
|
|
|
|
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token``
|
|
|
|
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens``
|
|
"""
|
|
vocab_files_names = {}
|
|
pretrained_vocab_files_map = {}
|
|
max_model_input_sizes = {}
|
|
|
|
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
|
|
"pad_token", "cls_token", "mask_token",
|
|
"additional_special_tokens"]
|
|
|
|
@property
|
|
def bos_token(self):
|
|
""" Beginning of sentence token (string). Log an error if used while not having been set. """
|
|
if self._bos_token is None:
|
|
logger.error("Using bos_token, but it is not set yet.")
|
|
return self._bos_token
|
|
|
|
@property
|
|
def eos_token(self):
|
|
""" End of sentence token (string). Log an error if used while not having been set. """
|
|
if self._eos_token is None:
|
|
logger.error("Using eos_token, but it is not set yet.")
|
|
return self._eos_token
|
|
|
|
@property
|
|
def unk_token(self):
|
|
""" Unknown token (string). Log an error if used while not having been set. """
|
|
if self._unk_token is None:
|
|
logger.error("Using unk_token, but it is not set yet.")
|
|
return self._unk_token
|
|
|
|
@property
|
|
def sep_token(self):
|
|
""" Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
|
|
if self._sep_token is None:
|
|
logger.error("Using sep_token, but it is not set yet.")
|
|
return self._sep_token
|
|
|
|
@property
|
|
def pad_token(self):
|
|
""" Padding token (string). Log an error if used while not having been set. """
|
|
if self._pad_token is None:
|
|
logger.error("Using pad_token, but it is not set yet.")
|
|
return self._pad_token
|
|
|
|
@property
|
|
def cls_token(self):
|
|
""" Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
|
|
if self._cls_token is None:
|
|
logger.error("Using cls_token, but it is not set yet.")
|
|
return self._cls_token
|
|
|
|
@property
|
|
def mask_token(self):
|
|
""" Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
|
|
if self._mask_token is None:
|
|
logger.error("Using mask_token, but it is not set yet.")
|
|
return self._mask_token
|
|
|
|
@property
|
|
def additional_special_tokens(self):
|
|
""" All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """
|
|
if self._additional_special_tokens is None:
|
|
logger.error("Using additional_special_tokens, but it is not set yet.")
|
|
return self._additional_special_tokens
|
|
|
|
@bos_token.setter
|
|
def bos_token(self, value):
|
|
self._bos_token = value
|
|
|
|
@eos_token.setter
|
|
def eos_token(self, value):
|
|
self._eos_token = value
|
|
|
|
@unk_token.setter
|
|
def unk_token(self, value):
|
|
self._unk_token = value
|
|
|
|
@sep_token.setter
|
|
def sep_token(self, value):
|
|
self._sep_token = value
|
|
|
|
@pad_token.setter
|
|
def pad_token(self, value):
|
|
self._pad_token = value
|
|
|
|
@cls_token.setter
|
|
def cls_token(self, value):
|
|
self._cls_token = value
|
|
|
|
@mask_token.setter
|
|
def mask_token(self, value):
|
|
self._mask_token = value
|
|
|
|
@additional_special_tokens.setter
|
|
def additional_special_tokens(self, value):
|
|
self._additional_special_tokens = value
|
|
|
|
def __init__(self, max_len=None, **kwargs):
|
|
self._bos_token = None
|
|
self._eos_token = None
|
|
self._unk_token = None
|
|
self._sep_token = None
|
|
self._pad_token = None
|
|
self._cls_token = None
|
|
self._mask_token = None
|
|
self._additional_special_tokens = []
|
|
|
|
self.max_len = max_len if max_len is not None else int(1e12)
|
|
self.added_tokens_encoder = {}
|
|
self.added_tokens_decoder = {}
|
|
|
|
for key, value in kwargs.items():
|
|
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
|
if key == 'additional_special_tokens':
|
|
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
|
|
else:
|
|
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
|
|
setattr(self, key, value)
|
|
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, *inputs, **kwargs):
|
|
r"""
|
|
Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
|
|
|
|
Args:
|
|
pretrained_model_name_or_path: either:
|
|
|
|
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
|
|
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
|
|
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
|
|
|
|
cache_dir: (`optional`) string:
|
|
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
|
|
|
force_download: (`optional`) boolean, default False:
|
|
Force to (re-)download the vocabulary files and override the cached versions if they exists.
|
|
|
|
proxies: (`optional`) dict, default None:
|
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
|
The proxies are used on each request.
|
|
|
|
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
|
|
|
|
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
|
|
|
|
Examples::
|
|
|
|
# We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer
|
|
|
|
# Download vocabulary from S3 and cache.
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
|
# If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
|
|
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')
|
|
|
|
# If the tokenizer uses a single vocabulary file, you can point directly to this file
|
|
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')
|
|
|
|
# You can link tokens to special vocabulary when instantiating
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>')
|
|
# You should be sure '<unk>' is in the vocabulary when doing that.
|
|
# Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
|
|
assert tokenizer.unk_token == '<unk>'
|
|
|
|
"""
|
|
return cls._from_pretrained(*inputs, **kwargs)
|
|
|
|
|
|
@classmethod
|
|
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
|
cache_dir = kwargs.pop('cache_dir', None)
|
|
force_download = kwargs.pop('force_download', False)
|
|
proxies = kwargs.pop('proxies', None)
|
|
|
|
s3_models = list(cls.max_model_input_sizes.keys())
|
|
vocab_files = {}
|
|
if pretrained_model_name_or_path in s3_models:
|
|
# Get the vocabulary from AWS S3 bucket
|
|
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
|
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
|
else:
|
|
# Get the vocabulary from local files
|
|
logger.info(
|
|
"Model name '{}' not found in model shortcut name list ({}). "
|
|
"Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
|
|
pretrained_model_name_or_path, ', '.join(s3_models),
|
|
pretrained_model_name_or_path))
|
|
|
|
# Look for the tokenizer main vocabulary files
|
|
for file_id, file_name in cls.vocab_files_names.items():
|
|
if os.path.isdir(pretrained_model_name_or_path):
|
|
# If a directory is provided we look for the standard filenames
|
|
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
|
else:
|
|
# If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
|
|
full_file_name = pretrained_model_name_or_path
|
|
if not os.path.exists(full_file_name):
|
|
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
|
full_file_name = None
|
|
vocab_files[file_id] = full_file_name
|
|
|
|
# Look for the additional tokens files
|
|
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
|
|
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
|
|
|
|
# If a path to a file was provided, get the parent directory
|
|
saved_directory = pretrained_model_name_or_path
|
|
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
|
|
saved_directory = os.path.dirname(saved_directory)
|
|
|
|
for file_id, file_name in all_vocab_files_names.items():
|
|
full_file_name = os.path.join(saved_directory, file_name)
|
|
if not os.path.exists(full_file_name):
|
|
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
|
full_file_name = None
|
|
vocab_files[file_id] = full_file_name
|
|
|
|
if all(full_file_name is None for full_file_name in vocab_files.values()):
|
|
logger.error(
|
|
"Model name '{}' was not found in model name list ({}). "
|
|
"We assumed '{}' was a path or url but couldn't find tokenizer files"
|
|
"at this path or url.".format(
|
|
pretrained_model_name_or_path, ', '.join(s3_models),
|
|
pretrained_model_name_or_path, ))
|
|
return None
|
|
|
|
# Get files from url, cache, or disk depending on the case
|
|
try:
|
|
resolved_vocab_files = {}
|
|
for file_id, file_path in vocab_files.items():
|
|
if file_path is None:
|
|
resolved_vocab_files[file_id] = None
|
|
else:
|
|
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
|
except EnvironmentError as e:
|
|
if pretrained_model_name_or_path in s3_models:
|
|
logger.error("Couldn't reach server to download vocabulary.")
|
|
else:
|
|
logger.error(
|
|
"Model name '{}' was not found in model name list ({}). "
|
|
"We assumed '{}' was a path or url but couldn't find files {} "
|
|
"at this path or url.".format(
|
|
pretrained_model_name_or_path, ', '.join(s3_models),
|
|
pretrained_model_name_or_path, str(vocab_files.keys())))
|
|
raise e
|
|
|
|
for file_id, file_path in vocab_files.items():
|
|
if file_path == resolved_vocab_files[file_id]:
|
|
logger.info("loading file {}".format(file_path))
|
|
else:
|
|
logger.info("loading file {} from cache at {}".format(
|
|
file_path, resolved_vocab_files[file_id]))
|
|
|
|
# Set max length if needed
|
|
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
|
# if we're using a pretrained model, ensure the tokenizer
|
|
# wont index sequences longer than the number of positional embeddings
|
|
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
|
|
if max_len is not None and isinstance(max_len, (int, float)):
|
|
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
|
|
|
# Merge resolved_vocab_files arguments in kwargs.
|
|
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
|
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
|
|
for args_name, file_path in resolved_vocab_files.items():
|
|
if args_name not in kwargs:
|
|
kwargs[args_name] = file_path
|
|
if special_tokens_map_file is not None:
|
|
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
|
for key, value in special_tokens_map.items():
|
|
if key not in kwargs:
|
|
kwargs[key] = value
|
|
|
|
# Instantiate tokenizer.
|
|
tokenizer = cls(*inputs, **kwargs)
|
|
|
|
# Add supplementary tokens.
|
|
if added_tokens_file is not None:
|
|
added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
|
|
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
|
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
|
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
|
|
|
return tokenizer
|
|
|
|
|
|
def save_pretrained(self, save_directory):
|
|
""" Save the tokenizer vocabulary files (with added tokens) and the
|
|
special-tokens-to-class-attributes-mapping to a directory.
|
|
|
|
This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
|
|
"""
|
|
if not os.path.isdir(save_directory):
|
|
logger.error("Saving directory ({}) should be a directory".format(save_directory))
|
|
return
|
|
|
|
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
|
|
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
|
|
|
|
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
|
|
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
|
|
|
with open(added_tokens_file, 'w', encoding='utf-8') as f:
|
|
if self.added_tokens_encoder:
|
|
out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
|
|
else:
|
|
out_str = u"{}"
|
|
f.write(out_str)
|
|
|
|
vocab_files = self.save_vocabulary(save_directory)
|
|
|
|
return vocab_files + (special_tokens_map_file, added_tokens_file)
|
|
|
|
|
|
def save_vocabulary(self, save_directory):
|
|
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
|
|
and special token mappings.
|
|
|
|
Please use :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
def vocab_size(self):
|
|
""" Size of the base vocabulary (without the added tokens) """
|
|
raise NotImplementedError
|
|
|
|
|
|
def __len__(self):
|
|
""" Size of the full vocabulary with the added tokens """
|
|
return self.vocab_size + len(self.added_tokens_encoder)
|
|
|
|
|
|
def add_tokens(self, new_tokens):
|
|
"""
|
|
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
|
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
|
|
|
Args:
|
|
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
|
|
|
Returns:
|
|
Number of tokens added to the vocabulary.
|
|
|
|
Examples::
|
|
|
|
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
model = BertModel.from_pretrained('bert-base-uncased')
|
|
|
|
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
|
print('We have added', num_added_toks, 'tokens')
|
|
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
|
"""
|
|
if not new_tokens:
|
|
return 0
|
|
|
|
to_add_tokens = []
|
|
for token in new_tokens:
|
|
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
|
if token != self.unk_token and \
|
|
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
|
|
to_add_tokens.append(token)
|
|
logger.info("Adding %s to the vocabulary", token)
|
|
|
|
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
|
|
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
|
self.added_tokens_encoder.update(added_tok_encoder)
|
|
self.added_tokens_decoder.update(added_tok_decoder)
|
|
|
|
return len(to_add_tokens)
|
|
|
|
|
|
def add_special_tokens(self, special_tokens_dict):
|
|
"""
|
|
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
|
|
to class attributes. If special tokens are NOT in the vocabulary, they are added
|
|
to it (indexed starting from the last index of the current vocabulary).
|
|
|
|
Args:
|
|
special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
|
|
[``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
|
|
``additional_special_tokens``].
|
|
|
|
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
|
|
|
Returns:
|
|
Number of tokens added to the vocabulary.
|
|
|
|
Examples::
|
|
|
|
# Let's see how to add a new classification token to GPT-2
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model = GPT2Model.from_pretrained('gpt2')
|
|
|
|
special_tokens_dict = {'cls_token': '<CLS>'}
|
|
|
|
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
|
print('We have added', num_added_toks, 'tokens')
|
|
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
|
|
|
assert tokenizer.cls_token == '<CLS>'
|
|
"""
|
|
if not special_tokens_dict:
|
|
return 0
|
|
|
|
added_tokens = 0
|
|
for key, value in special_tokens_dict.items():
|
|
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
|
|
if key == 'additional_special_tokens':
|
|
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
|
|
added_tokens += self.add_tokens(value)
|
|
else:
|
|
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
|
|
added_tokens += self.add_tokens([value])
|
|
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
|
setattr(self, key, value)
|
|
|
|
return added_tokens
|
|
|
|
def tokenize(self, text, **kwargs):
|
|
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
|
Split in words for word-based vocabulary or sub-words for sub-word-based
|
|
vocabularies (BPE/SentencePieces/WordPieces).
|
|
|
|
Take care of added tokens.
|
|
"""
|
|
def split_on_token(tok, text):
|
|
result = []
|
|
split_text = text.split(tok)
|
|
for i, sub_text in enumerate(split_text):
|
|
sub_text = sub_text.strip()
|
|
if i == 0 and not sub_text:
|
|
result += [tok]
|
|
elif i == len(split_text) - 1:
|
|
if sub_text:
|
|
result += [sub_text]
|
|
else:
|
|
pass
|
|
else:
|
|
if sub_text:
|
|
result += [sub_text]
|
|
result += [tok]
|
|
return result
|
|
|
|
def split_on_tokens(tok_list, text):
|
|
if not text:
|
|
return []
|
|
if not tok_list:
|
|
return self._tokenize(text, **kwargs)
|
|
|
|
tokenized_text = []
|
|
text_list = [text]
|
|
for tok in tok_list:
|
|
tokenized_text = []
|
|
for sub_text in text_list:
|
|
if sub_text not in self.added_tokens_encoder \
|
|
and sub_text not in self.all_special_tokens:
|
|
tokenized_text += split_on_token(tok, sub_text)
|
|
else:
|
|
tokenized_text += [sub_text]
|
|
text_list = tokenized_text
|
|
|
|
return sum((self._tokenize(token, **kwargs) if token not \
|
|
in self.added_tokens_encoder and token not in self.all_special_tokens \
|
|
else [token] for token in tokenized_text), [])
|
|
|
|
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
|
|
tokenized_text = split_on_tokens(added_tokens, text)
|
|
return tokenized_text
|
|
|
|
def _tokenize(self, text, **kwargs):
|
|
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
|
Split in words for word-based vocabulary or sub-words for sub-word-based
|
|
vocabularies (BPE/SentencePieces/WordPieces).
|
|
|
|
Do NOT take care of added tokens.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def convert_tokens_to_ids(self, tokens):
|
|
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
|
|
(resp. a sequence of ids), using the vocabulary.
|
|
"""
|
|
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
|
return self._convert_token_to_id_with_added_voc(tokens)
|
|
|
|
ids = []
|
|
for token in tokens:
|
|
ids.append(self._convert_token_to_id_with_added_voc(token))
|
|
if len(ids) > self.max_len:
|
|
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
|
"for this model ({} > {}). Running this sequence through the model will result in "
|
|
"indexing errors".format(len(ids), self.max_len))
|
|
return ids
|
|
|
|
def _convert_token_to_id_with_added_voc(self, token):
|
|
if token in self.added_tokens_encoder:
|
|
return self.added_tokens_encoder[token]
|
|
return self._convert_token_to_id(token)
|
|
|
|
def _convert_token_to_id(self, token):
|
|
raise NotImplementedError
|
|
|
|
def encode(self, text, text_pair=None, add_special_tokens=False):
|
|
"""
|
|
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
|
|
|
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
|
|
|
Args:
|
|
text: The first sequence to be encoded.
|
|
text_pair: Optional second sequence to be encoded.
|
|
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
|
|
to their model.
|
|
"""
|
|
if text_pair is None:
|
|
if add_special_tokens:
|
|
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
|
|
else:
|
|
return self.convert_tokens_to_ids(self.tokenize(text))
|
|
|
|
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
|
|
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)]
|
|
|
|
if add_special_tokens:
|
|
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
|
else:
|
|
return first_sentence_tokens, second_sentence_tokens
|
|
|
|
def add_special_tokens_single_sentence(self, token_ids):
|
|
raise NotImplementedError
|
|
|
|
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
|
|
raise NotImplementedError
|
|
|
|
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
|
""" Converts a single index or a sequence of indices (integers) in a token "
|
|
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
|
|
|
Args:
|
|
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
|
"""
|
|
if isinstance(ids, int):
|
|
if ids in self.added_tokens_decoder:
|
|
return self.added_tokens_decoder[ids]
|
|
else:
|
|
return self._convert_id_to_token(ids)
|
|
tokens = []
|
|
for index in ids:
|
|
if index in self.all_special_ids and skip_special_tokens:
|
|
continue
|
|
if index in self.added_tokens_decoder:
|
|
tokens.append(self.added_tokens_decoder[index])
|
|
else:
|
|
tokens.append(self._convert_id_to_token(index))
|
|
return tokens
|
|
|
|
def _convert_id_to_token(self, index):
|
|
raise NotImplementedError
|
|
|
|
def convert_tokens_to_string(self, tokens):
|
|
""" Converts a sequence of tokens (string) in a single string.
|
|
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
|
|
but we often want to remove sub-word tokenization artifacts at the same time.
|
|
"""
|
|
return ' '.join(self.convert_ids_to_tokens(tokens))
|
|
|
|
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
|
"""
|
|
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
|
with options to remove special tokens and clean up tokenization spaces.
|
|
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
|
|
"""
|
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
|
text = self.convert_tokens_to_string(filtered_tokens)
|
|
|
|
if self.sep_token is not None and self.sep_token in text:
|
|
text = text.replace(self.cls_token, self.sep_token)
|
|
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self.sep_token)))
|
|
if clean_up_tokenization_spaces:
|
|
clean_text = [self.clean_up_tokenization(text) for text in split_text]
|
|
return clean_text
|
|
else:
|
|
return split_text
|
|
else:
|
|
if clean_up_tokenization_spaces:
|
|
clean_text = self.clean_up_tokenization(text)
|
|
return clean_text
|
|
else:
|
|
return text
|
|
|
|
@property
|
|
def special_tokens_map(self):
|
|
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
|
|
values ('<unk>', '<cls>'...)
|
|
"""
|
|
set_attr = {}
|
|
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
|
|
attr_value = getattr(self, "_" + attr)
|
|
if attr_value:
|
|
set_attr[attr] = attr_value
|
|
return set_attr
|
|
|
|
@property
|
|
def all_special_tokens(self):
|
|
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
|
|
(cls_token, unk_token...).
|
|
"""
|
|
all_toks = []
|
|
set_attr = self.special_tokens_map
|
|
for attr_value in set_attr.values():
|
|
all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value])
|
|
all_toks = list(set(all_toks))
|
|
return all_toks
|
|
|
|
@property
|
|
def all_special_ids(self):
|
|
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
|
|
class attributes (cls_token, unk_token...).
|
|
"""
|
|
all_toks = self.all_special_tokens
|
|
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
|
|
return all_ids
|
|
|
|
@staticmethod
|
|
def clean_up_tokenization(out_string):
|
|
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
|
|
"""
|
|
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
|
|
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
|
|
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
|
|
return out_string
|