mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-08 03:44:46 +00:00
coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci]
This commit disables the use of PyTorch's `scaled_dot_product_attention` in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions. By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process. Refs: https://github.com/ggerganov/whisper.cpp/issues/2783
This commit is contained in:
parent
04b9508fb3
commit
bb467e49fa
@ -12,6 +12,15 @@ from coremltools.models.neural_network.quantization_utils import quantize_weight
|
||||
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
|
||||
from whisper import load_model
|
||||
|
||||
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
|
||||
# The Whisper implementation expects a specific behavior from
|
||||
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
|
||||
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
|
||||
# implementation instead, which is more stable across different PyTorch versions
|
||||
# (2.5.0 required by coremltools vs newer versions).
|
||||
import whisper.model
|
||||
whisper.model.MultiHeadAttention.use_sdpa = False
|
||||
|
||||
# Use for changing dim of input in encoder and decoder embeddings
|
||||
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user