Add clean_up_tokenization_spaces to config (#22341)

* add draft changes

* fix failing wav2vec

* style

* make sure that the argument is saved + add tests

* style

* fixup

* update test

* default clean_up_tokenization_spaces to False for Bloom and Llama

* Update code based on review

Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com>

* style

* quality

---------

Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com>
This commit is contained in:
Arthur 2023-03-29 13:21:07 +02:00 committed by GitHub
parent b29fd6971d
commit 8d9c3836be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 150 additions and 42 deletions

View File

@ -204,7 +204,7 @@ class BigBirdTokenizer(PreTrainedTokenizer):
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
spaces_between_special_tokens: bool = True,
**kwargs,
) -> str:
@ -237,6 +237,11 @@ class BigBirdTokenizer(PreTrainedTokenizer):
else:
text = "".join(sub_texts)
clean_up_tokenization_spaces = (
clean_up_tokenization_spaces
if clean_up_tokenization_spaces is not None
else self.clean_up_tokenization_spaces
)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text

View File

@ -115,6 +115,7 @@ class BloomTokenizerFast(PreTrainedTokenizerFast):
eos_token="</s>",
pad_token="<pad>",
add_prefix_space=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
super().__init__(
@ -126,6 +127,7 @@ class BloomTokenizerFast(PreTrainedTokenizerFast):
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())

View File

@ -320,7 +320,7 @@ class CodeGenTokenizer(PreTrainedTokenizer):
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
truncate_before_pattern: Optional[List[str]] = None,
**kwargs,
) -> str:
@ -335,8 +335,9 @@ class CodeGenTokenizer(PreTrainedTokenizer):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
A list of regular expression strings that will be used to truncate the returned string. This can be
used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning

View File

@ -187,7 +187,7 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast):
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
truncate_before_pattern: Optional[List[str]] = None,
**kwargs,
) -> str:
@ -202,8 +202,9 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
A list of regular expression strings that will be used to truncate the returned string. This can be
used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning

View File

@ -236,7 +236,7 @@ class FNetTokenizer(PreTrainedTokenizer):
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
spaces_between_special_tokens: bool = True,
**kwargs,
) -> str:
@ -269,6 +269,11 @@ class FNetTokenizer(PreTrainedTokenizer):
else:
text = "".join(sub_texts)
clean_up_tokenization_spaces = (
clean_up_tokenization_spaces
if clean_up_tokenization_spaces is not None
else self.clean_up_tokenization_spaces
)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text

View File

@ -59,10 +59,17 @@ class LlamaTokenizer(PreTrainedTokenizer):
add_bos_token=True,
add_eos_token=False,
decode_with_prefix_space=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token

View File

@ -225,8 +225,9 @@ class MarianTokenizer(PreTrainedTokenizer):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
use_source_tokenizer (`bool`, *optional*, defaults to `False`):
Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
problems).
@ -250,8 +251,9 @@ class MarianTokenizer(PreTrainedTokenizer):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
use_source_tokenizer (`bool`, *optional*, defaults to `False`):
Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
problems).

View File

@ -373,7 +373,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
group_tokens: bool = True,
spaces_between_special_tokens: bool = False,
output_word_offsets: Optional[bool] = False,
@ -402,6 +402,11 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
text = string_output["text"]
clean_up_tokenization_spaces = (
clean_up_tokenization_spaces
if clean_up_tokenization_spaces is not None
else self.clean_up_tokenization_spaces
)
if clean_up_tokenization_spaces:
text = self.clean_up_tokenization(text)
@ -421,7 +426,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
self,
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
output_char_offsets: bool = False,
output_word_offsets: bool = False,
**kwargs,
@ -434,7 +439,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces.
output_char_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output character offsets. Character offsets can be used in combination with the
@ -491,7 +496,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
output_char_offsets: bool = False,
output_word_offsets: bool = False,
**kwargs,
@ -507,7 +512,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces.
output_char_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output character offsets. Character offsets can be used in combination with the
@ -887,7 +892,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
"""
@ -905,6 +910,11 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
text = self.convert_tokens_to_string(result)
clean_up_tokenization_spaces = (
clean_up_tokenization_spaces
if clean_up_tokenization_spaces is not None
else self.clean_up_tokenization_spaces
)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text

View File

@ -409,7 +409,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
group_tokens: bool = True,
filter_word_delimiter_token: bool = True,
spaces_between_special_tokens: bool = False,
@ -438,6 +438,11 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
text = string_output["text"]
clean_up_tokenization_spaces = (
clean_up_tokenization_spaces
if clean_up_tokenization_spaces is not None
else self.clean_up_tokenization_spaces
)
if clean_up_tokenization_spaces:
text = self.clean_up_tokenization(text)
@ -451,7 +456,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
output_char_offsets: bool = False,
**kwargs,
) -> str:
@ -466,7 +471,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces.
output_char_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output character offsets. Character offsets can be used in combination with the
@ -507,7 +512,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
self,
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
output_char_offsets: bool = False,
**kwargs,
) -> List[str]:
@ -519,7 +524,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces.
output_char_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output character offsets. Character offsets can be used in combination with the

View File

@ -556,7 +556,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
output_offsets: bool = False,
time_precision=0.02,
decode_with_timestamps: bool = False,
@ -573,8 +573,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):

View File

@ -266,7 +266,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
output_offsets: bool = False,
time_precision=0.02,
decode_with_timestamps: bool = False,
@ -283,8 +283,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):

View File

@ -254,7 +254,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
spaces_between_special_tokens: bool = True,
**kwargs,
) -> str:
@ -284,6 +284,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
# By default, there are no spaces between special tokens
text = "".join(sub_texts)
clean_up_tokenization_spaces = (
clean_up_tokenization_spaces
if clean_up_tokenization_spaces is not None
else self.clean_up_tokenization_spaces
)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text

View File

@ -922,7 +922,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
spaces_between_special_tokens: bool = True,
**kwargs,
) -> str:
@ -953,6 +953,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
else:
text = "".join(sub_texts)
clean_up_tokenization_spaces = (
clean_up_tokenization_spaces
if clean_up_tokenization_spaces is not None
else self.clean_up_tokenization_spaces
)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text

View File

@ -1470,6 +1470,9 @@ INIT_TOKENIZER_DOCSTRING = r"""
A tuple or a list of additional special tokens. Add them here to ensure they won't be split by the
tokenization process. Will be associated to `self.additional_special_tokens` and
`self.additional_special_tokens_ids`.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process.
"""
@ -1521,6 +1524,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
# By default, cleaning tokenization spaces for both fast and slow tokenizers
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)
self.deprecation_warnings = (
{}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
@ -1576,7 +1582,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
f"{self.__class__.__name__}(name_or_path='{self.name_or_path}',"
f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast},"
f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}',"
f" special_tokens={self.special_tokens_map_extended})"
f" special_tokens={self.special_tokens_map_extended}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces})"
)
def __len__(self) -> int:
@ -2112,7 +2118,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# TODO: Ensure the modified attributes (those are also in the __init__ kwargs) will give identical tokenizers
# target_keys = self.init_kwargs.keys()
target_keys = ["model_max_length"]
target_keys = ["model_max_length", "clean_up_tokenization_spaces"]
for k in target_keys:
if hasattr(self, k):
tokenizer_config[k] = getattr(self, k)
@ -3416,7 +3422,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
self,
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> List[str]:
"""
@ -3427,8 +3433,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces`.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
@ -3449,7 +3456,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
"""
@ -3463,8 +3470,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces`.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
@ -3485,7 +3493,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
raise NotImplementedError

View File

@ -539,7 +539,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
@ -548,6 +548,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
token_ids = [token_ids]
text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
clean_up_tokenization_spaces = (
clean_up_tokenization_spaces
if clean_up_tokenization_spaces is not None
else self.clean_up_tokenization_spaces
)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text

View File

@ -3895,6 +3895,51 @@ class TokenizerTesterMixin:
# Should not raise an error
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
def test_clean_up_tokenization_spaces(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
assert tokenizer.clean_up_tokenization_spaces is True
tokens = tokenizer.encode("This shouldn't be! He'll go.")
decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
tokenizer.clean_up_tokenization_spaces = False
decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn ' t be ! he ' ll go . [SEP]"
assert decoded == tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
# Fast from slow
with tempfile.TemporaryDirectory() as tmp_dir_2:
tokenizer.save_pretrained(tmp_dir_2)
tokenizer_fast = BertTokenizerFast.from_pretrained(tmp_dir_2)
del tokenizer
assert tokenizer_fast.clean_up_tokenization_spaces is False
decoded = tokenizer_fast.decode(tokens)
# fast and slow don't have the same output when we don't cleanup
# tokenization space. Here `be!` vs `be !` and `go.` vs `go .`
assert decoded == "[CLS] this shouldn ' t be! he ' ll go. [SEP]"
tokenizer_fast.clean_up_tokenization_spaces = True
assert tokenizer_fast.clean_up_tokenization_spaces is True
decoded = tokenizer_fast.decode(tokens)
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
# Slow from fast
with tempfile.TemporaryDirectory() as tmp_dir_2:
tokenizer_fast.clean_up_tokenization_spaces = False
tokenizer_fast.save_pretrained(tmp_dir_2)
tokenizer = BertTokenizer.from_pretrained(tmp_dir_2)
assert tokenizer_fast.clean_up_tokenization_spaces is False
decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn ' t be ! he ' ll go . [SEP]"
tokenizer.clean_up_tokenization_spaces = True
decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
class TokenizerUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):