Compare commits

...

1 Commits

Author SHA1 Message Date
8cbc363561 coreml : attempt to fix ANE-optimized models 2023-07-11 23:03:53 +03:00
5 changed files with 23 additions and 24 deletions

View File

@ -31,10 +31,10 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden"))) API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider> @interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
/// var_1346 as multidimensional array of floats /// var_1195 as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346; @property (readwrite, nonatomic, strong) MLMultiArray * var_1195;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER; - (instancetype)initWithVar_1195:(MLMultiArray *)var_1195 NS_DESIGNATED_INITIALIZER;
@end @end

View File

@ -39,21 +39,21 @@
@implementation whisper_decoder_implOutput @implementation whisper_decoder_implOutput
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 { - (instancetype)initWithVar_1195:(MLMultiArray *)var_1195 {
self = [super init]; self = [super init];
if (self) { if (self) {
_var_1346 = var_1346; _var_1195 = var_1195;
} }
return self; return self;
} }
- (NSSet<NSString *> *)featureNames { - (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"var_1346"]]; return [NSSet setWithArray:@[@"var_1195"]];
} }
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName { - (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"var_1346"]) { if ([featureName isEqualToString:@"var_1195"]) {
return [MLFeatureValue featureValueWithMultiArray:self.var_1346]; return [MLFeatureValue featureValueWithMultiArray:self.var_1195];
} }
return nil; return nil;
} }
@ -177,7 +177,7 @@
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error { - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error]; id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; } if (!outFeatures) { return nil; }
return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue]; return [[whisper_decoder_implOutput alloc] initWithVar_1195:(MLMultiArray *)[outFeatures featureValueForName:@"var_1195"].multiArrayValue];
} }
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error { - (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
@ -192,7 +192,7 @@
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count]; NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) { for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i]; id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue]; whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1195:(MLMultiArray *)[resultProvider featureValueForName:@"var_1195"].multiArrayValue];
[results addObject:result]; [results addObject:result];
} }
return results; return results;

View File

@ -7,7 +7,6 @@ from torch import Tensor
from torch import nn from torch import nn
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
from coremltools.models.neural_network.quantization_utils import quantize_weights from coremltools.models.neural_network.quantization_utils import quantize_weights
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
from whisper import load_model from whisper import load_model
@ -32,12 +31,12 @@ def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight'] state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight']
return state_dict return state_dict
class LayerNormANE(LayerNormANEBase): class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
def __init__(self, *args, **kwargs): x = x.transpose(1,3)
super().__init__(*args, **kwargs) x = super().forward(x)
self._register_load_state_dict_pre_hook( x = x.transpose(1,3)
correct_for_bias_scale_order_inversion) return x
class MultiHeadAttentionANE(MultiHeadAttention): class MultiHeadAttentionANE(MultiHeadAttention):
def __init__(self, n_state: int, n_head: int): def __init__(self, n_state: int, n_head: int):
@ -104,9 +103,9 @@ class ResidualAttentionBlockANE(ResidualAttentionBlock):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__(n_state, n_head, cross_attention) super().__init__(n_state, n_head, cross_attention)
self.attn = MultiHeadAttentionANE(n_state, n_head) self.attn = MultiHeadAttentionANE(n_state, n_head)
self.attn_ln = LayerNormANE(n_state) self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4 n_mlp = n_state * 4
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
@ -114,7 +113,7 @@ class ResidualAttentionBlockANE(ResidualAttentionBlock):
nn.GELU(), nn.GELU(),
nn.Conv2d(n_mlp, n_state, kernel_size=1) nn.Conv2d(n_mlp, n_state, kernel_size=1)
) )
self.mlp_ln = LayerNormANE(n_state) self.mlp_ln = LayerNorm(n_state)
class AudioEncoderANE(AudioEncoder): class AudioEncoderANE(AudioEncoder):
@ -124,7 +123,7 @@ class AudioEncoderANE(AudioEncoder):
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)] [ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
) )
self.ln_post = LayerNormANE(n_state) self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor): def forward(self, x: Tensor):
""" """
@ -168,7 +167,7 @@ class TextDecoderANE(TextDecoder):
self.blocks= nn.ModuleList( self.blocks= nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)] [ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
) )
self.ln= LayerNormANE(n_state) self.ln= LayerNorm(n_state)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
""" """

View File

@ -8,7 +8,7 @@
wd=$(dirname "$0") wd=$(dirname "$0")
cd "$wd/../" cd "$wd/../"
python3 models/convert-whisper-to-coreml.py --model tiny.en python3 models/convert-whisper-to-coreml.py --model tiny.en --optimize-ane True
mv -v models/coreml-encoder-tiny.en.mlpackage models/whisper-encoder-impl.mlpackage mv -v models/coreml-encoder-tiny.en.mlpackage models/whisper-encoder-impl.mlpackage
xcrun coremlc generate models/whisper-encoder-impl.mlpackage coreml/ xcrun coremlc generate models/whisper-encoder-impl.mlpackage coreml/

View File

@ -13,7 +13,7 @@ mname="$1"
wd=$(dirname "$0") wd=$(dirname "$0")
cd "$wd/../" cd "$wd/../"
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True --optimize-ane True
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/ xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
rm -rf models/ggml-${mname}-encoder.mlmodelc rm -rf models/ggml-${mname}-encoder.mlmodelc