mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-16 05:48:09 +00:00
coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] (#2979)
* coreml: fix Whisper to CoreML conversion by disabling SDPA This commit disables the use of PyTorch's `scaled_dot_product_attention` in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions. By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process. Refs: https://github.com/ggerganov/whisper.cpp/issues/2783 * coreml: fix audio shape in whisper decoder conversion This commit fixes the audio shape in the whisper decoder conversion script. The motivation for this is that the audio shape was incorrect and was causing the conversion to fail. * coreml : set -e in generate-coreml-interface.sh The commit sets the -e flag in the generate-coreml-interface.sh script to make sure the script fails if any command fails. * coreml : update generated encoder/decoder interfaces This commit updates the generated encoder/decoder interfaces for the whisper model which is the result of running the generate-coreml-interface.sh script.
This commit is contained in:
@ -12,6 +12,15 @@ from coremltools.models.neural_network.quantization_utils import quantize_weight
|
||||
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
|
||||
from whisper import load_model
|
||||
|
||||
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
|
||||
# The Whisper implementation expects a specific behavior from
|
||||
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
|
||||
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
|
||||
# implementation instead, which is more stable across different PyTorch versions
|
||||
# (2.5.0 required by coremltools vs newer versions).
|
||||
import whisper.model
|
||||
whisper.model.MultiHeadAttention.use_sdpa = False
|
||||
|
||||
# Use for changing dim of input in encoder and decoder embeddings
|
||||
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
@ -260,10 +269,11 @@ def convert_decoder(hparams, model, quantize=False):
|
||||
model.eval()
|
||||
|
||||
tokens_shape = (1, 1)
|
||||
audio_shape = (1, hparams.n_audio_state, 1, 1500)
|
||||
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
|
||||
|
||||
audio_data = torch.randn(audio_shape)
|
||||
token_data = torch.randint(50257, tokens_shape).long()
|
||||
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
|
||||
|
||||
traced_model = torch.jit.trace(model, (token_data, audio_data))
|
||||
|
||||
model = ct.convert(
|
||||
|
Reference in New Issue
Block a user