mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Whisper
] Add conversion script for the tokenizer (#27338)
* draft * updates * full conversion taken from `https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee` * psuh * nits * updates * more nits * Add co author Co-authored-by: Joshua Lochner <admin@xenova.com> * fixup * cleanup * styling * add proper path * update * nits * don't push the exit * clean * update whisper doc * don't error out if tiktoken is not here * make sure we are BC with conversion * nit * Update docs/source/en/model_doc/whisper.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * merge and update * update markdwon * Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --------- Co-authored-by: Joshua Lochner <admin@xenova.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
0ded281557
commit
88832c01c8
@ -34,8 +34,13 @@ The original code can be found [here](https://github.com/openai/whisper).
|
||||
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
|
||||
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
||||
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
|
||||
The original code can be found [here](https://github.com/openai/whisper).
|
||||
- To convert the tokenizer, we recommend using the following:
|
||||
|
||||
```bash
|
||||
python src/transformers/models/whisper/convert_openai_to_hf.py --checkpoint_path "" --pytorch_dump_folder_path "Arthur/whisper-3" --convert_tokenizer True --whisper_version 3 --multilingual True
|
||||
```
|
||||
Here the `whisper_version` will set the number of languages to `100` to account for `cantonese` which was added in `whisper-large-v3`.
|
||||
|
||||
|
||||
## Inference
|
||||
|
||||
|
@ -17,7 +17,9 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
@ -25,7 +27,9 @@ import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import WhisperConfig, WhisperForConditionalGeneration
|
||||
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperTokenizer
|
||||
from transformers.models.whisper.tokenization_whisper import LANGUAGES, bytes_to_unicode
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
_MODELS = {
|
||||
@ -41,6 +45,11 @@ _MODELS = {
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
|
||||
_TOKENIZERS = {
|
||||
"multilingual": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken",
|
||||
"english": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken",
|
||||
}
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = ["layers", "blocks"]
|
||||
@ -178,11 +187,119 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960
|
||||
def _bpe(mergeable_ranks, token: bytes, max_rank=None) -> list[bytes]:
|
||||
parts = [bytes([b]) for b in token]
|
||||
while True:
|
||||
min_idx = None
|
||||
min_rank = None
|
||||
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
|
||||
rank = mergeable_ranks.get(pair[0] + pair[1])
|
||||
if rank is not None and (min_rank is None or rank < min_rank):
|
||||
min_idx = i
|
||||
min_rank = rank
|
||||
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
|
||||
break
|
||||
assert min_idx is not None
|
||||
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
|
||||
return parts
|
||||
|
||||
|
||||
def convert_tiktoken_bpe_to_hf(tiktoken_url: str):
|
||||
bpe_ranks = load_tiktoken_bpe(tiktoken_url)
|
||||
byte_encoder = bytes_to_unicode()
|
||||
|
||||
def token_bytes_to_string(b):
|
||||
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
for token, rank in bpe_ranks.items():
|
||||
vocab[token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
continue
|
||||
merged = tuple(_bpe(bpe_ranks, token, max_rank=rank))
|
||||
if len(merged) == 2: # account for empty token
|
||||
merges.append(" ".join(map(token_bytes_to_string, merged)))
|
||||
return vocab, merges
|
||||
|
||||
|
||||
def convert_tiktoken_to_hf(
|
||||
pytorch_dump_folder_path: str, multilingual: bool = True, num_languages: int = 100, time_precision=0.02
|
||||
) -> WhisperTokenizer:
|
||||
# requires whisper, unless we use the path to the tiktoken file
|
||||
tiktoken_tokenizer_path = _TOKENIZERS["multilingual" if multilingual else "english"]
|
||||
start_of_transcript = ["<|endoftext|>", "<|startoftranscript|>"]
|
||||
control_tokens = [
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nocaptions|>",
|
||||
"<|notimestamps|>",
|
||||
]
|
||||
# these are special tokens, not normalized
|
||||
language_tokens = [f"<|{k}|>" for k in list(LANGUAGES)[:num_languages]]
|
||||
# These are not special but normalized
|
||||
timestamp_tokens = [("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]
|
||||
|
||||
vocab, merges = convert_tiktoken_bpe_to_hf(tiktoken_tokenizer_path)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = f"{tmpdirname}/vocab.json"
|
||||
merge_file = f"{tmpdirname}/merges.txt"
|
||||
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
||||
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write("#version: 0.2\n")
|
||||
for bpe_tokens in merges:
|
||||
writer.write(bpe_tokens + "\n")
|
||||
|
||||
hf_tokenizer = WhisperTokenizer(vocab_file, merge_file)
|
||||
|
||||
hf_tokenizer.add_tokens(start_of_transcript + language_tokens + control_tokens, special_tokens=True)
|
||||
hf_tokenizer.add_tokens(timestamp_tokens, special_tokens=False)
|
||||
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# # Required parameters
|
||||
parser.add_argument("--checkpoint_path", type=str, help="Patht to the downloaded checkpoints")
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--convert_tokenizer",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether or not the tokenizer should be converted along with the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper_version",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Version of the whisper release",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multilingual",
|
||||
type=bool,
|
||||
default="store_true",
|
||||
help="Whether or not the model is multilingual or english only",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.convert_tokenizer:
|
||||
try:
|
||||
if not _is_package_available("tiktoken"):
|
||||
raise """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
NUM_LANGUAGES_PER_RELEASE = {1: 99, 2: 99, 3: 100}
|
||||
convert_tiktoken_to_hf(
|
||||
args.pytorch_dump_folder_path, args.multilingual, NUM_LANGUAGES_PER_RELEASE[args.whisper_version]
|
||||
)
|
||||
|
||||
convert_openai_whisper_to_tfms(args.checkpoint_path, args.pytorch_dump_folder_path)
|
||||
|
@ -191,6 +191,7 @@ LANGUAGES = {
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
@ -207,6 +208,7 @@ TO_LANGUAGE_CODE = {
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
"mandarin": "zh",
|
||||
}
|
||||
|
||||
TASK_IDS = ["translate", "transcribe"]
|
||||
@ -1206,7 +1208,7 @@ def _combine_tokens_into_words(
|
||||
if language is None:
|
||||
language = "english"
|
||||
|
||||
if language in {"chinese", "japanese", "thai", "lao", "myanmar"}:
|
||||
if language in {"chinese", "japanese", "thai", "lao", "myanmar", "cantonese"}:
|
||||
# These languages don't typically use spaces.
|
||||
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user