mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-18 23:08:08 +00:00
bindings : add java bindings (#931)
* WIP - java bindings * updated README * failed attempt at JNI * fullTranscribe() test passes * tested on Ubuntu 20 * link to Java bindings
This commit is contained in:
33
bindings/java/src/main/cpp/whisper_java.cpp
Normal file
33
bindings/java/src/main/cpp/whisper_java.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
#include <stdio.h>
|
||||
#include "whisper_java.h"
|
||||
|
||||
struct whisper_full_params default_params;
|
||||
struct whisper_context * whisper_ctx = nullptr;
|
||||
|
||||
struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) {
|
||||
default_params = whisper_full_default_params(strategy);
|
||||
|
||||
// struct whisper_java_params result = {};
|
||||
// return result;
|
||||
return;
|
||||
}
|
||||
|
||||
void whisper_java_init_from_file(const char * path_model) {
|
||||
whisper_ctx = whisper_init_from_file(path_model);
|
||||
if (0 == default_params.n_threads) {
|
||||
whisper_java_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
}
|
||||
}
|
||||
|
||||
/** Delegates to whisper_full, but without having to pass `whisper_full_params` */
|
||||
int whisper_java_full(
|
||||
struct whisper_context * ctx,
|
||||
// struct whisper_java_params params,
|
||||
const float * samples,
|
||||
int n_samples) {
|
||||
return whisper_full(ctx, default_params, samples, n_samples);
|
||||
}
|
||||
|
||||
void whisper_java_free() {
|
||||
// free(default_params);
|
||||
}
|
24
bindings/java/src/main/cpp/whisper_java.h
Normal file
24
bindings/java/src/main/cpp/whisper_java.h
Normal file
@ -0,0 +1,24 @@
|
||||
#define WHISPER_BUILD
|
||||
#include <whisper.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct whisper_java_params {
|
||||
};
|
||||
|
||||
WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy);
|
||||
|
||||
WHISPER_API void whisper_java_init_from_file(const char * path_model);
|
||||
|
||||
WHISPER_API int whisper_java_full(
|
||||
struct whisper_context * ctx,
|
||||
// struct whisper_java_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
@ -0,0 +1,39 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Structure;
|
||||
import com.sun.jna.ptr.PointerByReference;
|
||||
import io.github.ggerganov.whispercpp.ggml.GgmlType;
|
||||
import io.github.ggerganov.whispercpp.WhisperModel;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class WhisperContext extends Structure {
|
||||
int t_load_us = 0;
|
||||
int t_start_us = 0;
|
||||
|
||||
/** weight type (FP32 / FP16 / QX) */
|
||||
GgmlType wtype = GgmlType.GGML_TYPE_F16;
|
||||
/** intermediate type (FP32 or FP16) */
|
||||
GgmlType itype = GgmlType.GGML_TYPE_F16;
|
||||
|
||||
// WhisperModel model;
|
||||
public PointerByReference model;
|
||||
// whisper_vocab vocab;
|
||||
// whisper_state * state = nullptr;
|
||||
public PointerByReference vocab;
|
||||
public PointerByReference state;
|
||||
|
||||
/** populated by whisper_init_from_file() */
|
||||
String path_model;
|
||||
|
||||
// public static class ByReference extends WhisperContext implements Structure.ByReference {
|
||||
// }
|
||||
//
|
||||
// public static class ByValue extends WhisperContext implements Structure.ByValue {
|
||||
// }
|
||||
//
|
||||
// @Override
|
||||
// protected List<String> getFieldOrder() {
|
||||
// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model");
|
||||
// }
|
||||
}
|
@ -0,0 +1,124 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
|
||||
*/
|
||||
public class WhisperCpp implements AutoCloseable {
|
||||
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
|
||||
private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance;
|
||||
private Pointer ctx = null;
|
||||
|
||||
public File modelDir() {
|
||||
String modelDirPath = System.getenv("XDG_CACHE_HOME");
|
||||
if (modelDirPath == null) {
|
||||
modelDirPath = System.getProperty("user.home") + "/.cache";
|
||||
}
|
||||
|
||||
return new File(modelDirPath, "whisper");
|
||||
}
|
||||
|
||||
/**
|
||||
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
|
||||
* @return a Pointer to the WhisperContext
|
||||
*/
|
||||
void initContext(String modelPath) throws FileNotFoundException {
|
||||
if (ctx != null) {
|
||||
lib.whisper_free(ctx);
|
||||
}
|
||||
|
||||
if (!modelPath.contains("/") && !modelPath.contains("\\")) {
|
||||
if (!modelPath.endsWith(".bin")) {
|
||||
modelPath = "ggml-" + modelPath.replace("-", ".") + ".bin";
|
||||
}
|
||||
|
||||
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
|
||||
}
|
||||
|
||||
javaLib.whisper_java_init_from_file(modelPath);
|
||||
ctx = lib.whisper_init_from_file(modelPath);
|
||||
|
||||
if (ctx == null) {
|
||||
throw new FileNotFoundException(modelPath);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything.
|
||||
* `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience.
|
||||
*/
|
||||
public void getDefaultJavaParams(WhisperSamplingStrategy strategy) {
|
||||
javaLib.whisper_java_default_params(strategy.ordinal());
|
||||
// return lib.whisper_full_default_params(strategy.value)
|
||||
}
|
||||
|
||||
// whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params
|
||||
// fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams {
|
||||
// return lib.whisper_full_default_params(strategy.value)
|
||||
// }
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
freeContext();
|
||||
System.out.println("Whisper closed");
|
||||
}
|
||||
|
||||
private void freeContext() {
|
||||
if (ctx != null) {
|
||||
lib.whisper_free(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
||||
* Not thread safe for same context
|
||||
* Uses the specified decoding strategy to obtain the text.
|
||||
*/
|
||||
public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException {
|
||||
if (ctx == null) {
|
||||
throw new IllegalStateException("Model not initialised");
|
||||
}
|
||||
|
||||
if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) {
|
||||
throw new IOException("Failed to process audio");
|
||||
}
|
||||
|
||||
int nSegments = lib.whisper_full_n_segments(ctx);
|
||||
|
||||
StringBuilder str = new StringBuilder();
|
||||
|
||||
for (int i = 0; i < nSegments; i++) {
|
||||
String text = lib.whisper_full_get_segment_text(ctx, i);
|
||||
System.out.println("Segment:" + text);
|
||||
str.append(text);
|
||||
}
|
||||
|
||||
return str.toString().trim();
|
||||
}
|
||||
|
||||
// public int getTextSegmentCount(Pointer ctx) {
|
||||
// return lib.whisper_full_n_segments(ctx);
|
||||
// }
|
||||
// public String getTextSegment(Pointer ctx, int index) {
|
||||
// return lib.whisper_full_get_segment_text(ctx, index);
|
||||
// }
|
||||
|
||||
public String getSystemInfo() {
|
||||
return lib.whisper_print_system_info();
|
||||
}
|
||||
|
||||
public int benchMemcpy(int nthread) {
|
||||
return lib.whisper_bench_memcpy(nthread);
|
||||
}
|
||||
|
||||
public int benchGgmlMulMat(int nthread) {
|
||||
return lib.whisper_bench_ggml_mul_mat(nthread);
|
||||
}
|
||||
}
|
@ -0,0 +1,365 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Library;
|
||||
import com.sun.jna.Native;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
|
||||
public interface WhisperCppJnaLibrary extends Library {
|
||||
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
|
||||
|
||||
String whisper_print_system_info();
|
||||
|
||||
/**
|
||||
* Allocate (almost) all memory needed for the model by loading from a file.
|
||||
*
|
||||
* @param path_model Path to the model file
|
||||
* @return Whisper context on success, null on failure
|
||||
*/
|
||||
Pointer whisper_init_from_file(String path_model);
|
||||
|
||||
/**
|
||||
* Allocate (almost) all memory needed for the model by loading from a buffer.
|
||||
*
|
||||
* @param buffer Model buffer
|
||||
* @param buffer_size Size of the model buffer
|
||||
* @return Whisper context on success, null on failure
|
||||
*/
|
||||
Pointer whisper_init_from_buffer(Pointer buffer, int buffer_size);
|
||||
|
||||
/**
|
||||
* Allocate (almost) all memory needed for the model using a model loader.
|
||||
*
|
||||
* @param loader Model loader
|
||||
* @return Whisper context on success, null on failure
|
||||
*/
|
||||
Pointer whisper_init(WhisperModelLoader loader);
|
||||
|
||||
/**
|
||||
* Allocate (almost) all memory needed for the model by loading from a file without allocating the state.
|
||||
*
|
||||
* @param path_model Path to the model file
|
||||
* @return Whisper context on success, null on failure
|
||||
*/
|
||||
Pointer whisper_init_from_file_no_state(String path_model);
|
||||
|
||||
/**
|
||||
* Allocate (almost) all memory needed for the model by loading from a buffer without allocating the state.
|
||||
*
|
||||
* @param buffer Model buffer
|
||||
* @param buffer_size Size of the model buffer
|
||||
* @return Whisper context on success, null on failure
|
||||
*/
|
||||
Pointer whisper_init_from_buffer_no_state(Pointer buffer, int buffer_size);
|
||||
|
||||
// Pointer whisper_init_from_buffer_no_state(Pointer buffer, long buffer_size);
|
||||
|
||||
/**
|
||||
* Allocate (almost) all memory needed for the model using a model loader without allocating the state.
|
||||
*
|
||||
* @param loader Model loader
|
||||
* @return Whisper context on success, null on failure
|
||||
*/
|
||||
Pointer whisper_init_no_state(WhisperModelLoader loader);
|
||||
|
||||
/**
|
||||
* Allocate memory for the Whisper state.
|
||||
*
|
||||
* @param ctx Whisper context
|
||||
* @return Whisper state on success, null on failure
|
||||
*/
|
||||
Pointer whisper_init_state(Pointer ctx);
|
||||
|
||||
/**
|
||||
* Free all allocated memory associated with the Whisper context.
|
||||
*
|
||||
* @param ctx Whisper context
|
||||
*/
|
||||
void whisper_free(Pointer ctx);
|
||||
|
||||
/**
|
||||
* Free all allocated memory associated with the Whisper state.
|
||||
*
|
||||
* @param state Whisper state
|
||||
*/
|
||||
void whisper_free_state(Pointer state);
|
||||
|
||||
|
||||
/**
|
||||
* Convert RAW PCM audio to log mel spectrogram.
|
||||
* The resulting spectrogram is stored inside the default state of the provided whisper context.
|
||||
*
|
||||
* @param ctx - Pointer to a WhisperContext
|
||||
* @return 0 on success
|
||||
*/
|
||||
int whisper_pcm_to_mel(Pointer ctx, final float[] samples, int n_samples, int n_threads);
|
||||
|
||||
/**
|
||||
* @param ctx Pointer to a WhisperContext
|
||||
* @param state Pointer to WhisperState
|
||||
* @param n_samples
|
||||
* @param n_threads
|
||||
* @return 0 on success
|
||||
*/
|
||||
int whisper_pcm_to_mel_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
|
||||
|
||||
/**
|
||||
* This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
|
||||
* Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
* n_mel must be 80
|
||||
* @return 0 on success
|
||||
*/
|
||||
int whisper_set_mel(Pointer ctx, final float[] data, int n_len, int n_mel);
|
||||
int whisper_set_mel_with_state(Pointer ctx, Pointer state, final float[] data, int n_len, int n_mel);
|
||||
|
||||
/**
|
||||
* Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
|
||||
* Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
* Offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
* @return 0 on success
|
||||
*/
|
||||
int whisper_encode(Pointer ctx, int offset, int n_threads);
|
||||
|
||||
int whisper_encode_with_state(Pointer ctx, Pointer state, int offset, int n_threads);
|
||||
|
||||
/**
|
||||
* Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||
* Make sure to call whisper_encode() first.
|
||||
* tokens + n_tokens is the provided context for the decoder.
|
||||
* n_past is the number of tokens to use from previous decoder calls.
|
||||
* Returns 0 on success
|
||||
* TODO: add support for multiple decoders
|
||||
*/
|
||||
int whisper_decode(Pointer ctx, Pointer tokens, int n_tokens, int n_past, int n_threads);
|
||||
|
||||
/**
|
||||
* @param ctx
|
||||
* @param state
|
||||
* @param tokens Pointer to int tokens
|
||||
* @param n_tokens
|
||||
* @param n_past
|
||||
* @param n_threads
|
||||
* @return
|
||||
*/
|
||||
int whisper_decode_with_state(Pointer ctx, Pointer state, Pointer tokens, int n_tokens, int n_past, int n_threads);
|
||||
|
||||
/**
|
||||
* Convert the provided text into tokens.
|
||||
* The tokens pointer must be large enough to hold the resulting tokens.
|
||||
* Returns the number of tokens on success, no more than n_max_tokens
|
||||
* Returns -1 on failure
|
||||
* TODO: not sure if correct
|
||||
*/
|
||||
int whisper_tokenize(Pointer ctx, String text, Pointer tokens, int n_max_tokens);
|
||||
|
||||
/** Largest language id (i.e. number of available languages - 1) */
|
||||
int whisper_lang_max_id();
|
||||
|
||||
/**
|
||||
* @return the id of the specified language, returns -1 if not found.
|
||||
* Examples:
|
||||
* "de" -> 2
|
||||
* "german" -> 2
|
||||
*/
|
||||
int whisper_lang_id(String lang);
|
||||
|
||||
/** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */
|
||||
String whisper_lang_str(int id);
|
||||
|
||||
/**
|
||||
* Use mel data at offset_ms to try and auto-detect the spoken language.
|
||||
* Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
|
||||
* Returns the top language id or negative on failure
|
||||
* If not null, fills the lang_probs array with the probabilities of all languages
|
||||
* The array must be whisper_lang_max_id() + 1 in size
|
||||
*
|
||||
* ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
||||
*/
|
||||
int whisper_lang_auto_detect(Pointer ctx, int offset_ms, int n_threads, float[] lang_probs);
|
||||
|
||||
int whisper_lang_auto_detect_with_state(Pointer ctx, Pointer state, int offset_ms, int n_threads, float[] lang_probs);
|
||||
|
||||
int whisper_n_len (Pointer ctx); // mel length
|
||||
int whisper_n_len_from_state(Pointer state); // mel length
|
||||
int whisper_n_vocab (Pointer ctx);
|
||||
int whisper_n_text_ctx (Pointer ctx);
|
||||
int whisper_n_audio_ctx (Pointer ctx);
|
||||
int whisper_is_multilingual (Pointer ctx);
|
||||
|
||||
int whisper_model_n_vocab (Pointer ctx);
|
||||
int whisper_model_n_audio_ctx (Pointer ctx);
|
||||
int whisper_model_n_audio_state(Pointer ctx);
|
||||
int whisper_model_n_audio_head (Pointer ctx);
|
||||
int whisper_model_n_audio_layer(Pointer ctx);
|
||||
int whisper_model_n_text_ctx (Pointer ctx);
|
||||
int whisper_model_n_text_state (Pointer ctx);
|
||||
int whisper_model_n_text_head (Pointer ctx);
|
||||
int whisper_model_n_text_layer (Pointer ctx);
|
||||
int whisper_model_n_mels (Pointer ctx);
|
||||
int whisper_model_ftype (Pointer ctx);
|
||||
int whisper_model_type (Pointer ctx);
|
||||
|
||||
/**
|
||||
* Token logits obtained from the last call to whisper_decode().
|
||||
* The logits for the last token are stored in the last row
|
||||
* Rows: n_tokens
|
||||
* Cols: n_vocab
|
||||
*/
|
||||
float[] whisper_get_logits (Pointer ctx);
|
||||
float[] whisper_get_logits_from_state(Pointer state);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
String whisper_token_to_str(Pointer ctx, int token);
|
||||
String whisper_model_type_readable(Pointer ctx);
|
||||
|
||||
// Special tokens
|
||||
int whisper_token_eot (Pointer ctx);
|
||||
int whisper_token_sot (Pointer ctx);
|
||||
int whisper_token_prev(Pointer ctx);
|
||||
int whisper_token_solm(Pointer ctx);
|
||||
int whisper_token_not (Pointer ctx);
|
||||
int whisper_token_beg (Pointer ctx);
|
||||
int whisper_token_lang(Pointer ctx, int lang_id);
|
||||
|
||||
// Task tokens
|
||||
int whisper_token_translate();
|
||||
int whisper_token_transcribe();
|
||||
|
||||
// Performance information from the default state.
|
||||
void whisper_print_timings(Pointer ctx);
|
||||
void whisper_reset_timings(Pointer ctx);
|
||||
|
||||
/**
|
||||
* @param strategy - WhisperSamplingStrategy.value
|
||||
*/
|
||||
WhisperFullParams whisper_full_default_params(int strategy);
|
||||
|
||||
/**
|
||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
* Not thread safe for same context
|
||||
* Uses the specified decoding strategy to obtain the text.
|
||||
*/
|
||||
int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples);
|
||||
|
||||
int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
|
||||
|
||||
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
|
||||
// Result is stored in the default state of the context
|
||||
// Not thread safe if executed in parallel on the same context.
|
||||
// It seems this approach can offer some speedup in some cases.
|
||||
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors);
|
||||
|
||||
/**
|
||||
* Number of generated text segments.
|
||||
* A segment can be a few words, a sentence, or even a paragraph.
|
||||
* @param ctx Pointer to WhisperContext
|
||||
*/
|
||||
int whisper_full_n_segments (Pointer ctx);
|
||||
|
||||
/**
|
||||
* @param state Pointer to WhisperState
|
||||
*/
|
||||
int whisper_full_n_segments_from_state(Pointer state);
|
||||
|
||||
/**
|
||||
* Language id associated with the context's default state.
|
||||
* @param ctx Pointer to WhisperContext
|
||||
*/
|
||||
int whisper_full_lang_id(Pointer ctx);
|
||||
|
||||
/** Language id associated with the provided state */
|
||||
int whisper_full_lang_id_from_state(Pointer state);
|
||||
|
||||
/**
|
||||
* Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
|
||||
* The resulting spectrogram is stored inside the default state of the provided whisper context.
|
||||
* @return 0 on success
|
||||
*/
|
||||
int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads);
|
||||
|
||||
int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
|
||||
|
||||
/** Get the start time of the specified segment. */
|
||||
long whisper_full_get_segment_t0(Pointer ctx, int i_segment);
|
||||
|
||||
/** Get the start time of the specified segment from the state. */
|
||||
long whisper_full_get_segment_t0_from_state(Pointer state, int i_segment);
|
||||
|
||||
/** Get the end time of the specified segment. */
|
||||
long whisper_full_get_segment_t1(Pointer ctx, int i_segment);
|
||||
|
||||
/** Get the end time of the specified segment from the state. */
|
||||
long whisper_full_get_segment_t1_from_state(Pointer state, int i_segment);
|
||||
|
||||
/** Get the text of the specified segment. */
|
||||
String whisper_full_get_segment_text(Pointer ctx, int i_segment);
|
||||
|
||||
/** Get the text of the specified segment from the state. */
|
||||
String whisper_full_get_segment_text_from_state(Pointer state, int i_segment);
|
||||
|
||||
/** Get the number of tokens in the specified segment. */
|
||||
int whisper_full_n_tokens(Pointer ctx, int i_segment);
|
||||
|
||||
/** Get the number of tokens in the specified segment from the state. */
|
||||
int whisper_full_n_tokens_from_state(Pointer state, int i_segment);
|
||||
|
||||
/** Get the token text of the specified token in the specified segment. */
|
||||
String whisper_full_get_token_text(Pointer ctx, int i_segment, int i_token);
|
||||
|
||||
|
||||
/** Get the token text of the specified token in the specified segment from the state. */
|
||||
String whisper_full_get_token_text_from_state(Pointer ctx, Pointer state, int i_segment, int i_token);
|
||||
|
||||
/** Get the token ID of the specified token in the specified segment. */
|
||||
int whisper_full_get_token_id(Pointer ctx, int i_segment, int i_token);
|
||||
|
||||
/** Get the token ID of the specified token in the specified segment from the state. */
|
||||
int whisper_full_get_token_id_from_state(Pointer state, int i_segment, int i_token);
|
||||
|
||||
/** Get token data for the specified token in the specified segment. */
|
||||
WhisperTokenData whisper_full_get_token_data(Pointer ctx, int i_segment, int i_token);
|
||||
|
||||
/** Get token data for the specified token in the specified segment from the state. */
|
||||
WhisperTokenData whisper_full_get_token_data_from_state(Pointer state, int i_segment, int i_token);
|
||||
|
||||
/** Get the probability of the specified token in the specified segment. */
|
||||
float whisper_full_get_token_p(Pointer ctx, int i_segment, int i_token);
|
||||
|
||||
/** Get the probability of the specified token in the specified segment from the state. */
|
||||
float whisper_full_get_token_p_from_state(Pointer state, int i_segment, int i_token);
|
||||
|
||||
/**
|
||||
* Benchmark function for memcpy.
|
||||
*
|
||||
* @param nThreads Number of threads to use for the benchmark.
|
||||
* @return The result of the benchmark.
|
||||
*/
|
||||
int whisper_bench_memcpy(int nThreads);
|
||||
|
||||
/**
|
||||
* Benchmark function for memcpy as a string.
|
||||
*
|
||||
* @param nThreads Number of threads to use for the benchmark.
|
||||
* @return The result of the benchmark as a string.
|
||||
*/
|
||||
String whisper_bench_memcpy_str(int nThreads);
|
||||
|
||||
/**
|
||||
* Benchmark function for ggml_mul_mat.
|
||||
*
|
||||
* @param nThreads Number of threads to use for the benchmark.
|
||||
* @return The result of the benchmark.
|
||||
*/
|
||||
int whisper_bench_ggml_mul_mat(int nThreads);
|
||||
|
||||
/**
|
||||
* Benchmark function for ggml_mul_mat as a string.
|
||||
*
|
||||
* @param nThreads Number of threads to use for the benchmark.
|
||||
* @return The result of the benchmark as a string.
|
||||
*/
|
||||
String whisper_bench_ggml_mul_mat_str(int nThreads);
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Library;
|
||||
import com.sun.jna.Native;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
||||
|
||||
interface WhisperJavaJnaLibrary extends Library {
|
||||
WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class);
|
||||
|
||||
void whisper_java_default_params(int strategy);
|
||||
|
||||
void whisper_java_free();
|
||||
|
||||
void whisper_java_init_from_file(String modelPath);
|
||||
|
||||
/**
|
||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
||||
* Not thread safe for same context
|
||||
* Uses the specified decoding strategy to obtain the text.
|
||||
*/
|
||||
int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples);
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
package io.github.ggerganov.whispercpp.callbacks;
|
||||
|
||||
import com.sun.jna.Callback;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.WhisperContext;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperState;
|
||||
|
||||
/**
|
||||
* Callback before the encoder starts.
|
||||
* If not null, called before the encoder starts.
|
||||
* If it returns false, the computation is aborted.
|
||||
*/
|
||||
public interface WhisperEncoderBeginCallback extends Callback {
|
||||
|
||||
/**
|
||||
* Callback method before the encoder starts.
|
||||
*
|
||||
* @param ctx The whisper context.
|
||||
* @param state The whisper state.
|
||||
* @param user_data User data.
|
||||
* @return True if the computation should proceed, false otherwise.
|
||||
*/
|
||||
boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data);
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
package io.github.ggerganov.whispercpp.callbacks;
|
||||
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.WhisperContext;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperState;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
|
||||
|
||||
import javax.security.auth.callback.Callback;
|
||||
|
||||
/**
|
||||
* Callback to filter logits.
|
||||
* Can be used to modify the logits before sampling.
|
||||
* If not null, called after applying temperature to logits.
|
||||
*/
|
||||
public interface WhisperLogitsFilterCallback extends Callback {
|
||||
|
||||
/**
|
||||
* Callback method to filter logits.
|
||||
*
|
||||
* @param ctx The whisper context.
|
||||
* @param state The whisper state.
|
||||
* @param tokens The array of whisper_token_data.
|
||||
* @param n_tokens The number of tokens.
|
||||
* @param logits The array of logits.
|
||||
* @param user_data User data.
|
||||
*/
|
||||
void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
package io.github.ggerganov.whispercpp.callbacks;
|
||||
|
||||
import com.sun.jna.Callback;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.WhisperContext;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperState;
|
||||
|
||||
/**
|
||||
* Callback for the text segment.
|
||||
* Called on every newly generated text segment.
|
||||
* Use the whisper_full_...() functions to obtain the text segments.
|
||||
*/
|
||||
public interface WhisperNewSegmentCallback extends Callback {
|
||||
|
||||
/**
|
||||
* Callback method for the text segment.
|
||||
*
|
||||
* @param ctx The whisper context.
|
||||
* @param state The whisper state.
|
||||
* @param n_new The number of newly generated text segments.
|
||||
* @param user_data User data.
|
||||
*/
|
||||
void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data);
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package io.github.ggerganov.whispercpp.callbacks;
|
||||
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.WhisperContext;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperState;
|
||||
|
||||
import javax.security.auth.callback.Callback;
|
||||
|
||||
/**
|
||||
* Callback for progress updates.
|
||||
*/
|
||||
public interface WhisperProgressCallback extends Callback {
|
||||
|
||||
/**
|
||||
* Callback method for progress updates.
|
||||
*
|
||||
* @param ctx The whisper context.
|
||||
* @param state The whisper state.
|
||||
* @param progress The progress value.
|
||||
* @param user_data User data.
|
||||
*/
|
||||
void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data);
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
package io.github.ggerganov.whispercpp.ggml;
|
||||
|
||||
public class GgmlTensor {
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
package io.github.ggerganov.whispercpp.ggml;
|
||||
|
||||
public enum GgmlType {
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F16,
|
||||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1,
|
||||
REMOVED_GGML_TYPE_Q4_2, // support has been removed
|
||||
REMOVED_GGML_TYPE_Q4_3, // support has been removed
|
||||
GGML_TYPE_Q5_0,
|
||||
GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q8_1,
|
||||
GGML_TYPE_I8,
|
||||
GGML_TYPE_I16,
|
||||
GGML_TYPE_I32,
|
||||
GGML_TYPE_COUNT,
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
package io.github.ggerganov.whispercpp.model;
|
||||
|
||||
public enum EModel {
|
||||
MODEL_UNKNOWN,
|
||||
MODEL_TINY,
|
||||
MODEL_BASE,
|
||||
MODEL_SMALL,
|
||||
MODEL_MEDIUM,
|
||||
MODEL_LARGE,
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import io.github.ggerganov.whispercpp.ggml.GgmlTensor;
|
||||
import io.github.ggerganov.whispercpp.model.EModel;
|
||||
|
||||
public class WhisperModel {
|
||||
// EModel type = EModel.MODEL_UNKNOWN;
|
||||
//
|
||||
// WhisperHParams hparams;
|
||||
// WhisperFilters filters;
|
||||
//
|
||||
// // encoder.positional_embedding
|
||||
// GgmlTensor e_pe;
|
||||
//
|
||||
// // encoder.conv1
|
||||
// GgmlTensor e_conv_1_w;
|
||||
// GgmlTensor e_conv_1_b;
|
||||
//
|
||||
// // encoder.conv2
|
||||
// GgmlTensor e_conv_2_w;
|
||||
// GgmlTensor e_conv_2_b;
|
||||
//
|
||||
// // encoder.ln_post
|
||||
// GgmlTensor e_ln_w;
|
||||
// GgmlTensor e_ln_b;
|
||||
//
|
||||
// // decoder.positional_embedding
|
||||
// GgmlTensor d_pe;
|
||||
//
|
||||
// // decoder.token_embedding
|
||||
// GgmlTensor d_te;
|
||||
//
|
||||
// // decoder.ln
|
||||
// GgmlTensor d_ln_w;
|
||||
// GgmlTensor d_ln_b;
|
||||
//
|
||||
// std::vector<whisper_layer_encoder> layers_encoder;
|
||||
// std::vector<whisper_layer_decoder> layers_decoder;
|
||||
//
|
||||
// // context
|
||||
// struct ggml_context * ctx;
|
||||
//
|
||||
// // the model memory buffer is read-only and can be shared between processors
|
||||
// std::vector<uint8_t> * buf;
|
||||
//
|
||||
// // tensors
|
||||
// int n_loaded;
|
||||
// Map<String, GgmlTensor> tensors;
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
package io.github.ggerganov.whispercpp.model;
|
||||
|
||||
import com.sun.jna.Callback;
|
||||
import com.sun.jna.Pointer;
|
||||
import com.sun.jna.Structure;
|
||||
|
||||
|
||||
public class WhisperModelLoader extends Structure {
|
||||
public Pointer context;
|
||||
public ReadFunction read;
|
||||
public EOFFunction eof;
|
||||
public CloseFunction close;
|
||||
|
||||
public static class ReadFunction implements Callback {
|
||||
public Pointer invoke(Pointer ctx, Pointer output, int readSize) {
|
||||
// TODO
|
||||
return ctx;
|
||||
}
|
||||
}
|
||||
|
||||
public static class EOFFunction implements Callback {
|
||||
public boolean invoke(Pointer ctx) {
|
||||
// TODO
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public static class CloseFunction implements Callback {
|
||||
public void invoke(Pointer ctx) {
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
|
||||
// public WhisperModelLoader(Pointer p) {
|
||||
// super(p);
|
||||
// read = new ReadFunction();
|
||||
// eof = new EOFFunction();
|
||||
// close = new CloseFunction();
|
||||
// read.setCallback(this);
|
||||
// eof.setCallback(this);
|
||||
// close.setCallback(this);
|
||||
// read.write();
|
||||
// eof.write();
|
||||
// close.write();
|
||||
// }
|
||||
|
||||
public WhisperModelLoader() {
|
||||
super();
|
||||
}
|
||||
|
||||
public interface ReadCallback extends Callback {
|
||||
Pointer invoke(Pointer ctx, Pointer output, int readSize);
|
||||
}
|
||||
|
||||
public interface EOFCallback extends Callback {
|
||||
boolean invoke(Pointer ctx);
|
||||
}
|
||||
|
||||
public interface CloseCallback extends Callback {
|
||||
void invoke(Pointer ctx);
|
||||
}
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
package io.github.ggerganov.whispercpp.model;
|
||||
|
||||
public class WhisperState {
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
package io.github.ggerganov.whispercpp.model;
|
||||
|
||||
import com.sun.jna.Structure;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Structure representing token data.
|
||||
*/
|
||||
public class WhisperTokenData extends Structure {
|
||||
|
||||
/** Token ID. */
|
||||
public int id;
|
||||
|
||||
/** Forced timestamp token ID. */
|
||||
public int tid;
|
||||
|
||||
/** Probability of the token. */
|
||||
public float p;
|
||||
|
||||
/** Log probability of the token. */
|
||||
public float plog;
|
||||
|
||||
/** Probability of the timestamp token. */
|
||||
public float pt;
|
||||
|
||||
/** Sum of probabilities of all timestamp tokens. */
|
||||
public float ptsum;
|
||||
|
||||
/**
|
||||
* Start time of the token (token-level timestamp data).
|
||||
* Do not use if you haven't computed token-level timestamps.
|
||||
*/
|
||||
public long t0;
|
||||
|
||||
/**
|
||||
* End time of the token (token-level timestamp data).
|
||||
* Do not use if you haven't computed token-level timestamps.
|
||||
*/
|
||||
public long t1;
|
||||
|
||||
/** Voice length of the token. */
|
||||
public float vlen;
|
||||
|
||||
@Override
|
||||
protected List<String> getFieldOrder() {
|
||||
return Arrays.asList("id", "tid", "p", "plog", "pt", "ptsum", "t0", "t1", "vlen");
|
||||
}
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class WhisperFilters {
|
||||
int n_mel;
|
||||
int n_fft;
|
||||
|
||||
List<Float> data;
|
||||
}
|
@ -0,0 +1,187 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
import com.sun.jna.Callback;
|
||||
import com.sun.jna.Pointer;
|
||||
import com.sun.jna.Structure;
|
||||
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
|
||||
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
|
||||
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
|
||||
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
||||
|
||||
/**
|
||||
* Parameters for the whisper_full() function.
|
||||
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
|
||||
* whisper_full_default_params()
|
||||
*/
|
||||
public class WhisperFullParams extends Structure {
|
||||
|
||||
/** Sampling strategy for whisper_full() function. */
|
||||
public int strategy;
|
||||
|
||||
/** Number of threads. */
|
||||
public int n_threads;
|
||||
|
||||
/** Maximum tokens to use from past text as a prompt for the decoder. */
|
||||
public int n_max_text_ctx;
|
||||
|
||||
/** Start offset in milliseconds. */
|
||||
public int offset_ms;
|
||||
|
||||
/** Audio duration to process in milliseconds. */
|
||||
public int duration_ms;
|
||||
|
||||
/** Translate flag. */
|
||||
public boolean translate;
|
||||
|
||||
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */
|
||||
public boolean no_context;
|
||||
|
||||
/** Flag to force single segment output (useful for streaming). */
|
||||
public boolean single_segment;
|
||||
|
||||
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). */
|
||||
public boolean print_special;
|
||||
|
||||
/** Flag to print progress information. */
|
||||
public boolean print_progress;
|
||||
|
||||
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). */
|
||||
public boolean print_realtime;
|
||||
|
||||
/** Flag to print timestamps for each text segment when printing realtime. */
|
||||
public boolean print_timestamps;
|
||||
|
||||
/** [EXPERIMENTAL] Flag to enable token-level timestamps. */
|
||||
public boolean token_timestamps;
|
||||
|
||||
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */
|
||||
public float thold_pt;
|
||||
|
||||
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
|
||||
public float thold_ptsum;
|
||||
|
||||
/** Maximum segment length in characters. */
|
||||
public int max_len;
|
||||
|
||||
/** Flag to split on word rather than on token (when used with max_len). */
|
||||
public boolean split_on_word;
|
||||
|
||||
/** Maximum tokens per segment (0 = no limit). */
|
||||
public int max_tokens;
|
||||
|
||||
/** Flag to speed up the audio by 2x using Phase Vocoder. */
|
||||
public boolean speed_up;
|
||||
|
||||
/** Overwrite the audio context size (0 = use default). */
|
||||
public int audio_ctx;
|
||||
|
||||
/** Tokens to provide to the whisper decoder as an initial prompt.
|
||||
* These are prepended to any existing text context from a previous call. */
|
||||
public String initial_prompt;
|
||||
|
||||
/** Prompt tokens. */
|
||||
public Pointer prompt_tokens;
|
||||
|
||||
/** Number of prompt tokens. */
|
||||
public int prompt_n_tokens;
|
||||
|
||||
/** Language for auto-detection.
|
||||
* For auto-detection, set to `null`, `""`, or "auto". */
|
||||
public String language;
|
||||
|
||||
/** Flag to indicate whether to detect language automatically. */
|
||||
public boolean detect_language;
|
||||
|
||||
/** Common decoding parameters. */
|
||||
|
||||
/** Flag to suppress blank tokens. */
|
||||
public boolean suppress_blank;
|
||||
|
||||
/** Flag to suppress non-speech tokens. */
|
||||
public boolean suppress_non_speech_tokens;
|
||||
|
||||
/** Initial decoding temperature. */
|
||||
public float temperature;
|
||||
|
||||
/** Maximum initial timestamp. */
|
||||
public float max_initial_ts;
|
||||
|
||||
/** Length penalty. */
|
||||
public float length_penalty;
|
||||
|
||||
/** Fallback parameters. */
|
||||
|
||||
/** Temperature increment. */
|
||||
public float temperature_inc;
|
||||
|
||||
/** Entropy threshold (similar to OpenAI's "compression_ratio_threshold"). */
|
||||
public float entropy_thold;
|
||||
|
||||
/** Log probability threshold. */
|
||||
public float logprob_thold;
|
||||
|
||||
/** No speech threshold. */
|
||||
public float no_speech_thold;
|
||||
|
||||
class GreedyParams extends Structure {
|
||||
/** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */
|
||||
public int best_of;
|
||||
}
|
||||
|
||||
/** Greedy decoding parameters. */
|
||||
public GreedyParams greedy;
|
||||
|
||||
class BeamSearchParams extends Structure {
|
||||
/** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */
|
||||
int beam_size;
|
||||
|
||||
/** ref: https://arxiv.org/pdf/2204.05424.pdf */
|
||||
float patience;
|
||||
}
|
||||
|
||||
/**
|
||||
* Beam search decoding parameters.
|
||||
*/
|
||||
public BeamSearchParams beam_search;
|
||||
|
||||
/**
|
||||
* Callback for every newly generated text segment.
|
||||
*/
|
||||
public WhisperNewSegmentCallback new_segment_callback;
|
||||
|
||||
/**
|
||||
* User data for the new_segment_callback.
|
||||
*/
|
||||
public Pointer new_segment_callback_user_data;
|
||||
|
||||
/**
|
||||
* Callback on each progress update.
|
||||
*/
|
||||
public WhisperProgressCallback progress_callback;
|
||||
|
||||
/**
|
||||
* User data for the progress_callback.
|
||||
*/
|
||||
public Pointer progress_callback_user_data;
|
||||
|
||||
/**
|
||||
* Callback each time before the encoder starts.
|
||||
*/
|
||||
public WhisperEncoderBeginCallback encoder_begin_callback;
|
||||
|
||||
/**
|
||||
* User data for the encoder_begin_callback.
|
||||
*/
|
||||
public Pointer encoder_begin_callback_user_data;
|
||||
|
||||
/**
|
||||
* Callback by each decoder to filter obtained logits.
|
||||
*/
|
||||
public WhisperLogitsFilterCallback logits_filter_callback;
|
||||
|
||||
/**
|
||||
* User data for the logits_filter_callback.
|
||||
*/
|
||||
public Pointer logits_filter_callback_user_data;
|
||||
}
|
||||
|
@ -0,0 +1,15 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
public class WhisperHParams {
|
||||
int n_vocab = 51864;
|
||||
int n_audio_ctx = 1500;
|
||||
int n_audio_state = 384;
|
||||
int n_audio_head = 6;
|
||||
int n_audio_layer = 4;
|
||||
int n_text_ctx = 448;
|
||||
int n_text_state = 384;
|
||||
int n_text_head = 6;
|
||||
int n_text_layer = 4;
|
||||
int n_mels = 80;
|
||||
int ftype = 1;
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
import com.sun.jna.Structure;
|
||||
|
||||
public class WhisperJavaParams extends Structure {
|
||||
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
/** Available sampling strategies */
|
||||
public enum WhisperSamplingStrategy {
|
||||
/** similar to OpenAI's GreedyDecoder */
|
||||
WHISPER_SAMPLING_GREEDY,
|
||||
|
||||
/** similar to OpenAI's BeamSearchDecoder */
|
||||
WHISPER_SAMPLING_BEAM_SEARCH
|
||||
}
|
@ -0,0 +1,75 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
|
||||
class WhisperCppTest {
|
||||
private static WhisperCpp whisper = new WhisperCpp();
|
||||
private static boolean modelInitialised = false;
|
||||
|
||||
@BeforeAll
|
||||
static void init() throws FileNotFoundException {
|
||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||
// or you can provide the absolute path to the model file.
|
||||
String modelName = "base.en";
|
||||
try {
|
||||
whisper.initContext(modelName);
|
||||
whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
modelInitialised = true;
|
||||
} catch (FileNotFoundException ex) {
|
||||
System.out.println("Model " + modelName + " not found");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetDefaultJavaParams() {
|
||||
// When
|
||||
whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
|
||||
// Then if it doesn't throw we've connected to whisper.cpp
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFullTranscribe() throws Exception {
|
||||
if (!modelInitialised) {
|
||||
System.out.println("Model not initialised, skipping test");
|
||||
return;
|
||||
}
|
||||
|
||||
// Given
|
||||
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
|
||||
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
|
||||
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
try {
|
||||
audioInputStream.read(b);
|
||||
|
||||
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
|
||||
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
|
||||
floats[j] = intSample / 32767.0f;
|
||||
}
|
||||
|
||||
// When
|
||||
String result = whisper.fullTranscribe(/*params,*/ floats);
|
||||
|
||||
// Then
|
||||
System.out.println(result);
|
||||
assertEquals("And so my fellow Americans, ask not what your country can do for you, " +
|
||||
"ask what you can do for your country.",
|
||||
result);
|
||||
} finally {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class WhisperJnaLibraryTest {
|
||||
|
||||
@Test
|
||||
void testWhisperPrint_system_info() {
|
||||
String systemInfo = WhisperCppJnaLibrary.instance.whisper_print_system_info();
|
||||
// eg: "AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0
|
||||
// | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | COREML = 0 | "
|
||||
System.out.println("System info: " + systemInfo);
|
||||
assertTrue(systemInfo.length() > 10);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user