mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-16 05:48:09 +00:00
minor : improve C++ and Python style (#768)
* use some STL functions * use self.field than setattr, use pathlib.Path * recover some format * const some iter * Keep the original * 2 space
This commit is contained in:
@ -23,6 +23,7 @@ import json
|
||||
import code
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import WhisperForConditionalGeneration
|
||||
|
||||
@ -75,16 +76,13 @@ if len(sys.argv) < 4:
|
||||
print("Usage: convert-h5-to-ggml.py dir_model path-to-whisper-repo dir-output [use-f32]\n")
|
||||
sys.exit(1)
|
||||
|
||||
dir_model = sys.argv[1]
|
||||
dir_whisper = sys.argv[2]
|
||||
dir_out = sys.argv[3]
|
||||
dir_model = Path(sys.argv[1])
|
||||
dir_whisper = Path(sys.argv[2])
|
||||
dir_out = Path(sys.argv[3])
|
||||
|
||||
with open(dir_model + "/vocab.json", "r", encoding="utf8") as f:
|
||||
encoder = json.load(f)
|
||||
with open(dir_model + "/added_tokens.json", "r", encoding="utf8") as f:
|
||||
encoder_added = json.load(f)
|
||||
with open(dir_model + "/config.json", "r", encoding="utf8") as f:
|
||||
hparams = json.load(f)
|
||||
encoder = json.load((dir_model / "vocab.json").open("r", encoding="utf8"))
|
||||
encoder_added = json.load((dir_model / "added_tokens.json").open( "r", encoding="utf8"))
|
||||
hparams = json.load((dir_model / "config.json").open("r", encoding="utf8") )
|
||||
|
||||
model = WhisperForConditionalGeneration.from_pretrained(dir_model)
|
||||
|
||||
@ -96,16 +94,15 @@ with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as
|
||||
|
||||
dir_tokenizer = dir_model
|
||||
|
||||
fname_out = dir_out + "/ggml-model.bin"
|
||||
fname_out = dir_out / "ggml-model.bin"
|
||||
|
||||
with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f:
|
||||
tokens = json.load(f)
|
||||
tokens = json.load(open(dir_tokenizer / "vocab.json", "r", encoding="utf8"))
|
||||
|
||||
# use 16-bit or 32-bit floats
|
||||
use_f16 = True
|
||||
if len(sys.argv) > 4:
|
||||
use_f16 = False
|
||||
fname_out = dir_out + "/ggml-model-f32.bin"
|
||||
fname_out = dir_out / "ggml-model-f32.bin"
|
||||
|
||||
fout = open(fname_out, "wb")
|
||||
|
||||
@ -171,10 +168,9 @@ for name in list_vars.keys():
|
||||
data = data.astype(np.float16)
|
||||
|
||||
# reshape conv bias from [n] to [n, 1]
|
||||
if name == "encoder.conv1.bias" or \
|
||||
name == "encoder.conv2.bias":
|
||||
if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
|
||||
data = data.reshape(data.shape[0], 1)
|
||||
print(" Reshaped variable: " + name + " to shape: ", data.shape)
|
||||
print(" Reshaped variable: " , name , " to shape: ", data.shape)
|
||||
|
||||
n_dims = len(data.shape)
|
||||
print(name, n_dims, data.shape)
|
||||
@ -182,7 +178,7 @@ for name in list_vars.keys():
|
||||
# looks like the whisper models are in f16 by default
|
||||
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
|
||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||
ftype = 1;
|
||||
ftype = 1
|
||||
if use_f16:
|
||||
if n_dims < 2 or \
|
||||
name == "encoder.conv1.bias" or \
|
||||
@ -197,16 +193,16 @@ for name in list_vars.keys():
|
||||
ftype = 0
|
||||
|
||||
# header
|
||||
str = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(str), ftype))
|
||||
str_ = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(str);
|
||||
fout.write(str_)
|
||||
|
||||
# data
|
||||
data.tofile(fout)
|
||||
|
||||
fout.close()
|
||||
|
||||
print("Done. Output file: " + fname_out)
|
||||
print("Done. Output file: " , fname_out)
|
||||
print("")
|
||||
|
Reference in New Issue
Block a user