mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix CLIP fast tokenizer and change some properties of the slow version (#15067)
Very big changes concerning the tokenizer fast of CLIP which did not correspond to the tokenizer slow of CLIP Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
240cc6cbdc
commit
e93763d420
@ -559,6 +559,10 @@ jobs:
|
||||
if [ -f test_list.txt ]; then
|
||||
python -m pytest -s --make-reports=tests_custom_tokenizers ./tests/test_tokenization_bert_japanese.py ./tests/test_tokenization_openai.py | tee tests_output.txt
|
||||
fi
|
||||
- run: |
|
||||
if [ -f test_list.txt ]; then
|
||||
python -m pytest -n 1 tests/test_tokenization_clip.py --dist=loadfile -s --make-reports=tests_tokenization_clip --durations=100 | tee tests_output.txt
|
||||
fi
|
||||
- store_artifacts:
|
||||
path: ~/transformers/tests_output.txt
|
||||
- store_artifacts:
|
||||
|
2
setup.py
2
setup.py
@ -105,6 +105,7 @@ _deps = [
|
||||
"filelock",
|
||||
"flake8>=3.8.3",
|
||||
"flax>=0.3.5",
|
||||
"ftfy",
|
||||
"fugashi>=1.0",
|
||||
"GitPython<3.1.19",
|
||||
"huggingface-hub>=0.1.0,<1.0",
|
||||
@ -242,6 +243,7 @@ else:
|
||||
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")
|
||||
|
||||
extras["tokenizers"] = deps_list("tokenizers")
|
||||
extras["ftfy"] = deps_list("ftfy")
|
||||
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
||||
extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxruntime"]
|
||||
extras["modelcreation"] = deps_list("cookiecutter")
|
||||
|
@ -823,6 +823,7 @@ class CLIPConverter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
vocab = self.original_tokenizer.encoder
|
||||
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
||||
unk_token = self.original_tokenizer.unk_token
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
@ -832,13 +833,32 @@ class CLIPConverter(Converter):
|
||||
continuing_subword_prefix="",
|
||||
end_of_word_suffix="</w>",
|
||||
fuse_unk=False,
|
||||
unk_token=str(unk_token),
|
||||
)
|
||||
)
|
||||
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
|
||||
tokenizer.normalizer = normalizers.Sequence(
|
||||
[normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
|
||||
)
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.Split(
|
||||
Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
|
||||
behavior="removed",
|
||||
invert=True,
|
||||
),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=False),
|
||||
]
|
||||
)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
||||
|
||||
# Hack to have a ByteLevel and TemplaceProcessor
|
||||
tokenizer.post_processor = processors.RobertaProcessing(
|
||||
sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
|
||||
cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
|
||||
add_prefix_space=False,
|
||||
trim_offsets=False,
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
|
@ -15,6 +15,7 @@ deps = {
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"flax": "flax>=0.3.5",
|
||||
"ftfy": "ftfy",
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
|
||||
|
@ -158,6 +158,13 @@ except importlib_metadata.PackageNotFoundError:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_faiss_available = False
|
||||
|
||||
_ftfy_available = importlib.util.find_spec("ftfy") is not None
|
||||
try:
|
||||
_ftfy_version = importlib_metadata.version("ftfy")
|
||||
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_ftfy_available = False
|
||||
|
||||
|
||||
coloredlogs = importlib.util.find_spec("coloredlogs") is not None
|
||||
try:
|
||||
@ -441,6 +448,10 @@ def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
||||
def is_ftfy_available():
|
||||
return _ftfy_available
|
||||
|
||||
|
||||
def is_torch_tpu_available():
|
||||
if not _torch_available:
|
||||
return False
|
||||
@ -516,10 +527,6 @@ def is_spacy_available():
|
||||
return importlib.util.find_spec("spacy") is not None
|
||||
|
||||
|
||||
def is_ftfy_available():
|
||||
return importlib.util.find_spec("ftfy") is not None
|
||||
|
||||
|
||||
def is_in_notebook():
|
||||
try:
|
||||
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||
@ -722,6 +729,13 @@ FLAX_IMPORT_ERROR = """
|
||||
installation page: https://github.com/google/flax and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
FTFY_IMPORT_ERROR = """
|
||||
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
|
||||
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
|
||||
that match your environment.
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
SCATTER_IMPORT_ERROR = """
|
||||
@ -801,6 +815,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
|
||||
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
|
||||
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
|
||||
("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
|
||||
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
|
||||
|
@ -48,7 +48,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"openai/clip-vit-base-patch32": {"do_lower_case": True},
|
||||
"openai/clip-vit-base-patch32": {},
|
||||
}
|
||||
|
||||
|
||||
@ -101,19 +101,6 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
||||
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
||||
|
||||
|
||||
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
|
||||
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
|
||||
|
||||
<Tip>
|
||||
|
||||
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
|
||||
|
||||
</Tip>
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||
this superclass for more information regarding those methods.
|
||||
|
||||
@ -132,9 +119,6 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
The beginning of sequence token.
|
||||
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
||||
The end of sequence token.
|
||||
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
||||
other word. (CLIP tokenizer detect beginning of words by the preceding space).
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
@ -151,8 +135,6 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
bos_token="<|startoftext|>",
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>", # hack to enable padding
|
||||
add_prefix_space=False,
|
||||
do_lower_case=True,
|
||||
**kwargs
|
||||
):
|
||||
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
||||
@ -165,8 +147,6 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
do_lower_case=do_lower_case,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -190,21 +170,12 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
|
||||
self.add_prefix_space = add_prefix_space
|
||||
|
||||
self.pat = re.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Very ugly hack to enable padding
|
||||
@property
|
||||
def pad_token_id(self) -> Optional[int]:
|
||||
"""
|
||||
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
|
||||
"""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
@ -232,9 +203,12 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
bos_token = [self.bos_token_id]
|
||||
eos_token = [self.eos_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
||||
return [self.bos_token_id] + token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
return bos_token + token_ids_0 + eos_token
|
||||
return bos_token + token_ids_0 + eos_token + eos_token + token_ids_1 + eos_token
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
@ -262,7 +236,30 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed. CLIP does not make use of token type ids, therefore a list of
|
||||
zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
"""
|
||||
bos_token = [self.bos_token_id]
|
||||
eos_token = [self.eos_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(bos_token + token_ids_0 + eos_token) * [0]
|
||||
return len(bos_token + token_ids_0 + eos_token + eos_token + token_ids_1 + eos_token) * [0]
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
@ -332,7 +329,8 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
text = "".join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors).replace("</w>", " ")
|
||||
byte_array = bytearray([self.byte_decoder[c] for c in text])
|
||||
text = byte_array.decode("utf-8", errors=self.errors).replace("</w>", " ").strip()
|
||||
return text
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
@ -363,9 +361,3 @@ class CLIPTokenizer(PreTrainedTokenizer):
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
||||
|
||||
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
||||
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
|
||||
if is_split_into_words or add_prefix_space:
|
||||
text = " " + text
|
||||
return (text, kwargs)
|
||||
|
@ -15,12 +15,10 @@
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
|
||||
|
||||
import json
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from tokenizers import pre_tokenizers
|
||||
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
from .tokenization_clip import CLIPTokenizer
|
||||
@ -52,27 +50,6 @@ class CLIPTokenizerFast(PreTrainedTokenizerFast):
|
||||
Construct a "fast" CLIP tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
|
||||
Byte-Pair-Encoding.
|
||||
|
||||
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
||||
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
||||
|
||||
```
|
||||
>>> from transformers import CLIPTokenizerFast
|
||||
>>> tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer("Hello world")['input_ids']
|
||||
[15496, 995]
|
||||
>>> tokenizer(" Hello world")['input_ids']
|
||||
[18435, 995]
|
||||
```
|
||||
|
||||
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
|
||||
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
|
||||
|
||||
<Tip>
|
||||
|
||||
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
|
||||
|
||||
</Tip>
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
||||
refer to this superclass for more information regarding those methods.
|
||||
|
||||
@ -81,9 +58,6 @@ class CLIPTokenizerFast(PreTrainedTokenizerFast):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
Path to the merges file.
|
||||
errors (`str`, *optional*, defaults to `"replace"`):
|
||||
Paradigm to follow when decoding bytes to UTF-8. See
|
||||
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
||||
unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
@ -91,11 +65,6 @@ class CLIPTokenizerFast(PreTrainedTokenizerFast):
|
||||
The beginning of sequence token.
|
||||
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
||||
The end of sequence token.
|
||||
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
||||
other word. (CLIP tokenizer detect beginning of words by the preceding space).
|
||||
trim_offsets (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the post-processing step should trim offsets to avoid including whitespaces.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
@ -113,7 +82,6 @@ class CLIPTokenizerFast(PreTrainedTokenizerFast):
|
||||
bos_token="<|startoftext|>",
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>", # hack to enable padding
|
||||
add_prefix_space=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@ -124,44 +92,81 @@ class CLIPTokenizerFast(PreTrainedTokenizerFast):
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
||||
if not isinstance(self.backend_tokenizer.pre_tokenizer, pre_tokenizers.Sequence):
|
||||
raise ValueError(
|
||||
"The `backend_tokenizer` provided does not match the expected format. The CLIP tokenizer has been "
|
||||
"heavily modified from transformers version 4.17.0. You need to convert the tokenizer you are using to be compatible with this version."
|
||||
"The easiest way to do so is "
|
||||
'`CLIPTokenizerFast.from_pretrained("path_to_local_folder_or_hub_repo, from_slow=True)`.'
|
||||
" If you want to use your existing tokenizer, you will have to revert to a version prior to "
|
||||
"4.17.0 of transformers."
|
||||
)
|
||||
|
||||
self.add_prefix_space = add_prefix_space
|
||||
self._wrap_decode_method_backend_tokenizer()
|
||||
|
||||
# Very ugly hack to enable padding
|
||||
@property
|
||||
def pad_token_id(self) -> Optional[int]:
|
||||
# Very ugly hack to enable padding to have a correct decoding see https://github.com/huggingface/tokenizers/issues/872
|
||||
def _wrap_decode_method_backend_tokenizer(self):
|
||||
orig_decode_method = self.backend_tokenizer.decode
|
||||
|
||||
def new_decode_method(*args, **kwargs):
|
||||
text = orig_decode_method(*args, **kwargs)
|
||||
text = text.replace(self.backend_tokenizer.model.end_of_word_suffix, " ").strip()
|
||||
return text
|
||||
|
||||
self.backend_tokenizer.decode = new_decode_method
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A CLIP sequence has the following format:
|
||||
|
||||
- single sequence: `<|startoftext|> X <|endoftext|>`
|
||||
|
||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
return 0
|
||||
bos_token = [self.bos_token_id]
|
||||
eos_token = [self.eos_token_id]
|
||||
|
||||
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
||||
assert self.add_prefix_space or not is_split_into_words, (
|
||||
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
|
||||
"to use it with pretokenized inputs."
|
||||
)
|
||||
if token_ids_1 is None:
|
||||
return bos_token + token_ids_0 + eos_token
|
||||
return bos_token + token_ids_0 + eos_token + eos_token + token_ids_1 + eos_token
|
||||
|
||||
return super()._batch_encode_plus(*args, **kwargs)
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed. CLIP does not make use of token type ids, therefore a list of
|
||||
zeros is returned.
|
||||
|
||||
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
assert self.add_prefix_space or not is_split_into_words, (
|
||||
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
|
||||
"to use it with pretokenized inputs."
|
||||
)
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
"""
|
||||
bos_token = [self.bos_token_id]
|
||||
eos_token = [self.eos_token_id]
|
||||
|
||||
return super()._encode_plus(*args, **kwargs)
|
||||
if token_ids_1 is None:
|
||||
return len(bos_token + token_ids_0 + eos_token) * [0]
|
||||
return len(bos_token + token_ids_0 + eos_token + eos_token + token_ids_1 + eos_token) * [0]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
|
@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTokenizerFast
|
||||
from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_tokenizers
|
||||
from transformers.testing_utils import require_ftfy, require_tokenizers
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@ -30,18 +30,20 @@ class CLIPTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = CLIPTokenizer
|
||||
rust_tokenizer_class = CLIPTokenizerFast
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {"add_prefix_space": True}
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_kwargs = {}
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# temporary addition: to test the new slow to fast converter
|
||||
self.tokenizers_list = [(CLIPTokenizerFast, "SaulLu/clip-vit-base-patch32", {})]
|
||||
|
||||
# fmt: off
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "low</w>", "er</w>", "lowest</w>", "newer</w>", "wider", "<unk>", "<|endoftext|>"]
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "l</w>", "w</w>", "r</w>", "t</w>", "low</w>", "er</w>", "lowest</w>", "newer</w>", "wider", "<unk>", "<|startoftext|>", "<|endoftext|>"]
|
||||
# fmt: on
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w</w>", "e r</w>", ""]
|
||||
merges = ["#version: 0.2", "l o", "lo w</w>", "e r</w>"]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
@ -61,148 +63,126 @@ class CLIPTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer "
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||
text = "lower newer"
|
||||
bpe_tokens = ["lo", "w", "er</w>", "n", "e", "w", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text, add_prefix_space=True)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [10, 2, 12, 9, 3, 2, 12, 16]
|
||||
input_bpe_tokens = [10, 2, 16, 9, 3, 2, 16, 20]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
||||
|
||||
sequence = "lower newer"
|
||||
|
||||
# Testing tokenization
|
||||
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
# Testing conversion to ids without special tokens
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
|
||||
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
# Testing conversion to ids with special tokens
|
||||
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
||||
ids = tokenizer.encode(sequence, add_prefix_space=True)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
# Testing the unknown token
|
||||
input_tokens = tokens + [rust_tokenizer.unk_token]
|
||||
input_bpe_tokens = [10, 2, 12, 9, 3, 2, 12, 16]
|
||||
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_pretokenized_inputs(self, *args, **kwargs):
|
||||
# It's very difficult to mix/test pretokenization with byte-level
|
||||
# And get both CLIP and Roberta to work at the same time (mostly an issue of adding a space before the string)
|
||||
pass
|
||||
|
||||
def test_padding(self, max_length=15):
|
||||
@require_ftfy
|
||||
def test_check_encoding_slow_fast(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
tokenizer_s = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
# Simple input
|
||||
s = "This is a simple input"
|
||||
s2 = ["This is a simple input 1", "This is a simple input 2"]
|
||||
p = ("This is a simple input", "This is a pair")
|
||||
p2 = [
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
text = "A\n'll 11p223RF☆ho!!to?'d'd''d of a cat"
|
||||
text_tokenized_s = tokenizer_s.tokenize(text)
|
||||
text_tokenized_r = tokenizer_r.tokenize(text)
|
||||
|
||||
self.assertListEqual(text_tokenized_s, text_tokenized_r)
|
||||
|
||||
# Test that the tokenization is identical on an example containing a character (Latin Small Letter A
|
||||
# with Tilde) encoded in 2 different ways
|
||||
text = "xa\u0303y" + " " + "x\xe3y"
|
||||
text_tokenized_s = tokenizer_s.tokenize(text)
|
||||
text_tokenized_r = tokenizer_r.tokenize(text)
|
||||
|
||||
self.assertListEqual(text_tokenized_s, text_tokenized_r)
|
||||
|
||||
# Test that the tokenization is identical on unicode of space type
|
||||
spaces_unicodes = [
|
||||
"\u0009", # (horizontal tab, '\t')
|
||||
"\u000B", # (vertical tab)
|
||||
"\u000C", # (form feed)
|
||||
"\u0020", # (space, ' ')
|
||||
"\u200E", # (left-to-right mark):w
|
||||
"\u200F", # (right-to-left mark)
|
||||
]
|
||||
for unicode_seq in spaces_unicodes:
|
||||
text_tokenized_s = tokenizer_s.tokenize(unicode_seq)
|
||||
text_tokenized_r = tokenizer_r.tokenize(unicode_seq)
|
||||
|
||||
self.assertListEqual(text_tokenized_s, text_tokenized_r)
|
||||
|
||||
# Test that the tokenization is identical on unicode of line break type
|
||||
line_break_unicodes = [
|
||||
"\u000A", # (line feed, '\n')
|
||||
"\r\n", # (carriage return and line feed, '\r\n')
|
||||
"\u000D", # (carriage return, '\r')
|
||||
"\r", # (carriage return, '\r')
|
||||
"\u000D", # (carriage return, '\r')
|
||||
"\u2028", # (line separator)
|
||||
"\u2029", # (paragraph separator)
|
||||
# "\u0085", # (next line)
|
||||
]
|
||||
|
||||
# Simple input tests
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
|
||||
# The tokenization is not identical for the character "\u0085" (next line). The slow version transforms
|
||||
# it into the Horizontal Ellipsis character "…" ("\u2026") while the fast version transforms it into a
|
||||
# space (and thus into an empty list).
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
|
||||
for unicode_seq in line_break_unicodes:
|
||||
text_tokenized_s = tokenizer_s.tokenize(unicode_seq)
|
||||
text_tokenized_r = tokenizer_r.tokenize(unicode_seq)
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
tokenizer_r.batch_encode_plus,
|
||||
s2,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
self.assertListEqual(text_tokenized_s, text_tokenized_r)
|
||||
|
||||
def test_offsets_mapping_with_different_add_prefix_space_argument(self):
|
||||
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space`
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
text_of_1_token = "hello" # `hello` is a token in the vocabulary of `pretrained_name`
|
||||
text = f"{text_of_1_token} {text_of_1_token}"
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name,
|
||||
use_fast=True,
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (0, len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(len(text_of_1_token) + 1, len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
|
||||
text = f" {text}"
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
tokenizer_r.batch_encode_plus,
|
||||
p2,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name,
|
||||
use_fast=True,
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (1, 1 + len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(1 + len(text_of_1_token) + 1, 1 + len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
def test_add_tokens_tokenizer(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_size = len(tokenizer)
|
||||
def test_log_warning(self):
|
||||
# Test related to the breaking change introduced in transformers v4.17.0
|
||||
# We need to check that an error in raised when the user try to load a previous version of the tokenizer.
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.rust_tokenizer_class.from_pretrained("robot-test/old-clip-tokenizer")
|
||||
|
||||
self.assertNotEqual(vocab_size, 0)
|
||||
self.assertTrue(
|
||||
context.exception.args[0].startswith(
|
||||
"The `backend_tokenizer` provided does not match the expected format."
|
||||
)
|
||||
)
|
||||
|
||||
# We usually have added tokens from the start in tests because our vocab fixtures are
|
||||
# smaller than the original vocabs - let's not assert this
|
||||
# self.assertEqual(vocab_size, all_size)
|
||||
@require_ftfy
|
||||
def test_tokenization_python_rust_equals(self):
|
||||
super().test_tokenization_python_rust_equals()
|
||||
|
||||
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
vocab_size_2 = tokenizer.vocab_size
|
||||
all_size_2 = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size_2, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_2)
|
||||
self.assertEqual(added_toks, len(new_toks))
|
||||
self.assertEqual(all_size_2, all_size + len(new_toks))
|
||||
|
||||
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
|
||||
|
||||
self.assertGreaterEqual(len(tokens), 4)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
|
||||
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
|
||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||
vocab_size_3 = tokenizer.vocab_size
|
||||
all_size_3 = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size_3, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_3)
|
||||
self.assertEqual(added_toks_2, len(new_toks_2))
|
||||
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||
|
||||
tokens = tokenizer.encode(
|
||||
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
|
||||
)
|
||||
|
||||
self.assertGreaterEqual(len(tokens), 6)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[0], tokens[1])
|
||||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-2], tokens[-3])
|
||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||
# padding is very hacky in CLIPTokenizer, pad_token_id is always 0
|
||||
# so skip this check
|
||||
# self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
||||
# overwrite common test
|
||||
def test_added_tokens_do_lower_case(self):
|
||||
# CLIP always lower cases letters
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user