added cached tokenizer

This commit is contained in:
Steven 2024-12-11 18:29:26 +00:00
parent 6181c6b095
commit 96945e2079

View File

@ -457,10 +457,27 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
@property
def added_tokens_encoder(self) -> Dict[str, int]:
"""
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
Returns the sorted mapping from string to index. The cache is dynamically invalidated if `_added_tokens_decoder`
has changed since the last computation.
"""
return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
# Check if cache exists and is valid
if not hasattr(self, "_cached_added_tokens_encoder") or self._is_decoder_modified():
# Recompute and cache the added tokens encoder
self._cached_added_tokens_encoder = {
k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])
}
# Store the current state of `_added_tokens_decoder` for future validity checks
self._cached_decoder_state = self._added_tokens_decoder.copy()
return self._cached_added_tokens_encoder
def _is_decoder_modified(self) -> bool:
"""
Check if `_added_tokens_decoder` has been modified since the last computation.
Returns:
bool: True if modified, False otherwise.
"""
# Compare the current state with the cached state
return not hasattr(self, "_cached_decoder_state") or self._cached_decoder_state != self._added_tokens_decoder
@property
def added_tokens_decoder(self) -> Dict[int, AddedToken]: