# coding=utf-8 # Copyright 2018 XXX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Tokenization class for model XXX.""" from __future__ import absolute_import, division, print_function, unicode_literals import collections import logging import os import unicodedata from io import open from .tokenization_utils import PreTrainedTokenizer logger = logging.getLogger(__name__) #################################################### # In this template, replace all the XXX (various casings) with your model name #################################################### #################################################### # Mapping from the keyword arguments names of Tokenizer `__init__` # to file names for serializing Tokenizer instances #################################################### VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} #################################################### # Mapping from the keyword arguments names of Tokenizer `__init__` # to pretrained vocabulary URL for all the model shortcut names. #################################################### PRETRAINED_VOCAB_FILES_MAP = { 'vocab_file': { 'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt", 'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt", } } #################################################### # Mapping from model shortcut names to max length of inputs #################################################### PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'xxx-base-uncased': 512, 'xxx-large-uncased': 512, } #################################################### # Mapping from model shortcut names to a dictionary of additional # keyword arguments for Tokenizer `__init__`. # To be used for checkpoint specific configurations. #################################################### PRETRAINED_INIT_CONFIGURATION = { 'xxx-base-uncased': {'do_lower_case': True}, 'xxx-large-uncased': {'do_lower_case': True}, } def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() with open(vocab_file, "r", encoding="utf-8") as reader: tokens = reader.readlines() for index, token in enumerate(tokens): token = token.rstrip('\n') vocab[token] = index return vocab class XxxTokenizer(PreTrainedTokenizer): r""" Constructs a XxxTokenizer. :class:`~transformers.XxxTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece Args: vocab_file: Path to a one-wordpiece-per-line vocabulary file do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, do_lower_case=True, unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", **kwargs): """Constructs a XxxTokenizer. Args: **vocab_file**: Path to a one-wordpiece-per-line vocabulary file **do_lower_case**: (`optional`) boolean (default True) Whether to lower case the input Only has an effect when do_basic_tokenize=True """ super(XxxTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, **kwargs) self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) self.vocab = load_vocab(vocab_file) @property def vocab_size(self): return len(self.vocab) def _tokenize(self, text): """ Take as input a string and return a list of strings (tokens) for words/sub-words """ split_tokens = [] if self.do_basic_tokenize: for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for sub_token in self.wordpiece_tokenizer.tokenize(token): split_tokens.append(sub_token) else: split_tokens = self.wordpiece_tokenizer.tokenize(text) return split_tokens def _convert_token_to_id(self, token): """ Converts a token (str/unicode) in an id using the vocab. """ return self.vocab.get(token, self.vocab.get(self.unk_token)) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (string/unicode) using the vocab.""" return self.ids_to_tokens.get(index, self.unk_token) def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ out_string = ' '.join(tokens).replace(' ##', '').strip() return out_string def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format: single sequence: [CLS] X [SEP] pair of sequences: [CLS] A [SEP] B [SEP] """ if token_ids_1 is None: return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] cls = [self.cls_token_id] sep = [self.sep_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. Args: token_ids_0: list of ids (must not contain special tokens) token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids for sequence pairs already_has_special_tokens: (default False) Set to True if the token list is already formated with special tokens for the model Returns: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: if token_ids_1 is not None: raise ValueError("You should not supply a second sequence if the provided sequence of " "ids is already formated with special tokens for the model.") return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 | first sequence | second sequence if token_ids_1 is None, only returns the first portion of the mask (0's). """ sep = [self.sep_token_id] cls = [self.cls_token_id] if token_ids_1 is None: return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] def save_vocabulary(self, vocab_path): """Save the tokenizer vocabulary to a directory or file.""" index = 0 if os.path.isdir(vocab_path): vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) else: vocab_file = vocab_path with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." " Please check that the vocabulary is not corrupted!".format(vocab_file)) index = token_index writer.write(token + u'\n') index += 1 return (vocab_file,)