From bb467e49fa448e40aabc4e434c81630d35673112 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 1 Apr 2025 12:26:22 +0200 Subject: [PATCH] coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] This commit disables the use of PyTorch's `scaled_dot_product_attention` in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions. By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process. Refs: https://github.com/ggerganov/whisper.cpp/issues/2783 --- models/convert-whisper-to-coreml.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index 441efdd2..74575052 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -12,6 +12,15 @@ from coremltools.models.neural_network.quantization_utils import quantize_weight from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions from whisper import load_model +# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues. +# The Whisper implementation expects a specific behavior from +# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch +# versions. Setting use_sdpa=False forces Whisper to use its manual attention +# implementation instead, which is more stable across different PyTorch versions +# (2.5.0 required by coremltools vs newer versions). +import whisper.model +whisper.model.MultiHeadAttention.use_sdpa = False + # Use for changing dim of input in encoder and decoder embeddings def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):