mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
tokenization: *align* fairseq and spm vocab to fix some tokenization errors
This commit is contained in:
parent
cce3089b65
commit
ca31abc6d6
@ -61,7 +61,19 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(str(vocab_file))
|
||||
self.vocab_file = vocab_file
|
||||
self.fairseq_tokens_to_ids = {"<s>": 0, "<unk>": 1, "</s>": 2}
|
||||
|
||||
# Original fairseq vocab and spm vocab must be "aligned":
|
||||
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
|
||||
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
|
||||
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
|
||||
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
|
||||
|
||||
# Mimic fairseq token-to-id alignment for the first 4 token
|
||||
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
|
||||
|
||||
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
|
||||
self.fairseq_offset = 1
|
||||
|
||||
self.fairseq_tokens_to_ids['<mask>'] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
|
||||
@ -131,13 +143,13 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
if token in self.fairseq_tokens_to_ids:
|
||||
return self.fairseq_tokens_to_ids[token]
|
||||
return self.sp_model.PieceToId(token) + 1
|
||||
return self.sp_model.PieceToId(token) + self.fairseq_offset
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
if index in self.fairseq_ids_to_tokens:
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index + 1)
|
||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
|
Loading…
Reference in New Issue
Block a user