mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-12 20:18:08 +00:00
whisper.swiftui : add model download list & bench methods (#2546)
* swift : fix resources & exclude build * whisper : impl whisper_timings struct & api * whisper.swiftui : model list & bench methods * whisper : return ptr for whisper_get_timings * revert unnecessary change * whisper : avoid designated initializer * whisper.swiftui: code style changes * whisper.swiftui : get device name / os from UIDevice * whisper.swiftui : fix UIDevice usage * whisper.swiftui : add memcpy and ggml_mul_mat (commented)
This commit is contained in:
@ -0,0 +1,17 @@
|
||||
import Foundation
|
||||
|
||||
struct Model: Identifiable {
|
||||
var id = UUID()
|
||||
var name: String
|
||||
var info: String
|
||||
var url: String
|
||||
|
||||
var filename: String
|
||||
var fileURL: URL {
|
||||
FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename)
|
||||
}
|
||||
|
||||
func fileExists() -> Bool {
|
||||
FileManager.default.fileExists(atPath: fileURL.path)
|
||||
}
|
||||
}
|
@ -14,7 +14,7 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
private var recordedFile: URL? = nil
|
||||
private var audioPlayer: AVAudioPlayer?
|
||||
|
||||
private var modelUrl: URL? {
|
||||
private var builtInModelUrl: URL? {
|
||||
Bundle.main.url(forResource: "ggml-base.en", withExtension: "bin", subdirectory: "models")
|
||||
}
|
||||
|
||||
@ -28,23 +28,59 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
|
||||
override init() {
|
||||
super.init()
|
||||
loadModel()
|
||||
}
|
||||
|
||||
func loadModel(path: URL? = nil, log: Bool = true) {
|
||||
do {
|
||||
try loadModel()
|
||||
whisperContext = nil
|
||||
if (log) { messageLog += "Loading model...\n" }
|
||||
let modelUrl = path ?? builtInModelUrl
|
||||
if let modelUrl {
|
||||
whisperContext = try WhisperContext.createContext(path: modelUrl.path())
|
||||
if (log) { messageLog += "Loaded model \(modelUrl.lastPathComponent)\n" }
|
||||
} else {
|
||||
if (log) { messageLog += "Could not locate model\n" }
|
||||
}
|
||||
canTranscribe = true
|
||||
} catch {
|
||||
print(error.localizedDescription)
|
||||
messageLog += "\(error.localizedDescription)\n"
|
||||
if (log) { messageLog += "\(error.localizedDescription)\n" }
|
||||
}
|
||||
}
|
||||
|
||||
private func loadModel() throws {
|
||||
messageLog += "Loading model...\n"
|
||||
if let modelUrl {
|
||||
whisperContext = try WhisperContext.createContext(path: modelUrl.path())
|
||||
messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
|
||||
} else {
|
||||
messageLog += "Could not locate model\n"
|
||||
|
||||
func benchCurrentModel() async {
|
||||
if whisperContext == nil {
|
||||
messageLog += "Cannot bench without loaded model\n"
|
||||
return
|
||||
}
|
||||
messageLog += "Running benchmark for loaded model\n"
|
||||
let result = await whisperContext?.benchFull(modelName: "<current>", nThreads: Int32(min(4, cpuCount())))
|
||||
if (result != nil) { messageLog += result! + "\n" }
|
||||
}
|
||||
|
||||
func bench(models: [Model]) async {
|
||||
let nThreads = Int32(min(4, cpuCount()))
|
||||
|
||||
// messageLog += "Running memcpy benchmark\n"
|
||||
// messageLog += await WhisperContext.benchMemcpy(nThreads: nThreads) + "\n"
|
||||
//
|
||||
// messageLog += "Running ggml_mul_mat benchmark with \(nThreads) threads\n"
|
||||
// messageLog += await WhisperContext.benchGgmlMulMat(nThreads: nThreads) + "\n"
|
||||
|
||||
messageLog += "Running benchmark for all downloaded models\n"
|
||||
messageLog += "| CPU | OS | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |\n"
|
||||
messageLog += "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n"
|
||||
for model in models {
|
||||
loadModel(path: model.fileURL, log: false)
|
||||
if whisperContext == nil {
|
||||
messageLog += "Cannot bench without loaded model\n"
|
||||
break
|
||||
}
|
||||
let result = await whisperContext?.benchFull(modelName: model.name, nThreads: nThreads)
|
||||
if (result != nil) { messageLog += result! + "\n" }
|
||||
}
|
||||
messageLog += "Benchmarking completed\n"
|
||||
}
|
||||
|
||||
func transcribeSample() async {
|
||||
@ -160,3 +196,8 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
|
||||
isRecording = false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fileprivate func cpuCount() -> Int {
|
||||
ProcessInfo.processInfo.processorCount
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
import SwiftUI
|
||||
import AVFoundation
|
||||
import Foundation
|
||||
|
||||
struct ContentView: View {
|
||||
@StateObject var whisperState = WhisperState()
|
||||
@ -29,15 +30,125 @@ struct ContentView: View {
|
||||
Text(verbatim: whisperState.messageLog)
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
}
|
||||
.font(.footnote)
|
||||
.padding()
|
||||
.background(Color.gray.opacity(0.1))
|
||||
.cornerRadius(10)
|
||||
|
||||
HStack {
|
||||
Button("Clear Logs", action: {
|
||||
whisperState.messageLog = ""
|
||||
})
|
||||
.font(.footnote)
|
||||
.buttonStyle(.bordered)
|
||||
|
||||
Button("Copy Logs", action: {
|
||||
UIPasteboard.general.string = whisperState.messageLog
|
||||
})
|
||||
.font(.footnote)
|
||||
.buttonStyle(.bordered)
|
||||
|
||||
Button("Bench", action: {
|
||||
Task {
|
||||
await whisperState.benchCurrentModel()
|
||||
}
|
||||
})
|
||||
.font(.footnote)
|
||||
.buttonStyle(.bordered)
|
||||
.disabled(!whisperState.canTranscribe)
|
||||
|
||||
Button("Bench All", action: {
|
||||
Task {
|
||||
await whisperState.bench(models: ModelsView.getDownloadedModels())
|
||||
}
|
||||
})
|
||||
.font(.footnote)
|
||||
.buttonStyle(.bordered)
|
||||
.disabled(!whisperState.canTranscribe)
|
||||
}
|
||||
|
||||
NavigationLink(destination: ModelsView(whisperState: whisperState)) {
|
||||
Text("View Models")
|
||||
}
|
||||
.font(.footnote)
|
||||
.padding()
|
||||
}
|
||||
.navigationTitle("Whisper SwiftUI Demo")
|
||||
.padding()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ContentView_Previews: PreviewProvider {
|
||||
static var previews: some View {
|
||||
ContentView()
|
||||
struct ModelsView: View {
|
||||
@ObservedObject var whisperState: WhisperState
|
||||
@Environment(\.dismiss) var dismiss
|
||||
|
||||
private static let models: [Model] = [
|
||||
Model(name: "tiny", info: "(F16, 75 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin", filename: "tiny.bin"),
|
||||
Model(name: "tiny-q5_1", info: "(31 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny-q5_1.bin", filename: "tiny-q5_1.bin"),
|
||||
Model(name: "tiny-q8_0", info: "(42 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny-q8_0.bin", filename: "tiny-q8_0.bin"),
|
||||
Model(name: "tiny.en", info: "(F16, 75 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin", filename: "tiny.en.bin"),
|
||||
Model(name: "tiny.en-q5_1", info: "(31 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin", filename: "tiny.en-q5_1.bin"),
|
||||
Model(name: "tiny.en-q8_0", info: "(42 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q8_0.bin", filename: "tiny.en-q8_0.bin"),
|
||||
Model(name: "base", info: "(F16, 142 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin", filename: "base.bin"),
|
||||
Model(name: "base-q5_1", info: "(57 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base-q5_1.bin", filename: "base-q5_1.bin"),
|
||||
Model(name: "base-q8_0", info: "(78 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base-q8_0.bin", filename: "base-q8_0.bin"),
|
||||
Model(name: "base.en", info: "(F16, 142 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin", filename: "base.en.bin"),
|
||||
Model(name: "base.en-q5_1", info: "(57 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin", filename: "base.en-q5_1.bin"),
|
||||
Model(name: "base.en-q8_0", info: "(78 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q8_0.bin", filename: "base.en-q8_0.bin"),
|
||||
Model(name: "small", info: "(F16, 466 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin", filename: "small.bin"),
|
||||
Model(name: "small-q5_1", info: "(181 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small-q5_1.bin", filename: "small-q5_1.bin"),
|
||||
Model(name: "small-q8_0", info: "(252 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small-q8_0.bin", filename: "small-q8_0.bin"),
|
||||
Model(name: "small.en", info: "(F16, 466 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin", filename: "small.en.bin"),
|
||||
Model(name: "small.en-q5_1", info: "(181 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en-q5_1.bin", filename: "small.en-q5_1.bin"),
|
||||
Model(name: "small.en-q8_0", info: "(252 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en-q8_0.bin", filename: "small.en-q8_0.bin"),
|
||||
Model(name: "medium", info: "(F16, 1.5 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin", filename: "medium.bin"),
|
||||
Model(name: "medium-q5_0", info: "(514 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium-q5_0.bin", filename: "medium-q5_0.bin"),
|
||||
Model(name: "medium-q8_0", info: "(785 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium-q8_0.bin", filename: "medium-q8_0.bin"),
|
||||
Model(name: "medium.en", info: "(F16, 1.5 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin", filename: "medium.en.bin"),
|
||||
Model(name: "medium.en-q5_0", info: "(514 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en-q5_0.bin", filename: "medium.en-q5_0.bin"),
|
||||
Model(name: "medium.en-q8_0", info: "(785 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en-q8_0.bin", filename: "medium.en-q8_0.bin"),
|
||||
Model(name: "large-v1", info: "(F16, 2.9 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large.bin", filename: "large.bin"),
|
||||
Model(name: "large-v2", info: "(F16, 2.9 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v2.bin", filename: "large-v2.bin"),
|
||||
Model(name: "large-v2-q5_0", info: "(1.1 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v2-q5_0.bin", filename: "large-v2-q5_0.bin"),
|
||||
Model(name: "large-v2-q8_0", info: "(1.5 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v2-q8_0.bin", filename: "large-v2-q8_0.bin"),
|
||||
Model(name: "large-v3", info: "(F16, 2.9 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3.bin", filename: "large-v3.bin"),
|
||||
Model(name: "large-v3-q5_0", info: "(1.1 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-q5_0.bin", filename: "large-v3-q5_0.bin"),
|
||||
Model(name: "large-v3-turbo", info: "(F16, 1.5 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin", filename: "large-v3-turbo.bin"),
|
||||
Model(name: "large-v3-turbo-q5_0", info: "(547 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo-q5_0.bin", filename: "large-v3-turbo-q5_0.bin"),
|
||||
Model(name: "large-v3-turbo-q8_0", info: "(834 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo-q8_0.bin", filename: "large-v3-turbo-q8_0.bin"),
|
||||
]
|
||||
|
||||
static func getDownloadedModels() -> [Model] {
|
||||
// Filter models that have been downloaded
|
||||
return models.filter {
|
||||
FileManager.default.fileExists(atPath: $0.fileURL.path())
|
||||
}
|
||||
}
|
||||
|
||||
func loadModel(model: Model) {
|
||||
Task {
|
||||
dismiss()
|
||||
whisperState.loadModel(path: model.fileURL)
|
||||
}
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
List {
|
||||
Section(header: Text("Models")) {
|
||||
ForEach(ModelsView.models) { model in
|
||||
DownloadButton(model: model)
|
||||
.onLoad(perform: loadModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
.listStyle(GroupedListStyle())
|
||||
.navigationBarTitle("Models", displayMode: .inline).toolbar {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//struct ContentView_Previews: PreviewProvider {
|
||||
// static var previews: some View {
|
||||
// ContentView()
|
||||
// }
|
||||
//}
|
||||
|
@ -0,0 +1,102 @@
|
||||
import SwiftUI
|
||||
|
||||
struct DownloadButton: View {
|
||||
private var model: Model
|
||||
|
||||
@State private var status: String
|
||||
|
||||
@State private var downloadTask: URLSessionDownloadTask?
|
||||
@State private var progress = 0.0
|
||||
@State private var observation: NSKeyValueObservation?
|
||||
|
||||
private var onLoad: ((_ model: Model) -> Void)?
|
||||
|
||||
init(model: Model) {
|
||||
self.model = model
|
||||
status = model.fileExists() ? "downloaded" : "download"
|
||||
}
|
||||
|
||||
func onLoad(perform action: @escaping (_ model: Model) -> Void) -> DownloadButton {
|
||||
var button = self
|
||||
button.onLoad = action
|
||||
return button
|
||||
}
|
||||
|
||||
private func download() {
|
||||
status = "downloading"
|
||||
print("Downloading model \(model.name) from \(model.url)")
|
||||
guard let url = URL(string: model.url) else { return }
|
||||
|
||||
downloadTask = URLSession.shared.downloadTask(with: url) { temporaryURL, response, error in
|
||||
if let error = error {
|
||||
print("Error: \(error.localizedDescription)")
|
||||
return
|
||||
}
|
||||
|
||||
guard let response = response as? HTTPURLResponse, (200...299).contains(response.statusCode) else {
|
||||
print("Server error!")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
if let temporaryURL = temporaryURL {
|
||||
try FileManager.default.copyItem(at: temporaryURL, to: model.fileURL)
|
||||
print("Writing to \(model.filename) completed")
|
||||
status = "downloaded"
|
||||
}
|
||||
} catch let err {
|
||||
print("Error: \(err.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
observation = downloadTask?.progress.observe(\.fractionCompleted) { progress, _ in
|
||||
self.progress = progress.fractionCompleted
|
||||
}
|
||||
|
||||
downloadTask?.resume()
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
Button(action: {
|
||||
if (status == "download") {
|
||||
download()
|
||||
} else if (status == "downloading") {
|
||||
downloadTask?.cancel()
|
||||
status = "download"
|
||||
} else if (status == "downloaded") {
|
||||
if !model.fileExists() {
|
||||
download()
|
||||
}
|
||||
onLoad?(model)
|
||||
}
|
||||
}) {
|
||||
let title = "\(model.name) \(model.info)"
|
||||
if (status == "download") {
|
||||
Text("Download \(title)")
|
||||
} else if (status == "downloading") {
|
||||
Text("\(title) (Downloading \(Int(progress * 100))%)")
|
||||
} else if (status == "downloaded") {
|
||||
Text("Load \(title)")
|
||||
} else {
|
||||
Text("Unknown status")
|
||||
}
|
||||
}.swipeActions {
|
||||
if (status == "downloaded") {
|
||||
Button("Delete") {
|
||||
do {
|
||||
try FileManager.default.removeItem(at: model.fileURL)
|
||||
} catch {
|
||||
print("Error deleting file: \(error)")
|
||||
}
|
||||
status = "download"
|
||||
}
|
||||
.tint(.red)
|
||||
}
|
||||
}
|
||||
}
|
||||
.onDisappear() {
|
||||
downloadTask?.cancel()
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user