From b99ad457f4b7752bc8e5ffa21d46c80131e54d39 Mon Sep 17 00:00:00 2001 From: RafaelWO <38643099+RafaelWO@users.noreply.github.com> Date: Mon, 22 Jun 2020 15:40:52 +0200 Subject: [PATCH] Added feature to move added tokens in vocabulary for Transformer-XL (#4953) * Fixed resize_token_embeddings for transfo_xl model * Fixed resize_token_embeddings for transfo_xl. Added custom methods to TransfoXLPreTrainedModel for resizing layers of the AdaptiveEmbedding. * Updated docstring * Fixed resizinhg cutoffs; added check for new size of embedding layer. * Added test for resize_token_embeddings * Fixed code quality * Fixed unchanged cutoffs in model.config * Added feature to move added tokens in tokenizer. * Fixed code quality * Added feature to move added tokens in tokenizer. * Fixed code quality * Fixed docstring, renamed sym to oken. Co-authored-by: Rafael Weingartner --- src/transformers/tokenization_transfo_xl.py | 27 +++++++++++++++++++++ tests/test_tokenization_transfo_xl.py | 13 ++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index 5b174131bac..91d9ae1fab2 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -272,6 +272,33 @@ class TransfoXLTokenizer(PreTrainedTokenizer): self.idx2sym.append(sym) self.sym2idx[sym] = len(self.idx2sym) - 1 + def move_added_token(self, token: str, target_idx: int): + """ + Moves an added token to a specific position in the vocab. + This method should be used when resizing an embedding layer other than the last one in the `AdaptiveEmbedding` + in order to move the token in the tokenizer from the default position (at the very end) to the desired one. + + Args: + token: The token to move to a specific position in the vocab. + target_idx: The position where the token should be moved to. + """ + assert token in self.added_tokens_encoder, "Token which should be moved has to be an added token" + assert token not in self.idx2sym, "Token which should be moved is already in vocab" + + # Insert sym into vocab + self.idx2sym.insert(target_idx, token) + self.sym2idx[token] = target_idx + + # Shift following indices in sym2idx + for idx in range(target_idx + 1, len(self.idx2sym)): + current_sym = self.idx2sym[idx] + self.sym2idx[current_sym] = idx + + # Delete token from added_tokens + old_index = self.added_tokens_encoder[token] + del self.added_tokens_decoder[old_index] + del self.added_tokens_encoder[token] + def _convert_id_to_token(self, idx): """Converts an id in a token (BPE) using the vocab.""" assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx) diff --git a/tests/test_tokenization_transfo_xl.py b/tests/test_tokenization_transfo_xl.py index 257761fa38d..aa6b95e2da6 100644 --- a/tests/test_tokenization_transfo_xl.py +++ b/tests/test_tokenization_transfo_xl.py @@ -82,3 +82,16 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual( tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] ) + + def test_move_added_token(self): + tokenizer = self.get_tokenizer() + original_len = len(tokenizer) + + tokenizer.add_tokens(["new1", "new2"]) + tokenizer.move_added_token("new1", 1) + + # Check that moved token is not copied (duplicate) + self.assertEqual(len(tokenizer), original_len + 2) + # Check that token is moved to specified id + self.assertEqual(tokenizer.encode("new1"), [1]) + self.assertEqual(tokenizer.decode([1]), "new1")