whisper : add support for large v3 (#1444)

* whisper : add support for large v3

* bench : fix build + fix go bindings

* bench : fix n_mels

* models : update readme
This commit is contained in:
Georgi Gerganov
2023-11-07 15:30:18 +02:00
committed by GitHub
parent 973111088b
commit 2cdfc4e025
20 changed files with 70 additions and 38 deletions

View File

@ -194,7 +194,7 @@ class TextDecoderANE(TextDecoder):
x = x.permute(0,2,3,1).squeeze(0)
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks
if self.token_embedding.weight.shape[0] == 51865:
if self.token_embedding.weight.shape[0] >= 51865:
# split in 11 chunks - 4715 each
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
@ -296,13 +296,13 @@ def convert_decoder(hparams, model, quantize=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()