coreml: fix audio shape in whisper decoder conversion [no ci]

This commit fixes the audio shape in the whisper decoder conversion
script.

The motivation for this is that the  audio shape was incorrect and
was causing the conversion to fail.
This commit is contained in:
Daniel Bevenius 2025-04-01 12:29:27 +02:00
parent bb467e49fa
commit 858fce41e4

View File

@ -269,10 +269,11 @@ def convert_decoder(hparams, model, quantize=False):
model.eval()
tokens_shape = (1, 1)
audio_shape = (1, hparams.n_audio_state, 1, 1500)
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
audio_data = torch.randn(audio_shape)
token_data = torch.randint(50257, tokens_shape).long()
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
traced_model = torch.jit.trace(model, (token_data, audio_data))
model = ct.convert(