feat: auth v2 - supersedes #2894 (#3476)

feat: auth v2 - supercedes #2894, metrics to follow later

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-09-16 23:29:07 -04:00 committed by GitHub
parent a9a3a07c3b
commit db1159b651
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 264 additions and 158 deletions

View File

@ -41,31 +41,34 @@ type RunCMD struct {
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" default:"512" help:"Default context size for models" group:"performance"`
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"`
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"`
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"`
WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"`
EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"`
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"`
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"`
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"`
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"`
WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"`
EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"`
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
}
func (r *RunCMD) Run(ctx *cliContext.Context) error {
@ -97,6 +100,9 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
config.WithOpaqueErrors(r.OpaqueErrors),
config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
config.WithSubtleKeyComparison(r.UseSubtleKeyComparison),
config.WithDisableApiKeyRequirementForHttpGet(r.DisableApiKeyRequirementForHttpGet),
config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
}

View File

@ -4,6 +4,7 @@ import (
"context"
"embed"
"encoding/json"
"regexp"
"time"
"github.com/mudler/LocalAI/pkg/xsysinfo"
@ -16,7 +17,6 @@ type ApplicationConfig struct {
ModelPath string
LibPath string
UploadLimitMB, Threads, ContextSize int
DisableWebUI bool
F16 bool
Debug bool
ImageDir string
@ -31,11 +31,17 @@ type ApplicationConfig struct {
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
EnforcePredownloadScans bool
OpaqueErrors bool
P2PToken string
P2PNetworkID string
DisableWebUI bool
EnforcePredownloadScans bool
OpaqueErrors bool
UseSubtleKeyComparison bool
DisableApiKeyRequirementForHttpGet bool
HttpGetExemptedEndpoints []*regexp.Regexp
DisableGalleryEndpoint bool
ModelLibraryURL string
Galleries []Gallery
@ -57,8 +63,6 @@ type ApplicationConfig struct {
ModelsURL []string
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
DisableGalleryEndpoint bool
}
type AppOption func(*ApplicationConfig)
@ -327,6 +331,32 @@ func WithOpaqueErrors(opaque bool) AppOption {
}
}
func WithSubtleKeyComparison(subtle bool) AppOption {
return func(o *ApplicationConfig) {
o.UseSubtleKeyComparison = subtle
}
}
func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption {
return func(o *ApplicationConfig) {
o.DisableApiKeyRequirementForHttpGet = required
}
}
func WithHttpGetExemptedEndpoints(endpoints []string) AppOption {
return func(o *ApplicationConfig) {
o.HttpGetExemptedEndpoints = []*regexp.Regexp{}
for _, epr := range endpoints {
r, err := regexp.Compile(epr)
if err == nil && r != nil {
o.HttpGetExemptedEndpoints = append(o.HttpGetExemptedEndpoints, r)
} else {
log.Warn().Err(err).Str("regex", epr).Msg("Error while compiling HTTP Get Exemption regex, skipping this entry.")
}
}
}
}
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
// Some options defined at the application level are going to be passed as defaults for
// all the configuration for the models.

View File

@ -3,13 +3,15 @@ package http
import (
"embed"
"errors"
"fmt"
"net/http"
"strings"
"github.com/dave-gray101/v2keyauth"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/config"
@ -137,37 +139,14 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
})
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}
authHeader := readAuthHeader(c)
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
// If it's a bearer token
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range appConfig.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
kaConfig, err := middleware.GetKeyAuthConfig(appConfig)
if err != nil || kaConfig == nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
}
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
app.Use(v2keyauth.New(*kaConfig))
if appConfig.CORS {
var c func(ctx *fiber.Ctx) error
if appConfig.CORSAllowOrigins == "" {
@ -192,13 +171,13 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
galleryService := services.NewGalleryService(appConfig)
galleryService.Start(appConfig.Context, cl)
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService, auth)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
if !appConfig.DisableWebUI {
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth)
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
}
routes.RegisterJINARoutes(app, cl, ml, appConfig, auth)
routes.RegisterJINARoutes(app, cl, ml, appConfig)
httpFS := http.FS(embedDirStatic)

View File

@ -0,0 +1,93 @@
package middleware
import (
"crypto/subtle"
"errors"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/mudler/LocalAI/core/config"
)
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key"}, keyauth.ConfigDefault.AuthScheme)
if err != nil {
return nil, err
}
return &v2keyauth.Config{
CustomKeyLookup: customLookup,
Next: getApiKeyRequiredFilterFunction(applicationConfig),
Validator: getApiKeyValidationFunction(applicationConfig),
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
AuthScheme: "Bearer",
}, nil
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
return func(ctx *fiber.Ctx, err error) error {
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
if len(applicationConfig.ApiKeys) == 0 {
return ctx.Next() // if no keys are set up, any error we get here is not an error.
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(403)
}
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500)
}
return err
}
}
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
if applicationConfig.UseSubtleKeyComparison {
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
return true, nil
}
}
return false, v2keyauth.ErrMissingOrMalformedAPIKey
}
}
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if apiKey == validKey {
return true, nil
}
}
return false, v2keyauth.ErrMissingOrMalformedAPIKey
}
}
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
if applicationConfig.DisableApiKeyRequirementForHttpGet {
return func(c *fiber.Ctx) bool {
if c.Method() != "GET" {
return false
}
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
if rx.MatchString(c.Path()) {
return true
}
}
return false
}
}
return func(c *fiber.Ctx) bool { return false }
}

View File

@ -10,12 +10,11 @@ import (
func RegisterElevenLabsRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
appConfig *config.ApplicationConfig) {
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation", auth, elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
}

View File

@ -11,8 +11,7 @@ import (
func RegisterJINARoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
appConfig *config.ApplicationConfig) {
// POST endpoint to mimic the reranking
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))

View File

@ -15,33 +15,32 @@ func RegisterLocalAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService,
auth func(*fiber.Ctx) error) {
galleryService *services.GalleryService) {
app.Get("/swagger/*", swagger.HandlerDefault) // default
// LocalAI API endpoints
if !appConfig.DisableGalleryEndpoint {
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint())
app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
}
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
// Kubernetes health checks
ok := func(c *fiber.Ctx) error {
@ -51,20 +50,20 @@ func RegisterLocalAIRoutes(app *fiber.App,
app.Get("/healthz", ok)
app.Get("/readyz", ok)
app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
// Experimental Backend Statistics Module
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService))
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// p2p
if p2p.IsP2PEnabled() {
app.Get("/api/p2p", auth, localai.ShowP2PNodes(appConfig))
app.Get("/api/p2p/token", auth, localai.ShowP2PToken(appConfig))
app.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
}
app.Get("/version", auth, func(c *fiber.Ctx) error {
app.Get("/version", func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})

View File

@ -11,66 +11,65 @@ import (
func RegisterOpenAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
appConfig *config.ApplicationConfig) {
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig))
// assistant
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
// files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig))
if appConfig.ImageDir != "" {
app.Static("/generated-images", appConfig.ImageDir)
@ -81,6 +80,6 @@ func RegisterOpenAIRoutes(app *fiber.App,
}
// List models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml))
app.Get("/models", openai.ListModelsEndpoint(cl, ml))
}

View File

@ -59,8 +59,7 @@ func RegisterUIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService,
auth func(*fiber.Ctx) error) {
galleryService *services.GalleryService) {
// keeps the state of models that are being installed from the UI
var processingModels = NewModelOpCache()
@ -85,10 +84,10 @@ func RegisterUIRoutes(app *fiber.App,
return processingModelsData, taskTypes
}
app.Get("/", auth, localai.WelcomeEndpoint(appConfig, cl, ml, modelStatus))
app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, modelStatus))
if p2p.IsP2PEnabled() {
app.Get("/p2p", auth, func(c *fiber.Ctx) error {
app.Get("/p2p", func(c *fiber.Ctx) error {
summary := fiber.Map{
"Title": "LocalAI - P2P dashboard",
"Version": internal.PrintableVersion(),
@ -104,17 +103,17 @@ func RegisterUIRoutes(app *fiber.App,
})
/* show nodes live! */
app.Get("/p2p/ui/workers", auth, func(c *fiber.Ctx) error {
app.Get("/p2p/ui/workers", func(c *fiber.Ctx) error {
return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))))
})
app.Get("/p2p/ui/workers-federation", auth, func(c *fiber.Ctx) error {
app.Get("/p2p/ui/workers-federation", func(c *fiber.Ctx) error {
return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))))
})
app.Get("/p2p/ui/workers-stats", auth, func(c *fiber.Ctx) error {
app.Get("/p2p/ui/workers-stats", func(c *fiber.Ctx) error {
return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))))
})
app.Get("/p2p/ui/workers-federation-stats", auth, func(c *fiber.Ctx) error {
app.Get("/p2p/ui/workers-federation-stats", func(c *fiber.Ctx) error {
return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))))
})
}
@ -122,7 +121,7 @@ func RegisterUIRoutes(app *fiber.App,
if !appConfig.DisableGalleryEndpoint {
// Show the Models page (all models)
app.Get("/browse", auth, func(c *fiber.Ctx) error {
app.Get("/browse", func(c *fiber.Ctx) error {
term := c.Query("term")
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
@ -167,7 +166,7 @@ func RegisterUIRoutes(app *fiber.App,
// Show the models, filtered from the user input
// https://htmx.org/examples/active-search/
app.Post("/browse/search/models", auth, func(c *fiber.Ctx) error {
app.Post("/browse/search/models", func(c *fiber.Ctx) error {
form := struct {
Search string `form:"search"`
}{}
@ -188,7 +187,7 @@ func RegisterUIRoutes(app *fiber.App,
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error {
app.Post("/browse/install/model/:id", func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
log.Debug().Msgf("UI job submitted to install : %+v\n", galleryID)
@ -215,7 +214,7 @@ func RegisterUIRoutes(app *fiber.App,
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/delete/model/:id", auth, func(c *fiber.Ctx) error {
app.Post("/browse/delete/model/:id", func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
log.Debug().Msgf("UI job submitted to delete : %+v\n", galleryID)
var galleryName = galleryID
@ -255,7 +254,7 @@ func RegisterUIRoutes(app *fiber.App,
// Display the job current progress status
// If the job is done, we trigger the /browse/job/:uid route
// https://htmx.org/examples/progress-bar/
app.Get("/browse/job/progress/:uid", auth, func(c *fiber.Ctx) error {
app.Get("/browse/job/progress/:uid", func(c *fiber.Ctx) error {
jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests!
status := galleryService.GetStatus(jobUID)
@ -279,7 +278,7 @@ func RegisterUIRoutes(app *fiber.App,
// this route is hit when the job is done, and we display the
// final state (for now just displays "Installation completed")
app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error {
app.Get("/browse/job/:uid", func(c *fiber.Ctx) error {
jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests!
status := galleryService.GetStatus(jobUID)
@ -303,7 +302,7 @@ func RegisterUIRoutes(app *fiber.App,
}
// Show the Chat page
app.Get("/chat/:model", auth, func(c *fiber.Ctx) error {
app.Get("/chat/:model", func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, "", true)
summary := fiber.Map{
@ -318,7 +317,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/chat", summary)
})
app.Get("/talk/", auth, func(c *fiber.Ctx) error {
app.Get("/talk/", func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, "", true)
if len(backendConfigs) == 0 {
@ -338,7 +337,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/talk", summary)
})
app.Get("/chat/", auth, func(c *fiber.Ctx) error {
app.Get("/chat/", func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, "", true)
@ -359,7 +358,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/chat", summary)
})
app.Get("/text2image/:model", auth, func(c *fiber.Ctx) error {
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
summary := fiber.Map{
@ -374,7 +373,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/text2image", summary)
})
app.Get("/text2image/", auth, func(c *fiber.Ctx) error {
app.Get("/text2image/", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
@ -395,7 +394,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/text2image", summary)
})
app.Get("/tts/:model", auth, func(c *fiber.Ctx) error {
app.Get("/tts/:model", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
summary := fiber.Map{
@ -410,7 +409,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/tts", summary)
})
app.Get("/tts/", auth, func(c *fiber.Ctx) error {
app.Get("/tts/", func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()

1
go.mod
View File

@ -74,6 +74,7 @@ require (
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 // indirect
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect

2
go.sum
View File

@ -110,6 +110,8 @@ github.com/creachadair/otp v0.4.2 h1:ngNMaD6Tzd7UUNRFyed7ykZFn/Wr5sSs5ffqZWm9pu8
github.com/creachadair/otp v0.4.2/go.mod h1:DqV9hJyUbcUme0pooYfiFvvMe72Aua5sfhNzwfZvk40=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 h1:flLYmnQFZNo04x2NPehMbf30m7Pli57xwZ0NFqR/hb0=
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2/go.mod h1:NtWqRzAp/1tw+twkW8uuBenEVVYndEAZACWU3F3xdoQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=