Fix MarianTokenizer to remove metaspace character in decode (#26091)

* add: check to remove metaspace from marian tokenizer

* fix: metaspace character being removed from everywhere

* fix: remove redundant check at top

* add: test for marian tokenizer decode fix

* fix: simplified the test
This commit is contained in:
Tanay Mehta 2023-09-13 01:23:31 +05:30 committed by GitHub
parent 03e309d58e
commit 12f043eaea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 0 deletions

View File

@ -55,6 +55,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"Helsinki-NLP/opus-mt-en-de": 512}
PRETRAINED_INIT_CONFIGURATION = {}
SPIECE_UNDERLINE = ""
# Example URL https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json
@ -278,6 +280,7 @@ class MarianTokenizer(PreTrainedTokenizer):
else:
current_sub_tokens.append(token)
out_string += sp_model.decode_pieces(current_sub_tokens)
out_string = out_string.replace(SPIECE_UNDERLINE, " ")
return out_string.strip()
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:

View File

@ -149,3 +149,10 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
self.assertEqual(decoded, target_text)
def test_tokenizer_decode(self):
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
source_text = "Hello World"
ids = tokenizer(source_text)["input_ids"]
output_text = tokenizer.decode(ids, skip_special_tokens=True)
self.assertEqual(source_text, output_text)