mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-18 02:40:01 +00:00
whisper: add tests and allow to set upload size (#237)
This commit is contained in:
parent
5115b2faa3
commit
fd1df4e971
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install build-essential ffmpeg
|
||||
- name: Test
|
||||
run: |
|
||||
make test
|
||||
@ -38,7 +38,7 @@ jobs:
|
||||
- name: Dependencies
|
||||
run: |
|
||||
brew update
|
||||
brew install sdl2
|
||||
brew install sdl2 ffmpeg
|
||||
- name: Test
|
||||
run: |
|
||||
make test
|
5
Makefile
5
Makefile
@ -179,12 +179,15 @@ run: prepare ## run local-ai
|
||||
|
||||
test-models/testmodel:
|
||||
mkdir test-models
|
||||
mkdir test-dir
|
||||
wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel
|
||||
wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||
wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
|
||||
cp tests/fixtures/* test-models
|
||||
|
||||
test: prepare test-models/testmodel
|
||||
cp tests/fixtures/* test-models
|
||||
@C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api
|
||||
@C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api
|
||||
|
||||
## Help:
|
||||
help: ## Show this help.
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App {
|
||||
func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App {
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
if debug {
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
@ -20,6 +20,7 @@ func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16
|
||||
|
||||
// Return errors as JSON responses
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: disableMessage,
|
||||
// Override default error handler
|
||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||
|
@ -3,6 +3,7 @@ package api_test
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/api"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
@ -23,7 +24,7 @@ var _ = Describe("API test", func() {
|
||||
Context("API query", func() {
|
||||
BeforeEach(func() {
|
||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||
app = App("", modelLoader, 1, 512, false, true, true)
|
||||
app = App("", modelLoader, 15, 1, 512, false, true, true)
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
@ -45,7 +46,7 @@ var _ = Describe("API test", func() {
|
||||
It("returns the models list", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models.Models)).To(Equal(3))
|
||||
Expect(len(models.Models)).To(Equal(4))
|
||||
Expect(models.Models[0].ID).To(Equal("testmodel"))
|
||||
})
|
||||
It("can generate completions", func() {
|
||||
@ -81,13 +82,23 @@ var _ = Describe("API test", func() {
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:"))
|
||||
})
|
||||
|
||||
PIt("transcribes audio", func() {
|
||||
resp, err := client.CreateTranscription(
|
||||
context.Background(),
|
||||
openai.AudioRequest{
|
||||
Model: openai.Whisper1,
|
||||
FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"),
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Config file", func() {
|
||||
BeforeEach(func() {
|
||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 1, 512, false, true, true)
|
||||
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true)
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
@ -108,7 +119,7 @@ var _ = Describe("API test", func() {
|
||||
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models.Models)).To(Equal(5))
|
||||
Expect(len(models.Models)).To(Equal(6))
|
||||
Expect(models.Models[0].ID).To(Equal("testmodel"))
|
||||
})
|
||||
It("can generate chat completions from config file", func() {
|
||||
@ -134,5 +145,6 @@ var _ = Describe("API test", func() {
|
||||
Expect(len(resp.Choices)).To(Equal(1))
|
||||
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
})
|
||||
})
|
||||
|
@ -409,14 +409,13 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
return err
|
||||
}
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
log.Debug().Msgf("Audio file: %+v", file)
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
|
||||
@ -428,26 +427,33 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
|
||||
dst := filepath.Join(dir, path.Base(file.Filename))
|
||||
dstFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil {
|
||||
log.Debug().Msgf("Audio file %+v - %+v - err %+v", file.Filename, dst, err)
|
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
whisperModel, err := loader.BackendLoader("whisper", config.Model, []llama.ModelOption{}, uint32(config.Threads))
|
||||
whisperModel, err := loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads))
|
||||
if err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
return err
|
||||
}
|
||||
|
||||
w := whisperModel.(whisper.Model)
|
||||
if whisperModel == nil {
|
||||
return fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
tr, err := whisperutil.Transcript(w, dst, input.Language)
|
||||
w, ok := whisperModel.(whisper.Model)
|
||||
if !ok {
|
||||
return fmt.Errorf("loader returned non-whisper object")
|
||||
}
|
||||
|
||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads))
|
||||
if err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
|
8
main.go
8
main.go
@ -62,6 +62,12 @@ func main() {
|
||||
EnvVars: []string{"CONTEXT_SIZE"},
|
||||
Value: 512,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "upload-limit",
|
||||
DefaultText: "Default upload-limit. MB",
|
||||
EnvVars: []string{"UPLOAD_LIMIT"},
|
||||
Value: 15,
|
||||
},
|
||||
},
|
||||
Description: `
|
||||
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
|
||||
@ -81,7 +87,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
|
||||
Copyright: "go-skynet authors",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
|
||||
return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address"))
|
||||
return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address"))
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ func audioToWav(src, dst string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Transcript(model whisper.Model, audiopath, language string) (string, error) {
|
||||
func Transcript(model whisper.Model, audiopath, language string, threads uint) (string, error) {
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
@ -65,8 +65,12 @@ func Transcript(model whisper.Model, audiopath, language string) (string, error)
|
||||
|
||||
}
|
||||
|
||||
context.SetThreads(threads)
|
||||
|
||||
if language != "" {
|
||||
context.SetLanguage(language)
|
||||
} else {
|
||||
context.SetLanguage("auto")
|
||||
}
|
||||
|
||||
if err := context.Process(data, nil); err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user