mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-14 14:33:18 +00:00
feat: allow to set cors (#339)
This commit is contained in:
parent
ed5df1e68e
commit
6f54cab3f0
67
api/api.go
67
api/api.go
@ -1,10 +1,8 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||||
@ -13,16 +11,18 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
|
func App(opts ...AppOption) *fiber.App {
|
||||||
|
options := newOptions(opts...)
|
||||||
|
|
||||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||||
if debug {
|
if options.debug {
|
||||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return errors as JSON responses
|
// Return errors as JSON responses
|
||||||
app := fiber.New(fiber.Config{
|
app := fiber.New(fiber.Config{
|
||||||
BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||||
DisableStartupMessage: disableMessage,
|
DisableStartupMessage: options.disableMessage,
|
||||||
// Override default error handler
|
// Override default error handler
|
||||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||||
// Status code defaults to 500
|
// Status code defaults to 500
|
||||||
@ -43,24 +43,24 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if debug {
|
if options.debug {
|
||||||
app.Use(logger.New(logger.Config{
|
app.Use(logger.New(logger.Config{
|
||||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
cm := NewConfigMerger()
|
cm := NewConfigMerger()
|
||||||
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
|
if err := cm.LoadConfigs(options.loader.ModelPath); err != nil {
|
||||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if configFile != "" {
|
if options.configFile != "" {
|
||||||
if err := cm.LoadConfigFile(configFile); err != nil {
|
if err := cm.LoadConfigFile(options.configFile); err != nil {
|
||||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug {
|
if options.debug {
|
||||||
for _, v := range cm.ListConfigs() {
|
for _, v := range cm.ListConfigs() {
|
||||||
cfg, _ := cm.GetConfig(v)
|
cfg, _ := cm.GetConfig(v)
|
||||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||||
@ -68,46 +68,55 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload
|
|||||||
}
|
}
|
||||||
// Default middleware config
|
// Default middleware config
|
||||||
app.Use(recover.New())
|
app.Use(recover.New())
|
||||||
|
|
||||||
|
if options.cors {
|
||||||
|
if options.corsAllowOrigins == "" {
|
||||||
app.Use(cors.New())
|
app.Use(cors.New())
|
||||||
|
} else {
|
||||||
|
app.Use(cors.New(cors.Config{
|
||||||
|
AllowOrigins: options.corsAllowOrigins,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LocalAI API endpoints
|
// LocalAI API endpoints
|
||||||
applier := newGalleryApplier(loader.ModelPath)
|
applier := newGalleryApplier(options.loader.ModelPath)
|
||||||
applier.start(c, cm)
|
applier.start(options.context, cm)
|
||||||
app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C))
|
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C))
|
||||||
app.Get("/models/jobs/:uuid", getOpStatus(applier))
|
app.Get("/models/jobs/:uuid", getOpStatus(applier))
|
||||||
|
|
||||||
// openAI compatible API endpoint
|
// openAI compatible API endpoint
|
||||||
|
|
||||||
// chat
|
// chat
|
||||||
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/v1/chat/completions", chatEndpoint(cm, options))
|
||||||
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/chat/completions", chatEndpoint(cm, options))
|
||||||
|
|
||||||
// edit
|
// edit
|
||||||
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/v1/edits", editEndpoint(cm, options))
|
||||||
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/edits", editEndpoint(cm, options))
|
||||||
|
|
||||||
// completion
|
// completion
|
||||||
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/v1/completions", completionEndpoint(cm, options))
|
||||||
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/completions", completionEndpoint(cm, options))
|
||||||
|
|
||||||
// embeddings
|
// embeddings
|
||||||
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))
|
||||||
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/embeddings", embeddingsEndpoint(cm, options))
|
||||||
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options))
|
||||||
|
|
||||||
// audio
|
// audio
|
||||||
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
|
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options))
|
||||||
|
|
||||||
// images
|
// images
|
||||||
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir))
|
app.Post("/v1/images/generations", imageEndpoint(cm, options))
|
||||||
|
|
||||||
if imageDir != "" {
|
if options.imageDir != "" {
|
||||||
app.Static("/generated-images", imageDir)
|
app.Static("/generated-images", options.imageDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// models
|
// models
|
||||||
app.Get("/v1/models", listModels(loader, cm))
|
app.Get("/v1/models", listModels(options.loader, cm))
|
||||||
app.Get("/models", listModels(loader, cm))
|
app.Get("/models", listModels(options.loader, cm))
|
||||||
|
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,7 @@ var _ = Describe("API test", func() {
|
|||||||
modelLoader = model.NewModelLoader(tmpdir)
|
modelLoader = model.NewModelLoader(tmpdir)
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
|
app = App(WithContext(c), WithModelLoader(modelLoader))
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@ -198,7 +198,7 @@ var _ = Describe("API test", func() {
|
|||||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
|
app = App(WithContext(c), WithModelLoader(modelLoader))
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@ -316,7 +316,7 @@ var _ = Describe("API test", func() {
|
|||||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
|
app = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE")))
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
|
@ -142,15 +142,15 @@ func defaultRequest(modelFile string) OpenAIRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/completions
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
|
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
model, input, err := readInput(c, loader, true)
|
model, input, err := readInput(c, o.loader, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16)
|
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -166,7 +166,7 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
|
|||||||
var result []Choice
|
var result []Choice
|
||||||
for _, i := range config.PromptStrings {
|
for _, i := range config.PromptStrings {
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
|
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||||
Input string
|
Input string
|
||||||
}{Input: i})
|
}{Input: i})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -174,7 +174,7 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
|
|||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) {
|
r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) {
|
||||||
*c = append(*c, Choice{Text: s})
|
*c = append(*c, Choice{Text: s})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -199,14 +199,14 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/embeddings
|
// https://platform.openai.com/docs/api-reference/embeddings
|
||||||
func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
|
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
model, input, err := readInput(c, loader, true)
|
model, input, err := readInput(c, o.loader, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16)
|
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -216,7 +216,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
|
|||||||
|
|
||||||
for i, s := range config.InputToken {
|
for i, s := range config.InputToken {
|
||||||
// get the model function to call for the result
|
// get the model function to call for the result
|
||||||
embedFn, err := ModelEmbedding("", s, loader, *config)
|
embedFn, err := ModelEmbedding("", s, o.loader, *config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -230,7 +230,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
|
|||||||
|
|
||||||
for i, s := range config.InputStrings {
|
for i, s := range config.InputStrings {
|
||||||
// get the model function to call for the result
|
// get the model function to call for the result
|
||||||
embedFn, err := ModelEmbedding(s, []int{}, loader, *config)
|
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -256,7 +256,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
|
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||||
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||||
@ -273,12 +273,12 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
|
|||||||
close(responses)
|
close(responses)
|
||||||
}
|
}
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
model, input, err := readInput(c, loader, true)
|
model, input, err := readInput(c, o.loader, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16)
|
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -319,7 +319,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
|
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||||
Input string
|
Input string
|
||||||
}{Input: predInput})
|
}{Input: predInput})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -330,7 +330,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
|
|||||||
if input.Stream {
|
if input.Stream {
|
||||||
responses := make(chan OpenAIResponse)
|
responses := make(chan OpenAIResponse)
|
||||||
|
|
||||||
go process(predInput, input, config, loader, responses)
|
go process(predInput, input, config, o.loader, responses)
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
|
|
||||||
@ -358,7 +358,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {
|
result, err := ComputeChoices(predInput, input, config, o.loader, func(s string, c *[]Choice) {
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -378,14 +378,14 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
|
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
model, input, err := readInput(c, loader, true)
|
model, input, err := readInput(c, o.loader, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16)
|
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -401,7 +401,7 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
|
|||||||
var result []Choice
|
var result []Choice
|
||||||
for _, i := range config.InputStrings {
|
for _, i := range config.InputStrings {
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
|
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||||
Input string
|
Input string
|
||||||
Instruction string
|
Instruction string
|
||||||
}{Input: i})
|
}{Input: i})
|
||||||
@ -410,7 +410,7 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
|
|||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) {
|
r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) {
|
||||||
*c = append(*c, Choice{Text: s})
|
*c = append(*c, Choice{Text: s})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -449,9 +449,9 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
|
|||||||
|
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
|
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
m, input, err := readInput(c, loader, false)
|
m, input, err := readInput(c, o.loader, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -461,7 +461,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
|
|||||||
}
|
}
|
||||||
log.Debug().Msgf("Loading model: %+v", m)
|
log.Debug().Msgf("Loading model: %+v", m)
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, loader, debug, 0, 0, false)
|
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -518,7 +518,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
|
|||||||
|
|
||||||
tempDir := ""
|
tempDir := ""
|
||||||
if !b64JSON {
|
if !b64JSON {
|
||||||
tempDir = imageDir
|
tempDir = o.imageDir
|
||||||
}
|
}
|
||||||
// Create a temporary file
|
// Create a temporary file
|
||||||
outputFile, err := ioutil.TempFile(tempDir, "b64")
|
outputFile, err := ioutil.TempFile(tempDir, "b64")
|
||||||
@ -535,7 +535,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
|
|||||||
|
|
||||||
baseURL := c.BaseURL()
|
baseURL := c.BaseURL()
|
||||||
|
|
||||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, loader, *config)
|
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -574,14 +574,14 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
|
|||||||
}
|
}
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/audio/create
|
// https://platform.openai.com/docs/api-reference/audio/create
|
||||||
func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
|
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
m, input, err := readInput(c, loader, false)
|
m, input, err := readInput(c, o.loader, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, loader, debug, threads, ctx, f16)
|
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -616,7 +616,7 @@ func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
|
|||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||||
|
|
||||||
whisperModel, err := loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads))
|
whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
108
api/options.go
Normal file
108
api/options.go
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Option struct {
|
||||||
|
context context.Context
|
||||||
|
configFile string
|
||||||
|
loader *model.ModelLoader
|
||||||
|
uploadLimitMB, threads, ctxSize int
|
||||||
|
f16 bool
|
||||||
|
debug, disableMessage bool
|
||||||
|
imageDir string
|
||||||
|
cors bool
|
||||||
|
corsAllowOrigins string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppOption func(*Option)
|
||||||
|
|
||||||
|
func newOptions(o ...AppOption) *Option {
|
||||||
|
opt := &Option{
|
||||||
|
context: context.Background(),
|
||||||
|
uploadLimitMB: 15,
|
||||||
|
threads: 1,
|
||||||
|
ctxSize: 512,
|
||||||
|
debug: true,
|
||||||
|
disableMessage: true,
|
||||||
|
}
|
||||||
|
for _, oo := range o {
|
||||||
|
oo(opt)
|
||||||
|
}
|
||||||
|
return opt
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCors(b bool) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.cors = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCorsAllowOrigins(b string) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.corsAllowOrigins = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithContext(ctx context.Context) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.context = ctx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithConfigFile(configFile string) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.configFile = configFile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.loader = loader
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithUploadLimitMB(limit int) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.uploadLimitMB = limit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithThreads(threads int) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.threads = threads
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithContextSize(ctxSize int) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.ctxSize = ctxSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithF16(f16 bool) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.f16 = f16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDebug(debug bool) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.debug = debug
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDisableMessage(disableMessage bool) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.disableMessage = disableMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithImageDir(imageDir string) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.imageDir = imageDir
|
||||||
|
}
|
||||||
|
}
|
21
main.go
21
main.go
@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -34,6 +33,14 @@ func main() {
|
|||||||
Name: "debug",
|
Name: "debug",
|
||||||
EnvVars: []string{"DEBUG"},
|
EnvVars: []string{"DEBUG"},
|
||||||
},
|
},
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "cors",
|
||||||
|
EnvVars: []string{"CORS"},
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "cors-allow-origins",
|
||||||
|
EnvVars: []string{"CORS_ALLOW_ORIGINS"},
|
||||||
|
},
|
||||||
&cli.IntFlag{
|
&cli.IntFlag{
|
||||||
Name: "threads",
|
Name: "threads",
|
||||||
DefaultText: "Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested.",
|
DefaultText: "Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested.",
|
||||||
@ -94,7 +101,17 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
|
|||||||
Copyright: "go-skynet authors",
|
Copyright: "go-skynet authors",
|
||||||
Action: func(ctx *cli.Context) error {
|
Action: func(ctx *cli.Context) error {
|
||||||
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
|
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
|
||||||
return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address"))
|
return api.App(
|
||||||
|
api.WithConfigFile(ctx.String("config-file")),
|
||||||
|
api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
|
||||||
|
api.WithContextSize(ctx.Int("context-size")),
|
||||||
|
api.WithDebug(ctx.Bool("debug")),
|
||||||
|
api.WithImageDir(ctx.String("image-path")),
|
||||||
|
api.WithF16(ctx.Bool("f16")),
|
||||||
|
api.WithCors(ctx.Bool("cors")),
|
||||||
|
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
||||||
|
api.WithThreads(ctx.Int("threads")),
|
||||||
|
api.WithUploadLimitMB(ctx.Int("upload-limit"))).Listen(ctx.String("address"))
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user