diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 8ff40c573e3..082ccb47bdf 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -54,12 +54,15 @@ class SentencePieceExtractor: # Merges merges = [] - for piece_l in vocab.keys(): - for piece_r in vocab.keys(): - merge = f"{piece_l}{piece_r}" - piece_score = vocab_scores.get(merge, None) - if piece_score: - merges += [(piece_l, piece_r, piece_score)] + for merge, piece_score in vocab_scores.items(): + local = [] + for index in range(1, len(merge)): + piece_l, piece_r = merge[:index], merge[index:] + if piece_l in vocab and piece_r in vocab: + local.append((piece_l, piece_r, piece_score)) + local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]])) + merges.extend(local) + merges = sorted(merges, key=lambda val: val[2], reverse=reverse) merges = [(val[0], val[1]) for val in merges] return vocab, merges