feat: allow to set cors (#339)

This commit is contained in:
Ettore Di Giacinto 2023-05-21 14:38:25 +02:00 committed by GitHub
parent ed5df1e68e
commit 6f54cab3f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 199 additions and 65 deletions

View File

@ -1,10 +1,8 @@
package api
import (
"context"
"errors"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
@ -13,16 +11,18 @@ import (
"github.com/rs/zerolog/log"
)
func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
func App(opts ...AppOption) *fiber.App {
options := newOptions(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if debug {
if options.debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: disableMessage,
BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.disableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
@ -43,24 +43,24 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload
},
})
if debug {
if options.debug {
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
cm := NewConfigMerger()
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
if err := cm.LoadConfigs(options.loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if configFile != "" {
if err := cm.LoadConfigFile(configFile); err != nil {
if options.configFile != "" {
if err := cm.LoadConfigFile(options.configFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if debug {
if options.debug {
for _, v := range cm.ListConfigs() {
cfg, _ := cm.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
@ -68,46 +68,55 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload
}
// Default middleware config
app.Use(recover.New())
app.Use(cors.New())
if options.cors {
if options.corsAllowOrigins == "" {
app.Use(cors.New())
} else {
app.Use(cors.New(cors.Config{
AllowOrigins: options.corsAllowOrigins,
}))
}
}
// LocalAI API endpoints
applier := newGalleryApplier(loader.ModelPath)
applier.start(c, cm)
app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C))
applier := newGalleryApplier(options.loader.ModelPath)
applier.start(options.context, cm)
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C))
app.Get("/models/jobs/:uuid", getOpStatus(applier))
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/v1/chat/completions", chatEndpoint(cm, options))
app.Post("/chat/completions", chatEndpoint(cm, options))
// edit
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/v1/edits", editEndpoint(cm, options))
app.Post("/edits", editEndpoint(cm, options))
// completion
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/v1/completions", completionEndpoint(cm, options))
app.Post("/completions", completionEndpoint(cm, options))
// embeddings
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))
app.Post("/embeddings", embeddingsEndpoint(cm, options))
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options))
// audio
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options))
// images
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir))
app.Post("/v1/images/generations", imageEndpoint(cm, options))
if imageDir != "" {
app.Static("/generated-images", imageDir)
if options.imageDir != "" {
app.Static("/generated-images", options.imageDir)
}
// models
app.Get("/v1/models", listModels(loader, cm))
app.Get("/models", listModels(loader, cm))
app.Get("/v1/models", listModels(options.loader, cm))
app.Get("/models", listModels(options.loader, cm))
return app
}

View File

@ -114,7 +114,7 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background())
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
app = App(WithContext(c), WithModelLoader(modelLoader))
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@ -198,7 +198,7 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background())
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
app = App(WithContext(c), WithModelLoader(modelLoader))
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@ -316,7 +316,7 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background())
app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
app = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE")))
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")

View File

@ -142,15 +142,15 @@ func defaultRequest(modelFile string) OpenAIRequest {
}
// https://platform.openai.com/docs/api-reference/completions
func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
model, input, err := readInput(c, o.loader, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16)
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -166,7 +166,7 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
var result []Choice
for _, i := range config.PromptStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: i})
if err == nil {
@ -174,7 +174,7 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) {
r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
}, nil)
if err != nil {
@ -199,14 +199,14 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
}
// https://platform.openai.com/docs/api-reference/embeddings
func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
model, input, err := readInput(c, o.loader, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16)
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -216,7 +216,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := ModelEmbedding("", s, loader, *config)
embedFn, err := ModelEmbedding("", s, o.loader, *config)
if err != nil {
return err
}
@ -230,7 +230,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := ModelEmbedding(s, []int{}, loader, *config)
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config)
if err != nil {
return err
}
@ -256,7 +256,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
}
}
func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
@ -273,12 +273,12 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
close(responses)
}
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
model, input, err := readInput(c, o.loader, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16)
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -319,7 +319,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
}
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: predInput})
if err == nil {
@ -330,7 +330,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
if input.Stream {
responses := make(chan OpenAIResponse)
go process(predInput, input, config, loader, responses)
go process(predInput, input, config, o.loader, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
@ -358,7 +358,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
return nil
}
result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {
result, err := ComputeChoices(predInput, input, config, o.loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
}, nil)
if err != nil {
@ -378,14 +378,14 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
}
}
func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
model, input, err := readInput(c, o.loader, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16)
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -401,7 +401,7 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
var result []Choice
for _, i := range config.InputStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string
Instruction string
}{Input: i})
@ -410,7 +410,7 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) {
r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
}, nil)
if err != nil {
@ -449,9 +449,9 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
*
*/
func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false)
m, input, err := readInput(c, o.loader, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -461,7 +461,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
}
log.Debug().Msgf("Loading model: %+v", m)
config, input, err := readConfig(m, input, cm, loader, debug, 0, 0, false)
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -518,7 +518,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
tempDir := ""
if !b64JSON {
tempDir = imageDir
tempDir = o.imageDir
}
// Create a temporary file
outputFile, err := ioutil.TempFile(tempDir, "b64")
@ -535,7 +535,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
baseURL := c.BaseURL()
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, loader, *config)
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config)
if err != nil {
return err
}
@ -574,14 +574,14 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
}
// https://platform.openai.com/docs/api-reference/audio/create
func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false)
m, input, err := readInput(c, o.loader, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(m, input, cm, loader, debug, threads, ctx, f16)
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -616,7 +616,7 @@ func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Audio file copied to: %+v", dst)
whisperModel, err := loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads))
whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads))
if err != nil {
return err
}

108
api/options.go Normal file
View File

@ -0,0 +1,108 @@
package api
import (
"context"
model "github.com/go-skynet/LocalAI/pkg/model"
)
type Option struct {
context context.Context
configFile string
loader *model.ModelLoader
uploadLimitMB, threads, ctxSize int
f16 bool
debug, disableMessage bool
imageDir string
cors bool
corsAllowOrigins string
}
type AppOption func(*Option)
func newOptions(o ...AppOption) *Option {
opt := &Option{
context: context.Background(),
uploadLimitMB: 15,
threads: 1,
ctxSize: 512,
debug: true,
disableMessage: true,
}
for _, oo := range o {
oo(opt)
}
return opt
}
func WithCors(b bool) AppOption {
return func(o *Option) {
o.cors = b
}
}
func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) {
o.corsAllowOrigins = b
}
}
func WithContext(ctx context.Context) AppOption {
return func(o *Option) {
o.context = ctx
}
}
func WithConfigFile(configFile string) AppOption {
return func(o *Option) {
o.configFile = configFile
}
}
func WithModelLoader(loader *model.ModelLoader) AppOption {
return func(o *Option) {
o.loader = loader
}
}
func WithUploadLimitMB(limit int) AppOption {
return func(o *Option) {
o.uploadLimitMB = limit
}
}
func WithThreads(threads int) AppOption {
return func(o *Option) {
o.threads = threads
}
}
func WithContextSize(ctxSize int) AppOption {
return func(o *Option) {
o.ctxSize = ctxSize
}
}
func WithF16(f16 bool) AppOption {
return func(o *Option) {
o.f16 = f16
}
}
func WithDebug(debug bool) AppOption {
return func(o *Option) {
o.debug = debug
}
}
func WithDisableMessage(disableMessage bool) AppOption {
return func(o *Option) {
o.disableMessage = disableMessage
}
}
func WithImageDir(imageDir string) AppOption {
return func(o *Option) {
o.imageDir = imageDir
}
}

21
main.go
View File

@ -1,7 +1,6 @@
package main
import (
"context"
"fmt"
"os"
"path/filepath"
@ -34,6 +33,14 @@ func main() {
Name: "debug",
EnvVars: []string{"DEBUG"},
},
&cli.BoolFlag{
Name: "cors",
EnvVars: []string{"CORS"},
},
&cli.StringFlag{
Name: "cors-allow-origins",
EnvVars: []string{"CORS_ALLOW_ORIGINS"},
},
&cli.IntFlag{
Name: "threads",
DefaultText: "Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested.",
@ -94,7 +101,17 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
Copyright: "go-skynet authors",
Action: func(ctx *cli.Context) error {
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address"))
return api.App(
api.WithConfigFile(ctx.String("config-file")),
api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
api.WithContextSize(ctx.Int("context-size")),
api.WithDebug(ctx.Bool("debug")),
api.WithImageDir(ctx.String("image-path")),
api.WithF16(ctx.Bool("f16")),
api.WithCors(ctx.Bool("cors")),
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
api.WithThreads(ctx.Int("threads")),
api.WithUploadLimitMB(ctx.Int("upload-limit"))).Listen(ctx.String("address"))
},
}