coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] (#2979)

* coreml: fix Whisper to CoreML conversion by disabling SDPA

This commit disables the use of PyTorch's
`scaled_dot_product_attention` in the Whisper model to avoid
compatibility issues during CoreML conversion.
The issue occurs because coremltools requires PyTorch 2.5.0, but the
Whisper implementation may expect behavior from newer PyTorch versions.

By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to
use its fallback manual attention implementation, which works correctly
with PyTorch 2.5.0 during the tracing process.

Refs: https://github.com/ggerganov/whisper.cpp/issues/2783

* coreml: fix audio shape in whisper decoder conversion

This commit fixes the audio shape in the whisper decoder conversion
script.

The motivation for this is that the  audio shape was incorrect and
was causing the conversion to fail.

* coreml : set -e in generate-coreml-interface.sh

The commit sets the -e flag in the generate-coreml-interface.sh script
to make sure the script fails if any command fails.

* coreml : update generated encoder/decoder interfaces

This commit updates the generated encoder/decoder interfaces for the
whisper model which is the result of running the
generate-coreml-interface.sh script.
This commit is contained in:
Daniel Bevenius 2025-04-01 18:01:23 +02:00 committed by GitHub
parent 04b9508fb3
commit 11688b262f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 125 additions and 39 deletions

View File

@ -12,6 +12,15 @@ from coremltools.models.neural_network.quantization_utils import quantize_weight
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
from whisper import load_model
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
# The Whisper implementation expects a specific behavior from
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
# implementation instead, which is more stable across different PyTorch versions
# (2.5.0 required by coremltools vs newer versions).
import whisper.model
whisper.model.MultiHeadAttention.use_sdpa = False
# Use for changing dim of input in encoder and decoder embeddings
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
@ -260,10 +269,11 @@ def convert_decoder(hparams, model, quantize=False):
model.eval()
tokens_shape = (1, 1)
audio_shape = (1, hparams.n_audio_state, 1, 1500)
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
audio_data = torch.randn(audio_shape)
token_data = torch.randint(50257, tokens_shape).long()
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
traced_model = torch.jit.trace(model, (token_data, audio_data))
model = ct.convert(

View File

@ -5,6 +5,8 @@
# - src/coreml/whisper-decoder-impl.h and src/coreml/whisper-decoder-impl.m
#
set -e
wd=$(dirname "$0")
cd "$wd/../" || exit

View File

@ -11,36 +11,33 @@
NS_ASSUME_NONNULL_BEGIN
/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
/// token_data as 1 by 1 matrix of 32-bit integers
/// token_data as 1 by 1 matrix of floats
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
/// audio_data as 1 × 1500 × 384 3-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
@end
/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
/// var_1346 as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
/// cast_76 as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * cast_76;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 NS_DESIGNATED_INITIALIZER;
@end
/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
@ -94,7 +91,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@ -105,7 +102,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
/**
Make a prediction using the standard interface
@ -124,10 +121,25 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make an asynchronous prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
*/
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
/**
Make an asynchronous prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param options prediction options
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
*/
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
/**
Make a prediction using the convenience interface
@param token_data as 1 by 1 matrix of 32-bit integers:
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
@param token_data 1 by 1 matrix of floats
@param audio_data 1 × 1500 × 384 3-dimensional array of floats
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/

View File

@ -39,21 +39,21 @@
@implementation whisper_decoder_implOutput
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 {
self = [super init];
if (self) {
_var_1346 = var_1346;
_cast_76 = cast_76;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"var_1346"]];
return [NSSet setWithArray:@[@"cast_76"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"var_1346"]) {
return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
if ([featureName isEqualToString:@"cast_76"]) {
return [MLFeatureValue featureValueWithMultiArray:self.cast_76];
}
return nil;
}
@ -80,10 +80,13 @@
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
if (model == nil) {
return nil;
}
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
if (self != nil) {
_model = model;
}
return self;
}
@ -177,7 +180,29 @@
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
return [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[outFeatures featureValueForName:@"cast_76"].multiArrayValue];
}
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
[self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
if (prediction != nil) {
whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
completionHandler(output, predictionError);
} else {
completionHandler(nil, predictionError);
}
}];
}
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
[self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
if (prediction != nil) {
whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
completionHandler(output, predictionError);
} else {
completionHandler(nil, predictionError);
}
}];
}
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
@ -192,7 +217,7 @@
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[resultProvider featureValueForName:@"cast_76"].multiArrayValue];
[results addObject:result];
}
return results;

View File

@ -11,9 +11,8 @@
NS_ASSUME_NONNULL_BEGIN
/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
@ -23,9 +22,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@end
/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
/// output as multidimensional array of floats
@ -35,9 +33,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@end
/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
@ -91,7 +88,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@ -102,7 +99,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
/**
Make a prediction using the standard interface
@ -121,9 +118,24 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
/**
Make an asynchronous prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
*/
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
/**
Make an asynchronous prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param options prediction options
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
*/
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
/**
Make a prediction using the convenience interface
@param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
@param logmel_data 1 × 80 × 3000 3-dimensional array of floats
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/

View File

@ -76,10 +76,13 @@
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
if (model == nil) {
return nil;
}
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
if (self != nil) {
_model = model;
}
return self;
}
@ -176,6 +179,28 @@
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
}
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
[self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
if (prediction != nil) {
whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
completionHandler(output, predictionError);
} else {
completionHandler(nil, predictionError);
}
}];
}
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
[self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
if (prediction != nil) {
whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
completionHandler(output, predictionError);
} else {
completionHandler(nil, predictionError);
}
}];
}
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
return [self predictionFromFeatures:input_ error:error];