coreml : use the correct n_mel value (#1458)

This commit is contained in:
Xiao-Yong Jin
2023-11-08 14:01:41 -06:00
committed by GitHub
parent baeb733691
commit 0de8582f65
5 changed files with 12 additions and 6 deletions

View File

@ -252,7 +252,7 @@ class WhisperANE(Whisper):
def convert_encoder(hparams, model, quantize=False):
model.eval()
input_shape = (1, 80, 3000)
input_shape = (1, hparams.n_mels, 3000)
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)
@ -302,7 +302,7 @@ if __name__ == "__main__":
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()