mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-26 13:59:43 +00:00
Some checks failed
Bindings Tests (Ruby) / ubuntu-22 (push) Has been cancelled
CI / determine-tag (push) Has been cancelled
CI / ubuntu-22 (linux/amd64) (push) Has been cancelled
CI / ubuntu-22 (linux/ppc64le) (push) Has been cancelled
CI / ubuntu-22-arm64 (linux/arm64) (push) Has been cancelled
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Has been cancelled
CI / macOS-latest (generic/platform=iOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=macOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=tvOS) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 12.2.0, ON, 2.28.5) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / android_java (push) Has been cancelled
CI / quantize (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Has been cancelled
Examples WASM / deploy-wasm-github-pages (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / release (push) Has been cancelled
CI / coreml-base-en (push) Has been cancelled
* coreml : skip model load in convert-whisper-to-coreml.py This commit updates the conversion process for Whisper models to use the "mlprogram" format instead of "neuralnetwork". The motivation for this change is that when using the "neuralnetwork" format the underlying model produced is based on protobuf and my understanding is that there are limitations to this format, such as sizes of strings and the complexity of the model. Currently when trying to convert larger models such as large-v3 the conversion fails but succeeds for smaller models. The "mlprogram" format is a more recent addition to CoreML and is designed to be more flexible and powerful, allowing for more complex models and larger data types. This seems to work for larger and smaller models alike and unless I'm there are considerations that I'm not aware of I think this is what we should be using moving forward. The error that is generated for large models is the following: ```console Running MIL backend_neuralnetwork pipeline: 100%|█████████| 9/9 [00:00<00:00, 35.44 passes/s] Translating MIL ==> NeuralNetwork Ops: 100%|███████████| 5641/5641 [03:31<00:00, 26.65 ops/s] Traceback (most recent call last): File "/Users/danbev/work/ai/whisper-work/models/convert-whisper-to-coreml.py", line 322, in <module> encoder = convert_encoder(hparams, encoder, quantize=args.quantize) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/danbev/work/ai/whisper-work/models/convert-whisper-to-coreml.py", line 255, in convert_encoder model = ct.convert( ^^^^^^^^^^^ File "/Users/danbev/work/ai/whisper-work/venv/lib/python3.11/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert mlmodel = mil_convert( ^^^^^^^^^^^^ File "/Users/danbev/work/ai/whisper-work/venv/lib/python3.11/site-packages/coremltools/converters/mil/converter.py", line 186, in mil_convert return _mil_convert( ^^^^^^^^^^^^^ File "/Users/danbev/work/ai/whisper-work/venv/lib/python3.11/site-packages/coremltools/converters/mil/converter.py", line 245, in _mil_convert return modelClass( ^^^^^^^^^^^ File "/Users/danbev/work/ai/whisper-work/venv/lib/python3.11/site-packages/coremltools/models/model.py", line 489, in __init__ self.__proxy__, self._spec, self._framework_error = self._get_proxy_and_spec( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/danbev/work/ai/whisper-work/venv/lib/python3.11/site-packages/coremltools/models/model.py", line 550, in _get_proxy_and_spec _MLModelProxy( ValueError: basic_string ``` Refs: https://github.com/ggml-org/whisper.cpp/issues/3012
329 lines
12 KiB
Python
329 lines
12 KiB
Python
import argparse
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import coremltools as ct
|
|
|
|
from torch import Tensor
|
|
from torch import nn
|
|
from typing import Dict
|
|
from typing import Optional
|
|
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
|
|
from coremltools.models.neural_network.quantization_utils import quantize_weights
|
|
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
|
|
from whisper import load_model
|
|
|
|
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
|
|
# The Whisper implementation expects a specific behavior from
|
|
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
|
|
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
|
|
# implementation instead, which is more stable across different PyTorch versions
|
|
# (2.5.0 required by coremltools vs newer versions).
|
|
import whisper.model
|
|
whisper.model.MultiHeadAttention.use_sdpa = False
|
|
|
|
# Use for changing dim of input in encoder and decoder embeddings
|
|
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
"""
|
|
Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
|
|
"""
|
|
for k in state_dict:
|
|
is_attention = all(substr in k for substr in ['attn', '.weight'])
|
|
is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight'])
|
|
|
|
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
|
|
state_dict[k] = state_dict[k][:, :, None, None]
|
|
|
|
|
|
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
|
|
strict, missing_keys,
|
|
unexpected_keys, error_msgs):
|
|
state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight']
|
|
return state_dict
|
|
|
|
class LayerNormANE(LayerNormANEBase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._register_load_state_dict_pre_hook(
|
|
correct_for_bias_scale_order_inversion)
|
|
|
|
class MultiHeadAttentionANE(MultiHeadAttention):
|
|
def __init__(self, n_state: int, n_head: int):
|
|
super().__init__(n_state, n_head)
|
|
self.query = nn.Conv2d(n_state, n_state, kernel_size=1)
|
|
self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)
|
|
self.value = nn.Conv2d(n_state, n_state, kernel_size=1)
|
|
self.out = nn.Conv2d(n_state, n_state, kernel_size=1)
|
|
|
|
def forward(self,
|
|
x: Tensor,
|
|
xa: Optional[Tensor] = None,
|
|
mask: Optional[Tensor] = None,
|
|
kv_cache: Optional[dict] = None):
|
|
|
|
q = self.query(x)
|
|
|
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
k = self.key(x if xa is None else xa)
|
|
v = self.value(x if xa is None else xa)
|
|
|
|
else:
|
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
k = kv_cache[self.key]
|
|
v = kv_cache[self.value]
|
|
|
|
wv, qk = self.qkv_attention_ane(q, k, v, mask)
|
|
|
|
return self.out(wv), qk
|
|
|
|
def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
|
|
|
_, dim, _, seqlen = q.size()
|
|
|
|
dim_per_head = dim // self.n_head
|
|
|
|
scale = float(dim_per_head)**-0.5
|
|
|
|
q = q * scale
|
|
|
|
mh_q = q.split(dim_per_head, dim=1)
|
|
mh_k = k.transpose(1,3).split(dim_per_head, dim=3)
|
|
mh_v = v.split(dim_per_head, dim=1)
|
|
|
|
mh_qk = [
|
|
torch.einsum('bchq,bkhc->bkhq', [qi, ki])
|
|
for qi, ki in zip(mh_q, mh_k)
|
|
] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
|
|
|
|
if mask is not None:
|
|
for head_idx in range(self.n_head):
|
|
mh_qk[head_idx] = mh_qk[head_idx] + mask[:, :seqlen, :, :seqlen]
|
|
|
|
attn_weights = [aw.softmax(dim=1) for aw in mh_qk] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
|
|
attn = [torch.einsum('bkhq,bchk->bchq', wi, vi) for wi, vi in zip(attn_weights, mh_v)] # (batch_size, dim_per_head, 1, max_seq_length) * n_heads
|
|
attn = torch.cat(attn, dim=1) # (batch_size, dim, 1, max_seq_length)
|
|
|
|
return attn, torch.cat(mh_qk, dim=1).float().detach()
|
|
|
|
|
|
class ResidualAttentionBlockANE(ResidualAttentionBlock):
|
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
|
super().__init__(n_state, n_head, cross_attention)
|
|
self.attn = MultiHeadAttentionANE(n_state, n_head)
|
|
self.attn_ln = LayerNormANE(n_state)
|
|
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
|
|
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None
|
|
|
|
n_mlp = n_state * 4
|
|
self.mlp = nn.Sequential(
|
|
nn.Conv2d(n_state, n_mlp, kernel_size=1),
|
|
nn.GELU(),
|
|
nn.Conv2d(n_mlp, n_state, kernel_size=1)
|
|
)
|
|
self.mlp_ln = LayerNormANE(n_state)
|
|
|
|
|
|
class AudioEncoderANE(AudioEncoder):
|
|
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
|
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
|
|
|
|
self.blocks = nn.ModuleList(
|
|
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
|
|
)
|
|
self.ln_post = LayerNormANE(n_state)
|
|
|
|
def forward(self, x: Tensor):
|
|
"""
|
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
|
the mel spectrogram of the audio
|
|
"""
|
|
x = F.gelu(self.conv1(x))
|
|
x = F.gelu(self.conv2(x))
|
|
|
|
assert x.shape[1:] == self.positional_embedding.shape[::-1], "incorrect audio shape"
|
|
|
|
# Add positional embedding and add dummy dim for ANE
|
|
x = (x + self.positional_embedding.transpose(0,1)).to(x.dtype).unsqueeze(2)
|
|
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
|
|
x = self.ln_post(x)
|
|
x = x.squeeze(2).transpose(1, 2)
|
|
|
|
return x
|
|
|
|
class TextDecoderANE(TextDecoder):
|
|
|
|
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
|
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
|
|
|
|
self.blocks= nn.ModuleList(
|
|
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
|
)
|
|
self.ln= LayerNormANE(n_state)
|
|
|
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
|
"""
|
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
|
the text tokens
|
|
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
|
the encoded audio features to be attended on
|
|
"""
|
|
offset = next(iter(kv_cache.values())).shape[3] if kv_cache else 0
|
|
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
|
x = x.to(xa.dtype)
|
|
|
|
# Reformat for ANE
|
|
mask = self.mask[None, None, :, :].permute(0,3,1,2)
|
|
x = x.transpose(1,2).unsqueeze(2)
|
|
|
|
for block in self.blocks:
|
|
x = block(x, xa, mask=mask, kv_cache=kv_cache)
|
|
|
|
x = self.ln(x)
|
|
|
|
# Reformat back from ANE
|
|
x = x.permute(0,2,3,1).squeeze(0)
|
|
|
|
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks
|
|
if self.token_embedding.weight.shape[0] >= 51865:
|
|
# split in 11 chunks - 4715 each
|
|
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
|
|
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
|
|
else:
|
|
# split in 12 chunks - 4322 each
|
|
assert(self.token_embedding.weight.shape[0] == 51864)
|
|
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//12, dim=0)
|
|
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
|
|
|
|
return logits
|
|
|
|
class WhisperANE(Whisper):
|
|
def __init__(self, dims: ModelDimensions):
|
|
super().__init__(dims)
|
|
|
|
self.encoder = AudioEncoderANE(
|
|
self.dims.n_mels,
|
|
self.dims.n_audio_ctx,
|
|
self.dims.n_audio_state,
|
|
self.dims.n_audio_head,
|
|
self.dims.n_audio_layer,
|
|
)
|
|
self.decoder = TextDecoderANE(
|
|
self.dims.n_vocab,
|
|
self.dims.n_text_ctx,
|
|
self.dims.n_text_state,
|
|
self.dims.n_text_head,
|
|
self.dims.n_text_layer,
|
|
)
|
|
|
|
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
|
|
|
|
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
return self.decoder(tokens, self.encoder(mel))
|
|
|
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
|
cache = {**cache} if cache is not None else {}
|
|
hooks = []
|
|
|
|
def save_to_cache(module, _, output):
|
|
if module not in cache or output.shape[3] > self.decoder.positional_embedding.shape[0]:
|
|
cache[module] = output # save as-is, for the first token or cross attention
|
|
else:
|
|
cache[module] = torch.cat([cache[module], output], dim=3).detach()
|
|
return cache[module]
|
|
|
|
def install_hooks(layer: nn.Module):
|
|
if isinstance(layer, MultiHeadAttentionANE):
|
|
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
|
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
|
|
|
self.decoder.apply(install_hooks)
|
|
return cache, hooks
|
|
|
|
def convert_encoder(hparams, model, quantize=False):
|
|
model.eval()
|
|
|
|
input_shape = (1, hparams.n_mels, 3000)
|
|
input_data = torch.randn(input_shape)
|
|
traced_model = torch.jit.trace(model, input_data)
|
|
|
|
model = ct.convert(
|
|
traced_model,
|
|
convert_to="mlprogram",
|
|
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
|
|
outputs=[ct.TensorType(name="output")],
|
|
compute_units=ct.ComputeUnit.ALL,
|
|
)
|
|
|
|
if quantize:
|
|
model = quantize_weights(model, nbits=16)
|
|
|
|
return model
|
|
|
|
def convert_decoder(hparams, model, quantize=False):
|
|
model.eval()
|
|
|
|
tokens_shape = (1, 1)
|
|
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
|
|
|
|
audio_data = torch.randn(audio_shape)
|
|
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
|
|
|
|
traced_model = torch.jit.trace(model, (token_data, audio_data))
|
|
|
|
model = ct.convert(
|
|
traced_model,
|
|
convert_to="mlprogram",
|
|
inputs=[
|
|
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
|
|
ct.TensorType(name="audio_data", shape=audio_shape)
|
|
],
|
|
)
|
|
|
|
if quantize:
|
|
model = quantize_weights(model, nbits=16)
|
|
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True)
|
|
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
|
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
|
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
|
args = parser.parse_args()
|
|
|
|
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]:
|
|
raise ValueError("Invalid model name")
|
|
|
|
whisper = load_model(args.model).cpu()
|
|
hparams = whisper.dims
|
|
print(hparams)
|
|
|
|
if args.optimize_ane:
|
|
whisperANE = WhisperANE(hparams).eval()
|
|
whisperANE.load_state_dict(whisper.state_dict())
|
|
|
|
encoder = whisperANE.encoder
|
|
decoder = whisperANE.decoder
|
|
else:
|
|
encoder = whisper.encoder
|
|
decoder = whisper.decoder
|
|
|
|
# Convert encoder
|
|
encoder = convert_encoder(hparams, encoder, quantize=args.quantize)
|
|
encoder.save(f"models/coreml-encoder-{args.model}.mlpackage")
|
|
|
|
if args.encoder_only is False:
|
|
# Convert decoder
|
|
decoder = convert_decoder(hparams, decoder, quantize=args.quantize)
|
|
decoder.save(f"models/coreml-decoder-{args.model}.mlpackage")
|
|
|
|
print("done converting")
|