diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index 5b174131bac..91d9ae1fab2 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -272,6 +272,33 @@ class TransfoXLTokenizer(PreTrainedTokenizer): self.idx2sym.append(sym) self.sym2idx[sym] = len(self.idx2sym) - 1 + def move_added_token(self, token: str, target_idx: int): + """ + Moves an added token to a specific position in the vocab. + This method should be used when resizing an embedding layer other than the last one in the `AdaptiveEmbedding` + in order to move the token in the tokenizer from the default position (at the very end) to the desired one. + + Args: + token: The token to move to a specific position in the vocab. + target_idx: The position where the token should be moved to. + """ + assert token in self.added_tokens_encoder, "Token which should be moved has to be an added token" + assert token not in self.idx2sym, "Token which should be moved is already in vocab" + + # Insert sym into vocab + self.idx2sym.insert(target_idx, token) + self.sym2idx[token] = target_idx + + # Shift following indices in sym2idx + for idx in range(target_idx + 1, len(self.idx2sym)): + current_sym = self.idx2sym[idx] + self.sym2idx[current_sym] = idx + + # Delete token from added_tokens + old_index = self.added_tokens_encoder[token] + del self.added_tokens_decoder[old_index] + del self.added_tokens_encoder[token] + def _convert_id_to_token(self, idx): """Converts an id in a token (BPE) using the vocab.""" assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx) diff --git a/tests/test_tokenization_transfo_xl.py b/tests/test_tokenization_transfo_xl.py index 257761fa38d..aa6b95e2da6 100644 --- a/tests/test_tokenization_transfo_xl.py +++ b/tests/test_tokenization_transfo_xl.py @@ -82,3 +82,16 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual( tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] ) + + def test_move_added_token(self): + tokenizer = self.get_tokenizer() + original_len = len(tokenizer) + + tokenizer.add_tokens(["new1", "new2"]) + tokenizer.move_added_token("new1", 1) + + # Check that moved token is not copied (duplicate) + self.assertEqual(len(tokenizer), original_len + 2) + # Check that token is moved to specified id + self.assertEqual(tokenizer.encode("new1"), [1]) + self.assertEqual(tokenizer.decode([1]), "new1")