mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat: add API_KEY list support (#877)
Co-authored-by: Harold Sun <sunhua@amazon.com>
This commit is contained in:
parent
8c781a6a44
commit
1d1cae8e4d
65
api/api.go
65
api/api.go
@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/localai"
|
||||
@ -89,6 +90,32 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
||||
// Default middleware config
|
||||
app.Use(recover.New())
|
||||
|
||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||
auth := func(c *fiber.Ctx) error {
|
||||
if len(options.ApiKeys) > 0 {
|
||||
authHeader := c.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
||||
}
|
||||
authHeaderParts := strings.Split(authHeader, " ")
|
||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
||||
}
|
||||
|
||||
apiKey := authHeaderParts[1]
|
||||
validApiKey := false
|
||||
for _, key := range options.ApiKeys {
|
||||
if apiKey == key {
|
||||
validApiKey = true
|
||||
}
|
||||
}
|
||||
if !validApiKey {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||
}
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil {
|
||||
return nil, err
|
||||
@ -116,42 +143,42 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
||||
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
|
||||
galleryService.Start(options.Context, cm)
|
||||
|
||||
app.Get("/version", func(c *fiber.Ctx) error {
|
||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
Version string `json:"version"`
|
||||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
app.Post("/models/apply", localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries))
|
||||
app.Get("/models/available", localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
|
||||
app.Get("/models/jobs/:uuid", localai.GetOpStatusEndpoint(galleryService))
|
||||
app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries))
|
||||
app.Get("/models/available", auth, localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
|
||||
app.Get("/models/jobs/:uuid", auth, localai.GetOpStatusEndpoint(galleryService))
|
||||
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions", openai.ChatEndpoint(cm, options))
|
||||
app.Post("/chat/completions", openai.ChatEndpoint(cm, options))
|
||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cm, options))
|
||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cm, options))
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits", openai.EditEndpoint(cm, options))
|
||||
app.Post("/edits", openai.EditEndpoint(cm, options))
|
||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cm, options))
|
||||
app.Post("/edits", auth, openai.EditEndpoint(cm, options))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions", openai.CompletionEndpoint(cm, options))
|
||||
app.Post("/completions", openai.CompletionEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cm, options))
|
||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cm, options))
|
||||
app.Post("/completions", auth, openai.CompletionEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cm, options))
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cm, options))
|
||||
app.Post("/embeddings", openai.EmbeddingsEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cm, options))
|
||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
|
||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cm, options))
|
||||
app.Post("/tts", localai.TTSEndpoint(cm, options))
|
||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cm, options))
|
||||
app.Post("/tts", auth, localai.TTSEndpoint(cm, options))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", openai.ImageEndpoint(cm, options))
|
||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cm, options))
|
||||
|
||||
if options.ImageDir != "" {
|
||||
app.Static("/generated-images", options.ImageDir)
|
||||
@ -170,8 +197,8 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
||||
app.Get("/readyz", ok)
|
||||
|
||||
// models
|
||||
app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm))
|
||||
app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm))
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
go func() {
|
||||
|
@ -23,6 +23,7 @@ type Option struct {
|
||||
PreloadJSONModels string
|
||||
PreloadModelsFromPath string
|
||||
CORSAllowOrigins string
|
||||
ApiKeys []string
|
||||
|
||||
Galleries []gallery.Gallery
|
||||
|
||||
@ -184,3 +185,9 @@ func WithImageDir(imageDir string) AppOption {
|
||||
o.ImageDir = imageDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithApiKeys(apiKeys []string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.ApiKeys = apiKeys
|
||||
}
|
||||
}
|
||||
|
6
main.go
6
main.go
@ -130,6 +130,11 @@ func main() {
|
||||
EnvVars: []string{"UPLOAD_LIMIT"},
|
||||
Value: 15,
|
||||
},
|
||||
&cli.StringSliceFlag{
|
||||
Name: "api-keys",
|
||||
Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.",
|
||||
EnvVars: []string{"API_KEY"},
|
||||
},
|
||||
},
|
||||
Description: `
|
||||
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
|
||||
@ -167,6 +172,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||
options.WithUploadLimitMB(ctx.Int("upload-limit")),
|
||||
options.WithApiKeys(ctx.StringSlice("api-keys")),
|
||||
}
|
||||
|
||||
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
||||
|
Loading…
Reference in New Issue
Block a user