mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[WIP] Add Tokenizer for MyT5 Model (#31286)
* Initial commit for MyT5 model * custom implementation of MyT5 tokenizer, unused files deleted * unittest for myt5 tokenizer * upadate of import structure and style * removed remmanents of MyT5Config * fixed docstrings * Updates after review: filled documentaion file, new docstrings and tests added * Fixed code style issues * fixed copied from to refer to function * updated loading myt5 tokenizer in tests, added sample byte map file to fixtures * changes after review * removed redundant copied from * removed redundant copied from * optimalization and loading model from hf * [run_slow] myt5 * [run-slow] myt5 * Updated en documentation for myt5 Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
5ef432e474
commit
1bd604d11c
@ -494,6 +494,8 @@
|
|||||||
title: MT5
|
title: MT5
|
||||||
- local: model_doc/mvp
|
- local: model_doc/mvp
|
||||||
title: MVP
|
title: MVP
|
||||||
|
- local: model_doc/myt5
|
||||||
|
title: myt5
|
||||||
- local: model_doc/nemotron
|
- local: model_doc/nemotron
|
||||||
title: Nemotron
|
title: Nemotron
|
||||||
- local: model_doc/nezha
|
- local: model_doc/nezha
|
||||||
|
46
docs/source/en/model_doc/myt5.md
Normal file
46
docs/source/en/model_doc/myt5.md
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# myt5
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The myt5 model was proposed in [MYTE: Morphology-Driven Byte Encoding for Better and Fairer Multilingual Language Modeling](https://arxiv.org/pdf/2403.10691.pdf) by Tomasz Limisiewicz, Terra Blevins, Hila Gonen, Orevaoghene Ahia, and Luke Zettlemoyer.
|
||||||
|
MyT5 (**My**te **T5**) is a multilingual language model based on T5 architecture.
|
||||||
|
The model uses a **m**orphologically-driven **byte** (**MYTE**) representation described in our paper.
|
||||||
|
**MYTE** uses codepoints corresponding to morphemes in contrast to characters used in UTF-8 encoding.
|
||||||
|
As a pre-requisite, we used unsupervised morphological segmentation ([Morfessor](https://aclanthology.org/E14-2006.pdf)) to obtain morpheme inventories for 99 languages.
|
||||||
|
However, the morphological segmentation step is not needed when using the pre-defined morpheme inventory from the hub (see: [Tomli/myt5-base](https://huggingface.co/Tomlim/myt5-base)).
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*A major consideration in multilingual language modeling is how to best represent languages with diverse vocabularies and scripts. Although contemporary text encoding methods cover most of the world’s writing systems, they exhibit bias towards the high-resource languages of the Global West. As a result, texts of underrepresented languages tend to be segmented into long sequences of linguistically meaningless units. To address the disparities, we introduce a new paradigm that encodes the same information with segments of consistent size across diverse languages. Our encoding convention (MYTE) is based on morphemes, as their inventories are more balanced across languages than characters, which are used in previous methods. We show that MYTE produces shorter encodings for all 99 analyzed languages, with the most notable improvements for non-European languages and non-Latin scripts. This, in turn, improves multilingual LM performance and diminishes the perplexity gap throughout diverse languages.*
|
||||||
|
|
||||||
|
This model was contributed by [Tomasz Limisiewicz](https://huggingface.co/Tomlim).
|
||||||
|
The original code can be found [here](https://github.com/tomlimi/MYTE).
|
||||||
|
|
||||||
|
## MyT5Tokenizer
|
||||||
|
|
||||||
|
[[autodoc]] MyT5Tokenizer
|
||||||
|
- build_inputs_with_special_tokens
|
||||||
|
- get_special_tokens_mask
|
||||||
|
- create_token_type_ids_from_sequences
|
||||||
|
- save_vocabulary
|
||||||
|
|
||||||
|
## MyT5Tokenizer
|
||||||
|
|
||||||
|
[[autodoc]] MyT5Tokenizer
|
||||||
|
|
@ -607,6 +607,7 @@ _import_structure = {
|
|||||||
"MusicgenMelodyDecoderConfig",
|
"MusicgenMelodyDecoderConfig",
|
||||||
],
|
],
|
||||||
"models.mvp": ["MvpConfig", "MvpTokenizer"],
|
"models.mvp": ["MvpConfig", "MvpTokenizer"],
|
||||||
|
"models.myt5": ["MyT5Tokenizer"],
|
||||||
"models.nemotron": ["NemotronConfig"],
|
"models.nemotron": ["NemotronConfig"],
|
||||||
"models.nllb": [],
|
"models.nllb": [],
|
||||||
"models.nllb_moe": ["NllbMoeConfig"],
|
"models.nllb_moe": ["NllbMoeConfig"],
|
||||||
@ -5457,6 +5458,7 @@ if TYPE_CHECKING:
|
|||||||
MusicgenMelodyDecoderConfig,
|
MusicgenMelodyDecoderConfig,
|
||||||
)
|
)
|
||||||
from .models.mvp import MvpConfig, MvpTokenizer
|
from .models.mvp import MvpConfig, MvpTokenizer
|
||||||
|
from .models.myt5 import MyT5Tokenizer
|
||||||
from .models.nemotron import NemotronConfig
|
from .models.nemotron import NemotronConfig
|
||||||
from .models.nllb_moe import NllbMoeConfig
|
from .models.nllb_moe import NllbMoeConfig
|
||||||
from .models.nougat import NougatProcessor
|
from .models.nougat import NougatProcessor
|
||||||
|
@ -168,6 +168,7 @@ from . import (
|
|||||||
musicgen,
|
musicgen,
|
||||||
musicgen_melody,
|
musicgen_melody,
|
||||||
mvp,
|
mvp,
|
||||||
|
myt5,
|
||||||
nemotron,
|
nemotron,
|
||||||
nllb,
|
nllb,
|
||||||
nllb_moe,
|
nllb_moe,
|
||||||
|
@ -497,6 +497,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("musicgen", "MusicGen"),
|
("musicgen", "MusicGen"),
|
||||||
("musicgen_melody", "MusicGen Melody"),
|
("musicgen_melody", "MusicGen Melody"),
|
||||||
("mvp", "MVP"),
|
("mvp", "MVP"),
|
||||||
|
("myt5", "myt5"),
|
||||||
("nat", "NAT"),
|
("nat", "NAT"),
|
||||||
("nemotron", "Nemotron"),
|
("nemotron", "Nemotron"),
|
||||||
("nezha", "Nezha"),
|
("nezha", "Nezha"),
|
||||||
|
@ -322,6 +322,7 @@ else:
|
|||||||
("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
|
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
("myt5", ("MyT5Tokenizer", None)),
|
||||||
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
(
|
(
|
||||||
"nllb",
|
"nllb",
|
||||||
|
29
src/transformers/models/myt5/__init__.py
Normal file
29
src/transformers/models/myt5/__init__.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {"tokenization_myt5": ["MyT5Tokenizer"]}
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .tokenization_myt5 import MyT5Tokenizer
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
@ -0,0 +1,60 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The MyT5 authors and HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert MyT5 checkpoint."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.t5.convert_t5_original_tf_checkpoint_to_pytorch.convert_tf_checkpoint_to_pytorch
|
||||||
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||||
|
# Initialise PyTorch model
|
||||||
|
config = T5Config.from_json_file(config_file)
|
||||||
|
print(f"Building PyTorch model from configuration: {config}")
|
||||||
|
model = T5ForConditionalGeneration(config)
|
||||||
|
|
||||||
|
# Load weights from tf checkpoint
|
||||||
|
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
|
||||||
|
|
||||||
|
# Save pytorch-model
|
||||||
|
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||||
|
model.save_pretrained(pytorch_dump_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help=(
|
||||||
|
"The config json file corresponding to the pre-trained MyT5 model. \nThis specifies the model architecture."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
377
src/transformers/models/myt5/tokenization_myt5.py
Normal file
377
src/transformers/models/myt5/tokenization_myt5.py
Normal file
@ -0,0 +1,377 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization class for model MyT5."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "byte_maps.json"}
|
||||||
|
|
||||||
|
|
||||||
|
class ByteRewriter:
|
||||||
|
"""
|
||||||
|
Byte rewriter class for MyT5 tokenizer.
|
||||||
|
This class is used to rewrite bytes using a hash tree. The hash tree is constructed from a set of rewriting rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewriting_rules (`str` or `Dict[str, str]`):
|
||||||
|
A path to a json file containing the rewriting rules or a dictionary containing the rewriting rules.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
LEAF = "[LEAF]"
|
||||||
|
|
||||||
|
def __init__(self, rewriting_rules: Union[str, Dict[str, str]]):
|
||||||
|
if isinstance(rewriting_rules, str):
|
||||||
|
with open(rewriting_rules, "r") as f:
|
||||||
|
rewriting_rules = json.load(f)
|
||||||
|
elif not isinstance(rewriting_rules, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f"rewriting_rules should be either a path to json file or a dict, got {type(rewriting_rules)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hash_tree = self.construct_hash_tree(rewriting_rules)
|
||||||
|
reverse_rewriting_rules = {v: k for k, v in rewriting_rules.items()}
|
||||||
|
self.reverse_hash_tree = self.construct_hash_tree(reverse_rewriting_rules)
|
||||||
|
|
||||||
|
def add_leaf(self, hash_tree: Dict[str, Union[dict, List[str]]], byte_in_sequence: str, byte_out_sequence: str):
|
||||||
|
"""
|
||||||
|
Add a leaf with the output byte sequence to the hash tree.
|
||||||
|
"""
|
||||||
|
byte_in_list = byte_in_sequence.split(" ")
|
||||||
|
byte_out_list = byte_out_sequence.split(" ")
|
||||||
|
|
||||||
|
tree_pointer = hash_tree
|
||||||
|
for b in byte_in_list:
|
||||||
|
if b not in tree_pointer:
|
||||||
|
tree_pointer[b] = {}
|
||||||
|
tree_pointer = tree_pointer[b]
|
||||||
|
|
||||||
|
tree_pointer[self.LEAF] = byte_out_list
|
||||||
|
|
||||||
|
def construct_hash_tree(self, rewriting_rules: Dict[str, str]) -> Dict[str, Union[dict, List[str]]]:
|
||||||
|
"""
|
||||||
|
Construct a hash tree for rewritten byte sequences.
|
||||||
|
"""
|
||||||
|
hash_tree = defaultdict(dict)
|
||||||
|
for b in (f"{x:02x}" for x in range(256)):
|
||||||
|
hash_tree[b][self.LEAF] = [b]
|
||||||
|
|
||||||
|
for in_sequence, out_sequence in rewriting_rules.items():
|
||||||
|
self.add_leaf(hash_tree, in_sequence, out_sequence)
|
||||||
|
|
||||||
|
return hash_tree
|
||||||
|
|
||||||
|
def search_hash_tree(self, byte_sequence: List[str]) -> Union[None, List[str]]:
|
||||||
|
"""
|
||||||
|
Search the hash tree and return the rewritten byte sequence if found.
|
||||||
|
"""
|
||||||
|
tree_pointer = self.hash_tree
|
||||||
|
for b in byte_sequence:
|
||||||
|
if b in tree_pointer:
|
||||||
|
tree_pointer = tree_pointer[b]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return tree_pointer[self.LEAF]
|
||||||
|
|
||||||
|
def rewrite_bytes(self, in_bytes: List[str], reverse=False) -> List[str]:
|
||||||
|
"""
|
||||||
|
Rewrite a sequence of bytes using the hash tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_bytes (`List[str]`): A list of bytes to be rewritten.
|
||||||
|
reverse (`bool`): If True, decoding is performed with the reverse hash tree.
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The rewritten byte sequence.
|
||||||
|
"""
|
||||||
|
out_bytes = []
|
||||||
|
b_start = 0
|
||||||
|
b_end = 0
|
||||||
|
|
||||||
|
while b_start < len(in_bytes):
|
||||||
|
tree_pointer = self.hash_tree if not reverse else self.reverse_hash_tree
|
||||||
|
for j in range(b_start, len(in_bytes)):
|
||||||
|
b = in_bytes[j]
|
||||||
|
if b in tree_pointer:
|
||||||
|
tree_pointer = tree_pointer[b]
|
||||||
|
elif j == b_start:
|
||||||
|
cur_leaf = [b]
|
||||||
|
b_end = j
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if self.LEAF in tree_pointer:
|
||||||
|
cur_leaf = tree_pointer[self.LEAF]
|
||||||
|
b_end = j
|
||||||
|
out_bytes.extend(cur_leaf)
|
||||||
|
b_start = b_end + 1
|
||||||
|
|
||||||
|
return out_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class MyT5Tokenizer(PreTrainedTokenizer):
|
||||||
|
"""
|
||||||
|
Construct a MyT5 tokenizer.
|
||||||
|
|
||||||
|
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||||
|
this superclass for more information regarding those methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (`str`): The file containing the byte rewriting rules.
|
||||||
|
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
|
||||||
|
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
extra_ids (`int`, *optional*, defaults to 125):
|
||||||
|
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
|
||||||
|
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
|
||||||
|
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
|
||||||
|
like in ByT5 preprocessing see
|
||||||
|
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
|
||||||
|
additional_special_tokens (`List[str]`, *optional*):
|
||||||
|
Additional special tokens used by the tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
eos_token="</s>",
|
||||||
|
unk_token="<unk>",
|
||||||
|
pad_token="<pad>",
|
||||||
|
extra_ids=125,
|
||||||
|
additional_special_tokens=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
# Add extra_ids to the special token list
|
||||||
|
if extra_ids > 0 and additional_special_tokens is None:
|
||||||
|
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
|
||||||
|
elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
|
||||||
|
# Check that we have the right number of extra_id special tokens
|
||||||
|
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
|
||||||
|
if extra_tokens != extra_ids:
|
||||||
|
raise ValueError(
|
||||||
|
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
|
||||||
|
" provided to MyT5Tokenizer. In this case the additional_special_tokens must include the"
|
||||||
|
" extra_ids tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token
|
||||||
|
eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token
|
||||||
|
unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token
|
||||||
|
# unk token needs to be in the vocab with correct index
|
||||||
|
self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token}
|
||||||
|
self.offset = len(self._added_tokens_decoder)
|
||||||
|
self._utf_vocab_size = 2**8 # utf is 8 bits
|
||||||
|
|
||||||
|
# Load byte maps
|
||||||
|
self.byte_maps = json.load(open(vocab_file, "r"))
|
||||||
|
|
||||||
|
self.decompose_rewriter = ByteRewriter(self.byte_maps["decompose_map"])
|
||||||
|
self.merge_rewriter = ByteRewriter(self.byte_maps["merge_map"])
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
eos_token=eos_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
extra_ids=0,
|
||||||
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return self._utf_vocab_size
|
||||||
|
|
||||||
|
# Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_vocab
|
||||||
|
def get_vocab(self):
|
||||||
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
|
||||||
|
vocab.update(self.added_tokens_encoder)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
# Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_special_tokens_mask
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer `prepare_for_model` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (`List[int]`, *optional*):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the token list is already formatted with special tokens for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
if already_has_special_tokens:
|
||||||
|
return super().get_special_tokens_mask(
|
||||||
|
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# normal case: some special tokens
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return ([0] * len(token_ids_0)) + [1]
|
||||||
|
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||||
|
|
||||||
|
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
|
||||||
|
"""Do not add eos again if user already added it."""
|
||||||
|
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
|
||||||
|
warnings.warn(
|
||||||
|
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
|
||||||
|
" eos tokens being added."
|
||||||
|
)
|
||||||
|
return token_ids
|
||||||
|
else:
|
||||||
|
return token_ids + [self.eos_token_id]
|
||||||
|
|
||||||
|
def create_token_type_ids_from_sequences(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Create a mask from the two sequences passed to be used in a sequence-pair classification task. MyT5 does not
|
||||||
|
make use of token type ids, therefore a list of zeros is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (`List[int]`, *optional*):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[int]`: List of zeros.
|
||||||
|
"""
|
||||||
|
eos = [self.eos_token_id]
|
||||||
|
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return len(token_ids_0 + eos) * [0]
|
||||||
|
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||||
|
|
||||||
|
# Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.build_inputs_with_special_tokens
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||||
|
adding special tokens. A sequence has the following format:
|
||||||
|
|
||||||
|
- single sequence: `X </s>`
|
||||||
|
- pair of sequences: `A </s> B </s>`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (`List[int]`):
|
||||||
|
List of IDs to which the special tokens will be added.
|
||||||
|
token_ids_1 (`List[int]`, *optional*):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||||
|
"""
|
||||||
|
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return token_ids_0
|
||||||
|
else:
|
||||||
|
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
|
||||||
|
return token_ids_0 + token_ids_1
|
||||||
|
|
||||||
|
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
||||||
|
"""Take as input a string and return a list of strings (tokens) for words/sub-words.
|
||||||
|
Represents tokens in two character hex format"""
|
||||||
|
|
||||||
|
tokens = [f"{i:02x}" for i in text.encode("utf-8")]
|
||||||
|
tokens = self.morphological_encode(tokens)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token):
|
||||||
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
|
|
||||||
|
if len(token) != 2:
|
||||||
|
token_id = None
|
||||||
|
else:
|
||||||
|
token_id = int(token, 16) + self.offset
|
||||||
|
|
||||||
|
return token_id
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
token = f"{index - self.offset:02x}"
|
||||||
|
return token
|
||||||
|
|
||||||
|
def morphological_encode(self, indices: List[str]) -> List[str]:
|
||||||
|
# Decompose and merge morphological sequences
|
||||||
|
indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False)
|
||||||
|
indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def morphological_decode(self, indices: List[str]) -> List[str]:
|
||||||
|
# Demerge and compose morphological sequences
|
||||||
|
indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True)
|
||||||
|
indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
"""Converts a sequence of tokens (string) in a single string."""
|
||||||
|
bstring = b""
|
||||||
|
|
||||||
|
out_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if token in self.added_tokens_decoder:
|
||||||
|
out_tokens.append(self.added_tokens_decoder[token])
|
||||||
|
elif token in self.added_tokens_encoder:
|
||||||
|
out_tokens.append(token)
|
||||||
|
else:
|
||||||
|
out_tokens.append(token)
|
||||||
|
|
||||||
|
out_tokens = self.morphological_decode(out_tokens)
|
||||||
|
_added_tokens = set(self.added_tokens_decoder.values()) | set(self.added_tokens_encoder)
|
||||||
|
for token in out_tokens:
|
||||||
|
if token in _added_tokens:
|
||||||
|
bstring += bytes(token, "utf-8")
|
||||||
|
else:
|
||||||
|
bstring += bytes.fromhex(token)
|
||||||
|
string = bstring.decode("utf-8", errors="ignore")
|
||||||
|
return string
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
if os.path.isdir(save_directory):
|
||||||
|
vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||||
|
writer.write(json.dumps(self.byte_maps, indent=2, ensure_ascii=False))
|
||||||
|
return (vocab_file,)
|
0
tests/models/myt5/__init__.py
Normal file
0
tests/models/myt5/__init__.py
Normal file
187
tests/models/myt5/test_tokenization_myt5.py
Normal file
187
tests/models/myt5/test_tokenization_myt5.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import binascii
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import MyT5Tokenizer
|
||||||
|
from transformers.utils import is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
from ...test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
FRAMEWORK = "pt"
|
||||||
|
elif is_tf_available():
|
||||||
|
FRAMEWORK = "tf"
|
||||||
|
else:
|
||||||
|
FRAMEWORK = "jax"
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_hex(bline: bytes, sep: str = " ") -> str:
|
||||||
|
return str(binascii.hexlify(bline, sep), "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_hex(line: str, sep: str = " ") -> str:
|
||||||
|
return bytes_to_hex(bytes(line, "utf-8"), sep)
|
||||||
|
|
||||||
|
|
||||||
|
class TestByteRewriter(unittest.TestCase):
|
||||||
|
tokenizer = MyT5Tokenizer.from_pretrained("Tomlim/myt5-base")
|
||||||
|
|
||||||
|
def test_simple_decompose(self):
|
||||||
|
decompose_rewriter = self.tokenizer.decompose_rewriter
|
||||||
|
|
||||||
|
# test rewriting
|
||||||
|
in_str = "Hello WorlD"
|
||||||
|
out_str = "hAello wAorldA"
|
||||||
|
|
||||||
|
in_hex = str_to_hex(in_str).split(" ")
|
||||||
|
out_hex = str_to_hex(out_str).split(" ")
|
||||||
|
|
||||||
|
self.assertEqual(decompose_rewriter.rewrite_bytes(in_hex), out_hex)
|
||||||
|
|
||||||
|
def test_simple_decompose_reversible(self):
|
||||||
|
decompose_rewriter = self.tokenizer.decompose_rewriter
|
||||||
|
|
||||||
|
in_str = "Hello WorlD"
|
||||||
|
out_str = "Hello WorlD"
|
||||||
|
|
||||||
|
in_hex = str_to_hex(in_str).split(" ")
|
||||||
|
out_hex = str_to_hex(out_str).split(" ")
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
decompose_rewriter.rewrite_bytes(decompose_rewriter.rewrite_bytes(in_hex), reverse=True), out_hex
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_simple_decompose_non_latin(self):
|
||||||
|
decompose_rewriter = self.tokenizer.decompose_rewriter
|
||||||
|
|
||||||
|
in_str = "你好世界 Hello WorlD"
|
||||||
|
out_str = "你好世界 hAello wAorldA"
|
||||||
|
|
||||||
|
in_hex = str_to_hex(in_str).split(" ")
|
||||||
|
out_hex = str_to_hex(out_str).split(" ")
|
||||||
|
|
||||||
|
self.assertEqual(decompose_rewriter.rewrite_bytes(in_hex), out_hex)
|
||||||
|
|
||||||
|
def test_unrecognized_byte(self):
|
||||||
|
decompose_rewriter = self.tokenizer.decompose_rewriter
|
||||||
|
|
||||||
|
in_hex = ["00", "01", "xx", "03", "61"]
|
||||||
|
out_hex = ["00", "01", "xx", "03", "61"]
|
||||||
|
|
||||||
|
self.assertEqual(decompose_rewriter.rewrite_bytes(in_hex), out_hex)
|
||||||
|
|
||||||
|
|
||||||
|
class MyT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
tokenizer_class = MyT5Tokenizer
|
||||||
|
test_rust_tokenizer = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs) -> MyT5Tokenizer:
|
||||||
|
return self.tokenizer_class.from_pretrained("Tomlim/myt5-base", **kwargs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="inputs cannot be pretokenized as ids depend on whole input string")
|
||||||
|
def test_pretokenized_inputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_convert_tokens_to_string_format(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
tokens = ["52", "85", "91", "9f", "6f", "20", "52", "85", "9f", "90", "</s>"]
|
||||||
|
string = tokenizer.convert_tokens_to_string(tokens)
|
||||||
|
|
||||||
|
self.assertIsInstance(string, str)
|
||||||
|
|
||||||
|
def test_simple_tokenize(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
in_str = "Hello World"
|
||||||
|
out_tokens = ["52", "85", "91", "9f", "6f", "20", "52", "85", "9f", "90"]
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.tokenize(in_str), out_tokens)
|
||||||
|
|
||||||
|
in_pl_str = "Witaj świecie"
|
||||||
|
out_tokens = ["77", "41", "69", "74", "61", "6a", "20", "4b", "a5", "97", "63", "69", "65"]
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.tokenize(in_pl_str), out_tokens)
|
||||||
|
|
||||||
|
in_jp_str = "こんにちは世界"
|
||||||
|
out_tokens = ["58", "80", "91", "a1", "e4", "b8", "96", "e7", "95", "8c"]
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.tokenize(in_jp_str), out_tokens)
|
||||||
|
|
||||||
|
def test_batch_tokenize(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
in_batch = ["Hello World", "Witaj świecie", "こんにちは世界"]
|
||||||
|
|
||||||
|
out_tokens = [
|
||||||
|
["52", "85", "91", "9f", "6f", "20", "52", "85", "9f", "90", "</s>"],
|
||||||
|
["77", "41", "69", "74", "61", "6a", "20", "4b", "a5", "97", "63", "69", "65", "</s>"],
|
||||||
|
["58", "80", "91", "a1", "e4", "b8", "96", "e7", "95", "8c", "</s>"],
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
[tokenizer.convert_ids_to_tokens(ids) for ids in tokenizer(in_batch)["input_ids"]], out_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_special_bytes(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
in_str_special = "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09"
|
||||||
|
out_tokens = ["00", "01", "02", "03", "04", "05", "06", "07", "08", "09"]
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.tokenize(in_str_special), out_tokens)
|
||||||
|
|
||||||
|
in_str_mixed = "\x00Hello\x01 World\x02"
|
||||||
|
out_tokens = ["00", "52", "85", "91", "9f", "6f", "01", "20", "52", "85", "9f", "90", "02"]
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.tokenize(in_str_mixed), out_tokens)
|
||||||
|
|
||||||
|
def test_special_tokens(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
in_str_special = "<unk></s><pad>"
|
||||||
|
out_tokens = ["<unk>", "</s>", "<pad>"]
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.tokenize(in_str_special), out_tokens)
|
||||||
|
|
||||||
|
in_str_not_special = "<s>"
|
||||||
|
out_tokens = ["3c", "73", "3e"]
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.tokenize(in_str_not_special), out_tokens)
|
||||||
|
|
||||||
|
in_str_mixed = "<s>Hello World</s>"
|
||||||
|
out_tokens = ["3c", "73", "3e", "52", "85", "91", "9f", "6f", "20", "52", "85", "9f", "90", "</s>"]
|
||||||
|
|
||||||
|
self.assertEqual(tokenizer.tokenize(in_str_mixed), out_tokens)
|
||||||
|
|
||||||
|
def test_token_ids_conversion(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
tokens_range = [f"{x:02x}" for x in range(256)]
|
||||||
|
indices_range = list(range(3, 256 + 3))
|
||||||
|
|
||||||
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens_range), indices_range)
|
||||||
|
self.assertListEqual(tokenizer.convert_ids_to_tokens(indices_range), tokens_range)
|
||||||
|
|
||||||
|
special_tokens = ["<pad>", "</s>", "<unk>"]
|
||||||
|
special_indices = [0, 1, 2]
|
||||||
|
|
||||||
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(special_tokens), special_indices)
|
||||||
|
self.assertListEqual(tokenizer.convert_ids_to_tokens(special_indices), special_tokens)
|
Loading…
Reference in New Issue
Block a user