diff --git a/api/api.go b/api/api.go index 8dcefa24..66a1db09 100644 --- a/api/api.go +++ b/api/api.go @@ -2,6 +2,7 @@ package api import ( "errors" + "strings" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/localai" @@ -89,6 +90,32 @@ func App(opts ...options.AppOption) (*fiber.App, error) { // Default middleware config app.Use(recover.New()) + // Auth middleware checking if API key is valid. If no API key is set, no auth is required. + auth := func(c *fiber.Ctx) error { + if len(options.ApiKeys) > 0 { + authHeader := c.Get("Authorization") + if authHeader == "" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) + } + authHeaderParts := strings.Split(authHeader, " ") + if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + } + + apiKey := authHeaderParts[1] + validApiKey := false + for _, key := range options.ApiKeys { + if apiKey == key { + validApiKey = true + } + } + if !validApiKey { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + } + } + return c.Next() + } + if options.PreloadJSONModels != "" { if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil { return nil, err @@ -116,42 +143,42 @@ func App(opts ...options.AppOption) (*fiber.App, error) { galleryService := localai.NewGalleryService(options.Loader.ModelPath) galleryService.Start(options.Context, cm) - app.Get("/version", func(c *fiber.Ctx) error { + app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { Version string `json:"version"` }{Version: internal.PrintableVersion()}) }) - app.Post("/models/apply", localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries)) - app.Get("/models/available", localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath)) - app.Get("/models/jobs/:uuid", localai.GetOpStatusEndpoint(galleryService)) + app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries)) + app.Get("/models/available", auth, localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath)) + app.Get("/models/jobs/:uuid", auth, localai.GetOpStatusEndpoint(galleryService)) // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", openai.ChatEndpoint(cm, options)) - app.Post("/chat/completions", openai.ChatEndpoint(cm, options)) + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cm, options)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(cm, options)) // edit - app.Post("/v1/edits", openai.EditEndpoint(cm, options)) - app.Post("/edits", openai.EditEndpoint(cm, options)) + app.Post("/v1/edits", auth, openai.EditEndpoint(cm, options)) + app.Post("/edits", auth, openai.EditEndpoint(cm, options)) // completion - app.Post("/v1/completions", openai.CompletionEndpoint(cm, options)) - app.Post("/completions", openai.CompletionEndpoint(cm, options)) - app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cm, options)) + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cm, options)) + app.Post("/completions", auth, openai.CompletionEndpoint(cm, options)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cm, options)) // embeddings - app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cm, options)) - app.Post("/embeddings", openai.EmbeddingsEndpoint(cm, options)) - app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cm, options)) + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cm, options)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cm, options)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cm, options)) // audio - app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cm, options)) - app.Post("/tts", localai.TTSEndpoint(cm, options)) + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cm, options)) + app.Post("/tts", auth, localai.TTSEndpoint(cm, options)) // images - app.Post("/v1/images/generations", openai.ImageEndpoint(cm, options)) + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cm, options)) if options.ImageDir != "" { app.Static("/generated-images", options.ImageDir) @@ -170,8 +197,8 @@ func App(opts ...options.AppOption) (*fiber.App, error) { app.Get("/readyz", ok) // models - app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm)) - app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm)) + app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cm)) + app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cm)) // turn off any process that was started by GRPC if the context is canceled go func() { diff --git a/api/options/options.go b/api/options/options.go index b3269470..ada95d3f 100644 --- a/api/options/options.go +++ b/api/options/options.go @@ -23,6 +23,7 @@ type Option struct { PreloadJSONModels string PreloadModelsFromPath string CORSAllowOrigins string + ApiKeys []string Galleries []gallery.Gallery @@ -184,3 +185,9 @@ func WithImageDir(imageDir string) AppOption { o.ImageDir = imageDir } } + +func WithApiKeys(apiKeys []string) AppOption { + return func(o *Option) { + o.ApiKeys = apiKeys + } +} diff --git a/main.go b/main.go index 2cb86271..4f4b824c 100644 --- a/main.go +++ b/main.go @@ -130,6 +130,11 @@ func main() { EnvVars: []string{"UPLOAD_LIMIT"}, Value: 15, }, + &cli.StringSliceFlag{ + Name: "api-keys", + Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.", + EnvVars: []string{"API_KEY"}, + }, }, Description: ` LocalAI is a drop-in replacement OpenAI API which runs inference locally. @@ -167,6 +172,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), options.WithUploadLimitMB(ctx.Int("upload-limit")), + options.WithApiKeys(ctx.StringSlice("api-keys")), } externalgRPC := ctx.StringSlice("external-grpc-backends")