mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-29 17:22:25 +06:00
Added feature to move added tokens in vocabulary for Transformer-XL (#4953)
* Fixed resize_token_embeddings for transfo_xl model * Fixed resize_token_embeddings for transfo_xl. Added custom methods to TransfoXLPreTrainedModel for resizing layers of the AdaptiveEmbedding. * Updated docstring * Fixed resizinhg cutoffs; added check for new size of embedding layer. * Added test for resize_token_embeddings * Fixed code quality * Fixed unchanged cutoffs in model.config * Added feature to move added tokens in tokenizer. * Fixed code quality * Added feature to move added tokens in tokenizer. * Fixed code quality * Fixed docstring, renamed sym to oken. Co-authored-by: Rafael Weingartner <rweingartner.its-b2015@fh-salzburg.ac.at>
This commit is contained in:
parent
eb0ca71ef6
commit
b99ad457f4
@ -272,6 +272,33 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
self.idx2sym.append(sym)
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
|
||||
def move_added_token(self, token: str, target_idx: int):
|
||||
"""
|
||||
Moves an added token to a specific position in the vocab.
|
||||
This method should be used when resizing an embedding layer other than the last one in the `AdaptiveEmbedding`
|
||||
in order to move the token in the tokenizer from the default position (at the very end) to the desired one.
|
||||
|
||||
Args:
|
||||
token: The token to move to a specific position in the vocab.
|
||||
target_idx: The position where the token should be moved to.
|
||||
"""
|
||||
assert token in self.added_tokens_encoder, "Token which should be moved has to be an added token"
|
||||
assert token not in self.idx2sym, "Token which should be moved is already in vocab"
|
||||
|
||||
# Insert sym into vocab
|
||||
self.idx2sym.insert(target_idx, token)
|
||||
self.sym2idx[token] = target_idx
|
||||
|
||||
# Shift following indices in sym2idx
|
||||
for idx in range(target_idx + 1, len(self.idx2sym)):
|
||||
current_sym = self.idx2sym[idx]
|
||||
self.sym2idx[current_sym] = idx
|
||||
|
||||
# Delete token from added_tokens
|
||||
old_index = self.added_tokens_encoder[token]
|
||||
del self.added_tokens_decoder[old_index]
|
||||
del self.added_tokens_encoder[token]
|
||||
|
||||
def _convert_id_to_token(self, idx):
|
||||
"""Converts an id in a token (BPE) using the vocab."""
|
||||
assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
|
||||
|
@ -82,3 +82,16 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
|
||||
)
|
||||
|
||||
def test_move_added_token(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
original_len = len(tokenizer)
|
||||
|
||||
tokenizer.add_tokens(["new1", "new2"])
|
||||
tokenizer.move_added_token("new1", 1)
|
||||
|
||||
# Check that moved token is not copied (duplicate)
|
||||
self.assertEqual(len(tokenizer), original_len + 2)
|
||||
# Check that token is moved to specified id
|
||||
self.assertEqual(tokenizer.encode("new1"), [1])
|
||||
self.assertEqual(tokenizer.decode([1]), "new1")
|
||||
|
Loading…
Reference in New Issue
Block a user