mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-19 20:57:52 +00:00
py : make convert-pt-to-ggml.py backwards compatible with older vocab.json tokenizer files (#1001)
* patch checkpoint convert script to keep compatibility with older hf_transformers whisper tokenizer * typo fix
This commit is contained in:
parent
a7f822ef59
commit
3ec7bfffe0
@ -224,16 +224,39 @@ with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f:
|
|||||||
|
|
||||||
#code.interact(local=locals())
|
#code.interact(local=locals())
|
||||||
|
|
||||||
|
# load tokenizer
|
||||||
|
# for backwards compatibility, also check for older hf_transformers format tokenizer files
|
||||||
|
# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
|
||||||
|
# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
|
||||||
multilingual = hparams["n_vocab"] == 51865
|
multilingual = hparams["n_vocab"] == 51865
|
||||||
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
|
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
|
||||||
|
tokenizer_type = "tiktoken"
|
||||||
|
if not tokenizer.is_file():
|
||||||
|
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json"
|
||||||
|
tokenizer_type = "hf_transformers"
|
||||||
|
if not tokenizer.is_file():
|
||||||
|
print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
byte_encoder = bytes_to_unicode()
|
||||||
|
byte_decoder = {v:k for k, v in byte_encoder.items()}
|
||||||
|
|
||||||
|
if tokenizer_type == "tiktoken":
|
||||||
|
with open(tokenizer, "rb") as f:
|
||||||
|
contents = f.read()
|
||||||
|
tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
|
||||||
|
elif tokenizer_type == "hf_transformers":
|
||||||
|
with open(tokenizer, "r", encoding="utf8") as f:
|
||||||
|
_tokens_raw = json.load(f)
|
||||||
|
if '<|endoftext|>' in _tokens_raw:
|
||||||
|
# ensures exact same model as tokenizer_type == tiktoken
|
||||||
|
# details: https://github.com/ggerganov/whisper.cpp/pull/725
|
||||||
|
del _tokens_raw['<|endoftext|>']
|
||||||
|
tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()}
|
||||||
|
|
||||||
# output in the same directory as the model
|
# output in the same directory as the model
|
||||||
fname_out = dir_out / "ggml-model.bin"
|
fname_out = dir_out / "ggml-model.bin"
|
||||||
|
|
||||||
with open(tokenizer, "rb") as f:
|
|
||||||
contents = f.read()
|
|
||||||
tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
|
|
||||||
|
|
||||||
# use 16-bit or 32-bit floats
|
# use 16-bit or 32-bit floats
|
||||||
use_f16 = True
|
use_f16 = True
|
||||||
if len(sys.argv) > 4:
|
if len(sys.argv) > 4:
|
||||||
@ -262,9 +285,7 @@ for i in range(filters.shape[0]):
|
|||||||
for j in range(filters.shape[1]):
|
for j in range(filters.shape[1]):
|
||||||
fout.write(struct.pack("f", filters[i][j]))
|
fout.write(struct.pack("f", filters[i][j]))
|
||||||
|
|
||||||
byte_encoder = bytes_to_unicode()
|
# write tokenizer
|
||||||
byte_decoder = {v:k for k, v in byte_encoder.items()}
|
|
||||||
|
|
||||||
fout.write(struct.pack("i", len(tokens)))
|
fout.write(struct.pack("i", len(tokens)))
|
||||||
|
|
||||||
for key in tokens:
|
for key in tokens:
|
||||||
|
Loading…
Reference in New Issue
Block a user