diff --git a/CMakeLists.txt b/CMakeLists.txt index be6db903..8117436f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,9 @@ include(CheckIncludeFileCXX) set(SOVERSION 1) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + #set(CMAKE_WARN_DEPRECATED YES) set(CMAKE_WARN_UNUSED_CLI YES) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e4265aff..5106907f 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -56,6 +56,8 @@ add_library(${TARGET} STATIC common-whisper.cpp grammar-parser.h grammar-parser.cpp + wer.h + wer.cpp ${COMMON_SOURCES_FFMPEG} ) @@ -114,6 +116,7 @@ else() add_subdirectory(sycl) endif() endif (WHISPER_SDL2) + add_subdirectory(wer) add_subdirectory(deprecation-warning) endif() diff --git a/examples/wer.cpp b/examples/wer.cpp new file mode 100644 index 00000000..13ce9d65 --- /dev/null +++ b/examples/wer.cpp @@ -0,0 +1,114 @@ +#include "wer.h" + +#include +#include +#include +#include +#include +#include +#include + +std::vector split_into_words(const std::string& text) { + std::vector words; + std::stringstream ss(text); + std::string word; + + while (ss >> word) { + words.push_back(word); + } + + return words; +} + +std::tuple count_edit_ops(const std::vector& reference, + const std::vector& actual) { + int m = reference.size(); + int n = actual.size(); + + // Levenshtein matrix + std::vector> l_matrix(m + 1, std::vector(n + 1, 0)); + + // Initialize the first row and column of the matrix. + for (int i = 0; i <= m; i++) { + l_matrix[i][0] = i; + } + + for (int j = 0; j <= n; j++) { + l_matrix[0][j] = j; + } + + // Fill the matrix. + for (int i = 1; i <= m; i++) { + for (int j = 1; j <= n; j++) { + if (reference[i-1] == actual[j-1]) { + l_matrix[i][j] = l_matrix[i-1][j-1]; + } else { + l_matrix[i][j] = 1 + std::min({ + l_matrix[i-1][j], // Deletion (top/above) + l_matrix[i][j-1], // Insertion (left) + l_matrix[i-1][j-1] // Substitution (diagonal) + }); + } + } + } + + // Start backtracking from the bottom-right corner of the matrix. + int i = m; // rows + int j = n; // columns + + int substitutions = 0; + int deletions = 0; + int insertions = 0; + + // Backtrack to find the edit operations. + while (i > 0 || j > 0) { + if (i > 0 && j > 0 && reference[i-1] == actual[j-1]) { + // Recalll that reference and actual are vectors, so this is just checking + // the same position in both to see if they are equal. If they are equal + // this means there was no edit operation, so we move diagonally. + i--; + j--; + } else if (i > 0 && j > 0 && l_matrix[i][j] == l_matrix[i-1][j-1] + 1) { + // Check the if the current cell is equal to the diagonal cell + 1 + // (for the operation cost), which means we have a substitution. + substitutions++; + i--; + j--; + } else if (i > 0 && l_matrix[i][j] == l_matrix[i-1][j] + 1) { + // Check if the current cell is equal the top/above cell + 1 + // (for the operation cost) which means we have a deletion. + deletions++; + i--; + } else { + // If there there was no match for the diagonal cell or the top/above + // cell, then we must be at the left cell, which means we have an insertion. + insertions++; + j--; + } + } + + return {substitutions, deletions, insertions}; +} + +wer_result calculate_wer(const std::string& reference_text, const std::string& actual_text) { + std::vector reference = split_into_words(reference_text); + std::vector actual = split_into_words(actual_text); + + auto [n_sub, n_del, n_ins] = count_edit_ops(reference, actual); + int n_edits = n_sub + n_del + n_ins; + + double wer = 0.0; + if (!reference.empty()) { + wer = static_cast(n_edits) / reference.size(); + } + + return wer_result{ + /* n_ref_words */ reference.size(), + /* n_act_words */ actual.size(), + /* n_sub */ n_sub, + /* n_del */ n_del, + /* n_ins */ n_ins, + /* n_edits */ n_edits, + /* wer */ wer + }; +} diff --git a/examples/wer.h b/examples/wer.h new file mode 100644 index 00000000..9ac87496 --- /dev/null +++ b/examples/wer.h @@ -0,0 +1,18 @@ +#ifndef WER_H +#define WER_H +#include +#include + +struct wer_result { + size_t n_ref_words; // Number of words in the reference text. + size_t n_act_words; // Number of words in the actual (transcribed) text. + int n_sub; // Number of substitutions. + int n_del; // Number of deletions. + int n_ins; // Number of insertions. + int n_edits; // Total number of edits. + double wer; // The word error rate. +}; + +wer_result calculate_wer(const std::string& reference_text, const std::string& actual_text); + +#endif // WER_H diff --git a/examples/wer/CMakeLists.txt b/examples/wer/CMakeLists.txt new file mode 100644 index 00000000..3315a1c5 --- /dev/null +++ b/examples/wer/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET whisper-wer) +add_executable(${TARGET} cli.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/wer/README.md b/examples/wer/README.md new file mode 100644 index 00000000..31a35d82 --- /dev/null +++ b/examples/wer/README.md @@ -0,0 +1,51 @@ +# whisper.cpp/examples/wer + +This is a command line tool for calculating the Word Error Rate (WER). This tool +expects that reference transcriptions (the known correct transcriptions) +and acutual transcriptions from whisper.cpp are available in two separate +directories where the file names are the identical. + +### Usage +```console +$ ./build/bin/whisper-wer +Usage: ./build/bin/whisper-wer [options] +Options: + -r, --reference PATH Full path to reference transcriptions directory + -a, --actual PATH Full path to actual transcriptions directory + --help Display this help message +``` + +### Example Usage with whisper-cli +First, generate transcription(s) using whisper-cli: +``` +./build/bin/whisper-cli \ + -m models/ggml-base.en.bin \ + -f samples/jfk.wav \ + --output-txt + ... + output_txt: saving output to 'samples/jfk.wav.txt' +``` +Next, copy the transcription to a directory where the actual transcriptions +are stored. In this example we will use a directory called `actual_transcriptions` +in this examples directory: +```console +$ cp samples/jfk.wav.txt examples/wer/actual_transcriptions +``` +In a real world scenario the reference transcriptions would be available +representing the known correct text. In this case we have already placed a file +in `examples/wer/reference_transcriptions` that can be used for testing, where +only a single word was changed (`Americans` -> `Swedes`). + +Finally, run the whisper-wer tool: +```console +$ ./build/bin/whisper-wer -r examples/wer/reference_transcriptions/ -a examples/wer/actual_transcriptions/ +Word Error Rate for : jfk.wav.txt + Reference words: 22 + Actual words: 22 + Substitutions: 1 + Deletions: 0 + Insertions: 0 + Total edits: 1 + WER: 0.045455 +``` + diff --git a/examples/wer/actual_transcriptions/jfk.wav.txt b/examples/wer/actual_transcriptions/jfk.wav.txt new file mode 100644 index 00000000..81291a1c --- /dev/null +++ b/examples/wer/actual_transcriptions/jfk.wav.txt @@ -0,0 +1 @@ + And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/examples/wer/cli.cpp b/examples/wer/cli.cpp new file mode 100644 index 00000000..1b6aee82 --- /dev/null +++ b/examples/wer/cli.cpp @@ -0,0 +1,148 @@ +#include "wer.h" + +#include +#include +#include +#include +#include +#include +#include + +std::vector read_files_from_directory(const std::string& dir_path) { + std::vector file_paths; + try { + for (const auto& entry : std::filesystem::directory_iterator(dir_path)) { + if (entry.is_regular_file() && entry.path().extension() == ".txt") { + file_paths.push_back(entry.path().string()); + } + } + } catch (const std::filesystem::filesystem_error& e) { + printf("Error reading directory %s: %s\n", dir_path.c_str(), e.what()); + } + return file_paths; +} + +std::string read_file_content(const std::string& file_path) { + std::ifstream file(file_path); + std::string content; + + if (file.is_open()) { + std::string line; + while (std::getline(file, line)) { + content += line + "\n"; + } + file.close(); + } else { + printf("Unable to open file: %s\n", file_path.c_str()); + } + + return content; +} + +std::string get_base_filename(const std::string& path) { + return std::filesystem::path(path).filename().string(); +} + +void print_usage(const char* program_name) { + printf("Usage: %s [options]\n", program_name); + printf("Options:\n"); + printf(" -r, --reference PATH Full path to reference transcriptions directory\n"); + printf(" -a, --actual PATH Full path to actual transcriptions directory\n"); + printf(" --help Display this help message\n"); +} + +int main(int argc, char** argv) { + if (argc == 1) { + print_usage(argv[0]); + return 0; + } + + std::string reference_path; + std::string actual_path; + bool reference_set = false; + bool actual_set = false; + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--help") == 0) { + print_usage(argv[0]); + return 0; + } else if (strcmp(argv[i], "-r") == 0 || strcmp(argv[i], "--reference") == 0) { + if (i + 1 < argc) { + reference_path = argv[++i]; + reference_set = true; + } else { + printf("Error: Missing path after %s\n", argv[i]); + print_usage(argv[0]); + return 1; + } + } else if (strcmp(argv[i], "-a") == 0 || strcmp(argv[i], "--actual") == 0) { + if (i + 1 < argc) { + actual_path = argv[++i]; + actual_set = true; + } else { + printf("Error: Missing path after %s\n", argv[i]); + print_usage(argv[0]); + return 1; + } + } else { + printf("Error: Unknown option: %s\n", argv[i]); + print_usage(argv[0]); + return 1; + } + } + + if (!reference_set || !actual_set) { + printf("Error: Both reference and actual paths must be provided\n"); + print_usage(argv[0]); + return 1; + } + + if (!std::filesystem::exists(reference_path) || !std::filesystem::is_directory(reference_path)) { + printf("Error: Reference path '%s' does not exist or is not a directory\n", reference_path.c_str()); + return 1; + } + + if (!std::filesystem::exists(actual_path) || !std::filesystem::is_directory(actual_path)) { + printf("Error: Actual path '%s' does not exist or is not a directory\n", actual_path.c_str()); + return 1; + } + + std::vector reference_files = read_files_from_directory(reference_path); + std::vector actual_files = read_files_from_directory(actual_path); + + //printf("Found %zu reference files in %s\n", reference_files.size(), reference_path.c_str()); + //printf("Found %zu actual files in %s\n", actual_files.size(), actual_path.c_str()); + + std::map reference_map; + std::map actual_map; + + for (const auto& file : reference_files) { + reference_map[get_base_filename(file)] = file; + } + + for (const auto& file : actual_files) { + actual_map[get_base_filename(file)] = file; + } + + for (const auto& [filename, ref_path] : reference_map) { + auto actual_it = actual_map.find(filename); + if (actual_it != actual_map.end()) { + std::string reference_content = read_file_content(ref_path); + std::string actual_content = read_file_content(actual_it->second); + + wer_result result = calculate_wer(reference_content, actual_content); + printf("Word Error Rate for : %s\n", filename.c_str()); + printf(" Reference words: %ld\n", result.n_ref_words); + printf(" Actual words: %ld\n", result.n_act_words); + printf(" Substitutions: %d\n", result.n_sub); + printf(" Deletions: %d\n", result.n_del); + printf(" Insertions: %d\n", result.n_ins); + printf(" Total edits: %d\n", result.n_edits); + printf(" WER: %f\n", result.wer); + } else { + printf("Warning: No matching actual file found for reference file: %s\n", filename.c_str()); + } + } + + return 0; +} diff --git a/examples/wer/reference_transcriptions/jfk.wav.txt b/examples/wer/reference_transcriptions/jfk.wav.txt new file mode 100644 index 00000000..6ba53be5 --- /dev/null +++ b/examples/wer/reference_transcriptions/jfk.wav.txt @@ -0,0 +1 @@ + And so my fellow Swedes, ask not what your country can do for you, ask what you can do for your country. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7cdfed82..6cdf7e65 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -85,3 +85,11 @@ if (WHISPER_FFMPEG) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3") endif() +find_package(Threads REQUIRED) + +# WER Unit Test +add_executable(test-wer test-wer.cpp) +target_include_directories(test-wer PRIVATE ../examples) +target_link_libraries(test-wer PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) +add_test(NAME test-wer COMMAND test-wer) +set_tests_properties(test-wer PROPERTIES LABELS "unit") diff --git a/tests/test-wer.cpp b/tests/test-wer.cpp new file mode 100644 index 00000000..2f72b751 --- /dev/null +++ b/tests/test-wer.cpp @@ -0,0 +1,20 @@ +#include "wer.h" + +#include +#include + +int main() { + std::string reference = "the cat sat on the mat"; + std::string actual = "the cat sat mat"; + + wer_result result = calculate_wer(reference, actual); + assert(result.n_ref_words == 6); + assert(result.n_act_words == 4); + assert(result.n_sub == 0); + assert(result.n_del == 2); + assert(result.n_ins == 0); + assert(result.n_edits == 2); + assert(std::abs(result.wer - 0.333333) < 0.0001); + + return 0; +}