tokenization: *align* fairseq and spm vocab to fix some tokenization errors

This commit is contained in:
Stefan Schweter 2019-12-18 11:36:54 +01:00
parent cce3089b65
commit ca31abc6d6

View File

@ -61,7 +61,19 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
self.fairseq_tokens_to_ids = {"<s>": 0, "<unk>": 1, "</s>": 2}
# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
# Mimic fairseq token-to-id alignment for the first 4 token
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self.fairseq_offset = 1
self.fairseq_tokens_to_ids['<mask>'] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
@ -131,13 +143,13 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
""" Converts a token (str/unicode) in an id using the vocab. """
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
return self.sp_model.PieceToId(token) + 1
return self.sp_model.PieceToId(token) + self.fairseq_offset
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index + 1)
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file