diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 49aa64f8154..7d72e234ab3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -494,6 +494,8 @@ title: MT5 - local: model_doc/mvp title: MVP + - local: model_doc/myt5 + title: myt5 - local: model_doc/nemotron title: Nemotron - local: model_doc/nezha diff --git a/docs/source/en/model_doc/myt5.md b/docs/source/en/model_doc/myt5.md new file mode 100644 index 00000000000..c8b46f43512 --- /dev/null +++ b/docs/source/en/model_doc/myt5.md @@ -0,0 +1,46 @@ + + +# myt5 + +## Overview + +The myt5 model was proposed in [MYTE: Morphology-Driven Byte Encoding for Better and Fairer Multilingual Language Modeling](https://arxiv.org/pdf/2403.10691.pdf) by Tomasz Limisiewicz, Terra Blevins, Hila Gonen, Orevaoghene Ahia, and Luke Zettlemoyer. +MyT5 (**My**te **T5**) is a multilingual language model based on T5 architecture. +The model uses a **m**orphologically-driven **byte** (**MYTE**) representation described in our paper. +**MYTE** uses codepoints corresponding to morphemes in contrast to characters used in UTF-8 encoding. +As a pre-requisite, we used unsupervised morphological segmentation ([Morfessor](https://aclanthology.org/E14-2006.pdf)) to obtain morpheme inventories for 99 languages. +However, the morphological segmentation step is not needed when using the pre-defined morpheme inventory from the hub (see: [Tomli/myt5-base](https://huggingface.co/Tomlim/myt5-base)). + +The abstract from the paper is the following: + +*A major consideration in multilingual language modeling is how to best represent languages with diverse vocabularies and scripts. Although contemporary text encoding methods cover most of the world’s writing systems, they exhibit bias towards the high-resource languages of the Global West. As a result, texts of underrepresented languages tend to be segmented into long sequences of linguistically meaningless units. To address the disparities, we introduce a new paradigm that encodes the same information with segments of consistent size across diverse languages. Our encoding convention (MYTE) is based on morphemes, as their inventories are more balanced across languages than characters, which are used in previous methods. We show that MYTE produces shorter encodings for all 99 analyzed languages, with the most notable improvements for non-European languages and non-Latin scripts. This, in turn, improves multilingual LM performance and diminishes the perplexity gap throughout diverse languages.* + +This model was contributed by [Tomasz Limisiewicz](https://huggingface.co/Tomlim). +The original code can be found [here](https://github.com/tomlimi/MYTE). + +## MyT5Tokenizer + +[[autodoc]] MyT5Tokenizer + - build_inputs_with_special_tokens + - get_special_tokens_mask + - create_token_type_ids_from_sequences + - save_vocabulary + +## MyT5Tokenizer + +[[autodoc]] MyT5Tokenizer + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 667d51cb2a9..e4382e04c37 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -607,6 +607,7 @@ _import_structure = { "MusicgenMelodyDecoderConfig", ], "models.mvp": ["MvpConfig", "MvpTokenizer"], + "models.myt5": ["MyT5Tokenizer"], "models.nemotron": ["NemotronConfig"], "models.nllb": [], "models.nllb_moe": ["NllbMoeConfig"], @@ -5457,6 +5458,7 @@ if TYPE_CHECKING: MusicgenMelodyDecoderConfig, ) from .models.mvp import MvpConfig, MvpTokenizer + from .models.myt5 import MyT5Tokenizer from .models.nemotron import NemotronConfig from .models.nllb_moe import NllbMoeConfig from .models.nougat import NougatProcessor diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 12333c76a5d..804957c0a55 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -168,6 +168,7 @@ from . import ( musicgen, musicgen_melody, mvp, + myt5, nemotron, nllb, nllb_moe, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index b974daebfd0..17219570684 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -497,6 +497,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("musicgen", "MusicGen"), ("musicgen_melody", "MusicGen Melody"), ("mvp", "MVP"), + ("myt5", "myt5"), ("nat", "NAT"), ("nemotron", "Nemotron"), ("nezha", "Nezha"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index f5b029414d0..8c3a7a82a60 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -322,6 +322,7 @@ else: ("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)), + ("myt5", ("MyT5Tokenizer", None)), ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ( "nllb", diff --git a/src/transformers/models/myt5/__init__.py b/src/transformers/models/myt5/__init__.py new file mode 100644 index 00000000000..9579f723a00 --- /dev/null +++ b/src/transformers/models/myt5/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + + +_import_structure = {"tokenization_myt5": ["MyT5Tokenizer"]} + + +if TYPE_CHECKING: + from .tokenization_myt5 import MyT5Tokenizer + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 00000000000..39653e4b1c7 --- /dev/null +++ b/src/transformers/models/myt5/convert_myt5_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2024 The MyT5 authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert MyT5 checkpoint.""" + +import argparse + +from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 +from transformers.utils import logging + + +logging.set_verbosity_info() + + +# Copied from transformers.models.t5.convert_t5_original_tf_checkpoint_to_pytorch.convert_tf_checkpoint_to_pytorch +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = T5Config.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = T5ForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_t5(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained MyT5 model. \nThis specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/myt5/tokenization_myt5.py b/src/transformers/models/myt5/tokenization_myt5.py new file mode 100644 index 00000000000..69cb14b0cc9 --- /dev/null +++ b/src/transformers/models/myt5/tokenization_myt5.py @@ -0,0 +1,377 @@ +# coding=utf-8 +# Copyright 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model MyT5.""" + +import json +import os +import warnings +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "byte_maps.json"} + + +class ByteRewriter: + """ + Byte rewriter class for MyT5 tokenizer. + This class is used to rewrite bytes using a hash tree. The hash tree is constructed from a set of rewriting rules. + + Args: + rewriting_rules (`str` or `Dict[str, str]`): + A path to a json file containing the rewriting rules or a dictionary containing the rewriting rules. + + """ + + LEAF = "[LEAF]" + + def __init__(self, rewriting_rules: Union[str, Dict[str, str]]): + if isinstance(rewriting_rules, str): + with open(rewriting_rules, "r") as f: + rewriting_rules = json.load(f) + elif not isinstance(rewriting_rules, dict): + raise ValueError( + f"rewriting_rules should be either a path to json file or a dict, got {type(rewriting_rules)}" + ) + + self.hash_tree = self.construct_hash_tree(rewriting_rules) + reverse_rewriting_rules = {v: k for k, v in rewriting_rules.items()} + self.reverse_hash_tree = self.construct_hash_tree(reverse_rewriting_rules) + + def add_leaf(self, hash_tree: Dict[str, Union[dict, List[str]]], byte_in_sequence: str, byte_out_sequence: str): + """ + Add a leaf with the output byte sequence to the hash tree. + """ + byte_in_list = byte_in_sequence.split(" ") + byte_out_list = byte_out_sequence.split(" ") + + tree_pointer = hash_tree + for b in byte_in_list: + if b not in tree_pointer: + tree_pointer[b] = {} + tree_pointer = tree_pointer[b] + + tree_pointer[self.LEAF] = byte_out_list + + def construct_hash_tree(self, rewriting_rules: Dict[str, str]) -> Dict[str, Union[dict, List[str]]]: + """ + Construct a hash tree for rewritten byte sequences. + """ + hash_tree = defaultdict(dict) + for b in (f"{x:02x}" for x in range(256)): + hash_tree[b][self.LEAF] = [b] + + for in_sequence, out_sequence in rewriting_rules.items(): + self.add_leaf(hash_tree, in_sequence, out_sequence) + + return hash_tree + + def search_hash_tree(self, byte_sequence: List[str]) -> Union[None, List[str]]: + """ + Search the hash tree and return the rewritten byte sequence if found. + """ + tree_pointer = self.hash_tree + for b in byte_sequence: + if b in tree_pointer: + tree_pointer = tree_pointer[b] + else: + return None + + return tree_pointer[self.LEAF] + + def rewrite_bytes(self, in_bytes: List[str], reverse=False) -> List[str]: + """ + Rewrite a sequence of bytes using the hash tree. + + Args: + in_bytes (`List[str]`): A list of bytes to be rewritten. + reverse (`bool`): If True, decoding is performed with the reverse hash tree. + Returns: + `List[str]`: The rewritten byte sequence. + """ + out_bytes = [] + b_start = 0 + b_end = 0 + + while b_start < len(in_bytes): + tree_pointer = self.hash_tree if not reverse else self.reverse_hash_tree + for j in range(b_start, len(in_bytes)): + b = in_bytes[j] + if b in tree_pointer: + tree_pointer = tree_pointer[b] + elif j == b_start: + cur_leaf = [b] + b_end = j + break + else: + break + if self.LEAF in tree_pointer: + cur_leaf = tree_pointer[self.LEAF] + b_end = j + out_bytes.extend(cur_leaf) + b_start = b_end + 1 + + return out_bytes + + +class MyT5Tokenizer(PreTrainedTokenizer): + """ + Construct a MyT5 tokenizer. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): The file containing the byte rewriting rules. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 125): + Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are + accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are + indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary + like in ByT5 preprocessing see + [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + """ + + model_input_names = ["input_ids", "attention_mask"] + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + extra_ids=125, + additional_special_tokens=None, + **kwargs, + ) -> None: + # Add extra_ids to the special token list + if extra_ids > 0 and additional_special_tokens is None: + additional_special_tokens = [f"" for i in range(extra_ids)] + elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0: + # Check that we have the right number of extra_id special tokens + extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens))) + if extra_tokens != extra_ids: + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to MyT5Tokenizer. In this case the additional_special_tokens must include the" + " extra_ids tokens" + ) + + pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token + eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token + # unk token needs to be in the vocab with correct index + self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token} + self.offset = len(self._added_tokens_decoder) + self._utf_vocab_size = 2**8 # utf is 8 bits + + # Load byte maps + self.byte_maps = json.load(open(vocab_file, "r")) + + self.decompose_rewriter = ByteRewriter(self.byte_maps["decompose_map"]) + self.merge_rewriter = ByteRewriter(self.byte_maps["merge_map"]) + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=0, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self): + return self._utf_vocab_size + + # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. MyT5 does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + def _tokenize(self, text: str, **kwargs) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words. + Represents tokens in two character hex format""" + + tokens = [f"{i:02x}" for i in text.encode("utf-8")] + tokens = self.morphological_encode(tokens) + return tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + + if len(token) != 2: + token_id = None + else: + token_id = int(token, 16) + self.offset + + return token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = f"{index - self.offset:02x}" + return token + + def morphological_encode(self, indices: List[str]) -> List[str]: + # Decompose and merge morphological sequences + indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False) + indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False) + return indices + + def morphological_decode(self, indices: List[str]) -> List[str]: + # Demerge and compose morphological sequences + indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True) + indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True) + return indices + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + bstring = b"" + + out_tokens = [] + for token in tokens: + if token in self.added_tokens_decoder: + out_tokens.append(self.added_tokens_decoder[token]) + elif token in self.added_tokens_encoder: + out_tokens.append(token) + else: + out_tokens.append(token) + + out_tokens = self.morphological_decode(out_tokens) + _added_tokens = set(self.added_tokens_decoder.values()) | set(self.added_tokens_encoder) + for token in out_tokens: + if token in _added_tokens: + bstring += bytes(token, "utf-8") + else: + bstring += bytes.fromhex(token) + string = bstring.decode("utf-8", errors="ignore") + return string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + writer.write(json.dumps(self.byte_maps, indent=2, ensure_ascii=False)) + return (vocab_file,) diff --git a/tests/models/myt5/__init__.py b/tests/models/myt5/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/myt5/test_tokenization_myt5.py b/tests/models/myt5/test_tokenization_myt5.py new file mode 100644 index 00000000000..49e765ee3ea --- /dev/null +++ b/tests/models/myt5/test_tokenization_myt5.py @@ -0,0 +1,187 @@ +# coding=utf-8 +# Copyright 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import binascii +import unittest + +from transformers import MyT5Tokenizer +from transformers.utils import is_tf_available, is_torch_available + +from ...test_tokenization_common import TokenizerTesterMixin + + +if is_torch_available(): + FRAMEWORK = "pt" +elif is_tf_available(): + FRAMEWORK = "tf" +else: + FRAMEWORK = "jax" + + +def bytes_to_hex(bline: bytes, sep: str = " ") -> str: + return str(binascii.hexlify(bline, sep), "utf-8") + + +def str_to_hex(line: str, sep: str = " ") -> str: + return bytes_to_hex(bytes(line, "utf-8"), sep) + + +class TestByteRewriter(unittest.TestCase): + tokenizer = MyT5Tokenizer.from_pretrained("Tomlim/myt5-base") + + def test_simple_decompose(self): + decompose_rewriter = self.tokenizer.decompose_rewriter + + # test rewriting + in_str = "Hello WorlD" + out_str = "hAello wAorldA" + + in_hex = str_to_hex(in_str).split(" ") + out_hex = str_to_hex(out_str).split(" ") + + self.assertEqual(decompose_rewriter.rewrite_bytes(in_hex), out_hex) + + def test_simple_decompose_reversible(self): + decompose_rewriter = self.tokenizer.decompose_rewriter + + in_str = "Hello WorlD" + out_str = "Hello WorlD" + + in_hex = str_to_hex(in_str).split(" ") + out_hex = str_to_hex(out_str).split(" ") + + self.assertEqual( + decompose_rewriter.rewrite_bytes(decompose_rewriter.rewrite_bytes(in_hex), reverse=True), out_hex + ) + + def test_simple_decompose_non_latin(self): + decompose_rewriter = self.tokenizer.decompose_rewriter + + in_str = "你好世界 Hello WorlD" + out_str = "你好世界 hAello wAorldA" + + in_hex = str_to_hex(in_str).split(" ") + out_hex = str_to_hex(out_str).split(" ") + + self.assertEqual(decompose_rewriter.rewrite_bytes(in_hex), out_hex) + + def test_unrecognized_byte(self): + decompose_rewriter = self.tokenizer.decompose_rewriter + + in_hex = ["00", "01", "xx", "03", "61"] + out_hex = ["00", "01", "xx", "03", "61"] + + self.assertEqual(decompose_rewriter.rewrite_bytes(in_hex), out_hex) + + +class MyT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = MyT5Tokenizer + test_rust_tokenizer = False + + def setUp(self): + super().setUp() + + def get_tokenizer(self, **kwargs) -> MyT5Tokenizer: + return self.tokenizer_class.from_pretrained("Tomlim/myt5-base", **kwargs) + + @unittest.skip(reason="inputs cannot be pretokenized as ids depend on whole input string") + def test_pretokenized_inputs(self): + pass + + def test_convert_tokens_to_string_format(self): + tokenizer = self.get_tokenizer() + with self.subTest(f"{tokenizer.__class__.__name__}"): + tokens = ["52", "85", "91", "9f", "6f", "20", "52", "85", "9f", "90", ""] + string = tokenizer.convert_tokens_to_string(tokens) + + self.assertIsInstance(string, str) + + def test_simple_tokenize(self): + tokenizer = self.get_tokenizer() + + in_str = "Hello World" + out_tokens = ["52", "85", "91", "9f", "6f", "20", "52", "85", "9f", "90"] + + self.assertEqual(tokenizer.tokenize(in_str), out_tokens) + + in_pl_str = "Witaj świecie" + out_tokens = ["77", "41", "69", "74", "61", "6a", "20", "4b", "a5", "97", "63", "69", "65"] + + self.assertEqual(tokenizer.tokenize(in_pl_str), out_tokens) + + in_jp_str = "こんにちは世界" + out_tokens = ["58", "80", "91", "a1", "e4", "b8", "96", "e7", "95", "8c"] + + self.assertEqual(tokenizer.tokenize(in_jp_str), out_tokens) + + def test_batch_tokenize(self): + tokenizer = self.get_tokenizer() + + in_batch = ["Hello World", "Witaj świecie", "こんにちは世界"] + + out_tokens = [ + ["52", "85", "91", "9f", "6f", "20", "52", "85", "9f", "90", ""], + ["77", "41", "69", "74", "61", "6a", "20", "4b", "a5", "97", "63", "69", "65", ""], + ["58", "80", "91", "a1", "e4", "b8", "96", "e7", "95", "8c", ""], + ] + + self.assertListEqual( + [tokenizer.convert_ids_to_tokens(ids) for ids in tokenizer(in_batch)["input_ids"]], out_tokens + ) + + def test_special_bytes(self): + tokenizer = self.get_tokenizer() + + in_str_special = "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09" + out_tokens = ["00", "01", "02", "03", "04", "05", "06", "07", "08", "09"] + + self.assertEqual(tokenizer.tokenize(in_str_special), out_tokens) + + in_str_mixed = "\x00Hello\x01 World\x02" + out_tokens = ["00", "52", "85", "91", "9f", "6f", "01", "20", "52", "85", "9f", "90", "02"] + + self.assertEqual(tokenizer.tokenize(in_str_mixed), out_tokens) + + def test_special_tokens(self): + tokenizer = self.get_tokenizer() + + in_str_special = "" + out_tokens = ["", "", ""] + + self.assertEqual(tokenizer.tokenize(in_str_special), out_tokens) + + in_str_not_special = "" + out_tokens = ["3c", "73", "3e"] + + self.assertEqual(tokenizer.tokenize(in_str_not_special), out_tokens) + + in_str_mixed = "Hello World" + out_tokens = ["3c", "73", "3e", "52", "85", "91", "9f", "6f", "20", "52", "85", "9f", "90", ""] + + self.assertEqual(tokenizer.tokenize(in_str_mixed), out_tokens) + + def test_token_ids_conversion(self): + tokenizer = self.get_tokenizer() + + tokens_range = [f"{x:02x}" for x in range(256)] + indices_range = list(range(3, 256 + 3)) + + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens_range), indices_range) + self.assertListEqual(tokenizer.convert_ids_to_tokens(indices_range), tokens_range) + + special_tokens = ["", "", ""] + special_indices = [0, 1, 2] + + self.assertListEqual(tokenizer.convert_tokens_to_ids(special_tokens), special_indices) + self.assertListEqual(tokenizer.convert_ids_to_tokens(special_indices), special_tokens)