[Whisper] Add conversion script for the tokenizer (#27338)

* draft

* updates

* full conversion taken from `https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee`

* psuh

* nits

* updates

* more nits

* Add co author

Co-authored-by: Joshua Lochner <admin@xenova.com>

* fixup

* cleanup

* styling

* add proper path

* update

* nits

* don't  push the exit

* clean

* update whisper doc

* don't error out if tiktoken is not here

* make sure we are BC with conversion

* nit

* Update docs/source/en/model_doc/whisper.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* merge and update

* update markdwon

* Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

---------

Co-authored-by: Joshua Lochner <admin@xenova.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Arthur 2023-11-07 15:07:55 +01:00 committed by GitHub
parent 0ded281557
commit 88832c01c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 128 additions and 4 deletions

View File

@ -34,8 +34,13 @@ The original code can be found [here](https://github.com/openai/whisper).
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
The original code can be found [here](https://github.com/openai/whisper).
- To convert the tokenizer, we recommend using the following:
```bash
python src/transformers/models/whisper/convert_openai_to_hf.py --checkpoint_path "" --pytorch_dump_folder_path "Arthur/whisper-3" --convert_tokenizer True --whisper_version 3 --multilingual True
```
Here the `whisper_version` will set the number of languages to `100` to account for `cantonese` which was added in `whisper-large-v3`.
## Inference

View File

@ -17,7 +17,9 @@
import argparse
import hashlib
import io
import json
import os
import tempfile
import urllib
import warnings
@ -25,7 +27,9 @@ import torch
from torch import nn
from tqdm import tqdm
from transformers import WhisperConfig, WhisperForConditionalGeneration
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperTokenizer
from transformers.models.whisper.tokenization_whisper import LANGUAGES, bytes_to_unicode
from transformers.utils.import_utils import _is_package_available
_MODELS = {
@ -41,6 +45,11 @@ _MODELS = {
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
_TOKENIZERS = {
"multilingual": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken",
"english": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken",
}
def remove_ignore_keys_(state_dict):
ignore_keys = ["layers", "blocks"]
@ -178,11 +187,119 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
model.save_pretrained(pytorch_dump_folder_path)
# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960
def _bpe(mergeable_ranks, token: bytes, max_rank=None) -> list[bytes]:
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
return parts
def convert_tiktoken_bpe_to_hf(tiktoken_url: str):
bpe_ranks = load_tiktoken_bpe(tiktoken_url)
byte_encoder = bytes_to_unicode()
def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
merges = []
vocab = {}
for token, rank in bpe_ranks.items():
vocab[token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = tuple(_bpe(bpe_ranks, token, max_rank=rank))
if len(merged) == 2: # account for empty token
merges.append(" ".join(map(token_bytes_to_string, merged)))
return vocab, merges
def convert_tiktoken_to_hf(
pytorch_dump_folder_path: str, multilingual: bool = True, num_languages: int = 100, time_precision=0.02
) -> WhisperTokenizer:
# requires whisper, unless we use the path to the tiktoken file
tiktoken_tokenizer_path = _TOKENIZERS["multilingual" if multilingual else "english"]
start_of_transcript = ["<|endoftext|>", "<|startoftranscript|>"]
control_tokens = [
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nocaptions|>",
"<|notimestamps|>",
]
# these are special tokens, not normalized
language_tokens = [f"<|{k}|>" for k in list(LANGUAGES)[:num_languages]]
# These are not special but normalized
timestamp_tokens = [("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]
vocab, merges = convert_tiktoken_bpe_to_hf(tiktoken_tokenizer_path)
with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = f"{tmpdirname}/vocab.json"
merge_file = f"{tmpdirname}/merges.txt"
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens in merges:
writer.write(bpe_tokens + "\n")
hf_tokenizer = WhisperTokenizer(vocab_file, merge_file)
hf_tokenizer.add_tokens(start_of_transcript + language_tokens + control_tokens, special_tokens=True)
hf_tokenizer.add_tokens(timestamp_tokens, special_tokens=False)
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument("--checkpoint_path", type=str, help="Patht to the downloaded checkpoints")
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--convert_tokenizer",
type=bool,
default=False,
help="Whether or not the tokenizer should be converted along with the model.",
)
parser.add_argument(
"--whisper_version",
type=int,
default=2,
help="Version of the whisper release",
)
parser.add_argument(
"--multilingual",
type=bool,
default="store_true",
help="Whether or not the model is multilingual or english only",
)
args = parser.parse_args()
if args.convert_tokenizer:
try:
if not _is_package_available("tiktoken"):
raise """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
except Exception:
pass
else:
from tiktoken.load import load_tiktoken_bpe
NUM_LANGUAGES_PER_RELEASE = {1: 99, 2: 99, 3: 100}
convert_tiktoken_to_hf(
args.pytorch_dump_folder_path, args.multilingual, NUM_LANGUAGES_PER_RELEASE[args.whisper_version]
)
convert_openai_whisper_to_tfms(args.checkpoint_path, args.pytorch_dump_folder_path)

View File

@ -191,6 +191,7 @@ LANGUAGES = {
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
@ -207,6 +208,7 @@ TO_LANGUAGE_CODE = {
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
TASK_IDS = ["translate", "transcribe"]
@ -1206,7 +1208,7 @@ def _combine_tokens_into_words(
if language is None:
language = "english"
if language in {"chinese", "japanese", "thai", "lao", "myanmar"}:
if language in {"chinese", "japanese", "thai", "lao", "myanmar", "cantonese"}:
# These languages don't typically use spaces.
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
else: