From 94aa56f19eed8b2419bc5ede6b7fda85d5ca59be Mon Sep 17 00:00:00 2001 From: AsukaMinato Date: Sat, 29 Apr 2023 16:06:25 +0900 Subject: [PATCH] minor : improve C++ and Python style (#768) * use some STL functions * use self.field than setattr, use pathlib.Path * recover some format * const some iter * Keep the original * 2 space --- models/convert-h5-to-ggml.py | 38 +++++++------- models/convert-pt-to-ggml.py | 41 ++++++++------- models/convert-whisper-to-coreml.py | 47 ++++++++--------- whisper.cpp | 79 +++++++++++++++-------------- 4 files changed, 100 insertions(+), 105 deletions(-) diff --git a/models/convert-h5-to-ggml.py b/models/convert-h5-to-ggml.py index 3ddee220..50836a21 100644 --- a/models/convert-h5-to-ggml.py +++ b/models/convert-h5-to-ggml.py @@ -23,6 +23,7 @@ import json import code import torch import numpy as np +from pathlib import Path from transformers import WhisperForConditionalGeneration @@ -75,16 +76,13 @@ if len(sys.argv) < 4: print("Usage: convert-h5-to-ggml.py dir_model path-to-whisper-repo dir-output [use-f32]\n") sys.exit(1) -dir_model = sys.argv[1] -dir_whisper = sys.argv[2] -dir_out = sys.argv[3] +dir_model = Path(sys.argv[1]) +dir_whisper = Path(sys.argv[2]) +dir_out = Path(sys.argv[3]) -with open(dir_model + "/vocab.json", "r", encoding="utf8") as f: - encoder = json.load(f) -with open(dir_model + "/added_tokens.json", "r", encoding="utf8") as f: - encoder_added = json.load(f) -with open(dir_model + "/config.json", "r", encoding="utf8") as f: - hparams = json.load(f) +encoder = json.load((dir_model / "vocab.json").open("r", encoding="utf8")) +encoder_added = json.load((dir_model / "added_tokens.json").open( "r", encoding="utf8")) +hparams = json.load((dir_model / "config.json").open("r", encoding="utf8") ) model = WhisperForConditionalGeneration.from_pretrained(dir_model) @@ -96,16 +94,15 @@ with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as dir_tokenizer = dir_model -fname_out = dir_out + "/ggml-model.bin" +fname_out = dir_out / "ggml-model.bin" -with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f: - tokens = json.load(f) +tokens = json.load(open(dir_tokenizer / "vocab.json", "r", encoding="utf8")) # use 16-bit or 32-bit floats use_f16 = True if len(sys.argv) > 4: use_f16 = False - fname_out = dir_out + "/ggml-model-f32.bin" + fname_out = dir_out / "ggml-model-f32.bin" fout = open(fname_out, "wb") @@ -171,10 +168,9 @@ for name in list_vars.keys(): data = data.astype(np.float16) # reshape conv bias from [n] to [n, 1] - if name == "encoder.conv1.bias" or \ - name == "encoder.conv2.bias": + if name in ["encoder.conv1.bias", "encoder.conv2.bias"]: data = data.reshape(data.shape[0], 1) - print(" Reshaped variable: " + name + " to shape: ", data.shape) + print(" Reshaped variable: " , name , " to shape: ", data.shape) n_dims = len(data.shape) print(name, n_dims, data.shape) @@ -182,7 +178,7 @@ for name in list_vars.keys(): # looks like the whisper models are in f16 by default # so we need to convert the small tensors to f32 until we fully support f16 in ggml # ftype == 0 -> float32, ftype == 1 -> float16 - ftype = 1; + ftype = 1 if use_f16: if n_dims < 2 or \ name == "encoder.conv1.bias" or \ @@ -197,16 +193,16 @@ for name in list_vars.keys(): ftype = 0 # header - str = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(str), ftype)) + str_ = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str_), ftype)) for i in range(n_dims): fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) - fout.write(str); + fout.write(str_) # data data.tofile(fout) fout.close() -print("Done. Output file: " + fname_out) +print("Done. Output file: " , fname_out) print("") diff --git a/models/convert-pt-to-ggml.py b/models/convert-pt-to-ggml.py index 31ab688e..f5aa6bd3 100644 --- a/models/convert-pt-to-ggml.py +++ b/models/convert-pt-to-ggml.py @@ -40,7 +40,7 @@ import code import torch import numpy as np import base64 - +from pathlib import Path #from transformers import GPTJForCausalLM #from transformers import GPT2TokenizerFast @@ -194,17 +194,17 @@ if len(sys.argv) < 4: print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n") sys.exit(1) -fname_inp = sys.argv[1] -dir_whisper = sys.argv[2] -dir_out = sys.argv[3] +fname_inp = Path(sys.argv[1]) +dir_whisper = Path(sys.argv[2]) +dir_out = Path(sys.argv[3]) # try to load PyTorch binary data try: model_bytes = open(fname_inp, "rb").read() with io.BytesIO(model_bytes) as fp: checkpoint = torch.load(fp, map_location="cpu") -except: - print("Error: failed to load PyTorch model file: %s" % fname_inp) +except Exception: + print("Error: failed to load PyTorch model file:" , fname_inp) sys.exit(1) hparams = checkpoint["dims"] @@ -218,17 +218,17 @@ list_vars = checkpoint["model_state_dict"] # load mel filters n_mels = hparams["n_mels"] -with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f: +with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f: filters = torch.from_numpy(f[f"mel_{n_mels}"]) #print (filters) #code.interact(local=locals()) multilingual = hparams["n_vocab"] == 51865 -tokenizer = os.path.join(dir_whisper, "whisper/assets", multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") +tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") # output in the same directory as the model -fname_out = dir_out + "/ggml-model.bin" +fname_out = dir_out / "ggml-model.bin" with open(tokenizer, "rb") as f: contents = f.read() @@ -238,9 +238,9 @@ with open(tokenizer, "rb") as f: use_f16 = True if len(sys.argv) > 4: use_f16 = False - fname_out = dir_out + "/ggml-model-f32.bin" + fname_out = dir_out / "ggml-model-f32.bin" -fout = open(fname_out, "wb") +fout = fname_out.open("wb") fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex fout.write(struct.pack("i", hparams["n_vocab"])) @@ -273,20 +273,19 @@ for key in tokens: for name in list_vars.keys(): data = list_vars[name].squeeze().numpy() - print("Processing variable: " + name + " with shape: ", data.shape) + print("Processing variable: " , name , " with shape: ", data.shape) # reshape conv bias from [n] to [n, 1] - if name == "encoder.conv1.bias" or \ - name == "encoder.conv2.bias": + if name in ["encoder.conv1.bias", "encoder.conv2.bias"]: data = data.reshape(data.shape[0], 1) - print(" Reshaped variable: " + name + " to shape: ", data.shape) + print(f" Reshaped variable: {name} to shape: ", data.shape) - n_dims = len(data.shape); + n_dims = len(data.shape) # looks like the whisper models are in f16 by default # so we need to convert the small tensors to f32 until we fully support f16 in ggml # ftype == 0 -> float32, ftype == 1 -> float16 - ftype = 1; + ftype = 1 if use_f16: if n_dims < 2 or \ name == "encoder.conv1.bias" or \ @@ -307,16 +306,16 @@ for name in list_vars.keys(): # data = data.transpose() # header - str = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(str), ftype)) + str_ = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str_), ftype)) for i in range(n_dims): fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) - fout.write(str); + fout.write(str_) # data data.tofile(fout) fout.close() -print("Done. Output file: " + fname_out) +print("Done. Output file: " , fname_out) print("") diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index 489854ed..4d4b46c3 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -20,7 +20,7 @@ def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, """ for k in state_dict: is_attention = all(substr in k for substr in ['attn', '.weight']) - is_mlp = any([k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight']]) + is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight']) if (is_attention or is_mlp) and len(state_dict[k].shape) == 2: state_dict[k] = state_dict[k][:, :, None, None] @@ -42,11 +42,10 @@ class LayerNormANE(LayerNormANEBase): class MultiHeadAttentionANE(MultiHeadAttention): def __init__(self, n_state: int, n_head: int): super().__init__(n_state, n_head) - - setattr(self, 'query', nn.Conv2d(n_state, n_state, kernel_size=1)) - setattr(self, 'key', nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)) - setattr(self, 'value', nn.Conv2d(n_state, n_state, kernel_size=1)) - setattr(self, 'out', nn.Conv2d(n_state, n_state, kernel_size=1)) + self.query = nn.Conv2d(n_state, n_state, kernel_size=1) + self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False) + self.value = nn.Conv2d(n_state, n_state, kernel_size=1) + self.out = nn.Conv2d(n_state, n_state, kernel_size=1) def forward(self, x: Tensor, @@ -104,30 +103,28 @@ class MultiHeadAttentionANE(MultiHeadAttention): class ResidualAttentionBlockANE(ResidualAttentionBlock): def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): super().__init__(n_state, n_head, cross_attention) - - setattr(self, 'attn', MultiHeadAttentionANE(n_state, n_head)) - setattr(self, 'attn_ln', LayerNormANE(n_state)) - - setattr(self, 'cross_attn', MultiHeadAttentionANE(n_state, n_head) if cross_attention else None) - setattr(self, 'cross_attn_ln', LayerNormANE(n_state) if cross_attention else None) + self.attn = MultiHeadAttentionANE(n_state, n_head) + self.attn_ln = LayerNormANE(n_state) + self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None + self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None n_mlp = n_state * 4 - setattr(self, 'mlp', nn.Sequential( + self.mlp = nn.Sequential( nn.Conv2d(n_state, n_mlp, kernel_size=1), nn.GELU(), nn.Conv2d(n_mlp, n_state, kernel_size=1) - )) - setattr(self, 'mlp_ln', LayerNormANE(n_state)) + ) + self.mlp_ln = LayerNormANE(n_state) class AudioEncoderANE(AudioEncoder): def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): super().__init__(n_mels, n_ctx, n_state, n_head, n_layer) - setattr(self, 'blocks', nn.ModuleList( + self.blocks = nn.ModuleList( [ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)] - )) - setattr(self, 'ln_post', LayerNormANE(n_state)) + ) + self.ln_post = LayerNormANE(n_state) def forward(self, x: Tensor): """ @@ -168,10 +165,10 @@ class TextDecoderANE(TextDecoder): def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer) - setattr(self, 'blocks', nn.ModuleList( + self.blocks= nn.ModuleList( [ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)] - )) - setattr(self, 'ln', LayerNormANE(n_state)) + ) + self.ln= LayerNormANE(n_state) def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): """ @@ -213,20 +210,20 @@ class WhisperANE(Whisper): def __init__(self, dims: ModelDimensions): super().__init__(dims) - setattr(self, 'encoder', AudioEncoderANE( + self.encoder = AudioEncoderANE( self.dims.n_mels, self.dims.n_audio_ctx, self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer, - )) - setattr(self, 'decoder', TextDecoderANE( + ) + self.decoder = TextDecoderANE( self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer, - )) + ) self._register_load_state_dict_pre_hook(linear_to_conv2d_map) diff --git a/whisper.cpp b/whisper.cpp index 3137373b..1a13cc11 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2356,11 +2356,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector sum += fft_out[k] * filters.data[j * n_fft + k]; } - if (sum < 1e-10) { - sum = 1e-10; - } - - sum = log10(sum); + sum = log10(std::max(sum, 1e-10)); mel.data[j * mel.n_len + i] = sum; } @@ -2602,7 +2598,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { - whisper_model_loader loader = {}; fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); @@ -2612,22 +2607,27 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model return nullptr; } - loader.context = &fin; + whisper_model_loader loader = { + .context = &fin, - loader.read = [](void * ctx, void * output, size_t read_size) { - std::ifstream * fin = (std::ifstream*)ctx; - fin->read((char *)output, read_size); - return read_size; - }; + .read = + [](void *ctx, void *output, size_t read_size) { + std::ifstream *fin = (std::ifstream *)ctx; + fin->read((char *)output, read_size); + return read_size; + }, - loader.eof = [](void * ctx) { - std::ifstream * fin = (std::ifstream*)ctx; - return fin->eof(); - }; + .eof = + [](void *ctx) { + std::ifstream *fin = (std::ifstream *)ctx; + return fin->eof(); + }, - loader.close = [](void * ctx) { - std::ifstream * fin = (std::ifstream*)ctx; - fin->close(); + .close = + [](void *ctx) { + std::ifstream *fin = (std::ifstream *)ctx; + fin->close(); + } }; auto ctx = whisper_init_no_state(&loader); @@ -2647,30 +2647,34 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t }; buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; - whisper_model_loader loader = {}; fprintf(stderr, "%s: loading model from buffer\n", __func__); - loader.context = &ctx; + whisper_model_loader loader = { + .context = &ctx, - loader.read = [](void * ctx, void * output, size_t read_size) { - buf_context * buf = reinterpret_cast(ctx); + .read = + [](void *ctx, void *output, size_t read_size) { + buf_context *buf = reinterpret_cast(ctx); - size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + size_t size_to_copy = buf->current_offset + read_size < buf->size + ? read_size + : buf->size - buf->current_offset; - memcpy(output, buf->buffer + buf->current_offset, size_to_copy); - buf->current_offset += size_to_copy; + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; - return size_to_copy; - }; + return size_to_copy; + }, - loader.eof = [](void * ctx) { - buf_context * buf = reinterpret_cast(ctx); + .eof = + [](void *ctx) { + buf_context *buf = reinterpret_cast(ctx); - return buf->current_offset >= buf->size; - }; + return buf->current_offset >= buf->size; + }, - loader.close = [](void * /*ctx*/) { }; + .close = [](void * /*ctx*/) {}}; return whisper_init_no_state(&loader); } @@ -2909,7 +2913,6 @@ int whisper_lang_id(const char * lang) { fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang); return -1; } - return g_lang.at(lang).first; } @@ -3303,15 +3306,15 @@ static void whisper_exp_compute_token_level_timestamps( // trim from start (in place) static inline void ltrim(std::string &s) { - s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { - return !std::isspace(ch); + s.erase(s.begin(), std::find_if_not(s.begin(), s.end(), [](unsigned char ch) { + return std::isspace(ch); })); } // trim from end (in place) static inline void rtrim(std::string &s) { - s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { - return !std::isspace(ch); + s.erase(std::find_if_not(s.rbegin(), s.rend(), [](unsigned char ch) { + return std::isspace(ch); }).base(), s.end()); }