mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-12 16:03:41 +00:00
223 lines
6.6 KiB
Python
223 lines
6.6 KiB
Python
|
import os
|
||
|
import subprocess
|
||
|
import re
|
||
|
import csv
|
||
|
import wave
|
||
|
import contextlib
|
||
|
import argparse
|
||
|
|
||
|
|
||
|
# Custom action to handle comma-separated list
|
||
|
class ListAction(argparse.Action):
|
||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||
|
setattr(namespace, self.dest, [int(val) for val in values.split(",")])
|
||
|
|
||
|
|
||
|
parser = argparse.ArgumentParser(description="Benchmark the speech recognition model")
|
||
|
|
||
|
# Define the argument to accept a list
|
||
|
parser.add_argument(
|
||
|
"-t",
|
||
|
"--threads",
|
||
|
dest="threads",
|
||
|
action=ListAction,
|
||
|
default=[4],
|
||
|
help="List of thread counts to benchmark (comma-separated, default: 4)",
|
||
|
)
|
||
|
|
||
|
parser.add_argument(
|
||
|
"-p",
|
||
|
"--processors",
|
||
|
dest="processors",
|
||
|
action=ListAction,
|
||
|
default=[1],
|
||
|
help="List of processor counts to benchmark (comma-separated, default: 1)",
|
||
|
)
|
||
|
|
||
|
|
||
|
parser.add_argument(
|
||
|
"-f",
|
||
|
"--filename",
|
||
|
type=str,
|
||
|
default="./samples/jfk.wav",
|
||
|
help="Relative path of the file to transcribe (default: ./samples/jfk.wav)",
|
||
|
)
|
||
|
|
||
|
# Parse the command line arguments
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
sample_file = args.filename
|
||
|
|
||
|
threads = args.threads
|
||
|
processors = args.processors
|
||
|
|
||
|
# Define the models, threads, and processor counts to benchmark
|
||
|
models = [
|
||
|
"ggml-tiny.en.bin",
|
||
|
"ggml-tiny.bin",
|
||
|
"ggml-base.en.bin",
|
||
|
"ggml-base.bin",
|
||
|
"ggml-small.en.bin",
|
||
|
"ggml-small.bin",
|
||
|
"ggml-medium.en.bin",
|
||
|
"ggml-medium.bin",
|
||
|
"ggml-large.bin",
|
||
|
]
|
||
|
|
||
|
|
||
|
metal_device = ""
|
||
|
|
||
|
# Initialize a dictionary to hold the results
|
||
|
results = {}
|
||
|
|
||
|
gitHashHeader = "Commit"
|
||
|
modelHeader = "Model"
|
||
|
hardwareHeader = "Hardware"
|
||
|
recordingLengthHeader = "Recording Length (seconds)"
|
||
|
threadHeader = "Thread"
|
||
|
processorCountHeader = "Processor Count"
|
||
|
loadTimeHeader = "Load Time (ms)"
|
||
|
sampleTimeHeader = "Sample Time (ms)"
|
||
|
encodeTimeHeader = "Encode Time (ms)"
|
||
|
decodeTimeHeader = "Decode Time (ms)"
|
||
|
sampleTimePerRunHeader = "Sample Time per Run (ms)"
|
||
|
encodeTimePerRunHeader = "Encode Time per Run (ms)"
|
||
|
decodeTimePerRunHeader = "Decode Time per Run (ms)"
|
||
|
totalTimeHeader = "Total Time (ms)"
|
||
|
|
||
|
|
||
|
def check_file_exists(file: str) -> bool:
|
||
|
return os.path.isfile(file)
|
||
|
|
||
|
|
||
|
def get_git_short_hash() -> str:
|
||
|
try:
|
||
|
return (
|
||
|
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
|
||
|
.decode()
|
||
|
.strip()
|
||
|
)
|
||
|
except subprocess.CalledProcessError as e:
|
||
|
return ""
|
||
|
|
||
|
|
||
|
def wav_file_length(file: str = sample_file) -> float:
|
||
|
with contextlib.closing(wave.open(file, "r")) as f:
|
||
|
frames = f.getnframes()
|
||
|
rate = f.getframerate()
|
||
|
duration = frames / float(rate)
|
||
|
return duration
|
||
|
|
||
|
|
||
|
def extract_metrics(output: str, label: str) -> tuple[float, float]:
|
||
|
match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output)
|
||
|
time = float(match.group(1)) if match else None
|
||
|
runs = float(match.group(2)) if match else None
|
||
|
return time, runs
|
||
|
|
||
|
|
||
|
def extract_device(output: str) -> str:
|
||
|
match = re.search(r"picking default device: (.*)", output)
|
||
|
device = match.group(1) if match else "Not found"
|
||
|
return device
|
||
|
|
||
|
|
||
|
# Check if the sample file exists
|
||
|
if not check_file_exists(sample_file):
|
||
|
raise FileNotFoundError(f"Sample file {sample_file} not found")
|
||
|
|
||
|
recording_length = wav_file_length()
|
||
|
|
||
|
|
||
|
# Check that all models exist
|
||
|
# Filter out models from list that are not downloaded
|
||
|
filtered_models = []
|
||
|
for model in models:
|
||
|
if check_file_exists(f"models/{model}"):
|
||
|
filtered_models.append(model)
|
||
|
else:
|
||
|
print(f"Model {model} not found, removing from list")
|
||
|
|
||
|
models = filtered_models
|
||
|
|
||
|
# Loop over each combination of parameters
|
||
|
for model in filtered_models:
|
||
|
for thread in threads:
|
||
|
for processor_count in processors:
|
||
|
# Construct the command to run
|
||
|
cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}"
|
||
|
# Run the command and get the output
|
||
|
process = subprocess.Popen(
|
||
|
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
||
|
)
|
||
|
|
||
|
output = ""
|
||
|
while process.poll() is None:
|
||
|
output += process.stdout.read().decode()
|
||
|
|
||
|
# Parse the output
|
||
|
load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output)
|
||
|
load_time = float(load_time_match.group(1)) if load_time_match else None
|
||
|
|
||
|
metal_device = extract_device(output)
|
||
|
sample_time, sample_runs = extract_metrics(output, "sample time")
|
||
|
encode_time, encode_runs = extract_metrics(output, "encode time")
|
||
|
decode_time, decode_runs = extract_metrics(output, "decode time")
|
||
|
|
||
|
total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output)
|
||
|
total_time = float(total_time_match.group(1)) if total_time_match else None
|
||
|
|
||
|
model_name = model.replace("ggml-", "").replace(".bin", "")
|
||
|
|
||
|
print(
|
||
|
f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms"
|
||
|
)
|
||
|
# Store the times in the results dictionary
|
||
|
results[(model_name, thread, processor_count)] = {
|
||
|
loadTimeHeader: load_time,
|
||
|
sampleTimeHeader: sample_time,
|
||
|
encodeTimeHeader: encode_time,
|
||
|
decodeTimeHeader: decode_time,
|
||
|
sampleTimePerRunHeader: round(sample_time / sample_runs, 2),
|
||
|
encodeTimePerRunHeader: round(encode_time / encode_runs, 2),
|
||
|
decodeTimePerRunHeader: round(decode_time / decode_runs, 2),
|
||
|
totalTimeHeader: total_time,
|
||
|
}
|
||
|
|
||
|
# Write the results to a CSV file
|
||
|
with open("benchmark_results.csv", "w", newline="") as csvfile:
|
||
|
fieldnames = [
|
||
|
gitHashHeader,
|
||
|
modelHeader,
|
||
|
hardwareHeader,
|
||
|
recordingLengthHeader,
|
||
|
threadHeader,
|
||
|
processorCountHeader,
|
||
|
loadTimeHeader,
|
||
|
sampleTimeHeader,
|
||
|
encodeTimeHeader,
|
||
|
decodeTimeHeader,
|
||
|
sampleTimePerRunHeader,
|
||
|
encodeTimePerRunHeader,
|
||
|
decodeTimePerRunHeader,
|
||
|
totalTimeHeader,
|
||
|
]
|
||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||
|
|
||
|
writer.writeheader()
|
||
|
|
||
|
shortHash = get_git_short_hash()
|
||
|
# Sort the results by total time in ascending order
|
||
|
sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0))
|
||
|
for params, times in sorted_results:
|
||
|
row = {
|
||
|
gitHashHeader: shortHash,
|
||
|
modelHeader: params[0],
|
||
|
hardwareHeader: metal_device,
|
||
|
recordingLengthHeader: recording_length,
|
||
|
threadHeader: params[1],
|
||
|
processorCountHeader: params[2],
|
||
|
}
|
||
|
row.update(times)
|
||
|
writer.writerow(row)
|