mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00

This is the result of: $ black --line-length 119 examples templates transformers utils hubconf.py setup.py There's a lot of fairly long lines in the project. As a consequence, I'm picking the longest widely accepted line length, 119 characters. This is also Thomas' preference, because it allows for explicit variable names, to make the code easier to understand.
236 lines
9.5 KiB
Python
236 lines
9.5 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 XXX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Tokenization class for model XXX."""
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import collections
|
|
import logging
|
|
import os
|
|
import unicodedata
|
|
from io import open
|
|
|
|
from .tokenization_utils import PreTrainedTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
####################################################
|
|
# In this template, replace all the XXX (various casings) with your model name
|
|
####################################################
|
|
|
|
####################################################
|
|
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
|
# to file names for serializing Tokenizer instances
|
|
####################################################
|
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
|
|
|
####################################################
|
|
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
|
# to pretrained vocabulary URL for all the model shortcut names.
|
|
####################################################
|
|
PRETRAINED_VOCAB_FILES_MAP = {
|
|
"vocab_file": {
|
|
"xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
|
|
"xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
|
|
}
|
|
}
|
|
|
|
####################################################
|
|
# Mapping from model shortcut names to max length of inputs
|
|
####################################################
|
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|
"xxx-base-uncased": 512,
|
|
"xxx-large-uncased": 512,
|
|
}
|
|
|
|
####################################################
|
|
# Mapping from model shortcut names to a dictionary of additional
|
|
# keyword arguments for Tokenizer `__init__`.
|
|
# To be used for checkpoint specific configurations.
|
|
####################################################
|
|
PRETRAINED_INIT_CONFIGURATION = {
|
|
"xxx-base-uncased": {"do_lower_case": True},
|
|
"xxx-large-uncased": {"do_lower_case": True},
|
|
}
|
|
|
|
|
|
def load_vocab(vocab_file):
|
|
"""Loads a vocabulary file into a dictionary."""
|
|
vocab = collections.OrderedDict()
|
|
with open(vocab_file, "r", encoding="utf-8") as reader:
|
|
tokens = reader.readlines()
|
|
for index, token in enumerate(tokens):
|
|
token = token.rstrip("\n")
|
|
vocab[token] = index
|
|
return vocab
|
|
|
|
|
|
class XxxTokenizer(PreTrainedTokenizer):
|
|
r"""
|
|
Constructs a XxxTokenizer.
|
|
:class:`~transformers.XxxTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
|
|
|
Args:
|
|
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
|
do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True
|
|
"""
|
|
|
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_file,
|
|
do_lower_case=True,
|
|
unk_token="[UNK]",
|
|
sep_token="[SEP]",
|
|
pad_token="[PAD]",
|
|
cls_token="[CLS]",
|
|
mask_token="[MASK]",
|
|
**kwargs
|
|
):
|
|
"""Constructs a XxxTokenizer.
|
|
|
|
Args:
|
|
**vocab_file**: Path to a one-wordpiece-per-line vocabulary file
|
|
**do_lower_case**: (`optional`) boolean (default True)
|
|
Whether to lower case the input
|
|
Only has an effect when do_basic_tokenize=True
|
|
"""
|
|
super(XxxTokenizer, self).__init__(
|
|
unk_token=unk_token,
|
|
sep_token=sep_token,
|
|
pad_token=pad_token,
|
|
cls_token=cls_token,
|
|
mask_token=mask_token,
|
|
**kwargs
|
|
)
|
|
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
|
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
|
|
|
if not os.path.isfile(vocab_file):
|
|
raise ValueError(
|
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
|
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
|
|
)
|
|
self.vocab = load_vocab(vocab_file)
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return len(self.vocab)
|
|
|
|
def _tokenize(self, text):
|
|
""" Take as input a string and return a list of strings (tokens) for words/sub-words
|
|
"""
|
|
split_tokens = []
|
|
if self.do_basic_tokenize:
|
|
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
|
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
|
split_tokens.append(sub_token)
|
|
else:
|
|
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
|
return split_tokens
|
|
|
|
def _convert_token_to_id(self, token):
|
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
|
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
|
|
|
def _convert_id_to_token(self, index):
|
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
|
return self.ids_to_tokens.get(index, self.unk_token)
|
|
|
|
def convert_tokens_to_string(self, tokens):
|
|
""" Converts a sequence of tokens (string) in a single string. """
|
|
out_string = " ".join(tokens).replace(" ##", "").strip()
|
|
return out_string
|
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
|
"""
|
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
|
by concatenating and adding special tokens.
|
|
A BERT sequence has the following format:
|
|
single sequence: [CLS] X [SEP]
|
|
pair of sequences: [CLS] A [SEP] B [SEP]
|
|
"""
|
|
if token_ids_1 is None:
|
|
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
|
cls = [self.cls_token_id]
|
|
sep = [self.sep_token_id]
|
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
|
|
|
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
|
"""
|
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
|
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
|
|
|
Args:
|
|
token_ids_0: list of ids (must not contain special tokens)
|
|
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
|
|
for sequence pairs
|
|
already_has_special_tokens: (default False) Set to True if the token list is already formated with
|
|
special tokens for the model
|
|
|
|
Returns:
|
|
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
|
"""
|
|
|
|
if already_has_special_tokens:
|
|
if token_ids_1 is not None:
|
|
raise ValueError(
|
|
"You should not supply a second sequence if the provided sequence of "
|
|
"ids is already formated with special tokens for the model."
|
|
)
|
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
|
|
|
if token_ids_1 is not None:
|
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
|
return [1] + ([0] * len(token_ids_0)) + [1]
|
|
|
|
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
|
|
"""
|
|
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
|
A BERT sequence pair mask has the following format:
|
|
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
|
|
| first sequence | second sequence
|
|
|
|
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
|
"""
|
|
sep = [self.sep_token_id]
|
|
cls = [self.cls_token_id]
|
|
if token_ids_1 is None:
|
|
return len(cls + token_ids_0 + sep) * [0]
|
|
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
|
|
|
def save_vocabulary(self, vocab_path):
|
|
"""Save the tokenizer vocabulary to a directory or file."""
|
|
index = 0
|
|
if os.path.isdir(vocab_path):
|
|
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
|
else:
|
|
vocab_file = vocab_path
|
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
|
if index != token_index:
|
|
logger.warning(
|
|
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
|
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
|
)
|
|
index = token_index
|
|
writer.write(token + "\n")
|
|
index += 1
|
|
return (vocab_file,)
|