mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-26 18:03:21 +00:00
Compare commits
10 Commits
v1.6.2
...
gg/ci-cuda
Author | SHA1 | Date | |
---|---|---|---|
059bcd3009 | |||
20c542c713 | |||
c2bdb960cd | |||
87acd6d629 | |||
f842d31171 | |||
ffef323c4c | |||
af5833e298 | |||
b87494bb8f | |||
ad130431aa | |||
e130b66642 |
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@ -459,7 +459,7 @@ jobs:
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-cublas:
|
||||
runs-on: windows-latest
|
||||
runs-on: windows-2019
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@ -498,7 +498,7 @@ jobs:
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_CUBLAS=${{ matrix.cublas }}
|
||||
-DWHISPER_CUDA=${{ matrix.cublas }}
|
||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build ${{ matrix.cuda-toolkit }}
|
||||
|
@ -364,12 +364,12 @@ if (WHISPER_CUDA)
|
||||
if (WHISPER_STATIC)
|
||||
if (WIN32)
|
||||
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft)
|
||||
else ()
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static)
|
||||
endif()
|
||||
else()
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft)
|
||||
endif()
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
|
||||
@ -679,6 +679,10 @@ add_library(${TARGET}
|
||||
whisper.cpp
|
||||
)
|
||||
|
||||
if (WHISPER_CUDA)
|
||||
target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu)
|
||||
endif()
|
||||
|
||||
include_directories (
|
||||
.
|
||||
)
|
||||
|
9
Makefile
9
Makefile
@ -286,8 +286,8 @@ ifdef WHISPER_CUDA
|
||||
|
||||
CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
|
||||
WHISPER_OBJ += ggml-cuda.o
|
||||
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
|
||||
WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
|
||||
WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
|
||||
@ -299,6 +299,9 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h
|
||||
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
||||
endif
|
||||
|
||||
whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp
|
||||
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
||||
|
||||
ifdef WHISPER_HIPBLAS
|
||||
ROCM_PATH ?= /opt/rocm
|
||||
HIPCC ?= $(ROCM_PATH)/bin/hipcc
|
||||
@ -404,7 +407,7 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
|
||||
|
||||
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
|
||||
|
||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||
whisper.o: whisper.cpp whisper.h whisper-mel.hpp ggml.h ggml-cuda.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
ifndef WHISPER_COREML
|
||||
|
13
README.md
13
README.md
@ -4,6 +4,7 @@
|
||||
|
||||
[](https://github.com/ggerganov/whisper.cpp/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://conan.io/center/whisper-cpp)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.6.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
@ -502,6 +503,16 @@ docker run -it --rm \
|
||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
|
||||
```
|
||||
|
||||
## Installing with Conan
|
||||
|
||||
You can install pre-built binaries for whisper.cpp or build it from source using [Conan](https://conan.io/). Use the following command:
|
||||
|
||||
```
|
||||
conan install --requires="whisper-cpp/[*]" --build=missing
|
||||
```
|
||||
|
||||
For detailed instructions on how to use Conan, please refer to the [Conan documentation](https://docs.conan.io/2/).
|
||||
|
||||
## Limitations
|
||||
|
||||
- Inference only
|
||||
@ -710,7 +721,7 @@ The [main](examples/main) example provides support for output of karaoke-style m
|
||||
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
|
||||
This requires to have `ffmpeg` installed.
|
||||
|
||||
Here are a few *"typical"* examples:
|
||||
Here are a few _"typical"_ examples:
|
||||
|
||||
```bash
|
||||
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
|
||||
|
@ -68,10 +68,6 @@ func (flags *Flags) GetOut() string {
|
||||
return strings.ToLower(flags.Lookup("out").Value.String())
|
||||
}
|
||||
|
||||
func (flags *Flags) IsSpeedup() bool {
|
||||
return flags.Lookup("speedup").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTokens() bool {
|
||||
return flags.Lookup("tokens").Value.String() == "true"
|
||||
}
|
||||
@ -111,10 +107,6 @@ func (flags *Flags) SetParams(context whisper.Context) error {
|
||||
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
|
||||
context.SetDuration(duration)
|
||||
}
|
||||
if flags.IsSpeedup() {
|
||||
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
|
||||
context.SetSpeedup(true)
|
||||
}
|
||||
if threads := flags.GetThreads(); threads != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
|
||||
context.SetThreads(threads)
|
||||
@ -146,7 +138,6 @@ func registerFlags(flag *Flags) {
|
||||
flag.Duration("offset", 0, "Time offset")
|
||||
flag.Duration("duration", 0, "Duration of audio to process")
|
||||
flag.Uint("threads", 0, "Number of threads to use")
|
||||
flag.Bool("speedup", false, "Enable speedup")
|
||||
flag.Uint("max-len", 0, "Maximum segment length in characters")
|
||||
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
|
||||
flag.Float64("word-thold", 0, "Maximum segment score")
|
||||
|
@ -47,10 +47,6 @@ func (p *Params) SetPrintTimestamps(v bool) {
|
||||
p.print_timestamps = toBool(v)
|
||||
}
|
||||
|
||||
func (p *Params) SetSpeedup(v bool) {
|
||||
p.speed_up = toBool(v)
|
||||
}
|
||||
|
||||
// Set language id
|
||||
func (p *Params) SetLanguage(lang int) error {
|
||||
if lang == -1 {
|
||||
@ -177,9 +173,6 @@ func (p *Params) String() string {
|
||||
if p.token_timestamps {
|
||||
str += " token_timestamps"
|
||||
}
|
||||
if p.speed_up {
|
||||
str += " speed_up"
|
||||
}
|
||||
|
||||
return str + ">"
|
||||
}
|
||||
|
@ -76,11 +76,6 @@ func (context *context) SetTranslate(v bool) {
|
||||
context.params.SetTranslate(v)
|
||||
}
|
||||
|
||||
// Set speedup flag
|
||||
func (context *context) SetSpeedup(v bool) {
|
||||
context.params.SetSpeedup(v)
|
||||
}
|
||||
|
||||
func (context *context) SetSplitOnWord(v bool) {
|
||||
context.params.SetSplitOnWord(v)
|
||||
}
|
||||
|
@ -41,7 +41,6 @@ type Context interface {
|
||||
SetOffset(time.Duration) // Set offset
|
||||
SetDuration(time.Duration) // Set duration
|
||||
SetThreads(uint) // Set number of threads to use
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
SetSplitOnWord(bool) // Set split on word flag
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
|
@ -20,7 +20,7 @@ public interface WhisperCppJnaLibrary extends Library {
|
||||
* @return Whisper context on success, null on failure
|
||||
*/
|
||||
Pointer whisper_init_from_file(String path_model);
|
||||
|
||||
|
||||
/**
|
||||
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
|
||||
* Because this function allocates memory for the params, the caller must call either:
|
||||
@ -304,14 +304,6 @@ public interface WhisperCppJnaLibrary extends Library {
|
||||
/** Language id associated with the provided state */
|
||||
int whisper_full_lang_id_from_state(Pointer state);
|
||||
|
||||
/**
|
||||
* Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
|
||||
* The resulting spectrogram is stored inside the default state of the provided whisper context.
|
||||
* @return 0 on success
|
||||
*/
|
||||
int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads);
|
||||
|
||||
int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
|
||||
|
||||
/** Get the start time of the specified segment. */
|
||||
long whisper_full_get_segment_t0(Pointer ctx, int i_segment);
|
||||
|
@ -129,14 +129,6 @@ public class WhisperFullParams extends Structure {
|
||||
/** Maximum tokens per segment (0, default = no limit) */
|
||||
public int max_tokens;
|
||||
|
||||
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
|
||||
public CBool speed_up;
|
||||
|
||||
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
|
||||
public void speedUp(boolean enable) {
|
||||
speed_up = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Overwrite the audio context size (0 = use default). */
|
||||
public int audio_ctx;
|
||||
|
||||
@ -321,7 +313,7 @@ public class WhisperFullParams extends Structure {
|
||||
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
||||
"no_context", "single_segment", "no_timestamps",
|
||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
|
||||
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
|
||||
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
||||
|
@ -1,6 +1,7 @@
|
||||
require 'mkmf'
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")
|
||||
|
@ -311,12 +311,6 @@ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
|
||||
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, split_on_word, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
|
||||
BOOL_PARAMS_GETTER(self, speed_up)
|
||||
}
|
||||
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
|
||||
BOOL_PARAMS_SETTER(self, speed_up, value)
|
||||
}
|
||||
static VALUE ruby_whisper_params_get_diarize(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
@ -408,8 +402,6 @@ void Init_whisper() {
|
||||
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
|
||||
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
|
||||
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
|
||||
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
|
||||
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
|
||||
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
|
||||
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);
|
||||
|
||||
|
@ -117,13 +117,6 @@ class TestWhisper < Test::Unit::TestCase
|
||||
assert !@params.split_on_word
|
||||
end
|
||||
|
||||
def test_speed_up
|
||||
@params.speed_up = true
|
||||
assert @params.speed_up
|
||||
@params.speed_up = false
|
||||
assert !@params.speed_up
|
||||
end
|
||||
|
||||
def test_whisper
|
||||
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
|
||||
params = Whisper::Params.new
|
||||
|
@ -25,7 +25,6 @@ struct whisper_params {
|
||||
float entropy_thold = 2.4f;
|
||||
float logprob_thold = -1.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool output_txt = false;
|
||||
@ -232,8 +231,6 @@ int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
|
@ -38,7 +38,6 @@ struct whisper_params {
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
@ -76,7 +75,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
@ -115,7 +113,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
@ -165,7 +162,6 @@ std::string transcribe(
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.temperature = 0.4f;
|
||||
wparams.temperature_inc = 1.0f;
|
||||
@ -371,7 +367,6 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = k_tokens.data();
|
||||
wparams.prompt_n_tokens = k_tokens.size();
|
||||
|
@ -185,7 +185,7 @@ private:
|
||||
// It is assumed that PCM data is normalized to a range from -1 to 1
|
||||
bool write_audio(const float * data, size_t length) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
const int16_t intSample = data[i] * 32767;
|
||||
const int16_t intSample = int16_t(data[i] * 32767);
|
||||
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
|
||||
dataSize += sizeof(int16_t);
|
||||
}
|
||||
|
@ -26,7 +26,6 @@ struct whisper_params {
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
@ -70,7 +69,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
@ -102,7 +100,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
@ -184,7 +181,6 @@ json unguided_transcription(struct whisper_context * ctx, audio_async &audio, js
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
wparams.suppress_non_speech_tokens = true;
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
@ -223,7 +219,6 @@ json guided_transcription(struct whisper_context * ctx, audio_async &audio, cons
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
// TODO: Do some time testing. Does an overly long prompt slow down processing?
|
||||
// Set up command sets/precompute prompts
|
||||
|
@ -47,7 +47,6 @@ struct whisper_params {
|
||||
float temperature = 0.0f;
|
||||
float temperature_inc = 0.2f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool debug_mode = false;
|
||||
bool translate = false;
|
||||
bool detect_language = false;
|
||||
@ -138,7 +137,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); }
|
||||
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
|
||||
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
@ -206,7 +204,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
|
||||
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
|
||||
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
@ -1106,7 +1103,6 @@ int main(int argc, char ** argv) {
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
wparams.debug_mode = params.debug_mode;
|
||||
|
||||
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
||||
|
@ -61,7 +61,6 @@ struct whisper_params {
|
||||
float temperature = 0.00f;
|
||||
float temperature_inc = 0.20f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool debug_mode = false;
|
||||
bool translate = false;
|
||||
bool detect_language = false;
|
||||
@ -112,7 +111,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
@ -159,7 +157,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
@ -768,7 +765,6 @@ int main(int argc, char ** argv) {
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
wparams.debug_mode = params.debug_mode;
|
||||
|
||||
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
||||
|
@ -27,7 +27,6 @@ struct whisper_params {
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool no_fallback = false;
|
||||
bool print_special = false;
|
||||
@ -62,7 +61,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
@ -100,7 +98,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
@ -314,7 +311,6 @@ int main(int argc, char ** argv) {
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
||||
|
||||
|
@ -59,7 +59,6 @@ struct whisper_params {
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
@ -100,7 +99,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
@ -149,7 +147,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
@ -205,7 +202,6 @@ std::string transcribe(
|
||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
|
@ -26,7 +26,6 @@ struct whisper_params {
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
@ -60,7 +59,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
@ -96,7 +94,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
@ -132,7 +129,6 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
|
@ -26,7 +26,6 @@ struct whisper_params {
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
@ -57,7 +56,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
@ -89,7 +87,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
||||
|
363
whisper-mel-cuda.cu
Normal file
363
whisper-mel-cuda.cu
Normal file
@ -0,0 +1,363 @@
|
||||
#define CUB_IGNORE_DEPRECATED_CPP_DIALECT
|
||||
#include "whisper-mel-cuda.hpp"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cufft.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuComplex.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <device_launch_parameters.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4324) // added padding
|
||||
#endif
|
||||
|
||||
#ifndef NDEBUG
|
||||
# define DO_CHECKS 1
|
||||
#else
|
||||
# define DO_CHECKS 0
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
#if DO_CHECKS
|
||||
const char* cufftGetErrorString(cufftResult_t res) {
|
||||
switch (res) {
|
||||
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
|
||||
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
|
||||
case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory";
|
||||
case CUFFT_INVALID_TYPE: return "No longer used";
|
||||
case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter";
|
||||
case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error";
|
||||
case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU";
|
||||
case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize";
|
||||
case CUFFT_INVALID_SIZE: return "User specified an invalid transform size";
|
||||
case CUFFT_UNALIGNED_DATA: return "No longer used";
|
||||
case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call";
|
||||
case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation";
|
||||
case CUFFT_PARSE_ERROR: return "Internal plan database error";
|
||||
case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution";
|
||||
case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given.";
|
||||
case CUFFT_LICENSE_ERROR: return "Used in previous versions.";
|
||||
case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given.";
|
||||
default: return "Unknown error";
|
||||
}
|
||||
}
|
||||
|
||||
# define CUDA_CHECK_GEN(err, success, error_fn) \
|
||||
do { \
|
||||
auto err_ = (err); \
|
||||
if (err_ != (success)) { \
|
||||
fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
# define CUDA_CHECK_GEN(err, success, error_fn) err
|
||||
#endif
|
||||
|
||||
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
||||
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
|
||||
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
|
||||
|
||||
__global__ void k_fill_stft_input(
|
||||
const float * padded_samples,
|
||||
const int n_frames,
|
||||
const float * hann_window,
|
||||
float * stft_in
|
||||
) {
|
||||
auto y = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
// if (y >= n_frames) return;
|
||||
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// if (x >= WHISPER_N_FFT) return;
|
||||
|
||||
auto line = padded_samples + y * WHISPER_HOP_LENGTH;
|
||||
auto outLine = stft_in + y * WHISPER_N_FFT;
|
||||
|
||||
outLine[x] = line[x] * hann_window[x];
|
||||
}
|
||||
|
||||
__global__ void k_calc_magnitudes(
|
||||
const cuComplex* stft_out,
|
||||
const int n_frames,
|
||||
float * magnitudes
|
||||
) {
|
||||
auto y = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
// if (y >= n_frames) return;
|
||||
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// if (x >= WHISPER_N_FFT_HALF) return;
|
||||
|
||||
auto idx = y * WHISPER_N_FFT_HALF + x;
|
||||
|
||||
auto r = stft_out[idx].x;
|
||||
auto i = stft_out[idx].y;
|
||||
magnitudes[idx] = r * r + i * i;
|
||||
}
|
||||
|
||||
__global__ void k_calc_log_mel(
|
||||
const float * mel_data,
|
||||
const int n_mel,
|
||||
const float * max_val,
|
||||
float * log_mel
|
||||
) {
|
||||
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (x >= n_mel) return;
|
||||
|
||||
float val = mel_data[x];
|
||||
|
||||
constexpr float e = 1e-10f;
|
||||
if (val < e) val = e;
|
||||
|
||||
val = log10(val);
|
||||
|
||||
const float max = log10(*max_val) - 8.f;
|
||||
if (val < max) val = max;
|
||||
|
||||
log_mel[x] = (val + 4) / 4;
|
||||
}
|
||||
|
||||
void fill_stft_input(
|
||||
const float * padded_samples,
|
||||
int n_frames,
|
||||
const float * hann_window,
|
||||
float * stft_in,
|
||||
cudaStream_t stream
|
||||
) {
|
||||
dim3 block(WHISPER_N_FFT, 1);
|
||||
dim3 grid(1, n_frames);
|
||||
|
||||
k_fill_stft_input<<<grid, block, 0, stream>>>(padded_samples, n_frames, hann_window, stft_in);
|
||||
}
|
||||
|
||||
void calc_magnitudes(
|
||||
const cuComplex* stft_out,
|
||||
int n_frames,
|
||||
float * magnitudes,
|
||||
cudaStream_t stream
|
||||
) {
|
||||
dim3 block(WHISPER_N_FFT_HALF, 1);
|
||||
dim3 grid(1, n_frames);
|
||||
k_calc_magnitudes<<<grid, block, 0, stream>>>(stft_out, n_frames, magnitudes);
|
||||
}
|
||||
|
||||
constexpr auto LOG_MEL_PREFIX_SIZE = 256;
|
||||
|
||||
void calc_log_mel(
|
||||
const float * mel_data,
|
||||
int n_mel,
|
||||
void * tempStorage,
|
||||
int tempStorageSize,
|
||||
float * log_mel,
|
||||
cudaStream_t stream
|
||||
) {
|
||||
float * max_val = reinterpret_cast<float *>(tempStorage);
|
||||
void * maxTemp = reinterpret_cast<char*>(tempStorage) + LOG_MEL_PREFIX_SIZE;
|
||||
|
||||
size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE);
|
||||
cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream);
|
||||
|
||||
int block = 256;
|
||||
int grid = (n_mel + block - 1) / block;
|
||||
|
||||
k_calc_log_mel<<<grid, block, 0, stream>>>(mel_data, n_mel, max_val, log_mel);
|
||||
}
|
||||
|
||||
class mel_calc_cuda : public whisper_mel_calc {
|
||||
const int m_n_mel;
|
||||
|
||||
ggml_backend_t m_backend = nullptr;
|
||||
|
||||
cudaStream_t m_stream = nullptr;
|
||||
cublasHandle_t m_cublas_handle = nullptr;
|
||||
|
||||
float * m_hann_window = nullptr;
|
||||
|
||||
float * m_filters = nullptr;
|
||||
|
||||
// max samples for which we have allocated memory for the temp working areas below (cufft, log_mel)
|
||||
int m_n_max_samples = 0;
|
||||
|
||||
size_t m_cufft_workspace_size = 0;
|
||||
void * m_cufft_workspace = nullptr;
|
||||
|
||||
size_t m_log_mel_temp_storage_size = 0;
|
||||
void * m_log_mel_temp_storage = nullptr;
|
||||
public:
|
||||
mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters)
|
||||
: m_n_mel(filters.n_mel)
|
||||
, m_backend(backend)
|
||||
{
|
||||
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
||||
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
|
||||
}
|
||||
assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF);
|
||||
|
||||
CUDA_CHECK(cudaStreamCreate(&m_stream));
|
||||
CUBLAS_CHECK(cublasCreate(&m_cublas_handle));
|
||||
CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
|
||||
CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream));
|
||||
|
||||
// create Hann window
|
||||
{
|
||||
auto hw = whisper_mel_calc::hann_window();
|
||||
CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
||||
}
|
||||
|
||||
// fill filters
|
||||
{
|
||||
auto& f = filters.data;
|
||||
CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
||||
}
|
||||
|
||||
// preallocate working areas enough for the most common cases (<= 30s)
|
||||
ensure_working_areas(WHISPER_N_SAMPLES);
|
||||
}
|
||||
|
||||
~mel_calc_cuda() {
|
||||
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
||||
CUDA_CHECK(cudaStreamDestroy(m_stream));
|
||||
CUDA_CHECK(cudaFree(m_hann_window));
|
||||
CUDA_CHECK(cudaFree(m_cufft_workspace));
|
||||
CUDA_CHECK(cudaFree(m_filters));
|
||||
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
|
||||
}
|
||||
|
||||
void ensure_working_areas(int n_samples) {
|
||||
if (n_samples <= m_n_max_samples) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto max_padded_samples = n_samples + WHISPER_N_SAMPLES + WHISPER_N_FFT;
|
||||
const auto max_frames = 1 + (max_padded_samples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
||||
|
||||
// cufft workspace
|
||||
{
|
||||
if (m_cufft_workspace) {
|
||||
CUDA_CHECK(cudaFree(m_cufft_workspace));
|
||||
m_cufft_workspace_size = 0;
|
||||
m_cufft_workspace = nullptr;
|
||||
}
|
||||
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, max_frames, &m_cufft_workspace_size));
|
||||
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
|
||||
}
|
||||
|
||||
// device reduce working area
|
||||
{
|
||||
if (m_log_mel_temp_storage) {
|
||||
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
|
||||
m_log_mel_temp_storage_size = 0;
|
||||
m_log_mel_temp_storage = nullptr;
|
||||
}
|
||||
|
||||
const auto max_mels = 160;
|
||||
|
||||
size_t nbytes = 0;
|
||||
float* temp = nullptr;
|
||||
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, max_frames * max_mels);
|
||||
m_log_mel_temp_storage_size = nbytes + LOG_MEL_PREFIX_SIZE;
|
||||
|
||||
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
|
||||
}
|
||||
|
||||
m_n_max_samples = n_samples;
|
||||
}
|
||||
|
||||
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
|
||||
ensure_working_areas(samples.len);
|
||||
|
||||
const size_t mirror_pad = WHISPER_N_FFT / 2;
|
||||
const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT;
|
||||
|
||||
// pad
|
||||
std::vector<float> padded_samples(padded_size);
|
||||
std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect
|
||||
std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy
|
||||
|
||||
// fill the rest of the data
|
||||
// it should canonically be mirrored at the end as well,
|
||||
// but we just assume the last MEL_FRAME_SIZE/2 samples are zeros
|
||||
std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f);
|
||||
|
||||
const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
||||
|
||||
float * cu_padded_samples = nullptr;
|
||||
CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
||||
|
||||
float * stft_in = nullptr; // contiguous buffer for stft input
|
||||
CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream));
|
||||
|
||||
fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream);
|
||||
|
||||
cufftComplex* stft_out;
|
||||
CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream));
|
||||
|
||||
cufftHandle plan;
|
||||
CUFFT_CHECK(cufftCreate(&plan));
|
||||
CUFFT_CHECK(cufftSetAutoAllocation(plan, 0));
|
||||
{
|
||||
size_t waSize;
|
||||
CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize));
|
||||
assert(waSize <= m_cufft_workspace_size);
|
||||
CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace));
|
||||
CUFFT_CHECK(cufftSetStream(plan, m_stream));
|
||||
}
|
||||
CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out));
|
||||
|
||||
const auto n_mag_frames = n_frames - 1; // drop last frame
|
||||
float * magnitudes;
|
||||
CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream));
|
||||
calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream);
|
||||
|
||||
float * mel_data = nullptr;
|
||||
CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream));
|
||||
|
||||
const float fone = 1.0f, fzero = 0.0f;
|
||||
CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF,
|
||||
&fone,
|
||||
magnitudes, WHISPER_N_FFT_HALF,
|
||||
m_filters, WHISPER_N_FFT_HALF,
|
||||
&fzero,
|
||||
mel_data, int(n_mag_frames)));
|
||||
|
||||
whisper_mel ret;
|
||||
// Calculate semi-padded sample length to ensure compatibility
|
||||
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
||||
whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel);
|
||||
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
|
||||
|
||||
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
|
||||
|
||||
calc_log_mel(
|
||||
mel_data, int(m_n_mel * n_mag_frames),
|
||||
m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
|
||||
log_mels, m_stream);
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
||||
|
||||
// cleanup
|
||||
CUFFT_CHECK(cufftDestroy(plan));
|
||||
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
|
||||
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
|
||||
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
|
||||
CUDA_CHECK(cudaFreeAsync(stft_in, m_stream));
|
||||
CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream));
|
||||
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
|
||||
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
||||
return nullptr;
|
||||
}
|
||||
return new mel_calc_cuda(backend, filters);
|
||||
}
|
3
whisper-mel-cuda.hpp
Normal file
3
whisper-mel-cuda.hpp
Normal file
@ -0,0 +1,3 @@
|
||||
#include "whisper-mel.hpp"
|
||||
|
||||
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters);
|
34
whisper-mel.hpp
Normal file
34
whisper-mel.hpp
Normal file
@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
#include "ggml-backend.h"
|
||||
#include <vector>
|
||||
|
||||
struct whisper_mel {
|
||||
int n_len_org = 0;
|
||||
|
||||
ggml_context * ctx = nullptr;
|
||||
ggml_tensor * tensor = nullptr;
|
||||
ggml_backend_buffer_t buffer = nullptr;
|
||||
};
|
||||
|
||||
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
|
||||
|
||||
void whisper_mel_free(whisper_mel & mel);
|
||||
|
||||
struct whisper_filters {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct whisper_span {
|
||||
T * data;
|
||||
int len;
|
||||
};
|
||||
|
||||
struct whisper_mel_calc {
|
||||
virtual ~whisper_mel_calc();
|
||||
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) = 0;
|
||||
static whisper_span<const float> hann_window();
|
||||
};
|
506
whisper.cpp
506
whisper.cpp
@ -10,6 +10,7 @@
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
#include "ggml-cuda.h"
|
||||
#include "whisper-mel-cuda.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_SYCL
|
||||
@ -24,6 +25,8 @@
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#include "whisper-mel.hpp"
|
||||
|
||||
#include <atomic>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@ -380,21 +383,6 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
||||
|
||||
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
||||
|
||||
struct whisper_mel {
|
||||
int n_len;
|
||||
int n_len_org;
|
||||
int n_mel;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
struct whisper_filters {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
struct whisper_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
@ -813,6 +801,8 @@ struct whisper_state {
|
||||
whisper_kv_cache kv_pad;
|
||||
|
||||
whisper_mel mel;
|
||||
whisper_mel_calc * mel_calc = nullptr;
|
||||
whisper_mel_calc * mel_calc_fallback = nullptr;
|
||||
|
||||
whisper_batch batch;
|
||||
|
||||
@ -833,7 +823,6 @@ struct whisper_state {
|
||||
struct ggml_tensor * embd_enc = nullptr;
|
||||
|
||||
// helpers for GPU offloading
|
||||
std::vector<float> inp_mel;
|
||||
std::vector<float> inp_mask;
|
||||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
@ -904,7 +893,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
||||
BYTESWAP_VALUE(dest);
|
||||
}
|
||||
|
||||
static bool kv_cache_init(
|
||||
static bool whisper_kv_cache_init(
|
||||
struct whisper_kv_cache & cache,
|
||||
ggml_backend_t backend,
|
||||
ggml_type wtype,
|
||||
@ -947,7 +936,7 @@ static bool kv_cache_init(
|
||||
return true;
|
||||
}
|
||||
|
||||
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
||||
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
|
||||
ggml_free(cache.ctx);
|
||||
ggml_backend_buffer_free(cache.buffer);
|
||||
cache.ctx = nullptr;
|
||||
@ -1261,9 +1250,12 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
|
||||
}
|
||||
#endif
|
||||
|
||||
GGML_UNUSED(params);
|
||||
|
||||
if (backend_gpu) {
|
||||
return backend_gpu;
|
||||
}
|
||||
|
||||
return ggml_backend_cpu_init();
|
||||
}
|
||||
|
||||
@ -1825,7 +1817,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
|
||||
|
||||
static struct ggml_cgraph * whisper_build_graph_conv(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate) {
|
||||
whisper_state & wstate,
|
||||
const int mel_offset) {
|
||||
const auto & model = wctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
@ -1844,9 +1837,32 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
||||
ggml_set_name(mel, "mel");
|
||||
ggml_set_input(mel);
|
||||
ggml_tensor * mel_inp = wstate.mel.tensor;
|
||||
ggml_tensor * mel;
|
||||
if (mel_inp) {
|
||||
const int n_len = int(mel_inp->ne[0]);
|
||||
const int out_s = 2 * n_ctx;
|
||||
const int i0 = std::min(mel_offset, n_len);
|
||||
const int i1 = std::min(mel_offset + out_s, n_len);
|
||||
const int mel_s = i1 - i0;
|
||||
|
||||
assert(mel_inp->type == GGML_TYPE_F32);
|
||||
assert(mel_inp->ne[1] == n_mels);
|
||||
|
||||
ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0));
|
||||
|
||||
if (mel_s < out_s) {
|
||||
mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
|
||||
}
|
||||
else {
|
||||
mel = ggml_cont(ctx0, cur);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// just create some tensor so that the graph/buffer size estimation is correct
|
||||
mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
|
||||
}
|
||||
ggml_set_name(mel, "mel"); // used with external encoding
|
||||
|
||||
struct ggml_tensor * cur = nullptr;
|
||||
|
||||
@ -2228,45 +2244,21 @@ static bool whisper_encode_internal(
|
||||
{
|
||||
auto & alloc = wstate.alloc_conv.alloc;
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
|
||||
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
|
||||
|
||||
if (!ggml_gallocr_alloc_graph(alloc, gf)) {
|
||||
// should never happen as we pre-allocate the memory
|
||||
return false;
|
||||
}
|
||||
|
||||
struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
|
||||
|
||||
// set the input
|
||||
{
|
||||
const auto & mel_inp = wstate.mel;
|
||||
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
|
||||
|
||||
assert(mel->type == GGML_TYPE_F32);
|
||||
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
|
||||
|
||||
wstate.inp_mel.resize(ggml_nelements(mel));
|
||||
|
||||
float * dst = wstate.inp_mel.data();
|
||||
memset(dst, 0, ggml_nbytes(mel));
|
||||
|
||||
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
||||
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
||||
|
||||
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
||||
for (int i = i0; i < i1; ++i) {
|
||||
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
|
||||
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!whisper_encode_external(wstate)) {
|
||||
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (whisper_encode_external(wstate)) {
|
||||
ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
|
||||
assert(mel->ne[1] == wctx.model.hparams.n_mels);
|
||||
GGML_UNUSED(mel);
|
||||
#if defined(WHISPER_USE_COREML)
|
||||
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
|
||||
#elif defined(WHISPER_USE_OPENVINO)
|
||||
@ -2857,20 +2849,70 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
||||
}
|
||||
|
||||
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
||||
static float sin_vals[SIN_COS_N_COUNT];
|
||||
static float cos_vals[SIN_COS_N_COUNT];
|
||||
namespace {
|
||||
struct whisper_global_cache {
|
||||
// In FFT, we frequently use sine and cosine operations with the same values.
|
||||
// We can use precalculated values to speed up the process.
|
||||
float sin_vals[SIN_COS_N_COUNT];
|
||||
float cos_vals[SIN_COS_N_COUNT];
|
||||
|
||||
// In FFT, we frequently use sine and cosine operations with the same values.
|
||||
// We can use precalculated values to speed up the process.
|
||||
static void fill_sin_cos_table() {
|
||||
static bool is_filled = false;
|
||||
if (is_filled) return;
|
||||
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
||||
double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
|
||||
sin_vals[i] = sinf(theta);
|
||||
cos_vals[i] = cosf(theta);
|
||||
// Hann window (Use cosf to eliminate difference)
|
||||
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
||||
float hann_window[WHISPER_N_FFT];
|
||||
|
||||
whisper_global_cache() {
|
||||
fill_sin_cos_table();
|
||||
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
||||
}
|
||||
is_filled = true;
|
||||
|
||||
void fill_sin_cos_table() {
|
||||
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
||||
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
||||
sin_vals[i] = sinf(theta);
|
||||
cos_vals[i] = cosf(theta);
|
||||
}
|
||||
}
|
||||
|
||||
void fill_hann_window(int length, bool periodic, float * output) {
|
||||
int offset = -1;
|
||||
if (periodic) {
|
||||
offset = 0;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||
}
|
||||
}
|
||||
} global_cache;
|
||||
}
|
||||
|
||||
// Mel spectrogram
|
||||
|
||||
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
|
||||
WHISPER_LOG_INFO("%s: n_len = %d, n_len_org = %d, n_mel = %d\n", __func__, n_len, n_len_org, n_mel);
|
||||
mel.n_len_org = n_len_org;
|
||||
assert(!mel.ctx);
|
||||
mel.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
|
||||
mel.tensor = ggml_new_tensor_2d(mel.ctx, GGML_TYPE_F32, n_len, n_mel);
|
||||
mel.buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(mel.tensor) + ggml_backend_get_alignment(backend));
|
||||
auto alloc = ggml_tallocr_new(mel.buffer);
|
||||
ggml_tallocr_alloc(&alloc, mel.tensor);
|
||||
}
|
||||
|
||||
void whisper_mel_free(whisper_mel & mel) {
|
||||
ggml_free(mel.ctx);
|
||||
ggml_backend_buffer_free(mel.buffer);
|
||||
|
||||
mel.n_len_org = 0;
|
||||
mel.ctx = nullptr;
|
||||
mel.tensor = nullptr;
|
||||
mel.buffer = nullptr;
|
||||
}
|
||||
|
||||
whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
|
||||
|
||||
whisper_span<const float> whisper_mel_calc::hann_window() {
|
||||
return {global_cache.hann_window, WHISPER_N_FFT};
|
||||
}
|
||||
|
||||
// naive Discrete Fourier Transform
|
||||
@ -2888,8 +2930,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
||||
re += in[n]*cos_vals[idx]; // cos(t)
|
||||
im -= in[n]*sin_vals[idx]; // sin(t)
|
||||
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
||||
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
||||
}
|
||||
|
||||
out[k*2 + 0] = re;
|
||||
@ -2940,8 +2982,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
||||
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
||||
for (int k = 0; k < N/2; k++) {
|
||||
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
||||
float re = cos_vals[idx]; // cos(t)
|
||||
float im = -sin_vals[idx]; // sin(t)
|
||||
float re = global_cache.cos_vals[idx]; // cos(t)
|
||||
float im = -global_cache.sin_vals[idx]; // sin(t)
|
||||
|
||||
float re_odd = odd_fft[2*k + 0];
|
||||
float im_odd = odd_fft[2*k + 1];
|
||||
@ -2954,24 +2996,20 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
||||
}
|
||||
}
|
||||
|
||||
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
||||
if (output.size() < static_cast<size_t>(length)) {
|
||||
output.resize(length);
|
||||
}
|
||||
int offset = -1;
|
||||
if (periodic) {
|
||||
offset = 0;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
|
||||
}
|
||||
namespace {
|
||||
|
||||
return true;
|
||||
}
|
||||
struct whisper_mel_data {
|
||||
int n_len;
|
||||
int n_len_org;
|
||||
int n_mel;
|
||||
float * data;
|
||||
};
|
||||
|
||||
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
|
||||
int n_samples, int frame_size, int frame_step, int n_threads,
|
||||
const whisper_filters & filters, whisper_mel & mel) {
|
||||
void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
||||
int n_samples, int n_threads,
|
||||
const whisper_filters & filters, whisper_mel_data & mel) {
|
||||
const auto frame_size = WHISPER_N_FFT;
|
||||
const auto frame_step = WHISPER_HOP_LENGTH;
|
||||
std::vector<float> fft_in(frame_size, 0.0);
|
||||
std::vector<float> fft_out(2 * frame_size);
|
||||
int n_fft = filters.n_fft;
|
||||
@ -2984,7 +3022,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
||||
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
||||
const int offset = i * frame_step;
|
||||
|
||||
// apply Hanning window (~10% faster)
|
||||
// apply Hann window (~10% faster)
|
||||
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
||||
fft_in[j] = hann[j] * samples[offset + j];
|
||||
}
|
||||
@ -3036,101 +3074,109 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
||||
}
|
||||
}
|
||||
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
||||
static bool log_mel_spectrogram(
|
||||
whisper_state & wstate,
|
||||
const float * samples,
|
||||
const int n_samples,
|
||||
const int /*sample_rate*/,
|
||||
const int frame_size,
|
||||
const int frame_step,
|
||||
const int n_mel,
|
||||
const int n_threads,
|
||||
const whisper_filters & filters,
|
||||
const bool debug,
|
||||
whisper_mel & mel) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
struct mel_calc_cpu : public whisper_mel_calc {
|
||||
ggml_backend_t m_backend;
|
||||
const whisper_filters & m_filters;
|
||||
mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
|
||||
|
||||
// Hanning window (Use cosf to eliminate difference)
|
||||
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
||||
std::vector<float> hann;
|
||||
hann_window(frame_size, true, hann);
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
||||
whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) override {
|
||||
// Hann window
|
||||
const float * hann = global_cache.hann_window;
|
||||
|
||||
// Calculate the length of padding
|
||||
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
||||
int64_t stage_2_pad = WHISPER_N_FFT / 2;
|
||||
|
||||
// Calculate the length of padding
|
||||
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
||||
int64_t stage_2_pad = frame_size / 2;
|
||||
const int n_samples = int(ssamples.len);
|
||||
const float * samples = ssamples.data;
|
||||
|
||||
// Initialize a vector and copy data from C array to it.
|
||||
std::vector<float> samples_padded;
|
||||
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
|
||||
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
|
||||
// Initialize a vector and copy data from C array to it.
|
||||
std::vector<float> samples_padded;
|
||||
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
|
||||
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
|
||||
|
||||
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
|
||||
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
|
||||
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
|
||||
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
|
||||
|
||||
// reflective pad 200 samples at the beginning of audio
|
||||
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
|
||||
// reflective pad 200 samples at the beginning of audio
|
||||
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
|
||||
|
||||
mel.n_mel = n_mel;
|
||||
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
|
||||
// Calculate number of frames + remove the last frame
|
||||
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
|
||||
// Calculate semi-padded sample length to ensure compatibility
|
||||
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
||||
mel.data.resize(mel.n_mel * mel.n_len);
|
||||
whisper_mel_data mel;
|
||||
mel.n_mel = m_filters.n_mel;
|
||||
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
|
||||
// Calculate number of frames + remove the last frame
|
||||
mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
||||
// Calculate semi-padded sample length to ensure compatibility
|
||||
mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
||||
|
||||
std::vector<float> host_mel_data;
|
||||
|
||||
{
|
||||
std::vector<std::thread> workers(n_threads - 1);
|
||||
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||
workers[iw] = std::thread(
|
||||
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
|
||||
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
||||
std::cref(filters), std::ref(mel));
|
||||
whisper_mel ret;
|
||||
whisper_mel_init(ret, m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
|
||||
if (ggml_backend_buffer_is_host(ret.buffer)) {
|
||||
mel.data = reinterpret_cast<float*>(ret.tensor->data);
|
||||
} else {
|
||||
host_mel_data.resize(mel.n_len * mel.n_mel);
|
||||
mel.data = host_mel_data.data();
|
||||
}
|
||||
|
||||
// main thread
|
||||
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
|
||||
{
|
||||
std::vector<std::thread> workers(n_threads - 1);
|
||||
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||
workers[iw] = std::thread(
|
||||
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
|
||||
n_samples + stage_2_pad, n_threads,
|
||||
std::cref(m_filters), std::ref(mel));
|
||||
}
|
||||
|
||||
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||
workers[iw].join();
|
||||
// main thread
|
||||
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, n_threads, m_filters, mel);
|
||||
|
||||
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||
workers[iw].join();
|
||||
}
|
||||
}
|
||||
|
||||
// clamping and normalization
|
||||
double mmax = -1e20;
|
||||
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||
if (mel.data[i] > mmax) {
|
||||
mmax = mel.data[i];
|
||||
}
|
||||
}
|
||||
|
||||
mmax -= 8.0;
|
||||
|
||||
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||
if (mel.data[i] < mmax) {
|
||||
mel.data[i] = mmax;
|
||||
}
|
||||
|
||||
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
||||
}
|
||||
|
||||
if (!host_mel_data.empty()) {
|
||||
// the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it
|
||||
ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// clamping and normalization
|
||||
double mmax = -1e20;
|
||||
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||
if (mel.data[i] > mmax) {
|
||||
mmax = mel.data[i];
|
||||
}
|
||||
}
|
||||
|
||||
mmax -= 8.0;
|
||||
|
||||
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||
if (mel.data[i] < mmax) {
|
||||
mel.data[i] = mmax;
|
||||
}
|
||||
|
||||
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
||||
}
|
||||
|
||||
wstate.t_mel_us += ggml_time_us() - t_start_us;
|
||||
|
||||
// Dump log_mel_spectrogram
|
||||
if (debug) {
|
||||
std::ofstream outFile("log_mel_spectrogram.json");
|
||||
outFile << "[";
|
||||
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
||||
outFile << mel.data[i] << ", ";
|
||||
}
|
||||
outFile << mel.data[mel.data.size() - 1] << "]";
|
||||
outFile.close();
|
||||
}
|
||||
|
||||
return true;
|
||||
whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters) {
|
||||
#if GGML_USE_CUDA
|
||||
if (ggml_backend_is_cuda(backend)) {
|
||||
auto ret = whisper_mel_calc_create_cuda(backend, filters);
|
||||
// run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
|
||||
const float warmup[256] = {0};
|
||||
ret->calculate({warmup, 256}, 1);
|
||||
return ret;
|
||||
} else
|
||||
#endif
|
||||
return new mel_calc_cpu(backend, filters);
|
||||
}
|
||||
|
||||
// split text into tokens
|
||||
@ -3246,8 +3292,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
||||
#endif
|
||||
|
||||
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
fill_sin_cos_table();
|
||||
|
||||
whisper_state * state = new whisper_state;
|
||||
|
||||
state->backend = whisper_backend_init(ctx->params);
|
||||
@ -3257,15 +3301,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters);
|
||||
|
||||
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
||||
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
||||
const int factor = 3;
|
||||
|
||||
if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype,
|
||||
if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype,
|
||||
ctx->model.hparams.n_text_state,
|
||||
ctx->model.hparams.n_text_layer,
|
||||
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
||||
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
whisper_free_state(state);
|
||||
return nullptr;
|
||||
}
|
||||
@ -3275,11 +3321,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
||||
}
|
||||
|
||||
if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
|
||||
if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype,
|
||||
ctx->model.hparams.n_text_state,
|
||||
ctx->model.hparams.n_text_layer,
|
||||
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
||||
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
||||
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
|
||||
whisper_free_state(state);
|
||||
return nullptr;
|
||||
}
|
||||
@ -3289,11 +3335,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
||||
}
|
||||
|
||||
if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype,
|
||||
if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype,
|
||||
ctx->model.hparams.n_audio_state,
|
||||
1,
|
||||
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
||||
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
whisper_free_state(state);
|
||||
return nullptr;
|
||||
}
|
||||
@ -3305,7 +3351,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// [EXPERIMENTAL] Token-level timestamps with DTW
|
||||
if (ctx->params.dtw_token_timestamps) {
|
||||
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
|
||||
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) {
|
||||
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
||||
whisper_free_state(state);
|
||||
return nullptr;
|
||||
@ -3348,9 +3394,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// conv allocator
|
||||
{
|
||||
bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
|
||||
bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend,
|
||||
[&]() {
|
||||
return whisper_build_graph_conv(*ctx, *state);
|
||||
return whisper_build_graph_conv(*ctx, *state, 0);
|
||||
});
|
||||
|
||||
if (!ok) {
|
||||
@ -3364,7 +3410,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// encoder allocator
|
||||
if (!whisper_encode_external(*state)) {
|
||||
bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
|
||||
bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend,
|
||||
[&]() {
|
||||
return whisper_build_graph_encoder(*ctx, *state);
|
||||
});
|
||||
@ -3380,7 +3426,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// cross allocator
|
||||
{
|
||||
bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
|
||||
bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend,
|
||||
[&]() {
|
||||
return whisper_build_graph_cross(*ctx, *state);
|
||||
});
|
||||
@ -3396,7 +3442,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
// decoder allocator
|
||||
{
|
||||
bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
|
||||
bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend,
|
||||
[&]() {
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
@ -3668,9 +3714,16 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
||||
|
||||
void whisper_free_state(struct whisper_state * state) {
|
||||
if (state) {
|
||||
kv_cache_free(state->kv_self);
|
||||
kv_cache_free(state->kv_cross);
|
||||
kv_cache_free(state->kv_pad);
|
||||
whisper_kv_cache_free(state->kv_self);
|
||||
whisper_kv_cache_free(state->kv_cross);
|
||||
whisper_kv_cache_free(state->kv_pad);
|
||||
|
||||
whisper_mel_free(state->mel);
|
||||
|
||||
delete state->mel_calc;
|
||||
state->mel_calc = nullptr;
|
||||
delete state->mel_calc_fallback;
|
||||
state->mel_calc_fallback = nullptr;
|
||||
|
||||
#ifdef WHISPER_USE_COREML
|
||||
if (state->ctx_coreml != nullptr) {
|
||||
@ -3729,11 +3782,37 @@ void whisper_free_params(struct whisper_full_params * params) {
|
||||
}
|
||||
|
||||
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
||||
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
||||
return -1;
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
whisper_mel_free(state->mel);
|
||||
if (n_samples <= 5 * 60 * WHISPER_SAMPLE_RATE) {
|
||||
// calculate mel spectrogram for lengths up to 5 minutes on the most optimal mel calculator
|
||||
state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads);
|
||||
} else {
|
||||
// calcuate mel spectrogram for longer audios on the CPU
|
||||
// 1. gpu calculations may use hundreds of megabytes of memory for longer audios so we're being conservative
|
||||
// with our gpu demands
|
||||
// 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation
|
||||
// taking longer is not a major concern
|
||||
if (!state->mel_calc_fallback) {
|
||||
state->mel_calc_fallback = new mel_calc_cpu(state->backend, ctx->model.filters);
|
||||
}
|
||||
state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads);
|
||||
}
|
||||
|
||||
state->t_mel_us += ggml_time_us() - t_start_us;
|
||||
|
||||
// Dump log_mel_spectrogram
|
||||
//{
|
||||
// auto& mel = state->mel;
|
||||
// std::ofstream outFile("log_mel_spectrogram.json");
|
||||
// outFile << "[";
|
||||
// for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
||||
// outFile << mel.data[i] << ", ";
|
||||
// }
|
||||
// outFile << mel.data[mel.data.size() - 1] << "]";
|
||||
// outFile.close();
|
||||
//}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -3741,30 +3820,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
||||
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
||||
}
|
||||
|
||||
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
||||
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
||||
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
||||
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
||||
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
||||
}
|
||||
|
||||
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
|
||||
// TODO
|
||||
|
||||
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
|
||||
// TODO
|
||||
|
||||
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
|
||||
// TODO
|
||||
|
||||
int whisper_set_mel_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
@ -3776,12 +3831,10 @@ int whisper_set_mel_with_state(
|
||||
return -1;
|
||||
}
|
||||
|
||||
state->mel.n_len = n_len;
|
||||
state->mel.n_len_org = n_len;
|
||||
state->mel.n_mel = n_mel;
|
||||
whisper_mel_free(state->mel);
|
||||
whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel);
|
||||
|
||||
state->mel.data.resize(n_len*n_mel);
|
||||
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
|
||||
ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -4665,7 +4718,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
/*.split_on_word =*/ false,
|
||||
/*.max_tokens =*/ 0,
|
||||
|
||||
/*.speed_up =*/ false,
|
||||
/*.debug_mode =*/ false,
|
||||
/*.audio_ctx =*/ 0,
|
||||
|
||||
@ -5339,15 +5391,9 @@ int whisper_full_with_state(
|
||||
|
||||
if (n_samples > 0) {
|
||||
// compute log mel spectrogram
|
||||
if (params.speed_up) {
|
||||
// TODO: Replace PV with more advanced algorithm
|
||||
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
||||
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
||||
return -1;
|
||||
} else {
|
||||
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
||||
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
||||
return -2;
|
||||
}
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
|
||||
@ -5384,7 +5430,7 @@ int whisper_full_with_state(
|
||||
// if length of spectrogram is less than 1.0s (100 frames), then return
|
||||
// basically don't process anything that is less than 1.0s
|
||||
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
||||
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
||||
if (seek_end < seek_start + 100) {
|
||||
WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
||||
return 0;
|
||||
}
|
||||
@ -6096,8 +6142,8 @@ int whisper_full_with_state(
|
||||
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
||||
|
||||
if (!text.empty()) {
|
||||
const auto tt0 = params.speed_up ? 2*t0 : t0;
|
||||
const auto tt1 = params.speed_up ? 2*t1 : t1;
|
||||
const auto tt0 = t0;
|
||||
const auto tt1 = t1;
|
||||
|
||||
if (params.print_realtime) {
|
||||
if (params.print_timestamps) {
|
||||
@ -6143,8 +6189,8 @@ int whisper_full_with_state(
|
||||
if (!text.empty()) {
|
||||
const auto t1 = seek + seek_delta;
|
||||
|
||||
const auto tt0 = params.speed_up ? 2*t0 : t0;
|
||||
const auto tt1 = params.speed_up ? 2*t1 : t1;
|
||||
const auto tt0 = t0;
|
||||
const auto tt1 = t1;
|
||||
|
||||
if (params.print_realtime) {
|
||||
if (params.print_timestamps) {
|
||||
@ -7235,7 +7281,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
||||
// operation (after median filter)
|
||||
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
||||
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
||||
w = ggml_norm(gctx, w, 1e-9);
|
||||
w = ggml_norm(gctx, w, 1e-9f);
|
||||
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
||||
|
||||
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|
||||
|
19
whisper.h
19
whisper.h
@ -31,8 +31,10 @@
|
||||
|
||||
#define WHISPER_SAMPLE_RATE 16000
|
||||
#define WHISPER_N_FFT 400
|
||||
#define WHISPER_N_FFT_HALF (WHISPER_N_FFT / 2 + 1)
|
||||
#define WHISPER_HOP_LENGTH 160
|
||||
#define WHISPER_CHUNK_SIZE 30
|
||||
#define WHISPER_N_SAMPLES (WHISPER_SAMPLE_RATE * WHISPER_CHUNK_SIZE)
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -266,22 +268,6 @@ extern "C" {
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
|
||||
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
||||
// Returns 0 on success
|
||||
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
|
||||
struct whisper_context * ctx,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
|
||||
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 80
|
||||
@ -499,7 +485,6 @@ extern "C" {
|
||||
|
||||
// [EXPERIMENTAL] speed-up techniques
|
||||
// note: these can significantly reduce the quality of the output
|
||||
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
|
||||
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
|
||||
int audio_ctx; // overwrite the audio context size (0 = use default)
|
||||
|
||||
|
Reference in New Issue
Block a user