whisper.cpp/models/convert-pt-to-ggml.py

343 lines
10 KiB
Python
Raw Normal View History

2022-09-25 21:23:15 +03:00
# Convert Whisper transformer model from PyTorch to ggml format
#
# Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium
#
# You need to clone the original repo in ~/path/to/repo/whisper/
#
# git clone https://github.com/openai/whisper ~/path/to/repo/whisper/
#
# It is used to various assets needed by the algorithm:
#
# - tokenizer
# - mel filters
#
# Also, you need to have the original models in ~/.cache/whisper/
# See the original repo for more details.
#
# This script loads the specified model and whisper assets and saves them in ggml format.
# The output is a single binary file containing the following information:
#
# - hparams
# - mel filters
# - tokenizer vocab
# - model variables
#
# For each variable, write the following:
#
# - Number of dimensions (int)
# - Name length (int)
# - Dimensions (int[n_dims])
# - Name (char[name_length])
# - Data (float[n_dims])
#
import io
import os
import sys
import struct
import json
import code
import torch
import numpy as np
import base64
from pathlib import Path
#from transformers import GPTJForCausalLM
#from transformers import GPT2TokenizerFast
2022-09-25 21:23:15 +03:00
# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110
#LANGUAGES = {
# "en": "english",
# "zh": "chinese",
# "de": "german",
# "es": "spanish",
# "ru": "russian",
# "ko": "korean",
# "fr": "french",
# "ja": "japanese",
# "pt": "portuguese",
# "tr": "turkish",
# "pl": "polish",
# "ca": "catalan",
# "nl": "dutch",
# "ar": "arabic",
# "sv": "swedish",
# "it": "italian",
# "id": "indonesian",
# "hi": "hindi",
# "fi": "finnish",
# "vi": "vietnamese",
# "iw": "hebrew",
# "uk": "ukrainian",
# "el": "greek",
# "ms": "malay",
# "cs": "czech",
# "ro": "romanian",
# "da": "danish",
# "hu": "hungarian",
# "ta": "tamil",
# "no": "norwegian",
# "th": "thai",
# "ur": "urdu",
# "hr": "croatian",
# "bg": "bulgarian",
# "lt": "lithuanian",
# "la": "latin",
# "mi": "maori",
# "ml": "malayalam",
# "cy": "welsh",
# "sk": "slovak",
# "te": "telugu",
# "fa": "persian",
# "lv": "latvian",
# "bn": "bengali",
# "sr": "serbian",
# "az": "azerbaijani",
# "sl": "slovenian",
# "kn": "kannada",
# "et": "estonian",
# "mk": "macedonian",
# "br": "breton",
# "eu": "basque",
# "is": "icelandic",
# "hy": "armenian",
# "ne": "nepali",
# "mn": "mongolian",
# "bs": "bosnian",
# "kk": "kazakh",
# "sq": "albanian",
# "sw": "swahili",
# "gl": "galician",
# "mr": "marathi",
# "pa": "punjabi",
# "si": "sinhala",
# "km": "khmer",
# "sn": "shona",
# "yo": "yoruba",
# "so": "somali",
# "af": "afrikaans",
# "oc": "occitan",
# "ka": "georgian",
# "be": "belarusian",
# "tg": "tajik",
# "sd": "sindhi",
# "gu": "gujarati",
# "am": "amharic",
# "yi": "yiddish",
# "lo": "lao",
# "uz": "uzbek",
# "fo": "faroese",
# "ht": "haitian creole",
# "ps": "pashto",
# "tk": "turkmen",
# "nn": "nynorsk",
# "mt": "maltese",
# "sa": "sanskrit",
# "lb": "luxembourgish",
# "my": "myanmar",
# "bo": "tibetan",
# "tl": "tagalog",
# "mg": "malagasy",
# "as": "assamese",
# "tt": "tatar",
# "haw": "hawaiian",
# "ln": "lingala",
# "ha": "hausa",
# "ba": "bashkir",
# "jw": "javanese",
# "su": "sundanese",
#}
2022-09-25 21:23:15 +03:00
## ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
#def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
# tokenizer = GPT2TokenizerFast.from_pretrained(path)
#
# specials = [
# "<|startoftranscript|>",
# *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
# "<|translate|>",
# "<|transcribe|>",
# "<|startoflm|>",
# "<|startofprev|>",
# "<|nocaptions|>",
# "<|notimestamps|>",
# ]
#
# tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
# return tokenizer
2022-09-25 21:23:15 +03:00
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
if len(sys.argv) < 4:
print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
sys.exit(1)
fname_inp = Path(sys.argv[1])
dir_whisper = Path(sys.argv[2])
dir_out = Path(sys.argv[3])
2022-09-25 21:23:15 +03:00
# try to load PyTorch binary data
try:
model_bytes = open(fname_inp, "rb").read()
with io.BytesIO(model_bytes) as fp:
checkpoint = torch.load(fp, map_location="cpu")
except Exception:
print("Error: failed to load PyTorch model file:" , fname_inp)
2022-09-25 21:23:15 +03:00
sys.exit(1)
hparams = checkpoint["dims"]
print("hparams:", hparams)
list_vars = checkpoint["model_state_dict"]
#print(list_vars['encoder.positional_embedding'])
#print(list_vars['encoder.conv1.weight'])
#print(list_vars['encoder.conv1.weight'].shape)
# load mel filters
n_mels = hparams["n_mels"]
with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f:
2022-09-25 21:23:15 +03:00
filters = torch.from_numpy(f[f"mel_{n_mels}"])
#print (filters)
#code.interact(local=locals())
# load tokenizer
# for backwards compatibility, also check for older hf_transformers format tokenizer files
# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
multilingual = hparams["n_vocab"] >= 51865
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
tokenizer_type = "tiktoken"
if not tokenizer.is_file():
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json"
tokenizer_type = "hf_transformers"
if not tokenizer.is_file():
print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer)
sys.exit(1)
byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k, v in byte_encoder.items()}
if tokenizer_type == "tiktoken":
with open(tokenizer, "rb") as f:
contents = f.read()
tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
elif tokenizer_type == "hf_transformers":
with open(tokenizer, "r", encoding="utf8") as f:
_tokens_raw = json.load(f)
if '<|endoftext|>' in _tokens_raw:
# ensures exact same model as tokenizer_type == tiktoken
# details: https://github.com/ggerganov/whisper.cpp/pull/725
del _tokens_raw['<|endoftext|>']
tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()}
2022-09-25 21:23:15 +03:00
# output in the same directory as the model
fname_out = dir_out / "ggml-model.bin"
2022-09-25 21:23:15 +03:00
# use 16-bit or 32-bit floats
use_f16 = True
if len(sys.argv) > 4:
use_f16 = False
fname_out = dir_out / "ggml-model-f32.bin"
2022-09-25 21:23:15 +03:00
fout = fname_out.open("wb")
2022-09-25 21:23:15 +03:00
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["n_vocab"]))
fout.write(struct.pack("i", hparams["n_audio_ctx"]))
fout.write(struct.pack("i", hparams["n_audio_state"]))
fout.write(struct.pack("i", hparams["n_audio_head"]))
fout.write(struct.pack("i", hparams["n_audio_layer"]))
fout.write(struct.pack("i", hparams["n_text_ctx"]))
fout.write(struct.pack("i", hparams["n_text_state"]))
fout.write(struct.pack("i", hparams["n_text_head"]))
fout.write(struct.pack("i", hparams["n_text_layer"]))
fout.write(struct.pack("i", hparams["n_mels"]))
fout.write(struct.pack("i", use_f16))
# write mel filters
fout.write(struct.pack("i", filters.shape[0]))
fout.write(struct.pack("i", filters.shape[1]))
for i in range(filters.shape[0]):
for j in range(filters.shape[1]):
fout.write(struct.pack("f", filters[i][j]))
# write tokenizer
2022-09-25 21:23:15 +03:00
fout.write(struct.pack("i", len(tokens)))
for key in tokens:
fout.write(struct.pack("i", len(key)))
fout.write(key)
2022-09-25 21:23:15 +03:00
for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()
print("Processing variable: " , name , " with shape: ", data.shape)
2022-09-25 21:23:15 +03:00
# reshape conv bias from [n] to [n, 1]
if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
2022-09-25 21:23:15 +03:00
data = data.reshape(data.shape[0], 1)
print(f" Reshaped variable: {name} to shape: ", data.shape)
2022-09-25 21:23:15 +03:00
n_dims = len(data.shape)
2022-09-25 21:23:15 +03:00
# looks like the whisper models are in f16 by default
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
# ftype == 0 -> float32, ftype == 1 -> float16
ftype = 1
2022-09-25 21:23:15 +03:00
if use_f16:
if n_dims < 2 or \
name == "encoder.conv1.bias" or \
name == "encoder.conv2.bias" or \
name == "encoder.positional_embedding" or \
name == "decoder.positional_embedding":
print(" Converting to float32")
data = data.astype(np.float32)
ftype = 0
else:
data = data.astype(np.float32)
ftype = 0
#if name.startswith("encoder"):
# if name.endswith("mlp.0.weight") or \
# name.endswith("mlp.2.weight"):
# print(" Transposing")
# data = data.transpose()
# header
str_ = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
2022-09-25 21:23:15 +03:00
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str_)
2022-09-25 21:23:15 +03:00
# data
data.tofile(fout)
fout.close()
print("Done. Output file: " , fname_out)
2022-09-25 21:23:15 +03:00
print("")