diff --git a/.gitignore b/.gitignore index c1e584db..91368ec5 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,5 @@ cmake-build-debug/ .cxx/ .gradle/ local.properties +.log +.exe \ No newline at end of file diff --git a/bindings/go/examples/go-model-download/context.go b/bindings/go/examples/go-model-download/context.go index 639d8f5b..7d5f0ddb 100644 --- a/bindings/go/examples/go-model-download/context.go +++ b/bindings/go/examples/go-model-download/context.go @@ -9,22 +9,23 @@ import ( // ContextForSignal returns a context object which is cancelled when a signal // is received. It returns nil if no signal parameter is provided func ContextForSignal(signals ...os.Signal) context.Context { - if len(signals) == 0 { - return nil - } + if len(signals) == 0 { + return nil + } - ch := make(chan os.Signal) - ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan os.Signal, 1) // Buffered channel with space for 1 signal + ctx, cancel := context.WithCancel(context.Background()) - // Send message on channel when signal received - signal.Notify(ch, signals...) + // Send message on channel when signal received + signal.Notify(ch, signals...) - // When any signal received, call cancel - go func() { - <-ch - cancel() - }() + // When any signal is received, call cancel + go func() { + <-ch + cancel() + }() - // Return success - return ctx + // Return success + return ctx } + diff --git a/bindings/go/examples/go-model-download/main.go b/bindings/go/examples/go-model-download/main.go index d0c1cc78..728c6df5 100644 --- a/bindings/go/examples/go-model-download/main.go +++ b/bindings/go/examples/go-model-download/main.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "path/filepath" + "strings" "syscall" "time" ) @@ -17,14 +18,27 @@ import ( // CONSTANTS const ( - srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" // The location of the models - srcExt = ".bin" // Filename extension - bufSize = 1024 * 64 // Size of the buffer used for downloading the model + srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models + srcExt = ".bin" // Filename extension + bufSize = 1024 * 64 // Size of the buffer used for downloading the model ) var ( // The models which will be downloaded, if no model is specified as an argument - modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3", "large-v3-turbo"} + modelNames = []string{ + "tiny", "tiny-q5_1", "tiny-q8_0", + "tiny.en", "tiny.en-q5_1", "tiny.en-q8_0", + "base", "base-q5_1", "base-q8_0", + "base.en", "base.en-q5_1", "base.en-q8_0", + "small", "small-q5_1", "small-q8_0", + "small.en", "small.en-q5_1", "small.en-q8_0", + "medium", "medium-q5_0", "medium-q8_0", + "medium.en", "medium.en-q5_0", "medium.en-q8_0", + "large-v1", + "large-v2", "large-v2-q5_0", "large-v2-q8_0", + "large-v3", "large-v3-q5_0", + "large-v3-turbo", "large-v3-turbo-q5_0", "large-v3-turbo-q8_0", + } ) var ( @@ -44,7 +58,25 @@ var ( func main() { flag.Usage = func() { name := filepath.Base(flag.CommandLine.Name()) - fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] \n\n", name) + fmt.Fprintf(flag.CommandLine.Output(), ` + Usage: %s [options] [...] + + Options: + -out string Specify the output folder where models will be saved. + Default: Current working directory. + -timeout duration Set the maximum duration for downloading a model. + Example: 10m, 1h (default: 30m0s). + -quiet Suppress all output except errors. + + Examples: + 1. Download a specific model: + %s -out ./models tiny-q8_0 + + 2. Download all models: + %s -out ./models + + `, name, name, name) + flag.PrintDefaults() } flag.Parse() @@ -114,23 +146,87 @@ func GetOut() (string, error) { // GetModels returns the list of models to download func GetModels() []string { if flag.NArg() == 0 { - return modelNames - } else { - return flag.Args() + fmt.Println("No model specified.") + fmt.Println("Preparing to download all models...") + + // Calculate total download size + fmt.Println("Calculating total download size...") + totalSize, err := CalculateTotalDownloadSize(modelNames) + if err != nil { + fmt.Println("Error calculating download sizes:", err) + os.Exit(1) + } + + fmt.Println("View available models: https://huggingface.co/ggerganov/whisper.cpp/tree/main") + fmt.Printf("Total download size: %.2f GB\n", float64(totalSize)/(1024*1024*1024)) + fmt.Println("Would you like to download all models? (y/N)") + + // Prompt for user input + var response string + fmt.Scanln(&response) + if response != "y" && response != "Y" { + fmt.Println("Aborting. Specify a model to download.") + os.Exit(0) + } + + return modelNames // Return all models if confirmed } + return flag.Args() // Return specific models if arguments are provided +} + +func CalculateTotalDownloadSize(models []string) (int64, error) { + var totalSize int64 + client := http.Client{} + + for _, model := range models { + modelURL, err := URLForModel(model) + if err != nil { + return 0, err + } + + // Issue a HEAD request to get the file size + req, err := http.NewRequest("HEAD", modelURL, nil) + if err != nil { + return 0, err + } + + resp, err := client.Do(req) + if err != nil { + return 0, err + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + fmt.Printf("Warning: Unable to fetch size for %s (HTTP %d)\n", model, resp.StatusCode) + continue + } + + size := resp.ContentLength + totalSize += size + } + return totalSize, nil } // URLForModel returns the URL for the given model on huggingface.co func URLForModel(model string) (string, error) { + // Ensure "ggml-" prefix is added only once + if !strings.HasPrefix(model, "ggml-") { + model = "ggml-" + model + } + + // Ensure ".bin" extension is added only once if filepath.Ext(model) != srcExt { model += srcExt } + + // Parse the base URL url, err := url.Parse(srcUrl) if err != nil { return "", err - } else { - url.Path = filepath.Join(url.Path, model) } + + // Ensure no trailing slash in the base URL + url.Path = fmt.Sprintf("%s/%s", strings.TrimSuffix(url.Path, "/"), model) return url.String(), nil } diff --git a/models/download-ggml-model.cmd b/models/download-ggml-model.cmd index f329011d..566aa1bf 100644 --- a/models/download-ggml-model.cmd +++ b/models/download-ggml-model.cmd @@ -8,7 +8,18 @@ popd set argc=0 for %%x in (%*) do set /A argc+=1 -set models=tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo +set models=tiny tiny-q5_1 tiny-q8_0 ^ +tiny.en tiny.en-q5_1 tiny.en-q8_0 ^ +base base-q5_1 base-q8_0 ^ +base.en base.en-q5_1 base.en-q8_0 ^ +small small-q5_1 small-q8_0 ^ +small.en small.en-q5_1 small.en-q8_0 ^ +medium medium-q5_0 medium-q8_0 ^ +medium.en medium.en-q5_0 medium.en-q8_0 ^ +large-v1 ^ +large-v2 large-v2-q5_0 large-v2-q8_0 ^ +large-v3 large-v3-q5_0 ^ +large-v3-turbo large-v3-turbo-q5_0 large-v3-turbo-q8_0 if %argc% neq 1 ( echo. @@ -50,7 +61,7 @@ if %ERRORLEVEL% neq 0 ( echo Done! Model %model% saved in %root_path%\models\ggml-%model%.bin echo You can now use it like this: -echo build\bin\Release\whisper-cli.exe -m %root_path%\models\ggml-%model%.bin -f %root_path%\samples\jfk.wav +echo %~dp0build\bin\Release\whisper-cli.exe -m %root_path%\models\ggml-%model%.bin -f %root_path%\samples\jfk.wav goto :eof