mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-24 06:46:39 +00:00
[Refactor]: Core/API Split (#1506)
Refactors api folder to core, creates firm split between backend code and api frontend.
This commit is contained in:
parent
bcf02449b3
commit
ab7b4d5ee9
4
.gitignore
vendored
4
.gitignore
vendored
@ -19,8 +19,8 @@ LocalAI
|
|||||||
local-ai
|
local-ai
|
||||||
# prevent above rules from omitting the helm chart
|
# prevent above rules from omitting the helm chart
|
||||||
!charts/*
|
!charts/*
|
||||||
# prevent above rules from omitting the api/localai folder
|
# prevent above rules from omitting the core/**/localai folder
|
||||||
!api/localai
|
!core/**/localai
|
||||||
|
|
||||||
# Ignore models
|
# Ignore models
|
||||||
models/*
|
models/*
|
||||||
|
@ -88,7 +88,7 @@ ENV NVIDIA_VISIBLE_DEVICES=all
|
|||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
COPY .git .
|
COPY .git/ .git/
|
||||||
RUN make prepare
|
RUN make prepare
|
||||||
|
|
||||||
# stablediffusion does not tolerate a newer version of abseil, build it first
|
# stablediffusion does not tolerate a newer version of abseil, build it first
|
||||||
|
302
api/api.go
302
api/api.go
@ -1,302 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/localai"
|
|
||||||
"github.com/go-skynet/LocalAI/api/openai"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
"github.com/go-skynet/LocalAI/internal"
|
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
|
||||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
|
||||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
|
|
||||||
options := options.NewOptions(opts...)
|
|
||||||
|
|
||||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
|
||||||
if options.Debug {
|
|
||||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
|
|
||||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
|
||||||
|
|
||||||
modelPath := options.Loader.ModelPath
|
|
||||||
if len(options.ModelsURL) > 0 {
|
|
||||||
for _, url := range options.ModelsURL {
|
|
||||||
if utils.LooksLikeURL(url) {
|
|
||||||
// md5 of model name
|
|
||||||
md5Name := utils.MD5(url)
|
|
||||||
|
|
||||||
// check if file exists
|
|
||||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
|
||||||
err := utils.DownloadFile(url, filepath.Join(modelPath, md5Name)+".yaml", "", func(fileName, current, total string, percent float64) {
|
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("error loading model: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cl := config.NewConfigLoader()
|
|
||||||
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
|
|
||||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.ConfigFile != "" {
|
|
||||||
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
|
|
||||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cl.Preload(options.Loader.ModelPath); err != nil {
|
|
||||||
log.Error().Msgf("error downloading models: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.PreloadJSONModels != "" {
|
|
||||||
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.PreloadModelsFromPath != "" {
|
|
||||||
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.Debug {
|
|
||||||
for _, v := range cl.ListConfigs() {
|
|
||||||
cfg, _ := cl.GetConfig(v)
|
|
||||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.AssetsDestination != "" {
|
|
||||||
// Extract files from the embedded FS
|
|
||||||
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
|
||||||
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// turn off any process that was started by GRPC if the context is canceled
|
|
||||||
go func() {
|
|
||||||
<-options.Context.Done()
|
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
|
||||||
options.Loader.StopAllGRPC()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if options.WatchDog {
|
|
||||||
wd := model.NewWatchDog(
|
|
||||||
options.Loader,
|
|
||||||
options.WatchDogBusyTimeout,
|
|
||||||
options.WatchDogIdleTimeout,
|
|
||||||
options.WatchDogBusy,
|
|
||||||
options.WatchDogIdle)
|
|
||||||
options.Loader.SetWatchDog(wd)
|
|
||||||
go wd.Run()
|
|
||||||
go func() {
|
|
||||||
<-options.Context.Done()
|
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
|
||||||
wd.Shutdown()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
return options, cl, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func App(opts ...options.AppOption) (*fiber.App, error) {
|
|
||||||
|
|
||||||
options, cl, err := Startup(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return errors as JSON responses
|
|
||||||
app := fiber.New(fiber.Config{
|
|
||||||
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
|
||||||
DisableStartupMessage: options.DisableMessage,
|
|
||||||
// Override default error handler
|
|
||||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
|
||||||
// Status code defaults to 500
|
|
||||||
code := fiber.StatusInternalServerError
|
|
||||||
|
|
||||||
// Retrieve the custom status code if it's a *fiber.Error
|
|
||||||
var e *fiber.Error
|
|
||||||
if errors.As(err, &e) {
|
|
||||||
code = e.Code
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send custom error page
|
|
||||||
return ctx.Status(code).JSON(
|
|
||||||
schema.ErrorResponse{
|
|
||||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if options.Debug {
|
|
||||||
app.Use(logger.New(logger.Config{
|
|
||||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default middleware config
|
|
||||||
app.Use(recover.New())
|
|
||||||
if options.Metrics != nil {
|
|
||||||
app.Use(metrics.APIMiddleware(options.Metrics))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
|
||||||
auth := func(c *fiber.Ctx) error {
|
|
||||||
if len(options.ApiKeys) == 0 {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for api_keys.json file
|
|
||||||
fileContent, err := os.ReadFile("api_keys.json")
|
|
||||||
if err == nil {
|
|
||||||
// Parse JSON content from the file
|
|
||||||
var fileKeys []string
|
|
||||||
err := json.Unmarshal(fileContent, &fileKeys)
|
|
||||||
if err != nil {
|
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add file keys to options.ApiKeys
|
|
||||||
options.ApiKeys = append(options.ApiKeys, fileKeys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(options.ApiKeys) == 0 {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
authHeader := c.Get("Authorization")
|
|
||||||
if authHeader == "" {
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
|
||||||
}
|
|
||||||
authHeaderParts := strings.Split(authHeader, " ")
|
|
||||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := authHeaderParts[1]
|
|
||||||
for _, key := range options.ApiKeys {
|
|
||||||
if apiKey == key {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.CORS {
|
|
||||||
var c func(ctx *fiber.Ctx) error
|
|
||||||
if options.CORSAllowOrigins == "" {
|
|
||||||
c = cors.New()
|
|
||||||
} else {
|
|
||||||
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
|
||||||
}
|
|
||||||
|
|
||||||
app.Use(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAI API endpoints
|
|
||||||
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
|
|
||||||
galleryService.Start(options.Context, cl)
|
|
||||||
|
|
||||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
|
||||||
return c.JSON(struct {
|
|
||||||
Version string `json:"version"`
|
|
||||||
}{Version: internal.PrintableVersion()})
|
|
||||||
})
|
|
||||||
|
|
||||||
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
|
|
||||||
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
|
||||||
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
|
||||||
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
|
|
||||||
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
|
|
||||||
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
|
|
||||||
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
|
|
||||||
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
|
|
||||||
|
|
||||||
// openAI compatible API endpoint
|
|
||||||
|
|
||||||
// chat
|
|
||||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
|
||||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
|
||||||
|
|
||||||
// edit
|
|
||||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
|
|
||||||
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
|
|
||||||
|
|
||||||
// completion
|
|
||||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
|
|
||||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
|
|
||||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
|
|
||||||
|
|
||||||
// embeddings
|
|
||||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
|
||||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
|
||||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
|
||||||
|
|
||||||
// audio
|
|
||||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
|
|
||||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
|
|
||||||
|
|
||||||
// images
|
|
||||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
|
|
||||||
|
|
||||||
if options.ImageDir != "" {
|
|
||||||
app.Static("/generated-images", options.ImageDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.AudioDir != "" {
|
|
||||||
app.Static("/generated-audio", options.AudioDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok := func(c *fiber.Ctx) error {
|
|
||||||
return c.SendStatus(200)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Kubernetes health checks
|
|
||||||
app.Get("/healthz", ok)
|
|
||||||
app.Get("/readyz", ok)
|
|
||||||
|
|
||||||
// Experimental Backend Statistics Module
|
|
||||||
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
|
|
||||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
|
||||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
|
||||||
|
|
||||||
// models
|
|
||||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
|
||||||
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
|
||||||
|
|
||||||
app.Get("/metrics", metrics.MetricsHandler())
|
|
||||||
|
|
||||||
return app, nil
|
|
||||||
}
|
|
@ -1,61 +0,0 @@
|
|||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
|
|
||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
|
||||||
model.WithBackendString(c.Backend),
|
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
|
||||||
model.WithThreads(uint32(c.Threads)),
|
|
||||||
model.WithContext(o.Context),
|
|
||||||
model.WithModel(c.Model),
|
|
||||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
|
||||||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
|
||||||
SchedulerType: c.Diffusers.SchedulerType,
|
|
||||||
PipelineType: c.Diffusers.PipelineType,
|
|
||||||
CFGScale: c.Diffusers.CFGScale,
|
|
||||||
LoraAdapter: c.LoraAdapter,
|
|
||||||
LoraScale: c.LoraScale,
|
|
||||||
LoraBase: c.LoraBase,
|
|
||||||
IMG2IMG: c.Diffusers.IMG2IMG,
|
|
||||||
CLIPModel: c.Diffusers.ClipModel,
|
|
||||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
|
||||||
ControlNet: c.Diffusers.ControlNet,
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
|
|
||||||
inferenceModel, err := loader.BackendLoader(
|
|
||||||
opts...,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fn := func() error {
|
|
||||||
_, err := inferenceModel.GenerateImage(
|
|
||||||
o.Context,
|
|
||||||
&proto.GenerateImageRequest{
|
|
||||||
Height: int32(height),
|
|
||||||
Width: int32(width),
|
|
||||||
Mode: int32(mode),
|
|
||||||
Step: int32(step),
|
|
||||||
Seed: int32(seed),
|
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
|
||||||
PositivePrompt: positive_prompt,
|
|
||||||
NegativePrompt: negative_prompt,
|
|
||||||
Dst: dst,
|
|
||||||
Src: src,
|
|
||||||
EnableParameters: c.Diffusers.EnableParameters,
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fn, nil
|
|
||||||
}
|
|
@ -1,167 +0,0 @@
|
|||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type LLMResponse struct {
|
|
||||||
Response string // should this be []byte?
|
|
||||||
Usage TokenUsage
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenUsage struct {
|
|
||||||
Prompt int
|
|
||||||
Completion int
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
|
||||||
modelFile := c.Model
|
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c)
|
|
||||||
|
|
||||||
var inferenceModel *grpc.Client
|
|
||||||
var err error
|
|
||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
|
||||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
|
||||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
|
||||||
model.WithModel(modelFile),
|
|
||||||
model.WithContext(o.Context),
|
|
||||||
})
|
|
||||||
|
|
||||||
if c.Backend != "" {
|
|
||||||
opts = append(opts, model.WithBackendString(c.Backend))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
|
||||||
if o.AutoloadGalleries { // experimental
|
|
||||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
|
||||||
utils.ResetDownloadTimers()
|
|
||||||
// if we failed to load the model, we try to download it
|
|
||||||
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Backend == "" {
|
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
|
||||||
} else {
|
|
||||||
inferenceModel, err = loader.BackendLoader(opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
|
||||||
fn := func() (LLMResponse, error) {
|
|
||||||
opts := gRPCPredictOpts(c, loader.ModelPath)
|
|
||||||
opts.Prompt = s
|
|
||||||
opts.Images = images
|
|
||||||
|
|
||||||
tokenUsage := TokenUsage{}
|
|
||||||
|
|
||||||
// check the per-model feature flag for usage, since tokenCallback may have a cost.
|
|
||||||
// Defaults to off as for now it is still experimental
|
|
||||||
if c.FeatureFlag.Enabled("usage") {
|
|
||||||
userTokenCallback := tokenCallback
|
|
||||||
if userTokenCallback == nil {
|
|
||||||
userTokenCallback = func(token string, usage TokenUsage) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
|
|
||||||
if pErr == nil && promptInfo.Length > 0 {
|
|
||||||
tokenUsage.Prompt = int(promptInfo.Length)
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenCallback = func(token string, usage TokenUsage) bool {
|
|
||||||
tokenUsage.Completion++
|
|
||||||
return userTokenCallback(token, tokenUsage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenCallback != nil {
|
|
||||||
ss := ""
|
|
||||||
|
|
||||||
var partialRune []byte
|
|
||||||
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
|
|
||||||
partialRune = append(partialRune, chars...)
|
|
||||||
|
|
||||||
for len(partialRune) > 0 {
|
|
||||||
r, size := utf8.DecodeRune(partialRune)
|
|
||||||
if r == utf8.RuneError {
|
|
||||||
// incomplete rune, wait for more bytes
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenCallback(string(r), tokenUsage)
|
|
||||||
ss += string(r)
|
|
||||||
|
|
||||||
partialRune = partialRune[size:]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return LLMResponse{
|
|
||||||
Response: ss,
|
|
||||||
Usage: tokenUsage,
|
|
||||||
}, err
|
|
||||||
} else {
|
|
||||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
|
||||||
reply, err := inferenceModel.Predict(ctx, opts)
|
|
||||||
if err != nil {
|
|
||||||
return LLMResponse{}, err
|
|
||||||
}
|
|
||||||
return LLMResponse{
|
|
||||||
Response: string(reply.Message),
|
|
||||||
Usage: tokenUsage,
|
|
||||||
}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
|
||||||
var mu sync.Mutex = sync.Mutex{}
|
|
||||||
|
|
||||||
func Finetune(config config.Config, input, prediction string) string {
|
|
||||||
if config.Echo {
|
|
||||||
prediction = input + prediction
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range config.Cutstrings {
|
|
||||||
mu.Lock()
|
|
||||||
reg, ok := cutstrings[c]
|
|
||||||
if !ok {
|
|
||||||
cutstrings[c] = regexp.MustCompile(c)
|
|
||||||
reg = cutstrings[c]
|
|
||||||
}
|
|
||||||
mu.Unlock()
|
|
||||||
prediction = reg.ReplaceAllString(prediction, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range config.TrimSpace {
|
|
||||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range config.TrimSuffix {
|
|
||||||
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
|
|
||||||
}
|
|
||||||
return prediction
|
|
||||||
}
|
|
@ -1,39 +0,0 @@
|
|||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) {
|
|
||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
|
||||||
model.WithBackendString(model.WhisperBackend),
|
|
||||||
model.WithModel(c.Model),
|
|
||||||
model.WithContext(o.Context),
|
|
||||||
model.WithThreads(uint32(c.Threads)),
|
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
|
||||||
})
|
|
||||||
|
|
||||||
whisperModel, err := o.Loader.BackendLoader(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if whisperModel == nil {
|
|
||||||
return nil, fmt.Errorf("could not load whisper model")
|
|
||||||
}
|
|
||||||
|
|
||||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
|
||||||
Dst: audio,
|
|
||||||
Language: language,
|
|
||||||
Threads: uint32(c.Threads),
|
|
||||||
})
|
|
||||||
}
|
|
@ -1,162 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BackendMonitorRequest struct {
|
|
||||||
Model string `json:"model" yaml:"model"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BackendMonitorResponse struct {
|
|
||||||
MemoryInfo *gopsutil.MemoryInfoStat
|
|
||||||
MemoryPercent float32
|
|
||||||
CPUPercent float64
|
|
||||||
}
|
|
||||||
|
|
||||||
type BackendMonitor struct {
|
|
||||||
configLoader *config.ConfigLoader
|
|
||||||
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
|
|
||||||
return BackendMonitor{
|
|
||||||
configLoader: configLoader,
|
|
||||||
options: options,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
|
|
||||||
config, exists := bm.configLoader.GetConfig(model)
|
|
||||||
var backend string
|
|
||||||
if exists {
|
|
||||||
backend = config.Model
|
|
||||||
} else {
|
|
||||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
|
||||||
backend = model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(backend, ".bin") {
|
|
||||||
backend = fmt.Sprintf("%s.bin", backend)
|
|
||||||
}
|
|
||||||
|
|
||||||
pid, err := bm.options.Loader.GetGRPCPID(backend)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
|
|
||||||
backendProcess, err := gopsutil.NewProcess(int32(pid))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
memInfo, err := backendProcess.MemoryInfo()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
memPercent, err := backendProcess.MemoryPercent()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cpuPercent, err := backendProcess.CPUPercent()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &BackendMonitorResponse{
|
|
||||||
MemoryInfo: memInfo,
|
|
||||||
MemoryPercent: memPercent,
|
|
||||||
CPUPercent: cpuPercent,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
|
|
||||||
input := new(BackendMonitorRequest)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
config, exists := bm.configLoader.GetConfig(input.Model)
|
|
||||||
var backendId string
|
|
||||||
if exists {
|
|
||||||
backendId = config.Model
|
|
||||||
} else {
|
|
||||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
|
||||||
backendId = input.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(backendId, ".bin") {
|
|
||||||
backendId = fmt.Sprintf("%s.bin", backendId)
|
|
||||||
}
|
|
||||||
|
|
||||||
return backendId, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
|
|
||||||
backendId, err := bm.getModelLoaderIDFromCtx(c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
model := bm.options.Loader.CheckIsLoaded(backendId)
|
|
||||||
if model == "" {
|
|
||||||
return fmt.Errorf("backend %s is not currently loaded", backendId)
|
|
||||||
}
|
|
||||||
|
|
||||||
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
|
|
||||||
if rpcErr != nil {
|
|
||||||
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
|
||||||
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
|
||||||
if slbErr != nil {
|
|
||||||
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
|
||||||
}
|
|
||||||
return c.JSON(proto.StatusResponse{
|
|
||||||
State: proto.StatusResponse_ERROR,
|
|
||||||
Memory: &proto.MemoryUsageData{
|
|
||||||
Total: val.MemoryInfo.VMS,
|
|
||||||
Breakdown: map[string]uint64{
|
|
||||||
"gopsutil-RSS": val.MemoryInfo.RSS,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
backendId, err := bm.getModelLoaderIDFromCtx(c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return bm.options.Loader.ShutdownModel(backendId)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,326 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
json "github.com/json-iterator/go"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type galleryOp struct {
|
|
||||||
req gallery.GalleryModel
|
|
||||||
id string
|
|
||||||
galleries []gallery.Gallery
|
|
||||||
galleryName string
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryOpStatus struct {
|
|
||||||
FileName string `json:"file_name"`
|
|
||||||
Error error `json:"error"`
|
|
||||||
Processed bool `json:"processed"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Progress float64 `json:"progress"`
|
|
||||||
TotalFileSize string `json:"file_size"`
|
|
||||||
DownloadedFileSize string `json:"downloaded_size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryApplier struct {
|
|
||||||
modelPath string
|
|
||||||
sync.Mutex
|
|
||||||
C chan galleryOp
|
|
||||||
statuses map[string]*galleryOpStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGalleryService(modelPath string) *galleryApplier {
|
|
||||||
return &galleryApplier{
|
|
||||||
modelPath: modelPath,
|
|
||||||
C: make(chan galleryOp),
|
|
||||||
statuses: make(map[string]*galleryOpStatus),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
|
||||||
|
|
||||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
|
||||||
|
|
||||||
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
g.statuses[s] = op
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
|
|
||||||
return g.statuses[s]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
|
|
||||||
return g.statuses
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.Done():
|
|
||||||
return
|
|
||||||
case op := <-g.C:
|
|
||||||
utils.ResetDownloadTimers()
|
|
||||||
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
|
||||||
|
|
||||||
// updates the status with an error
|
|
||||||
updateError := func(e error) {
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
|
||||||
}
|
|
||||||
|
|
||||||
// displayDownload displays the download progress
|
|
||||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
|
||||||
if op.galleryName != "" {
|
|
||||||
if strings.Contains(op.galleryName, "@") {
|
|
||||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
|
||||||
} else {
|
|
||||||
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload models
|
|
||||||
err = cm.LoadConfigs(g.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cm.Preload(g.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryModel struct {
|
|
||||||
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
|
||||||
ID string `json:"id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
|
||||||
var err error
|
|
||||||
for _, r := range requests {
|
|
||||||
utils.ResetDownloadTimers()
|
|
||||||
if r.ID == "" {
|
|
||||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
|
||||||
} else {
|
|
||||||
if strings.Contains(r.ID, "@") {
|
|
||||||
err = gallery.InstallModelFromGallery(
|
|
||||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
|
||||||
} else {
|
|
||||||
err = gallery.InstallModelFromGalleryByName(
|
|
||||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
|
||||||
dat, err := os.ReadFile(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var requests []galleryModel
|
|
||||||
|
|
||||||
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return processRequests(modelPath, s, cm, galleries, requests)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
|
||||||
var requests []galleryModel
|
|
||||||
err := json.Unmarshal([]byte(s), &requests)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return processRequests(modelPath, s, cm, galleries, requests)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Endpoint Service
|
|
||||||
|
|
||||||
type ModelGalleryService struct {
|
|
||||||
galleries []gallery.Gallery
|
|
||||||
modelPath string
|
|
||||||
galleryApplier *galleryApplier
|
|
||||||
}
|
|
||||||
|
|
||||||
type GalleryModel struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
gallery.GalleryModel
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
|
|
||||||
return ModelGalleryService{
|
|
||||||
galleries: galleries,
|
|
||||||
modelPath: modelPath,
|
|
||||||
galleryApplier: galleryApplier,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
|
|
||||||
if status == nil {
|
|
||||||
return fmt.Errorf("could not find any status for ID")
|
|
||||||
}
|
|
||||||
return c.JSON(status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
return c.JSON(mgs.galleryApplier.getAllStatus())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(GalleryModel)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
uuid, err := uuid.NewUUID()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
mgs.galleryApplier.C <- galleryOp{
|
|
||||||
req: input.GalleryModel,
|
|
||||||
id: uuid.String(),
|
|
||||||
galleryName: input.ID,
|
|
||||||
galleries: mgs.galleries,
|
|
||||||
}
|
|
||||||
return c.JSON(struct {
|
|
||||||
ID string `json:"uuid"`
|
|
||||||
StatusURL string `json:"status"`
|
|
||||||
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
|
||||||
|
|
||||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Models found from galleries: %+v", models)
|
|
||||||
for _, m := range models {
|
|
||||||
log.Debug().Msgf("Model found from galleries: %+v", m)
|
|
||||||
}
|
|
||||||
dat, err := json.Marshal(models)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
|
||||||
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
|
||||||
dat, err := json.Marshal(mgs.galleries)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(gallery.Gallery)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
}) {
|
|
||||||
return fmt.Errorf("%s already exists", input.Name)
|
|
||||||
}
|
|
||||||
dat, err := json.Marshal(mgs.galleries)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
|
||||||
mgs.galleries = append(mgs.galleries, *input)
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(gallery.Gallery)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
}) {
|
|
||||||
return fmt.Errorf("%s is not currently registered", input.Name)
|
|
||||||
}
|
|
||||||
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
})
|
|
||||||
return c.Send(nil)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,32 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TTSRequest struct {
|
|
||||||
Model string `json:"model" yaml:"model"`
|
|
||||||
Input string `json:"input" yaml:"input"`
|
|
||||||
Backend string `json:"backend" yaml:"backend"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
|
|
||||||
input := new(TTSRequest)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, o.Loader, o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.Download(filePath)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,399 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
emptyMessage := ""
|
|
||||||
id := uuid.New().String()
|
|
||||||
created := int(time.Now().Unix())
|
|
||||||
|
|
||||||
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
|
||||||
initialMessage := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
responses <- initialMessage
|
|
||||||
|
|
||||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
|
||||||
resp := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Usage: schema.OpenAIUsage{
|
|
||||||
PromptTokens: usage.Prompt,
|
|
||||||
CompletionTokens: usage.Completion,
|
|
||||||
TotalTokens: usage.Prompt + usage.Completion,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
responses <- resp
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
close(responses)
|
|
||||||
}
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
processFunctions := false
|
|
||||||
funcs := grammar.Functions{}
|
|
||||||
modelFile, input, err := readInput(c, o, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Configuration read: %+v", config)
|
|
||||||
|
|
||||||
// Allow the user to set custom actions via config file
|
|
||||||
// to be "embedded" in each model
|
|
||||||
noActionName := "answer"
|
|
||||||
noActionDescription := "use this action to answer without performing any action"
|
|
||||||
|
|
||||||
if config.FunctionsConfig.NoActionFunctionName != "" {
|
|
||||||
noActionName = config.FunctionsConfig.NoActionFunctionName
|
|
||||||
}
|
|
||||||
if config.FunctionsConfig.NoActionDescriptionName != "" {
|
|
||||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ResponseFormat.Type == "json_object" {
|
|
||||||
input.Grammar = grammar.JSONBNF
|
|
||||||
}
|
|
||||||
|
|
||||||
// process functions if we have any defined or if we have a function call string
|
|
||||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
|
||||||
log.Debug().Msgf("Response needs to process functions")
|
|
||||||
|
|
||||||
processFunctions = true
|
|
||||||
|
|
||||||
noActionGrammar := grammar.Function{
|
|
||||||
Name: noActionName,
|
|
||||||
Description: noActionDescription,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "The message to reply the user with",
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append the no action function
|
|
||||||
funcs = append(funcs, input.Functions...)
|
|
||||||
if !config.FunctionsConfig.DisableNoAction {
|
|
||||||
funcs = append(funcs, noActionGrammar)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Force picking one of the functions by the request
|
|
||||||
if config.FunctionToCall() != "" {
|
|
||||||
funcs = funcs.Select(config.FunctionToCall())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update input grammar
|
|
||||||
jsStruct := funcs.ToJSONStructure()
|
|
||||||
config.Grammar = jsStruct.Grammar("")
|
|
||||||
} else if input.JSONFunctionGrammarObject != nil {
|
|
||||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
|
|
||||||
}
|
|
||||||
|
|
||||||
// functions are not supported in stream mode (yet?)
|
|
||||||
toStream := input.Stream && !processFunctions
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameters: %+v", config)
|
|
||||||
|
|
||||||
var predInput string
|
|
||||||
|
|
||||||
suppressConfigSystemPrompt := false
|
|
||||||
mess := []string{}
|
|
||||||
for messageIndex, i := range input.Messages {
|
|
||||||
var content string
|
|
||||||
role := i.Role
|
|
||||||
|
|
||||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
|
||||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
|
||||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
|
||||||
roleFn := "assistant_function_call"
|
|
||||||
r := config.Roles[roleFn]
|
|
||||||
if r != "" {
|
|
||||||
role = roleFn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r := config.Roles[role]
|
|
||||||
contentExists := i.Content != nil && i.StringContent != ""
|
|
||||||
// First attempt to populate content via a chat message specific template
|
|
||||||
if config.TemplateConfig.ChatMessage != "" {
|
|
||||||
chatMessageData := model.ChatMessageTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
Role: r,
|
|
||||||
RoleName: role,
|
|
||||||
Content: i.StringContent,
|
|
||||||
MessageIndex: messageIndex,
|
|
||||||
}
|
|
||||||
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
|
||||||
} else {
|
|
||||||
if templatedChatMessage == "" {
|
|
||||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
|
||||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
|
||||||
content = templatedChatMessage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
|
||||||
if content == "" {
|
|
||||||
if r != "" {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(r, i.StringContent)
|
|
||||||
}
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
j, err := json.Marshal(i.FunctionCall)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
|
||||||
} else {
|
|
||||||
content = fmt.Sprint(r, " ", string(j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(i.StringContent)
|
|
||||||
}
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
j, err := json.Marshal(i.FunctionCall)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + string(j)
|
|
||||||
} else {
|
|
||||||
content = string(j)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
|
||||||
if contentExists && role == "system" {
|
|
||||||
suppressConfigSystemPrompt = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mess = append(mess, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
predInput = strings.Join(mess, "\n")
|
|
||||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
|
||||||
|
|
||||||
if toStream {
|
|
||||||
log.Debug().Msgf("Stream request received")
|
|
||||||
c.Context().SetContentType("text/event-stream")
|
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
// c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache")
|
|
||||||
c.Set("Connection", "keep-alive")
|
|
||||||
c.Set("Transfer-Encoding", "chunked")
|
|
||||||
}
|
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
|
||||||
templateFile = config.TemplateConfig.Chat
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
|
||||||
templateFile = config.TemplateConfig.Functions
|
|
||||||
}
|
|
||||||
|
|
||||||
if templateFile != "" {
|
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
|
||||||
Input: predInput,
|
|
||||||
Functions: funcs,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
predInput = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
} else {
|
|
||||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
|
||||||
if processFunctions {
|
|
||||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
|
||||||
}
|
|
||||||
|
|
||||||
if toStream {
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
|
||||||
|
|
||||||
go process(predInput, input, config, o.Loader, responses)
|
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
|
||||||
|
|
||||||
usage := &schema.OpenAIUsage{}
|
|
||||||
|
|
||||||
for ev := range responses {
|
|
||||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
|
||||||
var buf bytes.Buffer
|
|
||||||
enc := json.NewEncoder(&buf)
|
|
||||||
enc.Encode(ev)
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
|
||||||
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
|
||||||
input.Cancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
w.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{
|
|
||||||
{
|
|
||||||
FinishReason: "stop",
|
|
||||||
Index: 0,
|
|
||||||
Delta: &schema.Message{Content: &emptyMessage},
|
|
||||||
}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Usage: *usage,
|
|
||||||
}
|
|
||||||
respData, _ := json.Marshal(resp)
|
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
|
||||||
w.WriteString("data: [DONE]\n\n")
|
|
||||||
w.Flush()
|
|
||||||
}))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
|
||||||
if processFunctions {
|
|
||||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
|
||||||
ss := map[string]interface{}{}
|
|
||||||
// This prevent newlines to break JSON parsing for clients
|
|
||||||
s = utils.EscapeNewLines(s)
|
|
||||||
json.Unmarshal([]byte(s), &ss)
|
|
||||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
|
||||||
|
|
||||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
|
||||||
func_name := ss["function"]
|
|
||||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
|
||||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
|
||||||
d, _ := json.Marshal(args)
|
|
||||||
|
|
||||||
ss["arguments"] = string(d)
|
|
||||||
ss["name"] = func_name
|
|
||||||
|
|
||||||
// if do nothing, reply with a message
|
|
||||||
if func_name == noActionName {
|
|
||||||
log.Debug().Msgf("nothing to do, computing a reply")
|
|
||||||
|
|
||||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
|
||||||
arguments := map[string]interface{}{}
|
|
||||||
json.Unmarshal([]byte(d), &arguments)
|
|
||||||
m, exists := arguments["message"]
|
|
||||||
if exists {
|
|
||||||
switch message := m.(type) {
|
|
||||||
case string:
|
|
||||||
if message != "" {
|
|
||||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
|
||||||
message = backend.Finetune(*config, predInput, message)
|
|
||||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
|
||||||
|
|
||||||
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
|
||||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
|
||||||
// Note: This costs (in term of CPU) another computation
|
|
||||||
config.Grammar = ""
|
|
||||||
images := []string{}
|
|
||||||
for _, m := range input.Messages {
|
|
||||||
images = append(images, m.StringImages...)
|
|
||||||
}
|
|
||||||
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("inference error: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
prediction, err := predFunc()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("inference error: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
|
||||||
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
|
|
||||||
} else {
|
|
||||||
// otherwise reply with the function call
|
|
||||||
*c = append(*c, schema.Choice{
|
|
||||||
FinishReason: "function_call",
|
|
||||||
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
|
||||||
}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Usage: schema.OpenAIUsage{
|
|
||||||
PromptTokens: tokenUsage.Prompt,
|
|
||||||
CompletionTokens: tokenUsage.Completion,
|
|
||||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
respData, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", respData)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,199 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/completions
|
|
||||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
id := uuid.New().String()
|
|
||||||
created := int(time.Now().Unix())
|
|
||||||
|
|
||||||
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
|
||||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
|
||||||
resp := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{
|
|
||||||
{
|
|
||||||
Index: 0,
|
|
||||||
Text: s,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Object: "text_completion",
|
|
||||||
Usage: schema.OpenAIUsage{
|
|
||||||
PromptTokens: usage.Prompt,
|
|
||||||
CompletionTokens: usage.Completion,
|
|
||||||
TotalTokens: usage.Prompt + usage.Completion,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
|
||||||
|
|
||||||
responses <- resp
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
close(responses)
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
modelFile, input, err := readInput(c, o, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("`input`: %+v", input)
|
|
||||||
|
|
||||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ResponseFormat.Type == "json_object" {
|
|
||||||
input.Grammar = grammar.JSONBNF
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
||||||
|
|
||||||
if input.Stream {
|
|
||||||
log.Debug().Msgf("Stream request received")
|
|
||||||
c.Context().SetContentType("text/event-stream")
|
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
//c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache")
|
|
||||||
c.Set("Connection", "keep-alive")
|
|
||||||
c.Set("Transfer-Encoding", "chunked")
|
|
||||||
}
|
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Completion != "" {
|
|
||||||
templateFile = config.TemplateConfig.Completion
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Stream {
|
|
||||||
if len(config.PromptStrings) > 1 {
|
|
||||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
|
||||||
}
|
|
||||||
|
|
||||||
predInput := config.PromptStrings[0]
|
|
||||||
|
|
||||||
if templateFile != "" {
|
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
|
||||||
Input: predInput,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
predInput = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
|
||||||
|
|
||||||
go process(predInput, input, config, o.Loader, responses)
|
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
|
||||||
|
|
||||||
for ev := range responses {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
enc := json.NewEncoder(&buf)
|
|
||||||
enc.Encode(ev)
|
|
||||||
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
|
||||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
|
||||||
w.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{
|
|
||||||
{
|
|
||||||
Index: 0,
|
|
||||||
FinishReason: "stop",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Object: "text_completion",
|
|
||||||
}
|
|
||||||
respData, _ := json.Marshal(resp)
|
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
|
||||||
w.WriteString("data: [DONE]\n\n")
|
|
||||||
w.Flush()
|
|
||||||
}))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []schema.Choice
|
|
||||||
|
|
||||||
totalTokenUsage := backend.TokenUsage{}
|
|
||||||
|
|
||||||
for k, i := range config.PromptStrings {
|
|
||||||
if templateFile != "" {
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
Input: i,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
i = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(
|
|
||||||
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
|
||||||
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
|
||||||
}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
|
||||||
totalTokenUsage.Completion += tokenUsage.Completion
|
|
||||||
|
|
||||||
result = append(result, r...)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result,
|
|
||||||
Object: "text_completion",
|
|
||||||
Usage: schema.OpenAIUsage{
|
|
||||||
PromptTokens: totalTokenUsage.Prompt,
|
|
||||||
CompletionTokens: totalTokenUsage.Completion,
|
|
||||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", jsonResult)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,94 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
modelFile, input, err := readInput(c, o, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Edit != "" {
|
|
||||||
templateFile = config.TemplateConfig.Edit
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []schema.Choice
|
|
||||||
totalTokenUsage := backend.TokenUsage{}
|
|
||||||
|
|
||||||
for _, i := range config.InputStrings {
|
|
||||||
if templateFile != "" {
|
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
|
||||||
Input: i,
|
|
||||||
Instruction: input.Instruction,
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
i = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
|
||||||
*c = append(*c, schema.Choice{Text: s})
|
|
||||||
}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
|
||||||
totalTokenUsage.Completion += tokenUsage.Completion
|
|
||||||
|
|
||||||
result = append(result, r...)
|
|
||||||
}
|
|
||||||
|
|
||||||
id := uuid.New().String()
|
|
||||||
created := int(time.Now().Unix())
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result,
|
|
||||||
Object: "edit",
|
|
||||||
Usage: schema.OpenAIUsage{
|
|
||||||
PromptTokens: totalTokenUsage.Prompt,
|
|
||||||
CompletionTokens: totalTokenUsage.Completion,
|
|
||||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", jsonResult)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,78 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/embeddings
|
|
||||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
model, input, err := readInput(c, o, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
||||||
items := []schema.Item{}
|
|
||||||
|
|
||||||
for i, s := range config.InputToken {
|
|
||||||
// get the model function to call for the result
|
|
||||||
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddings, err := embedFn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, s := range config.InputStrings {
|
|
||||||
// get the model function to call for the result
|
|
||||||
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddings, err := embedFn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
|
||||||
}
|
|
||||||
|
|
||||||
id := uuid.New().String()
|
|
||||||
created := int(time.Now().Unix())
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Data: items,
|
|
||||||
Object: "list",
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", jsonResult)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,239 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
func downloadFile(url string) (string, error) {
|
|
||||||
// Get the data
|
|
||||||
resp, err := http.Get(url)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Create the file
|
|
||||||
out, err := os.CreateTemp("", "image")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer out.Close()
|
|
||||||
|
|
||||||
// Write the body to file
|
|
||||||
_, err = io.Copy(out, resp.Body)
|
|
||||||
return out.Name(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/images/create
|
|
||||||
|
|
||||||
/*
|
|
||||||
*
|
|
||||||
|
|
||||||
curl http://localhost:8080/v1/images/generations \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"prompt": "A cute baby sea otter",
|
|
||||||
"n": 1,
|
|
||||||
"size": "512x512"
|
|
||||||
}'
|
|
||||||
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
m, input, err := readInput(c, o, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m == "" {
|
|
||||||
m = model.StableDiffusionBackend
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Loading model: %+v", m)
|
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
src := ""
|
|
||||||
if input.File != "" {
|
|
||||||
|
|
||||||
fileData := []byte{}
|
|
||||||
// check if input.File is an URL, if so download it and save it
|
|
||||||
// to a temporary file
|
|
||||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
|
||||||
out, err := downloadFile(input.File)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed downloading file:%w", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(out)
|
|
||||||
|
|
||||||
fileData, err = os.ReadFile(out)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading file:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// base 64 decode the file and write it somewhere
|
|
||||||
// that we will cleanup
|
|
||||||
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a temporary file
|
|
||||||
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// write the base64 result
|
|
||||||
writer := bufio.NewWriter(outputFile)
|
|
||||||
_, err = writer.Write(fileData)
|
|
||||||
if err != nil {
|
|
||||||
outputFile.Close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
outputFile.Close()
|
|
||||||
src = outputFile.Name()
|
|
||||||
defer os.RemoveAll(src)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
||||||
|
|
||||||
switch config.Backend {
|
|
||||||
case "stablediffusion":
|
|
||||||
config.Backend = model.StableDiffusionBackend
|
|
||||||
case "tinydream":
|
|
||||||
config.Backend = model.TinyDreamBackend
|
|
||||||
case "":
|
|
||||||
config.Backend = model.StableDiffusionBackend
|
|
||||||
}
|
|
||||||
|
|
||||||
sizeParts := strings.Split(input.Size, "x")
|
|
||||||
if len(sizeParts) != 2 {
|
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
|
||||||
}
|
|
||||||
width, err := strconv.Atoi(sizeParts[0])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
|
||||||
}
|
|
||||||
height, err := strconv.Atoi(sizeParts[1])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
|
||||||
}
|
|
||||||
|
|
||||||
b64JSON := false
|
|
||||||
if input.ResponseFormat.Type == "b64_json" {
|
|
||||||
b64JSON = true
|
|
||||||
}
|
|
||||||
// src and clip_skip
|
|
||||||
var result []schema.Item
|
|
||||||
for _, i := range config.PromptStrings {
|
|
||||||
n := input.N
|
|
||||||
if input.N == 0 {
|
|
||||||
n = 1
|
|
||||||
}
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
prompts := strings.Split(i, "|")
|
|
||||||
positive_prompt := prompts[0]
|
|
||||||
negative_prompt := ""
|
|
||||||
if len(prompts) > 1 {
|
|
||||||
negative_prompt = prompts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
mode := 0
|
|
||||||
step := config.Step
|
|
||||||
if step == 0 {
|
|
||||||
step = 15
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Mode != 0 {
|
|
||||||
mode = input.Mode
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Step != 0 {
|
|
||||||
step = input.Step
|
|
||||||
}
|
|
||||||
|
|
||||||
tempDir := ""
|
|
||||||
if !b64JSON {
|
|
||||||
tempDir = o.ImageDir
|
|
||||||
}
|
|
||||||
// Create a temporary file
|
|
||||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
outputFile.Close()
|
|
||||||
output := outputFile.Name() + ".png"
|
|
||||||
// Rename the temporary file
|
|
||||||
err = os.Rename(outputFile.Name(), output)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := c.BaseURL()
|
|
||||||
|
|
||||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := fn(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
item := &schema.Item{}
|
|
||||||
|
|
||||||
if b64JSON {
|
|
||||||
defer os.RemoveAll(output)
|
|
||||||
data, err := os.ReadFile(output)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
|
||||||
} else {
|
|
||||||
base := filepath.Base(output)
|
|
||||||
item.URL = baseURL + "/generated-images/" + base
|
|
||||||
}
|
|
||||||
|
|
||||||
result = append(result, *item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
id := uuid.New().String()
|
|
||||||
created := int(time.Now().Unix())
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Data: result,
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", jsonResult)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,55 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ComputeChoices(
|
|
||||||
req *schema.OpenAIRequest,
|
|
||||||
predInput string,
|
|
||||||
config *config.Config,
|
|
||||||
o *options.Option,
|
|
||||||
loader *model.ModelLoader,
|
|
||||||
cb func(string, *[]schema.Choice),
|
|
||||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
|
|
||||||
n := req.N // number of completions to return
|
|
||||||
result := []schema.Choice{}
|
|
||||||
|
|
||||||
if n == 0 {
|
|
||||||
n = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
images := []string{}
|
|
||||||
for _, m := range req.Messages {
|
|
||||||
images = append(images, m.StringImages...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the model function to call for the result
|
|
||||||
predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
|
|
||||||
if err != nil {
|
|
||||||
return result, backend.TokenUsage{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenUsage := backend.TokenUsage{}
|
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
prediction, err := predFunc()
|
|
||||||
if err != nil {
|
|
||||||
return result, backend.TokenUsage{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenUsage.Prompt += prediction.Usage.Prompt
|
|
||||||
tokenUsage.Completion += prediction.Usage.Completion
|
|
||||||
|
|
||||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
|
||||||
cb(finetunedResponse, &result)
|
|
||||||
|
|
||||||
//result = append(result, Choice{Text: prediction})
|
|
||||||
|
|
||||||
}
|
|
||||||
return result, tokenUsage, err
|
|
||||||
}
|
|
@ -1,336 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
options "github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *schema.OpenAIRequest, error) {
|
|
||||||
loader := o.Loader
|
|
||||||
input := new(schema.OpenAIRequest)
|
|
||||||
ctx, cancel := context.WithCancel(o.Context)
|
|
||||||
input.Context = ctx
|
|
||||||
input.Cancel = cancel
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
modelFile := input.Model
|
|
||||||
|
|
||||||
if c.Params("model") != "" {
|
|
||||||
modelFile = c.Params("model")
|
|
||||||
}
|
|
||||||
|
|
||||||
received, _ := json.Marshal(input)
|
|
||||||
|
|
||||||
log.Debug().Msgf("Request received: %s", string(received))
|
|
||||||
|
|
||||||
// Set model from bearer token, if available
|
|
||||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
|
||||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
|
||||||
|
|
||||||
// If no model was specified, take the first available
|
|
||||||
if modelFile == "" && !bearerExists && randomModel {
|
|
||||||
models, _ := loader.ListModels()
|
|
||||||
if len(models) > 0 {
|
|
||||||
modelFile = models[0]
|
|
||||||
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
|
||||||
} else {
|
|
||||||
log.Debug().Msgf("No model specified, returning error")
|
|
||||||
return "", nil, fmt.Errorf("no model specified")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a model is found in bearer token takes precedence
|
|
||||||
if bearerExists {
|
|
||||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
|
||||||
modelFile = bearer
|
|
||||||
}
|
|
||||||
return modelFile, input, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
|
||||||
// encodes it in base64 and returns the base64 string
|
|
||||||
func getBase64Image(s string) (string, error) {
|
|
||||||
if strings.HasPrefix(s, "http") {
|
|
||||||
// download the image
|
|
||||||
resp, err := http.Get(s)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// read the image data into memory
|
|
||||||
data, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// encode the image data in base64
|
|
||||||
encoded := base64.StdEncoding.EncodeToString(data)
|
|
||||||
|
|
||||||
// return the base64 string
|
|
||||||
return encoded, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
|
||||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
|
||||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("not valid string")
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|
||||||
if input.Echo {
|
|
||||||
config.Echo = input.Echo
|
|
||||||
}
|
|
||||||
if input.TopK != 0 {
|
|
||||||
config.TopK = input.TopK
|
|
||||||
}
|
|
||||||
if input.TopP != 0 {
|
|
||||||
config.TopP = input.TopP
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Backend != "" {
|
|
||||||
config.Backend = input.Backend
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ClipSkip != 0 {
|
|
||||||
config.Diffusers.ClipSkip = input.ClipSkip
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ModelBaseName != "" {
|
|
||||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.NegativePromptScale != 0 {
|
|
||||||
config.NegativePromptScale = input.NegativePromptScale
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.UseFastTokenizer {
|
|
||||||
config.UseFastTokenizer = input.UseFastTokenizer
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.NegativePrompt != "" {
|
|
||||||
config.NegativePrompt = input.NegativePrompt
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RopeFreqBase != 0 {
|
|
||||||
config.RopeFreqBase = input.RopeFreqBase
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RopeFreqScale != 0 {
|
|
||||||
config.RopeFreqScale = input.RopeFreqScale
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Grammar != "" {
|
|
||||||
config.Grammar = input.Grammar
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Temperature != 0 {
|
|
||||||
config.Temperature = input.Temperature
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Maxtokens != 0 {
|
|
||||||
config.Maxtokens = input.Maxtokens
|
|
||||||
}
|
|
||||||
|
|
||||||
switch stop := input.Stop.(type) {
|
|
||||||
case string:
|
|
||||||
if stop != "" {
|
|
||||||
config.StopWords = append(config.StopWords, stop)
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
for _, pp := range stop {
|
|
||||||
if s, ok := pp.(string); ok {
|
|
||||||
config.StopWords = append(config.StopWords, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode each request's message content
|
|
||||||
index := 0
|
|
||||||
for i, m := range input.Messages {
|
|
||||||
switch content := m.Content.(type) {
|
|
||||||
case string:
|
|
||||||
input.Messages[i].StringContent = content
|
|
||||||
case []interface{}:
|
|
||||||
dat, _ := json.Marshal(content)
|
|
||||||
c := []schema.Content{}
|
|
||||||
json.Unmarshal(dat, &c)
|
|
||||||
for _, pp := range c {
|
|
||||||
if pp.Type == "text" {
|
|
||||||
input.Messages[i].StringContent = pp.Text
|
|
||||||
} else if pp.Type == "image_url" {
|
|
||||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
|
||||||
base64, err := getBase64Image(pp.ImageURL.URL)
|
|
||||||
if err == nil {
|
|
||||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
|
||||||
// set a placeholder for each image
|
|
||||||
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
|
||||||
index++
|
|
||||||
} else {
|
|
||||||
fmt.Print("Failed encoding image", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RepeatPenalty != 0 {
|
|
||||||
config.RepeatPenalty = input.RepeatPenalty
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Keep != 0 {
|
|
||||||
config.Keep = input.Keep
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Batch != 0 {
|
|
||||||
config.Batch = input.Batch
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.F16 {
|
|
||||||
config.F16 = input.F16
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.IgnoreEOS {
|
|
||||||
config.IgnoreEOS = input.IgnoreEOS
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Seed != 0 {
|
|
||||||
config.Seed = input.Seed
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Mirostat != 0 {
|
|
||||||
config.LLMConfig.Mirostat = input.Mirostat
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.MirostatETA != 0 {
|
|
||||||
config.LLMConfig.MirostatETA = input.MirostatETA
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.MirostatTAU != 0 {
|
|
||||||
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.TypicalP != 0 {
|
|
||||||
config.TypicalP = input.TypicalP
|
|
||||||
}
|
|
||||||
|
|
||||||
switch inputs := input.Input.(type) {
|
|
||||||
case string:
|
|
||||||
if inputs != "" {
|
|
||||||
config.InputStrings = append(config.InputStrings, inputs)
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
for _, pp := range inputs {
|
|
||||||
switch i := pp.(type) {
|
|
||||||
case string:
|
|
||||||
config.InputStrings = append(config.InputStrings, i)
|
|
||||||
case []interface{}:
|
|
||||||
tokens := []int{}
|
|
||||||
for _, ii := range i {
|
|
||||||
tokens = append(tokens, int(ii.(float64)))
|
|
||||||
}
|
|
||||||
config.InputToken = append(config.InputToken, tokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Can be either a string or an object
|
|
||||||
switch fnc := input.FunctionCall.(type) {
|
|
||||||
case string:
|
|
||||||
if fnc != "" {
|
|
||||||
config.SetFunctionCallString(fnc)
|
|
||||||
}
|
|
||||||
case map[string]interface{}:
|
|
||||||
var name string
|
|
||||||
n, exists := fnc["name"]
|
|
||||||
if exists {
|
|
||||||
nn, e := n.(string)
|
|
||||||
if e {
|
|
||||||
name = nn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.SetFunctionCallNameString(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch p := input.Prompt.(type) {
|
|
||||||
case string:
|
|
||||||
config.PromptStrings = append(config.PromptStrings, p)
|
|
||||||
case []interface{}:
|
|
||||||
for _, pp := range p {
|
|
||||||
if s, ok := pp.(string); ok {
|
|
||||||
config.PromptStrings = append(config.PromptStrings, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func readConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) {
|
|
||||||
// Load a config file if present after the model name
|
|
||||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
|
|
||||||
|
|
||||||
var cfg *config.Config
|
|
||||||
|
|
||||||
defaults := func() {
|
|
||||||
cfg = config.DefaultConfig(modelFile)
|
|
||||||
cfg.ContextSize = ctx
|
|
||||||
cfg.Threads = threads
|
|
||||||
cfg.F16 = f16
|
|
||||||
cfg.Debug = debug
|
|
||||||
}
|
|
||||||
|
|
||||||
cfgExisting, exists := cm.GetConfig(modelFile)
|
|
||||||
if !exists {
|
|
||||||
if _, err := os.Stat(modelConfig); err == nil {
|
|
||||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
|
||||||
}
|
|
||||||
cfgExisting, exists = cm.GetConfig(modelFile)
|
|
||||||
if exists {
|
|
||||||
cfg = &cfgExisting
|
|
||||||
} else {
|
|
||||||
defaults()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
defaults()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cfg = &cfgExisting
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
|
||||||
updateConfig(cfg, input)
|
|
||||||
|
|
||||||
// Don't allow 0 as setting
|
|
||||||
if cfg.Threads == 0 {
|
|
||||||
if threads != 0 {
|
|
||||||
cfg.Threads = threads
|
|
||||||
} else {
|
|
||||||
cfg.Threads = 4
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce debug flag if passed from CLI
|
|
||||||
if debug {
|
|
||||||
cfg.Debug = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg, input, nil
|
|
||||||
}
|
|
@ -1,71 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/audio/create
|
|
||||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
m, input, err := readInput(c, o, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
// retrieve the file data from the request
|
|
||||||
file, err := c.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
f, err := file.Open()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
dir, err := os.MkdirTemp("", "whisper")
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(dir)
|
|
||||||
|
|
||||||
dst := filepath.Join(dir, path.Base(file.Filename))
|
|
||||||
dstFile, err := os.Create(dst)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.Copy(dstFile, f); err != nil {
|
|
||||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
|
||||||
|
|
||||||
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
|
||||||
// TODO: handle different outputs here
|
|
||||||
return c.Status(http.StatusOK).JSON(tr)
|
|
||||||
}
|
|
||||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
"github.com/go-audio/wav"
|
"github.com/go-audio/wav"
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func sh(c string) (string, error) {
|
func sh(c string) (string, error) {
|
||||||
@ -29,8 +29,8 @@ func audioToWav(src, dst string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) {
|
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.WhisperResult, error) {
|
||||||
res := schema.Result{}
|
res := schema.WhisperResult{}
|
||||||
|
|
||||||
dir, err := os.MkdirTemp("", "whisper")
|
dir, err := os.MkdirTemp("", "whisper")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -90,7 +90,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
|
|||||||
tokens = append(tokens, t.Id)
|
tokens = append(tokens, t.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
|
segment := schema.WhisperSegment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
|
||||||
res.Segments = append(res.Segments, segment)
|
res.Segments = append(res.Segments, segment)
|
||||||
|
|
||||||
res.Text += s.Text
|
res.Text += s.Text
|
||||||
|
@ -4,9 +4,9 @@ package main
|
|||||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
import (
|
import (
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Whisper struct {
|
type Whisper struct {
|
||||||
@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) {
|
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.WhisperResult, error) {
|
||||||
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
|
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
|
||||||
}
|
}
|
||||||
|
0
config/.keep
Normal file
0
config/.keep
Normal file
@ -2,14 +2,17 @@ package backend
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
|
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() ([]float32, error), error) {
|
||||||
if !c.Embeddings {
|
if !c.Embeddings {
|
||||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
||||||
}
|
}
|
||||||
@ -27,6 +30,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
model.WithModel(modelFile),
|
model.WithModel(modelFile),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(o.Context),
|
||||||
|
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||||
})
|
})
|
||||||
|
|
||||||
if c.Backend == "" {
|
if c.Backend == "" {
|
||||||
@ -90,3 +94,51 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
return embeds, nil
|
return embeds, nil
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EmbeddingOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||||
|
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
items := []schema.Item{}
|
||||||
|
|
||||||
|
for i, s := range config.InputToken {
|
||||||
|
// get the model function to call for the result
|
||||||
|
embedFn, err := ModelEmbedding("", s, ml, *config, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings, err := embedFn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range config.InputStrings {
|
||||||
|
// get the model function to call for the result
|
||||||
|
embedFn, err := ModelEmbedding(s, []int{}, ml, *config, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings, err := embedFn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||||
|
}
|
||||||
|
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
return &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Data: items,
|
||||||
|
Object: "list",
|
||||||
|
}, nil
|
||||||
|
}
|
210
core/backend/image.go
Normal file
210
core/backend/image.go
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() error, error) {
|
||||||
|
|
||||||
|
opts := modelOpts(c, o, []model.Option{
|
||||||
|
model.WithBackendString(c.Backend),
|
||||||
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
model.WithThreads(uint32(c.Threads)),
|
||||||
|
model.WithContext(o.Context),
|
||||||
|
model.WithModel(c.Model),
|
||||||
|
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
||||||
|
CUDA: c.CUDA || c.Diffusers.CUDA,
|
||||||
|
SchedulerType: c.Diffusers.SchedulerType,
|
||||||
|
PipelineType: c.Diffusers.PipelineType,
|
||||||
|
CFGScale: c.Diffusers.CFGScale,
|
||||||
|
LoraAdapter: c.LoraAdapter,
|
||||||
|
LoraScale: c.LoraScale,
|
||||||
|
LoraBase: c.LoraBase,
|
||||||
|
IMG2IMG: c.Diffusers.IMG2IMG,
|
||||||
|
CLIPModel: c.Diffusers.ClipModel,
|
||||||
|
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||||
|
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||||
|
ControlNet: c.Diffusers.ControlNet,
|
||||||
|
}),
|
||||||
|
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||||
|
})
|
||||||
|
|
||||||
|
inferenceModel, err := loader.BackendLoader(
|
||||||
|
opts...,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fn := func() error {
|
||||||
|
_, err := inferenceModel.GenerateImage(
|
||||||
|
o.Context,
|
||||||
|
&proto.GenerateImageRequest{
|
||||||
|
Height: int32(height),
|
||||||
|
Width: int32(width),
|
||||||
|
Mode: int32(mode),
|
||||||
|
Step: int32(step),
|
||||||
|
Seed: int32(seed),
|
||||||
|
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||||
|
PositivePrompt: positive_prompt,
|
||||||
|
NegativePrompt: negative_prompt,
|
||||||
|
Dst: dst,
|
||||||
|
Src: src,
|
||||||
|
EnableParameters: c.Diffusers.EnableParameters,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ImageGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
if modelName == "" {
|
||||||
|
modelName = model.StableDiffusionBackend
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Loading model: %+v", modelName)
|
||||||
|
|
||||||
|
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed reading parameters from request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
src := ""
|
||||||
|
if input.File != "" {
|
||||||
|
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||||
|
src, err = utils.CreateTempFileFromUrl(input.File, "", "image-src")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed downloading file:%w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
src, err = utils.CreateTempFileFromBase64(input.File, "", "base64-image-src")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating temporary image source file: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
|
switch config.Backend {
|
||||||
|
case "stablediffusion":
|
||||||
|
config.Backend = model.StableDiffusionBackend
|
||||||
|
case "tinydream":
|
||||||
|
config.Backend = model.TinyDreamBackend
|
||||||
|
case "":
|
||||||
|
config.Backend = model.StableDiffusionBackend
|
||||||
|
}
|
||||||
|
|
||||||
|
sizeParts := strings.Split(input.Size, "x")
|
||||||
|
if len(sizeParts) != 2 {
|
||||||
|
return nil, fmt.Errorf("invalid value for 'size'")
|
||||||
|
}
|
||||||
|
width, err := strconv.Atoi(sizeParts[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid value for 'size'")
|
||||||
|
}
|
||||||
|
height, err := strconv.Atoi(sizeParts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid value for 'size'")
|
||||||
|
}
|
||||||
|
|
||||||
|
b64JSON := false
|
||||||
|
if input.ResponseFormat.Type == "b64_json" {
|
||||||
|
b64JSON = true
|
||||||
|
}
|
||||||
|
// src and clip_skip
|
||||||
|
var result []schema.Item
|
||||||
|
for _, i := range config.PromptStrings {
|
||||||
|
n := input.N
|
||||||
|
if input.N == 0 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
prompts := strings.Split(i, "|")
|
||||||
|
positive_prompt := prompts[0]
|
||||||
|
negative_prompt := ""
|
||||||
|
if len(prompts) > 1 {
|
||||||
|
negative_prompt = prompts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := 0
|
||||||
|
step := config.Step
|
||||||
|
if step == 0 {
|
||||||
|
step = 15
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Mode != 0 {
|
||||||
|
mode = input.Mode
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Step != 0 {
|
||||||
|
step = input.Step
|
||||||
|
}
|
||||||
|
|
||||||
|
tempDir := ""
|
||||||
|
if !b64JSON {
|
||||||
|
tempDir = startupOptions.ImageDir
|
||||||
|
}
|
||||||
|
// Create a temporary file
|
||||||
|
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
outputFile.Close()
|
||||||
|
output := outputFile.Name() + ".png"
|
||||||
|
// Rename the temporary file
|
||||||
|
err = os.Rename(outputFile.Name(), output)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := fn(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
item := &schema.Item{}
|
||||||
|
|
||||||
|
if b64JSON {
|
||||||
|
defer os.RemoveAll(output)
|
||||||
|
data, err := os.ReadFile(output)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||||
|
} else {
|
||||||
|
base := filepath.Base(output)
|
||||||
|
item.URL = path.Join(startupOptions.ImageDir, base)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, *item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Data: result,
|
||||||
|
}, nil
|
||||||
|
}
|
861
core/backend/llm.go
Normal file
861
core/backend/llm.go
Normal file
@ -0,0 +1,861 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
////////// TYPES //////////////
|
||||||
|
|
||||||
|
type LLMResponse struct {
|
||||||
|
Response string // should this be []byte?
|
||||||
|
Usage TokenUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Test removing this and using the variant in pkg/schema someday?
|
||||||
|
type TokenUsage struct {
|
||||||
|
Prompt int
|
||||||
|
Completion int
|
||||||
|
}
|
||||||
|
|
||||||
|
type TemplateConfigBindingFn func(*schema.Config) *string
|
||||||
|
|
||||||
|
// type LLMStreamProcessor func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse)
|
||||||
|
|
||||||
|
/////// CONSTS ///////////
|
||||||
|
|
||||||
|
const DEFAULT_NO_ACTION_NAME = "answer"
|
||||||
|
const DEFAULT_NO_ACTION_DESCRIPTION = "use this action to answer without performing any action"
|
||||||
|
|
||||||
|
////// INFERENCE /////////
|
||||||
|
|
||||||
|
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||||
|
modelFile := c.Model
|
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(c)
|
||||||
|
|
||||||
|
var inferenceModel *grpc.Client
|
||||||
|
var err error
|
||||||
|
|
||||||
|
opts := modelOpts(c, o, []model.Option{
|
||||||
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
|
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||||
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
model.WithModel(modelFile),
|
||||||
|
model.WithContext(o.Context),
|
||||||
|
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||||
|
})
|
||||||
|
|
||||||
|
if c.Backend != "" {
|
||||||
|
opts = append(opts, model.WithBackendString(c.Backend))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||||
|
if o.AutoloadGalleries { // experimental
|
||||||
|
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||||
|
utils.ResetDownloadTimers()
|
||||||
|
// if we failed to load the model, we try to download it
|
||||||
|
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Backend == "" {
|
||||||
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||||
|
} else {
|
||||||
|
inferenceModel, err = loader.BackendLoader(opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||||
|
fn := func() (LLMResponse, error) {
|
||||||
|
opts := gRPCPredictOpts(c, loader.ModelPath)
|
||||||
|
opts.Prompt = s
|
||||||
|
opts.Images = images
|
||||||
|
|
||||||
|
tokenUsage := TokenUsage{}
|
||||||
|
|
||||||
|
// check the per-model feature flag for usage, since tokenCallback may have a cost.
|
||||||
|
// Defaults to off as for now it is still experimental
|
||||||
|
if c.FeatureFlag.Enabled("usage") {
|
||||||
|
userTokenCallback := tokenCallback
|
||||||
|
if userTokenCallback == nil {
|
||||||
|
userTokenCallback = func(token string, usage TokenUsage) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
|
||||||
|
if pErr == nil && promptInfo.Length > 0 {
|
||||||
|
tokenUsage.Prompt = int(promptInfo.Length)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCallback = func(token string, usage TokenUsage) bool {
|
||||||
|
tokenUsage.Completion++
|
||||||
|
return userTokenCallback(token, tokenUsage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenCallback != nil {
|
||||||
|
ss := ""
|
||||||
|
|
||||||
|
var partialRune []byte
|
||||||
|
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
|
||||||
|
partialRune = append(partialRune, chars...)
|
||||||
|
|
||||||
|
for len(partialRune) > 0 {
|
||||||
|
r, size := utf8.DecodeRune(partialRune)
|
||||||
|
if r == utf8.RuneError {
|
||||||
|
// incomplete rune, wait for more bytes
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCallback(string(r), tokenUsage)
|
||||||
|
ss += string(r)
|
||||||
|
|
||||||
|
partialRune = partialRune[size:]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return LLMResponse{
|
||||||
|
Response: ss,
|
||||||
|
Usage: tokenUsage,
|
||||||
|
}, err
|
||||||
|
} else {
|
||||||
|
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||||
|
reply, err := inferenceModel.Predict(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
return LLMResponse{}, err
|
||||||
|
}
|
||||||
|
return LLMResponse{
|
||||||
|
Response: string(reply.Message),
|
||||||
|
Usage: tokenUsage,
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||||
|
var mu sync.Mutex = sync.Mutex{}
|
||||||
|
|
||||||
|
func Finetune(config schema.Config, input, prediction string) string {
|
||||||
|
if config.Echo {
|
||||||
|
prediction = input + prediction
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range config.Cutstrings {
|
||||||
|
mu.Lock()
|
||||||
|
reg, ok := cutstrings[c]
|
||||||
|
if !ok {
|
||||||
|
cutstrings[c] = regexp.MustCompile(c)
|
||||||
|
reg = cutstrings[c]
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
prediction = reg.ReplaceAllString(prediction, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range config.TrimSpace {
|
||||||
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range config.TrimSuffix {
|
||||||
|
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
|
||||||
|
}
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////// CONFIG AND REQUEST HANDLING ///////////////
|
||||||
|
|
||||||
|
func ReadConfigFromFileAndCombineWithOpenAIRequest(modelFile string, input *schema.OpenAIRequest, cm *services.ConfigLoader, startupOptions *schema.StartupOptions) (*schema.Config, *schema.OpenAIRequest, error) {
|
||||||
|
// Load a config file if present after the model name
|
||||||
|
modelConfig := filepath.Join(startupOptions.ModelPath, modelFile+".yaml")
|
||||||
|
|
||||||
|
var cfg *schema.Config
|
||||||
|
|
||||||
|
defaults := func() {
|
||||||
|
cfg = schema.DefaultConfig(modelFile)
|
||||||
|
cfg.ContextSize = startupOptions.ContextSize
|
||||||
|
cfg.Threads = startupOptions.Threads
|
||||||
|
cfg.F16 = startupOptions.F16
|
||||||
|
cfg.Debug = startupOptions.Debug
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgExisting, exists := cm.GetConfig(modelFile)
|
||||||
|
if !exists {
|
||||||
|
if _, err := os.Stat(modelConfig); err == nil {
|
||||||
|
if err := cm.LoadConfig(modelConfig); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||||
|
}
|
||||||
|
cfgExisting, exists = cm.GetConfig(modelFile)
|
||||||
|
if exists {
|
||||||
|
cfg = &cfgExisting
|
||||||
|
} else {
|
||||||
|
defaults()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
defaults()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cfg = &cfgExisting
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the parameters for the language model prediction
|
||||||
|
schema.UpdateConfigFromOpenAIRequest(cfg, input)
|
||||||
|
|
||||||
|
// Don't allow 0 as setting
|
||||||
|
if cfg.Threads == 0 {
|
||||||
|
if startupOptions.Threads != 0 {
|
||||||
|
cfg.Threads = startupOptions.Threads
|
||||||
|
} else {
|
||||||
|
cfg.Threads = 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enforce debug flag if passed from CLI
|
||||||
|
if startupOptions.Debug {
|
||||||
|
cfg.Debug = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, input, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ComputeChoices(
|
||||||
|
req *schema.OpenAIRequest,
|
||||||
|
predInput string,
|
||||||
|
config *schema.Config,
|
||||||
|
o *schema.StartupOptions,
|
||||||
|
loader *model.ModelLoader,
|
||||||
|
cb func(string, *[]schema.Choice),
|
||||||
|
tokenCallback func(string, TokenUsage) bool) ([]schema.Choice, TokenUsage, error) {
|
||||||
|
n := req.N // number of completions to return
|
||||||
|
result := []schema.Choice{}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
images := []string{}
|
||||||
|
for _, m := range req.Messages {
|
||||||
|
images = append(images, m.StringImages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the model function to call for the result
|
||||||
|
predFunc, err := ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
|
||||||
|
if err != nil {
|
||||||
|
return result, TokenUsage{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenUsage := TokenUsage{}
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
prediction, err := predFunc()
|
||||||
|
if err != nil {
|
||||||
|
return result, TokenUsage{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenUsage.Prompt += prediction.Usage.Prompt
|
||||||
|
tokenUsage.Completion += prediction.Usage.Completion
|
||||||
|
|
||||||
|
finetunedResponse := Finetune(*config, predInput, prediction.Response)
|
||||||
|
cb(finetunedResponse, &result)
|
||||||
|
|
||||||
|
//result = append(result, Choice{Text: prediction})
|
||||||
|
|
||||||
|
}
|
||||||
|
return result, tokenUsage, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: No functions???? Commonize with prepareChatGenerationOpenAIRequest below?
|
||||||
|
func prepareGenerationOpenAIRequest(bindingFn TemplateConfigBindingFn, modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, error) {
|
||||||
|
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.ResponseFormat.Type == "json_object" {
|
||||||
|
input.Grammar = grammar.JSONBNF
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
|
configTemplate := bindingFn(config)
|
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
if (*configTemplate == "") && (ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model))) {
|
||||||
|
*configTemplate = config.Model
|
||||||
|
}
|
||||||
|
if *configTemplate == "" {
|
||||||
|
return nil, fmt.Errorf(("failed to find templateConfig"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
////////// SPECIFIC REQUESTS //////////////
|
||||||
|
// TODO: For round one of the refactor, give each of the three primary text endpoints their own function?
|
||||||
|
// SEMITODO: During a merge, edit/completion were semi-combined - but remain nominally split
|
||||||
|
// Can cleanup into a common form later if possible easier if they are all here for now
|
||||||
|
// If they remain different, extract each of these named segments to a seperate file
|
||||||
|
|
||||||
|
func prepareChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, string, bool, error) {
|
||||||
|
|
||||||
|
// IMPORTANT DEFS
|
||||||
|
funcs := grammar.Functions{}
|
||||||
|
|
||||||
|
// The Basic Begining
|
||||||
|
|
||||||
|
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", false, fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Configuration read: %+v", config)
|
||||||
|
|
||||||
|
// Special Input/Config Handling
|
||||||
|
|
||||||
|
// Allow the user to set custom actions via config file
|
||||||
|
// to be "embedded" in each model - but if they are missing, use defaults.
|
||||||
|
if config.FunctionsConfig.NoActionFunctionName == "" {
|
||||||
|
config.FunctionsConfig.NoActionFunctionName = DEFAULT_NO_ACTION_NAME
|
||||||
|
}
|
||||||
|
if config.FunctionsConfig.NoActionDescriptionName == "" {
|
||||||
|
config.FunctionsConfig.NoActionDescriptionName = DEFAULT_NO_ACTION_DESCRIPTION
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.ResponseFormat.Type == "json_object" {
|
||||||
|
input.Grammar = grammar.JSONBNF
|
||||||
|
}
|
||||||
|
|
||||||
|
processFunctions := len(input.Functions) > 0 && config.ShouldUseFunctions()
|
||||||
|
|
||||||
|
if processFunctions {
|
||||||
|
log.Debug().Msgf("Response needs to process functions")
|
||||||
|
|
||||||
|
noActionGrammar := grammar.Function{
|
||||||
|
Name: config.FunctionsConfig.NoActionFunctionName,
|
||||||
|
Description: config.FunctionsConfig.NoActionDescriptionName,
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to reply the user with",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the no action function
|
||||||
|
funcs = append(funcs, input.Functions...)
|
||||||
|
if !config.FunctionsConfig.DisableNoAction {
|
||||||
|
funcs = append(funcs, noActionGrammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force picking one of the functions by the request
|
||||||
|
if config.FunctionToCall() != "" {
|
||||||
|
funcs = funcs.Select(config.FunctionToCall())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update input grammar
|
||||||
|
jsStruct := funcs.ToJSONStructure()
|
||||||
|
config.Grammar = jsStruct.Grammar("")
|
||||||
|
} else if input.JSONFunctionGrammarObject != nil {
|
||||||
|
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameters: %+v", config)
|
||||||
|
|
||||||
|
var predInput string
|
||||||
|
|
||||||
|
suppressConfigSystemPrompt := false
|
||||||
|
mess := []string{}
|
||||||
|
for messageIndex, i := range input.Messages {
|
||||||
|
var content string
|
||||||
|
role := i.Role
|
||||||
|
|
||||||
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||||
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||||
|
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||||
|
roleFn := "assistant_function_call"
|
||||||
|
r := config.Roles[roleFn]
|
||||||
|
if r != "" {
|
||||||
|
role = roleFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r := config.Roles[role]
|
||||||
|
contentExists := i.Content != nil && i.StringContent != ""
|
||||||
|
// First attempt to populate content via a chat message specific template
|
||||||
|
if config.TemplateConfig.ChatMessage != "" {
|
||||||
|
chatMessageData := model.ChatMessageTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
Role: r,
|
||||||
|
RoleName: role,
|
||||||
|
Content: i.StringContent,
|
||||||
|
MessageIndex: messageIndex,
|
||||||
|
}
|
||||||
|
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||||
|
} else {
|
||||||
|
if templatedChatMessage == "" {
|
||||||
|
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||||
|
content = templatedChatMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||||
|
if content == "" {
|
||||||
|
if r != "" {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(r, i.StringContent)
|
||||||
|
}
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
j, err := json.Marshal(i.FunctionCall)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||||
|
} else {
|
||||||
|
content = fmt.Sprint(r, " ", string(j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(i.StringContent)
|
||||||
|
}
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
j, err := json.Marshal(i.FunctionCall)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + string(j)
|
||||||
|
} else {
|
||||||
|
content = string(j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||||
|
if contentExists && role == "system" {
|
||||||
|
suppressConfigSystemPrompt = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mess = append(mess, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
predInput = strings.Join(mess, "\n")
|
||||||
|
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||||
|
|
||||||
|
templateFile := ""
|
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
|
templateFile = config.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||||
|
templateFile = config.TemplateConfig.Chat
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||||
|
templateFile = config.TemplateConfig.Functions
|
||||||
|
}
|
||||||
|
|
||||||
|
if templateFile != "" {
|
||||||
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||||
|
Input: predInput,
|
||||||
|
Functions: funcs,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
predInput = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||||
|
if processFunctions {
|
||||||
|
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, predInput, processFunctions, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func EditGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
binding := func(config *schema.Config) *string {
|
||||||
|
return &config.TemplateConfig.Edit
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []schema.Choice
|
||||||
|
totalTokenUsage := TokenUsage{}
|
||||||
|
|
||||||
|
for _, i := range config.InputStrings {
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, config.TemplateConfig.Edit, model.PromptTemplateData{
|
||||||
|
Input: i,
|
||||||
|
Instruction: input.Instruction,
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
i = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, tokenUsage, err := ComputeChoices(input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||||
|
*c = append(*c, schema.Choice{Text: s})
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||||
|
totalTokenUsage.Completion += tokenUsage.Completion
|
||||||
|
|
||||||
|
result = append(result, r...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result,
|
||||||
|
Object: "edit",
|
||||||
|
Usage: schema.OpenAIUsage{
|
||||||
|
PromptTokens: totalTokenUsage.Prompt,
|
||||||
|
CompletionTokens: totalTokenUsage.Completion,
|
||||||
|
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||||
|
|
||||||
|
// DEFS
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
// Prepare
|
||||||
|
config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||||
|
if processFunctions {
|
||||||
|
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||||
|
ss := map[string]interface{}{}
|
||||||
|
// This prevent newlines to break JSON parsing for clients
|
||||||
|
s = utils.EscapeNewLines(s)
|
||||||
|
json.Unmarshal([]byte(s), &ss)
|
||||||
|
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||||
|
|
||||||
|
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||||
|
func_name := ss["function"]
|
||||||
|
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||||
|
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||||
|
d, _ := json.Marshal(args)
|
||||||
|
|
||||||
|
ss["arguments"] = string(d)
|
||||||
|
ss["name"] = func_name
|
||||||
|
|
||||||
|
// if do nothing, reply with a message
|
||||||
|
if func_name == config.FunctionsConfig.NoActionFunctionName {
|
||||||
|
log.Debug().Msgf("nothing to do, computing a reply")
|
||||||
|
|
||||||
|
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||||
|
arguments := map[string]interface{}{}
|
||||||
|
json.Unmarshal([]byte(d), &arguments)
|
||||||
|
m, exists := arguments["message"]
|
||||||
|
if exists {
|
||||||
|
switch message := m.(type) {
|
||||||
|
case string:
|
||||||
|
if message != "" {
|
||||||
|
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||||
|
message = Finetune(*config, predInput, message)
|
||||||
|
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||||
|
|
||||||
|
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||||
|
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||||
|
// Note: This costs (in term of CPU) another computation
|
||||||
|
config.Grammar = ""
|
||||||
|
images := []string{}
|
||||||
|
for _, m := range input.Messages {
|
||||||
|
images = append(images, m.StringImages...)
|
||||||
|
}
|
||||||
|
predFunc, err := ModelInference(input.Context, predInput, images, ml, *config, startupOptions, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction, err := predFunc()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fineTunedResponse := Finetune(*config, predInput, prediction.Response)
|
||||||
|
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
|
||||||
|
} else {
|
||||||
|
// otherwise reply with the function call
|
||||||
|
*c = append(*c, schema.Choice{
|
||||||
|
FinishReason: "function_call",
|
||||||
|
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Usage: schema.OpenAIUsage{
|
||||||
|
PromptTokens: tokenUsage.Prompt,
|
||||||
|
CompletionTokens: tokenUsage.Completion,
|
||||||
|
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||||
|
// Prepare
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
binding := func(config *schema.Config) *string {
|
||||||
|
return &config.TemplateConfig.Completion
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []schema.Choice
|
||||||
|
|
||||||
|
totalTokenUsage := TokenUsage{}
|
||||||
|
|
||||||
|
for k, i := range config.PromptStrings {
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
Input: i,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
i = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, tokenUsage, err := ComputeChoices(
|
||||||
|
input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||||
|
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||||
|
totalTokenUsage.Completion += tokenUsage.Completion
|
||||||
|
|
||||||
|
result = append(result, r...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result,
|
||||||
|
Object: "text_completion",
|
||||||
|
Usage: schema.OpenAIUsage{
|
||||||
|
PromptTokens: totalTokenUsage.Prompt,
|
||||||
|
CompletionTokens: totalTokenUsage.Completion,
|
||||||
|
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamingChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) {
|
||||||
|
|
||||||
|
// DEFS
|
||||||
|
emptyMessage := ""
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
// Prepare
|
||||||
|
config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if processFunctions {
|
||||||
|
// TODO: unused variable means I did something wrong. investigate once stable
|
||||||
|
log.Debug().Msgf("StreamingChatGenerationOpenAIRequest with processFunctions=true for %s?", config.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
|
initialMessage := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
responses <- initialMessage
|
||||||
|
|
||||||
|
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool {
|
||||||
|
resp := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Usage: schema.OpenAIUsage{
|
||||||
|
PromptTokens: usage.Prompt,
|
||||||
|
CompletionTokens: usage.Completion,
|
||||||
|
TotalTokens: usage.Prompt + usage.Completion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
responses <- resp
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
close(responses)
|
||||||
|
}
|
||||||
|
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to create response channel")
|
||||||
|
|
||||||
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
|
||||||
|
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to start processor goroutine")
|
||||||
|
|
||||||
|
go processor(predInput, input, config, ml, responses)
|
||||||
|
|
||||||
|
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: DONE! successfully returning to caller!")
|
||||||
|
|
||||||
|
return responses, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamingCompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) {
|
||||||
|
// DEFS
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
binding := func(config *schema.Config) *string {
|
||||||
|
return &config.TemplateConfig.Completion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare
|
||||||
|
|
||||||
|
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
|
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool {
|
||||||
|
resp := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Text: s,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Object: "text_completion",
|
||||||
|
Usage: schema.OpenAIUsage{
|
||||||
|
PromptTokens: usage.Prompt,
|
||||||
|
CompletionTokens: usage.Completion,
|
||||||
|
TotalTokens: usage.Prompt + usage.Completion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||||
|
|
||||||
|
responses <- resp
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
close(responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.PromptStrings) > 1 {
|
||||||
|
return nil, errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
|
//A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{
|
||||||
|
Input: predInput,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
predInput = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to create response channel")
|
||||||
|
|
||||||
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
|
||||||
|
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to start processor goroutine")
|
||||||
|
|
||||||
|
go processor(predInput, input, config, ml, responses)
|
||||||
|
|
||||||
|
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: DONE! successfully returning to caller!")
|
||||||
|
|
||||||
|
return responses, nil
|
||||||
|
}
|
@ -5,13 +5,11 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
|
func modelOpts(c schema.Config, o *schema.StartupOptions, opts []model.Option) []model.Option {
|
||||||
if o.SingleBackend {
|
if o.SingleBackend {
|
||||||
opts = append(opts, model.WithSingleActiveBackend())
|
opts = append(opts, model.WithSingleActiveBackend())
|
||||||
}
|
}
|
||||||
@ -35,7 +33,7 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.
|
|||||||
return opts
|
return opts
|
||||||
}
|
}
|
||||||
|
|
||||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
func gRPCModelOpts(c schema.Config) *pb.ModelOptions {
|
||||||
b := 512
|
b := 512
|
||||||
if c.Batch != 0 {
|
if c.Batch != 0 {
|
||||||
b = c.Batch
|
b = c.Batch
|
||||||
@ -82,7 +80,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
|
func gRPCPredictOpts(c schema.Config, modelPath string) *pb.PredictOptions {
|
||||||
promptCachePath := ""
|
promptCachePath := ""
|
||||||
if c.PromptCachePath != "" {
|
if c.PromptCachePath != "" {
|
||||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
p := filepath.Join(modelPath, c.PromptCachePath)
|
52
core/backend/transcription.go
Normal file
52
core/backend/transcription.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ModelTranscription(audio, language string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (*schema.WhisperResult, error) {
|
||||||
|
|
||||||
|
opts := modelOpts(c, o, []model.Option{
|
||||||
|
model.WithBackendString(model.WhisperBackend),
|
||||||
|
model.WithModel(c.Model),
|
||||||
|
model.WithContext(o.Context),
|
||||||
|
model.WithThreads(uint32(c.Threads)),
|
||||||
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||||
|
})
|
||||||
|
|
||||||
|
whisperModel, err := loader.BackendLoader(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if whisperModel == nil {
|
||||||
|
return nil, fmt.Errorf("could not load whisper model")
|
||||||
|
}
|
||||||
|
|
||||||
|
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||||
|
Dst: audio,
|
||||||
|
Language: language,
|
||||||
|
Threads: uint32(c.Threads),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TranscriptionOpenAIRequest(modelName string, input *schema.OpenAIRequest, audioFilePath string, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.WhisperResult, error) {
|
||||||
|
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tr, err := ModelTranscription(audioFilePath, input.Language, ml, *config, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tr, nil
|
||||||
|
}
|
@ -6,10 +6,9 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
api_config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,18 +28,19 @@ func generateUniqueFileName(dir, baseName, ext string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) {
|
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *schema.StartupOptions) (string, *proto.Result, error) {
|
||||||
bb := backend
|
bb := backend
|
||||||
if bb == "" {
|
if bb == "" {
|
||||||
bb = model.PiperBackend
|
bb = model.PiperBackend
|
||||||
}
|
}
|
||||||
opts := modelOpts(api_config.Config{}, o, []model.Option{
|
opts := modelOpts(schema.Config{}, o, []model.Option{
|
||||||
model.WithBackendString(bb),
|
model.WithBackendString(bb),
|
||||||
model.WithModel(modelFile),
|
model.WithModel(modelFile),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(o.Context),
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||||
})
|
})
|
||||||
piperModel, err := o.Loader.BackendLoader(opts...)
|
piperModel, err := loader.BackendLoader(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
@ -60,8 +60,8 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt
|
|||||||
modelPath := ""
|
modelPath := ""
|
||||||
if modelFile != "" {
|
if modelFile != "" {
|
||||||
if bb != model.TransformersMusicGen {
|
if bb != model.TransformersMusicGen {
|
||||||
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
|
modelPath = filepath.Join(o.ModelPath, modelFile)
|
||||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
if err := utils.VerifyPath(modelPath, o.ModelPath); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
169
core/http/api.go
Normal file
169
core/http/api.go
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
|
||||||
|
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||||
|
)
|
||||||
|
|
||||||
|
func App(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*fiber.App, error) {
|
||||||
|
|
||||||
|
// Return errors as JSON responses
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||||
|
DisableStartupMessage: options.DisableMessage,
|
||||||
|
// Override default error handler
|
||||||
|
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||||
|
// Status code defaults to 500
|
||||||
|
code := fiber.StatusInternalServerError
|
||||||
|
|
||||||
|
// Retrieve the custom status code if it's a *fiber.Error
|
||||||
|
var e *fiber.Error
|
||||||
|
if errors.As(err, &e) {
|
||||||
|
code = e.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send custom error page
|
||||||
|
return ctx.Status(code).JSON(
|
||||||
|
schema.ErrorResponse{
|
||||||
|
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if options.Debug {
|
||||||
|
app.Use(logger.New(logger.Config{
|
||||||
|
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default middleware config
|
||||||
|
app.Use(recover.New())
|
||||||
|
|
||||||
|
if options.Metrics != nil {
|
||||||
|
app.Use(localai.MetricsAPIMiddleware(options.Metrics))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||||
|
auth := func(c *fiber.Ctx) error {
|
||||||
|
if len(options.ApiKeys) == 0 {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
authHeader := c.Get("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
||||||
|
}
|
||||||
|
authHeaderParts := strings.Split(authHeader, " ")
|
||||||
|
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := authHeaderParts[1]
|
||||||
|
for _, key := range options.ApiKeys {
|
||||||
|
if apiKey == key {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.CORS {
|
||||||
|
var c func(ctx *fiber.Ctx) error
|
||||||
|
if options.CORSAllowOrigins == "" {
|
||||||
|
c = cors.New()
|
||||||
|
} else {
|
||||||
|
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
||||||
|
}
|
||||||
|
|
||||||
|
app.Use(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAI API endpoints
|
||||||
|
galleryService := services.NewGalleryApplier(options.ModelPath)
|
||||||
|
galleryService.Start(options.Context, cl)
|
||||||
|
|
||||||
|
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||||
|
return c.JSON(struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
}{Version: internal.PrintableVersion()})
|
||||||
|
})
|
||||||
|
|
||||||
|
modelGalleryService := localai.CreateModelGalleryEndpointService(options.Galleries, options.ModelPath, galleryService)
|
||||||
|
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
||||||
|
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
||||||
|
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
|
||||||
|
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
|
||||||
|
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
|
||||||
|
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
|
||||||
|
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
|
||||||
|
|
||||||
|
// openAI compatible API endpoint
|
||||||
|
|
||||||
|
// chat
|
||||||
|
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, options))
|
||||||
|
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, options))
|
||||||
|
|
||||||
|
// edit
|
||||||
|
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, options))
|
||||||
|
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, options))
|
||||||
|
|
||||||
|
// completion
|
||||||
|
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, options))
|
||||||
|
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, options))
|
||||||
|
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, options))
|
||||||
|
|
||||||
|
// embeddings
|
||||||
|
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
|
||||||
|
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
|
||||||
|
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
|
||||||
|
|
||||||
|
// audio
|
||||||
|
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, options))
|
||||||
|
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, options))
|
||||||
|
|
||||||
|
// images
|
||||||
|
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, options))
|
||||||
|
|
||||||
|
if options.ImageDir != "" {
|
||||||
|
app.Static("/generated-images", options.ImageDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.AudioDir != "" {
|
||||||
|
app.Static("/generated-audio", options.AudioDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := func(c *fiber.Ctx) error {
|
||||||
|
return c.SendStatus(200)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kubernetes health checks
|
||||||
|
app.Get("/healthz", ok)
|
||||||
|
app.Get("/readyz", ok)
|
||||||
|
|
||||||
|
app.Get("/metrics", localai.MetricsHandler())
|
||||||
|
|
||||||
|
backendMonitor := services.NewBackendMonitor(cl, ml, options)
|
||||||
|
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
||||||
|
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
||||||
|
|
||||||
|
// model listing
|
||||||
|
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||||
|
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||||
|
|
||||||
|
return app, nil
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package api_test
|
package http_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -13,11 +13,12 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
. "github.com/go-skynet/LocalAI/api"
|
server "github.com/go-skynet/LocalAI/core/http"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
"github.com/go-skynet/LocalAI/core/startup"
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
@ -118,16 +119,15 @@ var backendAssets embed.FS
|
|||||||
var _ = Describe("API test", func() {
|
var _ = Describe("API test", func() {
|
||||||
|
|
||||||
var app *fiber.App
|
var app *fiber.App
|
||||||
var modelLoader *model.ModelLoader
|
|
||||||
var client *openai.Client
|
var client *openai.Client
|
||||||
var client2 *openaigo.Client
|
var client2 *openaigo.Client
|
||||||
var c context.Context
|
var c context.Context
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
var tmpdir string
|
var tmpdir string
|
||||||
|
|
||||||
commonOpts := []options.AppOption{
|
commonOpts := []schema.AppOption{
|
||||||
options.WithDebug(true),
|
schema.WithDebug(true),
|
||||||
options.WithDisableMessage(true),
|
schema.WithDisableMessage(true),
|
||||||
}
|
}
|
||||||
|
|
||||||
Context("API with ephemeral models", func() {
|
Context("API with ephemeral models", func() {
|
||||||
@ -136,7 +136,6 @@ var _ = Describe("API test", func() {
|
|||||||
tmpdir, err = os.MkdirTemp("", "")
|
tmpdir, err = os.MkdirTemp("", "")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
modelLoader = model.NewModelLoader(tmpdir)
|
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
g := []gallery.GalleryModel{
|
g := []gallery.GalleryModel{
|
||||||
@ -163,15 +162,20 @@ var _ = Describe("API test", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
metricsService, err := services.SetupMetrics()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
app, err = App(
|
cl, ml, options, err := startup.Startup(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithMetrics(metricsService),
|
schema.WithMetrics(metricsService),
|
||||||
options.WithContext(c),
|
schema.WithContext(c),
|
||||||
options.WithGalleries(galleries),
|
schema.WithGalleries(galleries),
|
||||||
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
|
schema.WithModelPath(tmpdir),
|
||||||
|
schema.WithBackendAssets(backendAssets),
|
||||||
|
schema.WithBackendAssetsOutput(tmpdir))...)
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = server.App(cl, ml, options)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
@ -475,7 +479,6 @@ var _ = Describe("API test", func() {
|
|||||||
tmpdir, err = os.MkdirTemp("", "")
|
tmpdir, err = os.MkdirTemp("", "")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
modelLoader = model.NewModelLoader(tmpdir)
|
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
galleries := []gallery.Gallery{
|
galleries := []gallery.Gallery{
|
||||||
@ -485,21 +488,22 @@ var _ = Describe("API test", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
metricsService, err := services.SetupMetrics()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
app, err = App(
|
cl, ml, options, err := startup.Startup(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithContext(c),
|
schema.WithContext(c),
|
||||||
options.WithMetrics(metricsService),
|
schema.WithMetrics(metricsService),
|
||||||
options.WithAudioDir(tmpdir),
|
schema.WithAudioDir(tmpdir),
|
||||||
options.WithImageDir(tmpdir),
|
schema.WithImageDir(tmpdir),
|
||||||
options.WithGalleries(galleries),
|
schema.WithGalleries(galleries),
|
||||||
options.WithModelLoader(modelLoader),
|
schema.WithModelPath(tmpdir),
|
||||||
options.WithBackendAssets(backendAssets),
|
schema.WithBackendAssets(backendAssets),
|
||||||
options.WithBackendAssetsOutput(tmpdir))...,
|
schema.WithBackendAssetsOutput(tmpdir))...,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = server.App(cl, ml, options)
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@ -590,20 +594,21 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
Context("API query", func() {
|
Context("API query", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
metricsService, err := services.SetupMetrics()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
app, err = App(
|
cl, ml, options, err := startup.Startup(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
schema.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||||
options.WithContext(c),
|
schema.WithContext(c),
|
||||||
options.WithModelLoader(modelLoader),
|
schema.WithModelPath(os.Getenv("MODELS_PATH")),
|
||||||
options.WithMetrics(metricsService),
|
schema.WithMetrics(metricsService),
|
||||||
)...)
|
)...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = server.App(cl, ml, options)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@ -802,20 +807,21 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
Context("Config file", func() {
|
Context("Config file", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
metricsService, err := services.SetupMetrics()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
app, err = App(
|
cl, ml, options, err := startup.Startup(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithContext(c),
|
schema.WithContext(c),
|
||||||
options.WithMetrics(metricsService),
|
schema.WithMetrics(metricsService),
|
||||||
options.WithModelLoader(modelLoader),
|
schema.WithModelPath(os.Getenv("MODELS_PATH")),
|
||||||
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
schema.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = server.App(cl, ml, options)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
@ -1,4 +1,4 @@
|
|||||||
package api_test
|
package http_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
34
core/http/endpoints/localai/backend_monitor.go
Normal file
34
core/http/endpoints/localai/backend_monitor.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BackendMonitorEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(schema.BackendMonitorRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := bm.CheckAndSample(input.Model)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BackendShutdownEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(schema.BackendMonitorRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return bm.ShutdownModel(input.Model)
|
||||||
|
}
|
||||||
|
}
|
148
core/http/endpoints/localai/gallery.go
Normal file
148
core/http/endpoints/localai/gallery.go
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
/// Endpoint Service
|
||||||
|
|
||||||
|
type ModelGalleryEndpointService struct {
|
||||||
|
galleries []gallery.Gallery
|
||||||
|
modelPath string
|
||||||
|
galleryApplier *services.GalleryApplier
|
||||||
|
}
|
||||||
|
|
||||||
|
type GalleryModel struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
gallery.GalleryModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryApplier) ModelGalleryEndpointService {
|
||||||
|
return ModelGalleryEndpointService{
|
||||||
|
galleries: galleries,
|
||||||
|
modelPath: modelPath,
|
||||||
|
galleryApplier: galleryApplier,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
|
||||||
|
if status == nil {
|
||||||
|
return fmt.Errorf("could not find any status for ID")
|
||||||
|
}
|
||||||
|
return c.JSON(status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
return c.JSON(mgs.galleryApplier.GetAllStatus())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(GalleryModel)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
uuid, err := uuid.NewUUID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mgs.galleryApplier.C <- gallery.GalleryOp{
|
||||||
|
Req: input.GalleryModel,
|
||||||
|
Id: uuid.String(),
|
||||||
|
GalleryName: input.ID,
|
||||||
|
Galleries: mgs.galleries,
|
||||||
|
}
|
||||||
|
return c.JSON(struct {
|
||||||
|
ID string `json:"uuid"`
|
||||||
|
StatusURL string `json:"status"`
|
||||||
|
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
||||||
|
|
||||||
|
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Models found from galleries: %+v", models)
|
||||||
|
for _, m := range models {
|
||||||
|
log.Debug().Msgf("Model found from galleries: %+v", m)
|
||||||
|
}
|
||||||
|
dat, err := json.Marshal(models)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||||
|
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||||
|
dat, err := json.Marshal(mgs.galleries)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(gallery.Gallery)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
}) {
|
||||||
|
return fmt.Errorf("%s already exists", input.Name)
|
||||||
|
}
|
||||||
|
dat, err := json.Marshal(mgs.galleries)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
||||||
|
mgs.galleries = append(mgs.galleries, *input)
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(gallery.Gallery)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
}) {
|
||||||
|
return fmt.Errorf("%s is not currently registered", input.Name)
|
||||||
|
}
|
||||||
|
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
})
|
||||||
|
return c.Send(nil)
|
||||||
|
}
|
||||||
|
}
|
42
core/http/endpoints/localai/metrics.go
Normal file
42
core/http/endpoints/localai/metrics.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func MetricsHandler() fiber.Handler {
|
||||||
|
return adaptor.HTTPHandler(promhttp.Handler())
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiMiddlewareConfig struct {
|
||||||
|
Filter func(c *fiber.Ctx) bool
|
||||||
|
metrics *schema.LocalAIMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
func MetricsAPIMiddleware(metrics *schema.LocalAIMetrics) fiber.Handler {
|
||||||
|
cfg := apiMiddlewareConfig{
|
||||||
|
metrics: metrics,
|
||||||
|
Filter: func(c *fiber.Ctx) bool {
|
||||||
|
return c.Path() == "/metrics"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
if cfg.Filter != nil && cfg.Filter(c) {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
path := c.Path()
|
||||||
|
method := c.Method()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err := c.Next()
|
||||||
|
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||||
|
cfg.metrics.ObserveAPICall(method, path, elapsed)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
25
core/http/endpoints/localai/tts.go
Normal file
25
core/http/endpoints/localai/tts.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TTSEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(schema.TTSRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, ml, so)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Download(filePath)
|
||||||
|
}
|
||||||
|
}
|
97
core/http/endpoints/openai/chat.go
Normal file
97
core/http/endpoints/openai/chat.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChatEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
emptyMessage := ""
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
modelName, input, err := readInput(c, startupOptions, ml, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The scary comment I feel like I forgot about along the way:
|
||||||
|
//
|
||||||
|
// functions are not supported in stream mode (yet?)
|
||||||
|
//
|
||||||
|
if input.Stream {
|
||||||
|
log.Debug().Msgf("Stream request received")
|
||||||
|
c.Context().SetContentType("text/event-stream")
|
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
// c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache")
|
||||||
|
c.Set("Connection", "keep-alive")
|
||||||
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
responses, err := backend.StreamingChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed establishing streaming chat request :%w", err)
|
||||||
|
}
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
|
usage := &schema.OpenAIUsage{}
|
||||||
|
id := ""
|
||||||
|
created := 0
|
||||||
|
for ev := range responses {
|
||||||
|
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||||
|
id = ev.ID
|
||||||
|
created = ev.Created // Similarly, grab the ID and created from any / the last response so we can use it for the stop
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := json.NewEncoder(&buf)
|
||||||
|
enc.Encode(ev)
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||||
|
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||||
|
input.Cancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{
|
||||||
|
{
|
||||||
|
FinishReason: "stop",
|
||||||
|
Index: 0,
|
||||||
|
Delta: &schema.Message{Content: &emptyMessage},
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Usage: *usage,
|
||||||
|
}
|
||||||
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||||
|
w.WriteString("data: [DONE]\n\n")
|
||||||
|
w.Flush()
|
||||||
|
}))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////
|
||||||
|
|
||||||
|
resp, err := backend.ChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating chat request: +%w", err)
|
||||||
|
}
|
||||||
|
respData, _ := json.Marshal(resp) // TODO this is only used for the debug log and costs performance. monitor this?
|
||||||
|
log.Debug().Msgf("Response: %s", respData)
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
91
core/http/endpoints/openai/completion.go
Normal file
91
core/http/endpoints/openai/completion.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
|
func CompletionEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
modelName, input, err := readInput(c, so, ml, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("`input`: %+v", input)
|
||||||
|
|
||||||
|
if input.Stream {
|
||||||
|
log.Debug().Msgf("Stream request received")
|
||||||
|
c.Context().SetContentType("text/event-stream")
|
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
//c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache")
|
||||||
|
c.Set("Connection", "keep-alive")
|
||||||
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
responses, err := backend.StreamingCompletionGenerationOpenAIRequest(modelName, input, cl, ml, so)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed establishing streaming completion request :%w", err)
|
||||||
|
}
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
|
|
||||||
|
for ev := range responses {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := json.NewEncoder(&buf)
|
||||||
|
enc.Encode(ev)
|
||||||
|
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||||
|
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||||
|
w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Object: "text_completion",
|
||||||
|
}
|
||||||
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||||
|
w.WriteString("data: [DONE]\n\n")
|
||||||
|
w.Flush()
|
||||||
|
}))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////
|
||||||
|
|
||||||
|
resp, err := backend.CompletionGenerationOpenAIRequest(modelName, input, cl, ml, so)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating completion request: +%w", err)
|
||||||
|
}
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
34
core/http/endpoints/openai/edit.go
Normal file
34
core/http/endpoints/openai/edit.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func EditEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
modelFile, input, err := readInput(c, so, ml, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := backend.EditGenerationOpenAIRequest(modelFile, input, cl, ml, so)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
35
core/http/endpoints/openai/embeddings.go
Normal file
35
core/http/endpoints/openai/embeddings.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/embeddings
|
||||||
|
func EmbeddingsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
modelFile, input, err := readInput(c, so, ml, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := backend.EmbeddingOpenAIRequest(modelFile, input, cl, ml, so)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
48
core/http/endpoints/openai/image.go
Normal file
48
core/http/endpoints/openai/image.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/images/create
|
||||||
|
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
|
||||||
|
curl http://localhost:8080/v1/images/generations \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"prompt": "A cute baby sea otter",
|
||||||
|
"n": 1,
|
||||||
|
"size": "512x512"
|
||||||
|
}'
|
||||||
|
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
func ImageEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
modelName, input, err := readInput(c, so, ml, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := backend.ImageGenerationOpenAIRequest(modelName, input, cl, ml, so)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating image request: +%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
@ -3,21 +3,21 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error {
|
func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
models, err := loader.ListModels()
|
models, err := ml.ListModels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var mm map[string]interface{} = map[string]interface{}{}
|
var mm map[string]interface{} = map[string]interface{}{}
|
||||||
|
|
||||||
dataModels := []schema.OpenAIModel{}
|
openAIModels := []schema.OpenAIModel{}
|
||||||
|
|
||||||
var filterFn func(name string) bool
|
var filterFn func(name string) bool
|
||||||
filter := c.Query("filter")
|
filter := c.Query("filter")
|
||||||
@ -40,13 +40,13 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
|
|||||||
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
||||||
|
|
||||||
// Start with the known configurations
|
// Start with the known configurations
|
||||||
for _, c := range cm.GetAllConfigs() {
|
for _, c := range cl.GetAllConfigs() {
|
||||||
if excludeConfigured {
|
if excludeConfigured {
|
||||||
mm[c.Model] = nil
|
mm[c.Model] = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if filterFn(c.Name) {
|
if filterFn(c.Name) {
|
||||||
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
|
openAIModels = append(openAIModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
|
|||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
// And only adds them if they shouldn't be skipped.
|
// And only adds them if they shouldn't be skipped.
|
||||||
if _, exists := mm[m]; !exists && filterFn(m) {
|
if _, exists := mm[m]; !exists && filterFn(m) {
|
||||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
openAIModels = append(openAIModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
|
|||||||
Data []schema.OpenAIModel `json:"data"`
|
Data []schema.OpenAIModel `json:"data"`
|
||||||
}{
|
}{
|
||||||
Object: "list",
|
Object: "list",
|
||||||
Data: dataModels,
|
Data: openAIModels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
57
core/http/endpoints/openai/request.go
Normal file
57
core/http/endpoints/openai/request.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func readInput(c *fiber.Ctx, o *schema.StartupOptions, ml *model.ModelLoader, randomModel bool) (string, *schema.OpenAIRequest, error) {
|
||||||
|
input := new(schema.OpenAIRequest)
|
||||||
|
ctx, cancel := context.WithCancel(o.Context)
|
||||||
|
input.Context = ctx
|
||||||
|
input.Cancel = cancel
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile := input.Model
|
||||||
|
|
||||||
|
if c.Params("model") != "" {
|
||||||
|
modelFile = c.Params("model")
|
||||||
|
}
|
||||||
|
|
||||||
|
received, _ := json.Marshal(input)
|
||||||
|
|
||||||
|
log.Debug().Msgf("Request received: %s", string(received))
|
||||||
|
|
||||||
|
// Set model from bearer token, if available
|
||||||
|
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
||||||
|
bearerExists := bearer != "" && ml.ExistsInModelPath(bearer)
|
||||||
|
|
||||||
|
// If no model was specified, take the first available
|
||||||
|
if modelFile == "" && !bearerExists && randomModel {
|
||||||
|
models, _ := ml.ListModels()
|
||||||
|
if len(models) > 0 {
|
||||||
|
modelFile = models[0]
|
||||||
|
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("No model specified, returning error")
|
||||||
|
return "", nil, fmt.Errorf("no model specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a model is found in bearer token takes precedence
|
||||||
|
if bearerExists {
|
||||||
|
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||||
|
modelFile = bearer
|
||||||
|
}
|
||||||
|
return modelFile, input, nil
|
||||||
|
}
|
49
core/http/endpoints/openai/transcription.go
Normal file
49
core/http/endpoints/openai/transcription.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/audio/create
|
||||||
|
func TranscriptEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
modelName, input, err := readInput(c, so, ml, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve the file data from the request
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst, err := utils.CreateTempFileFromMultipartFile(file, "", "transcription") // 3rd param formerly whisper
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||||
|
defer os.RemoveAll(path.Dir(dst))
|
||||||
|
|
||||||
|
tr, err := backend.TranscriptionOpenAIRequest(modelName, input, dst, cl, ml, so)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating transcription request: +%w", err)
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||||
|
// TODO: handle different outputs here
|
||||||
|
return c.Status(http.StatusOK).JSON(tr)
|
||||||
|
}
|
||||||
|
}
|
24
core/mqtt/manager.go
Normal file
24
core/mqtt/manager.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package mqtt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PLACEHOLDER DURING PART 1 OF THE REFACTOR
|
||||||
|
|
||||||
|
type MQTTManager struct {
|
||||||
|
configLoader *services.ConfigLoader
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
startupOptions *schema.StartupOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMQTTManager(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*MQTTManager, error) {
|
||||||
|
|
||||||
|
return &MQTTManager{
|
||||||
|
configLoader: cl,
|
||||||
|
modelLoader: ml,
|
||||||
|
startupOptions: options,
|
||||||
|
}, nil
|
||||||
|
}
|
138
core/services/backend_monitor.go
Normal file
138
core/services/backend_monitor.go
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendMonitor struct {
|
||||||
|
configLoader *ConfigLoader
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
options *schema.StartupOptions // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackendMonitor(configLoader *ConfigLoader, modelLoader *model.ModelLoader, options *schema.StartupOptions) *BackendMonitor {
|
||||||
|
return &BackendMonitor{
|
||||||
|
configLoader: configLoader,
|
||||||
|
modelLoader: modelLoader,
|
||||||
|
options: options,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
|
||||||
|
config, exists := bm.configLoader.GetConfig(model)
|
||||||
|
var backend string
|
||||||
|
if exists {
|
||||||
|
backend = config.Model
|
||||||
|
} else {
|
||||||
|
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||||
|
backend = model
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(backend, ".bin") {
|
||||||
|
backend = fmt.Sprintf("%s.bin", backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
pid, err := bm.modelLoader.GetGRPCPID(backend)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
|
||||||
|
backendProcess, err := gopsutil.NewProcess(int32(pid))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memInfo, err := backendProcess.MemoryInfo()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memPercent, err := backendProcess.MemoryPercent()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cpuPercent, err := backendProcess.CPUPercent()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &schema.BackendMonitorResponse{
|
||||||
|
MemoryInfo: memInfo,
|
||||||
|
MemoryPercent: memPercent,
|
||||||
|
CPUPercent: cpuPercent,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
|
||||||
|
config, exists := bm.configLoader.GetConfig(modelName)
|
||||||
|
var backendId string
|
||||||
|
if exists {
|
||||||
|
backendId = config.Model
|
||||||
|
} else {
|
||||||
|
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||||
|
backendId = modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(backendId, ".bin") {
|
||||||
|
backendId = fmt.Sprintf("%s.bin", backendId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return backendId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
|
||||||
|
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
|
||||||
|
if modelAddr == "" {
|
||||||
|
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
|
||||||
|
}
|
||||||
|
|
||||||
|
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
|
||||||
|
if rpcErr != nil {
|
||||||
|
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
||||||
|
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
||||||
|
if slbErr != nil {
|
||||||
|
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
||||||
|
}
|
||||||
|
return &proto.StatusResponse{
|
||||||
|
State: proto.StatusResponse_ERROR,
|
||||||
|
Memory: &proto.MemoryUsageData{
|
||||||
|
Total: val.MemoryInfo.VMS,
|
||||||
|
Breakdown: map[string]uint64{
|
||||||
|
"gopsutil-RSS": val.MemoryInfo.RSS,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) ShutdownModel(modelName string) error {
|
||||||
|
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return bm.modelLoader.ShutdownModel(backendId)
|
||||||
|
}
|
157
core/services/config.go
Normal file
157
core/services/config.go
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfigLoader struct {
|
||||||
|
configs map[string]schema.Config
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfigLoader() *ConfigLoader {
|
||||||
|
return &ConfigLoader{
|
||||||
|
configs: make(map[string]schema.Config),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: check this is correct post-merge
|
||||||
|
func (cm *ConfigLoader) LoadConfig(file string) error {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
c, err := schema.ReadSingleConfigFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot read config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.configs[c.Name] = *c
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) GetConfig(m string) (schema.Config, bool) {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
v, exists := cm.configs[m]
|
||||||
|
return v, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) GetAllConfigs() []schema.Config {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
var res []schema.Config
|
||||||
|
for _, v := range cm.configs {
|
||||||
|
res = append(res, v)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) ListConfigs() []string {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
var res []string
|
||||||
|
for k := range cm.configs {
|
||||||
|
res = append(res, k)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
entries, err := os.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
files := make([]fs.FileInfo, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
files = append(files, info)
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
// Skip templates, YAML and .keep files
|
||||||
|
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c, err := schema.ReadSingleConfigFile(filepath.Join(path, file.Name()))
|
||||||
|
if err == nil {
|
||||||
|
cm.configs[c.Name] = *c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preload prepare models if they are not local but url or huggingface repositories
|
||||||
|
func (cm *ConfigLoader) Preload(modelPath string) error {
|
||||||
|
cm.Lock()
|
||||||
|
defer cm.Unlock()
|
||||||
|
|
||||||
|
status := func(fileName, current, total string, percent float64) {
|
||||||
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msgf("Preloading models from %s", modelPath)
|
||||||
|
|
||||||
|
for _, config := range cm.configs {
|
||||||
|
|
||||||
|
// Download files and verify their SHA
|
||||||
|
for _, file := range config.DownloadFiles {
|
||||||
|
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||||||
|
|
||||||
|
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Create file path
|
||||||
|
filePath := filepath.Join(modelPath, file.Filename)
|
||||||
|
|
||||||
|
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelURL := config.PredictionOptions.Model
|
||||||
|
modelURL = utils.ConvertURL(modelURL)
|
||||||
|
|
||||||
|
if utils.LooksLikeURL(modelURL) {
|
||||||
|
// md5 of model name
|
||||||
|
md5Name := utils.MD5(modelURL)
|
||||||
|
|
||||||
|
// check if file exists
|
||||||
|
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
||||||
|
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cl *ConfigLoader) LoadConfigFile(file string) error {
|
||||||
|
cl.Lock()
|
||||||
|
defer cl.Unlock()
|
||||||
|
c, err := schema.ReadConfigFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot load config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cc := range c {
|
||||||
|
cl.configs[cc.Name] = *cc
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
160
core/services/gallery.go
Normal file
160
core/services/gallery.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GalleryApplier struct {
|
||||||
|
modelPath string
|
||||||
|
sync.Mutex
|
||||||
|
C chan gallery.GalleryOp
|
||||||
|
statuses map[string]*gallery.GalleryOpStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGalleryApplier(modelPath string) *GalleryApplier {
|
||||||
|
return &GalleryApplier{
|
||||||
|
modelPath: modelPath,
|
||||||
|
C: make(chan gallery.GalleryOp),
|
||||||
|
statuses: make(map[string]*gallery.GalleryOpStatus),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GalleryApplier) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
|
||||||
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
g.statuses[s] = op
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GalleryApplier) GetStatus(s string) *gallery.GalleryOpStatus {
|
||||||
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
|
||||||
|
return g.statuses[s]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GalleryApplier) GetAllStatus() map[string]*gallery.GalleryOpStatus {
|
||||||
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
|
||||||
|
return g.statuses
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GalleryApplier) Start(c context.Context, cm *ConfigLoader) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Done():
|
||||||
|
return
|
||||||
|
case op := <-g.C:
|
||||||
|
utils.ResetDownloadTimers()
|
||||||
|
|
||||||
|
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0})
|
||||||
|
|
||||||
|
// updates the status with an error
|
||||||
|
updateError := func(e error) {
|
||||||
|
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
// displayDownload displays the download progress
|
||||||
|
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||||
|
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||||
|
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||||
|
if op.GalleryName != "" {
|
||||||
|
if strings.Contains(op.GalleryName, "@") {
|
||||||
|
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
|
||||||
|
} else {
|
||||||
|
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = PrepareModel(g.modelPath, op.Req, cm, progressCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
updateError(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload models
|
||||||
|
err = cm.LoadConfigs(g.modelPath)
|
||||||
|
if err != nil {
|
||||||
|
updateError(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
type galleryModel struct {
|
||||||
|
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
||||||
|
ID string `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
||||||
|
|
||||||
|
config, err := gallery.GetInstallableModelFromURL(req.URL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||||
|
|
||||||
|
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func processRequests(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
||||||
|
var err error
|
||||||
|
for _, r := range requests {
|
||||||
|
utils.ResetDownloadTimers()
|
||||||
|
if r.ID == "" {
|
||||||
|
err = PrepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||||
|
} else {
|
||||||
|
if strings.Contains(r.ID, "@") {
|
||||||
|
err = gallery.InstallModelFromGallery(
|
||||||
|
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||||
|
} else {
|
||||||
|
err = gallery.InstallModelFromGalleryByName(
|
||||||
|
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
|
||||||
|
dat, err := os.ReadFile(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var requests []galleryModel
|
||||||
|
|
||||||
|
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return processRequests(modelPath, s, cm, galleries, requests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyGalleryFromString(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
|
||||||
|
var requests []galleryModel
|
||||||
|
err := json.Unmarshal([]byte(s), &requests)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return processRequests(modelPath, s, cm, galleries, requests)
|
||||||
|
}
|
29
core/services/metrics.go
Normal file
29
core/services/metrics.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"go.opentelemetry.io/otel/exporters/prometheus"
|
||||||
|
api "go.opentelemetry.io/otel/metric"
|
||||||
|
"go.opentelemetry.io/otel/sdk/metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
|
||||||
|
// If it does not return an error, make sure to call shutdown for proper cleanup.
|
||||||
|
func SetupMetrics() (*schema.LocalAIMetrics, error) {
|
||||||
|
exporter, err := prometheus.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||||
|
meter := provider.Meter("github.com/go-skynet/LocalAI")
|
||||||
|
|
||||||
|
apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &schema.LocalAIMetrics{
|
||||||
|
Meter: meter,
|
||||||
|
ApiTimeMetric: apiTimeMetric,
|
||||||
|
}, nil
|
||||||
|
}
|
100
core/startup/config_file_watcher.go
Normal file
100
core/startup/config_file_watcher.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package startup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/imdario/mergo"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WatchConfigDirectoryCloser func() error
|
||||||
|
|
||||||
|
func ReadApiKeysJson(configDir string, options *schema.StartupOptions) error {
|
||||||
|
fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json"))
|
||||||
|
if err == nil {
|
||||||
|
// Parse JSON content from the file
|
||||||
|
var fileKeys []string
|
||||||
|
err := json.Unmarshal(fileContent, &fileKeys)
|
||||||
|
if err == nil {
|
||||||
|
options.ApiKeys = append(options.ApiKeys, fileKeys...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadExternalBackendsJson(configDir string, options *schema.StartupOptions) error {
|
||||||
|
fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Parse JSON content from the file
|
||||||
|
var fileBackends map[string]string
|
||||||
|
err = json.Unmarshal(fileContent, &fileBackends)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = mergo.Merge(&options.ExternalGRPCBackends, fileBackends)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var CONFIG_FILE_UPDATES = map[string]func(configDir string, options *schema.StartupOptions) error{
|
||||||
|
"api_keys.json": ReadApiKeysJson,
|
||||||
|
"external_backends.json": ReadExternalBackendsJson,
|
||||||
|
}
|
||||||
|
|
||||||
|
func WatchConfigDirectory(configDir string, options *schema.StartupOptions) (WatchConfigDirectoryCloser, error) {
|
||||||
|
if len(configDir) == 0 {
|
||||||
|
return nil, fmt.Errorf("configDir blank")
|
||||||
|
}
|
||||||
|
configWatcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err)
|
||||||
|
}
|
||||||
|
ret := func() error {
|
||||||
|
configWatcher.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start listening for events.
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event, ok := <-configWatcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if event.Has(fsnotify.Write) {
|
||||||
|
for targetName, watchFn := range CONFIG_FILE_UPDATES {
|
||||||
|
if event.Name == targetName {
|
||||||
|
err := watchFn(configDir, options)
|
||||||
|
log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case _, ok := <-configWatcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Add a path.
|
||||||
|
err = configWatcher.Add(configDir)
|
||||||
|
if err != nil {
|
||||||
|
return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
93
core/startup/startup.go
Normal file
93
core/startup/startup.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package startup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Startup(opts ...schema.AppOption) (*services.ConfigLoader, *model.ModelLoader, *schema.StartupOptions, error) {
|
||||||
|
options := schema.NewStartupOptions(opts...)
|
||||||
|
|
||||||
|
ml := model.NewModelLoader(options.ModelPath)
|
||||||
|
|
||||||
|
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||||
|
if options.Debug {
|
||||||
|
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
||||||
|
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||||
|
|
||||||
|
cl := services.NewConfigLoader()
|
||||||
|
if err := cl.LoadConfigs(options.ModelPath); err != nil {
|
||||||
|
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.ConfigFile != "" {
|
||||||
|
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
|
||||||
|
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cl.Preload(options.ModelPath); err != nil {
|
||||||
|
log.Error().Msgf("error downloading models: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.PreloadJSONModels != "" {
|
||||||
|
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.PreloadModelsFromPath != "" {
|
||||||
|
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.Debug {
|
||||||
|
for _, v := range cl.ListConfigs() {
|
||||||
|
cfg, _ := cl.GetConfig(v)
|
||||||
|
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.AssetsDestination != "" {
|
||||||
|
// Extract files from the embedded FS
|
||||||
|
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
||||||
|
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// turn off any process that was started by GRPC if the context is canceled
|
||||||
|
go func() {
|
||||||
|
<-options.Context.Done()
|
||||||
|
log.Debug().Msgf("Context canceled, shutting down")
|
||||||
|
ml.StopAllGRPC()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if options.WatchDog {
|
||||||
|
wd := model.NewWatchDog(
|
||||||
|
ml,
|
||||||
|
options.WatchDogBusyTimeout,
|
||||||
|
options.WatchDogIdleTimeout,
|
||||||
|
options.WatchDogBusy,
|
||||||
|
options.WatchDogIdle)
|
||||||
|
ml.SetWatchDog(wd)
|
||||||
|
go wd.Run()
|
||||||
|
go func() {
|
||||||
|
<-options.Context.Done()
|
||||||
|
log.Debug().Msgf("Context canceled, shutting down")
|
||||||
|
wd.Shutdown()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return cl, ml, options, nil
|
||||||
|
}
|
@ -17,6 +17,53 @@ This section will collect how-to, notes and development documentation
|
|||||||
|
|
||||||
We use conventional commits and semantic versioning. Please follow the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) specification when writing commit messages.
|
We use conventional commits and semantic versioning. Please follow the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) specification when writing commit messages.
|
||||||
|
|
||||||
|
## LocalAI Project Structure
|
||||||
|
|
||||||
|
**LocalAI is made of multiple components, developed in multiple repositories:**
|
||||||
|
|
||||||
|
The core repository, containing the primary `local-ai` server code, gRPC stubs, this documentation website, and docker container building resources are all located at [mudler/LocalAI](https://github.com/mudler/LocalAI).
|
||||||
|
|
||||||
|
As LocalAI is designed to make use of multiple, independent model galleries, those are maintained seperately. The following public model galleries are available for use:
|
||||||
|
|
||||||
|
* [go-skynet/model-gallery](https://github.com/go-skynet/model-gallery) - The original gallery, the `golang` huggingface scraper ran into limits and was largely retired, so this now holds handmade yaml configs
|
||||||
|
* [dave-gray101/model-gallery](https://github.com/dave-gray101/model-gallery) - An automated gallery designed to track HuggingFace uploads and produce best-effort automatically generated configurations for LocalAI. It is designed to produce one LocalAI gallery per repository on HuggingFace.
|
||||||
|
|
||||||
|
### Directory Structure of this Repo
|
||||||
|
|
||||||
|
The core repository is broken up into the following primary chunks:
|
||||||
|
|
||||||
|
* `/backend`: gRPC protobuf specification and gRPC backends. Subfolders for each language.
|
||||||
|
* **`/core`**: golang sourcecode for the core LocalAI application. Broken down below.
|
||||||
|
* `/docs`: localai.io website that you are reading now
|
||||||
|
* `/examples`: example code integrating LocalAI to other projects and/or developer samples and tools
|
||||||
|
* `/internal`: **here be dragons**. Don't touch this, it's used for automatic versioning.
|
||||||
|
* `/models`: _No code here!_ This is where models are installed!
|
||||||
|
* **`/pkg`**: golang sourcecode that is intended to be reusable or at least widely imported across LocalAI. Broken down below
|
||||||
|
* `/prompt-templates`: _No code here!_ This is where **example** prompt templates were historically stored. Somewhat obsolete these days, model-galleries tend to replace manually creating these?
|
||||||
|
* `/tests`: Does what it says on the tin. Please write tests and put them here when you do.
|
||||||
|
|
||||||
|
The `core` folder is broken down further:
|
||||||
|
|
||||||
|
* **`/core/backend`**: code that interacts with a gRPC backend to perform AI tasks.
|
||||||
|
* `/core/http`: code specifically related to the REST server
|
||||||
|
* `/core/http/endpoints`: Has two subdirectories, `openai` and `localai` for binding the respective endpoints to the correct backend or service.
|
||||||
|
* `/core/mqtt`: core specifically related to the MQTT server. Stub for now. Coming soon!
|
||||||
|
* **`/core/services`**: code implementing functionality performed by `local-ai` itself, rather than delegated to a backend.
|
||||||
|
* `/core/startup`: code related specifically to application startup of `local-ai`. Potentially to be refactored to become a part of `/core/services` at a later date, or not.
|
||||||
|
|
||||||
|
The `pkg` folder is broken down further:
|
||||||
|
|
||||||
|
* `/pkg/assets`: Currently contains a single function related to extracting files from archives. Potentially to be refactored to become a part of `/core/utils` at a later date?
|
||||||
|
* **`/pkg/datamodel`**: Contains the data types and definitions used by the LocalAI project. Imported widely!
|
||||||
|
* `/pkg/gallery`: Code related to interacting with a `model-gallery`
|
||||||
|
* `/pkg/grammar`: Code related to BNF / functions for LLM
|
||||||
|
* `/pkg/grpc`: base classes and interfaces for gRPC backends to implement
|
||||||
|
* `/pkg/langchain`: langchain related code in golang
|
||||||
|
* **`/pkg/model`**: Code related to loading and initializing a model and creating the appropriate gRPC backend.
|
||||||
|
* `/pkg/stablediffusion`: Code related to stablediffusion in golang.
|
||||||
|
* `/pkg/utils`: Every real programmer knows what they are going to find in here... it's our junk drawer of utility functions.
|
||||||
|
|
||||||
|
|
||||||
## Creating a gRPC backend
|
## Creating a gRPC backend
|
||||||
|
|
||||||
LocalAI backends are `gRPC` servers.
|
LocalAI backends are `gRPC` servers.
|
||||||
|
@ -20,7 +20,7 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
|||||||
|
|
||||||
Returns an `audio/wav` file.
|
Returns an `audio/wav` file.
|
||||||
|
|
||||||
#### Setup
|
#### Text-To-Speech Setup
|
||||||
|
|
||||||
LocalAI supports [bark]({{%relref "model-compatibility/bark" %}}) , `piper` and `vall-e-x`:
|
LocalAI supports [bark]({{%relref "model-compatibility/bark" %}}) , `piper` and `vall-e-x`:
|
||||||
|
|
||||||
@ -52,6 +52,8 @@ Note:
|
|||||||
- The model name is case sensitive.
|
- The model name is case sensitive.
|
||||||
- LocalAI must be compiled with the `GO_TAGS=tts` flag.
|
- LocalAI must be compiled with the `GO_TAGS=tts` flag.
|
||||||
|
|
||||||
|
#### Music
|
||||||
|
|
||||||
LocalAI also has experimental support for `transformers-musicgen` for the generation of short musical compositions. Currently, this is implemented via the same requests used for text to speech:
|
LocalAI also has experimental support for `transformers-musicgen` for the generation of short musical compositions. Currently, this is implemented via the same requests used for text to speech:
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -62,7 +64,8 @@ curl --request POST \
|
|||||||
"backend": "transformers-musicgen",
|
"backend": "transformers-musicgen",
|
||||||
"model": "facebook/musicgen-medium",
|
"model": "facebook/musicgen-medium",
|
||||||
"input": "Cello Rave"
|
"input": "Cello Rave"
|
||||||
}' | aplay```
|
}' | aplay
|
||||||
|
```
|
||||||
|
|
||||||
Future versions of LocalAI will expose additional control over audio generation beyond the text prompt.
|
Future versions of LocalAI will expose additional control over audio generation beyond the text prompt.
|
||||||
|
|
||||||
|
4
go.mod
4
go.mod
@ -5,6 +5,7 @@ go 1.21
|
|||||||
require (
|
require (
|
||||||
github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf
|
github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf
|
||||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
|
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0
|
||||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
|
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
|
||||||
github.com/go-audio/wav v1.1.0
|
github.com/go-audio/wav v1.1.0
|
||||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
|
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
|
||||||
@ -15,7 +16,6 @@ require (
|
|||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hpcloud/tail v1.0.0
|
github.com/hpcloud/tail v1.0.0
|
||||||
github.com/imdario/mergo v0.3.16
|
github.com/imdario/mergo v0.3.16
|
||||||
github.com/json-iterator/go v1.1.12
|
|
||||||
github.com/mholt/archiver/v3 v3.5.1
|
github.com/mholt/archiver/v3 v3.5.1
|
||||||
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
|
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
|
||||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
|
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
|
||||||
@ -63,8 +63,6 @@ require (
|
|||||||
github.com/klauspost/pgzip v1.2.5 // indirect
|
github.com/klauspost/pgzip v1.2.5 // indirect
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
|
||||||
github.com/nwaples/rardecode v1.1.0 // indirect
|
github.com/nwaples/rardecode v1.1.0 // indirect
|
||||||
github.com/pierrec/lz4/v4 v4.1.2 // indirect
|
github.com/pierrec/lz4/v4 v4.1.2 // indirect
|
||||||
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
|
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
|
||||||
|
11
go.sum
11
go.sum
@ -24,8 +24,9 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L
|
|||||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
||||||
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
||||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
|
||||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ=
|
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ=
|
||||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
|
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
|
||||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||||
@ -74,7 +75,6 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||||
@ -88,8 +88,6 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO
|
|||||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||||
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
|
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
|
||||||
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
|
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
|
||||||
github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw=
|
github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw=
|
||||||
github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
|
github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
|
||||||
github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||||
@ -119,11 +117,6 @@ github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Cl
|
|||||||
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
|
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
|
||||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
|
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
|
||||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
|
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
|
||||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
|
||||||
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
|
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
|
||||||
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
|
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
|
||||||
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI=
|
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI=
|
||||||
|
120
main.go
120
main.go
@ -12,14 +12,14 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
api "github.com/go-skynet/LocalAI/api"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
"github.com/go-skynet/LocalAI/core/http"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
"github.com/go-skynet/LocalAI/core/startup"
|
||||||
"github.com/go-skynet/LocalAI/internal"
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
progressbar "github.com/schollz/progressbar/v3"
|
progressbar "github.com/schollz/progressbar/v3"
|
||||||
@ -190,6 +190,12 @@ func main() {
|
|||||||
EnvVars: []string{"PRELOAD_BACKEND_ONLY"},
|
EnvVars: []string{"PRELOAD_BACKEND_ONLY"},
|
||||||
Value: false,
|
Value: false,
|
||||||
},
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "localai-config-dir",
|
||||||
|
Usage: "Directory to use for the configuration files of LocalAI itself. This is NOT where model files should be placed.",
|
||||||
|
EnvVars: []string{"LOCALAI_CONFIG_DIR"},
|
||||||
|
Value: "./config",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Description: `
|
Description: `
|
||||||
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
|
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
|
||||||
@ -208,54 +214,54 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
|||||||
UsageText: `local-ai [options]`,
|
UsageText: `local-ai [options]`,
|
||||||
Copyright: "Ettore Di Giacinto",
|
Copyright: "Ettore Di Giacinto",
|
||||||
Action: func(ctx *cli.Context) error {
|
Action: func(ctx *cli.Context) error {
|
||||||
opts := []options.AppOption{
|
opts := []schema.AppOption{
|
||||||
options.WithConfigFile(ctx.String("config-file")),
|
schema.WithConfigFile(ctx.String("config-file")),
|
||||||
options.WithJSONStringPreload(ctx.String("preload-models")),
|
schema.WithJSONStringPreload(ctx.String("preload-models")),
|
||||||
options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
schema.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||||
options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
|
schema.WithModelPath(ctx.String("models-path")),
|
||||||
options.WithContextSize(ctx.Int("context-size")),
|
schema.WithContextSize(ctx.Int("context-size")),
|
||||||
options.WithDebug(ctx.Bool("debug")),
|
schema.WithDebug(ctx.Bool("debug")),
|
||||||
options.WithImageDir(ctx.String("image-path")),
|
schema.WithImageDir(ctx.String("image-path")),
|
||||||
options.WithAudioDir(ctx.String("audio-path")),
|
schema.WithAudioDir(ctx.String("audio-path")),
|
||||||
options.WithF16(ctx.Bool("f16")),
|
schema.WithF16(ctx.Bool("f16")),
|
||||||
options.WithStringGalleries(ctx.String("galleries")),
|
schema.WithStringGalleries(ctx.String("galleries")),
|
||||||
options.WithDisableMessage(false),
|
schema.WithDisableMessage(false),
|
||||||
options.WithCors(ctx.Bool("cors")),
|
schema.WithCors(ctx.Bool("cors")),
|
||||||
options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
schema.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
||||||
options.WithThreads(ctx.Int("threads")),
|
schema.WithThreads(ctx.Int("threads")),
|
||||||
options.WithBackendAssets(backendAssets),
|
schema.WithBackendAssets(backendAssets),
|
||||||
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
schema.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||||
options.WithUploadLimitMB(ctx.Int("upload-limit")),
|
schema.WithUploadLimitMB(ctx.Int("upload-limit")),
|
||||||
options.WithApiKeys(ctx.StringSlice("api-keys")),
|
schema.WithApiKeys(ctx.StringSlice("api-keys")),
|
||||||
options.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...),
|
schema.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...),
|
||||||
}
|
}
|
||||||
|
|
||||||
idleWatchDog := ctx.Bool("enable-watchdog-idle")
|
idleWatchDog := ctx.Bool("enable-watchdog-idle")
|
||||||
busyWatchDog := ctx.Bool("enable-watchdog-busy")
|
busyWatchDog := ctx.Bool("enable-watchdog-busy")
|
||||||
if idleWatchDog || busyWatchDog {
|
if idleWatchDog || busyWatchDog {
|
||||||
opts = append(opts, options.EnableWatchDog)
|
opts = append(opts, schema.EnableWatchDog)
|
||||||
if idleWatchDog {
|
if idleWatchDog {
|
||||||
opts = append(opts, options.EnableWatchDogIdleCheck)
|
opts = append(opts, schema.EnableWatchDogIdleCheck)
|
||||||
dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout"))
|
dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
opts = append(opts, options.SetWatchDogIdleTimeout(dur))
|
opts = append(opts, schema.SetWatchDogIdleTimeout(dur))
|
||||||
}
|
}
|
||||||
if busyWatchDog {
|
if busyWatchDog {
|
||||||
opts = append(opts, options.EnableWatchDogBusyCheck)
|
opts = append(opts, schema.EnableWatchDogBusyCheck)
|
||||||
dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout"))
|
dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
opts = append(opts, options.SetWatchDogBusyTimeout(dur))
|
opts = append(opts, schema.SetWatchDogBusyTimeout(dur))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ctx.Bool("parallel-requests") {
|
if ctx.Bool("parallel-requests") {
|
||||||
opts = append(opts, options.EnableParallelBackendRequests)
|
opts = append(opts, schema.EnableParallelBackendRequests)
|
||||||
}
|
}
|
||||||
if ctx.Bool("single-active-backend") {
|
if ctx.Bool("single-active-backend") {
|
||||||
opts = append(opts, options.EnableSingleBackend)
|
opts = append(opts, schema.EnableSingleBackend)
|
||||||
}
|
}
|
||||||
|
|
||||||
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
||||||
@ -263,30 +269,42 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
|||||||
for _, v := range externalgRPC {
|
for _, v := range externalgRPC {
|
||||||
backend := v[:strings.IndexByte(v, ':')]
|
backend := v[:strings.IndexByte(v, ':')]
|
||||||
uri := v[strings.IndexByte(v, ':')+1:]
|
uri := v[strings.IndexByte(v, ':')+1:]
|
||||||
opts = append(opts, options.WithExternalBackend(backend, uri))
|
opts = append(opts, schema.WithExternalBackend(backend, uri))
|
||||||
}
|
}
|
||||||
|
|
||||||
if ctx.Bool("autoload-galleries") {
|
if ctx.Bool("autoload-galleries") {
|
||||||
opts = append(opts, options.EnableGalleriesAutoload)
|
opts = append(opts, schema.EnableGalleriesAutoload)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ctx.Bool("preload-backend-only") {
|
if ctx.Bool("preload-backend-only") {
|
||||||
_, _, err := api.Startup(opts...)
|
_, _, _, err := startup.Startup(opts...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics, err := metrics.SetupMetrics()
|
metrics, err := services.SetupMetrics()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
opts = append(opts, options.WithMetrics(metrics))
|
opts = append(opts, schema.WithMetrics(metrics))
|
||||||
|
|
||||||
app, err := api.App(opts...)
|
cl, ml, options, err := startup.Startup(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
closeConfigWatcherFn, err := startup.WatchConfigDirectory(ctx.String("localai-config-dir"), options)
|
||||||
|
|
||||||
|
defer closeConfigWatcherFn()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed while watching configuration directory %s", ctx.String("localai-config-dir"))
|
||||||
|
}
|
||||||
|
|
||||||
|
appHTTP, err := http.App(cl, ml, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return app.Listen(ctx.String("address"))
|
return appHTTP.Listen(ctx.String("address"))
|
||||||
},
|
},
|
||||||
Commands: []*cli.Command{
|
Commands: []*cli.Command{
|
||||||
{
|
{
|
||||||
@ -384,16 +402,18 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
|||||||
|
|
||||||
text := strings.Join(ctx.Args().Slice(), " ")
|
text := strings.Join(ctx.Args().Slice(), " ")
|
||||||
|
|
||||||
opts := &options.Option{
|
opts := &schema.StartupOptions{
|
||||||
Loader: model.NewModelLoader(ctx.String("models-path")),
|
ModelPath: ctx.String("models-path"),
|
||||||
Context: context.Background(),
|
Context: context.Background(),
|
||||||
AudioDir: outputDir,
|
AudioDir: outputDir,
|
||||||
AssetsDestination: ctx.String("backend-assets-path"),
|
AssetsDestination: ctx.String("backend-assets-path"),
|
||||||
}
|
}
|
||||||
|
|
||||||
defer opts.Loader.StopAllGRPC()
|
loader := model.NewModelLoader(opts.ModelPath)
|
||||||
|
|
||||||
filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, opts.Loader, opts)
|
defer loader.StopAllGRPC()
|
||||||
|
|
||||||
|
filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, loader, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -446,13 +466,15 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
|||||||
language := ctx.String("language")
|
language := ctx.String("language")
|
||||||
threads := ctx.Int("threads")
|
threads := ctx.Int("threads")
|
||||||
|
|
||||||
opts := &options.Option{
|
opts := &schema.StartupOptions{
|
||||||
Loader: model.NewModelLoader(ctx.String("models-path")),
|
ModelPath: ctx.String("models-path"),
|
||||||
Context: context.Background(),
|
Context: context.Background(),
|
||||||
AssetsDestination: ctx.String("backend-assets-path"),
|
AssetsDestination: ctx.String("backend-assets-path"),
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := config.NewConfigLoader()
|
ml := model.NewModelLoader(opts.ModelPath)
|
||||||
|
|
||||||
|
cl := services.NewConfigLoader()
|
||||||
if err := cl.LoadConfigs(ctx.String("models-path")); err != nil {
|
if err := cl.LoadConfigs(ctx.String("models-path")); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -464,9 +486,9 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
|||||||
|
|
||||||
c.Threads = threads
|
c.Threads = threads
|
||||||
|
|
||||||
defer opts.Loader.StopAllGRPC()
|
defer ml.StopAllGRPC()
|
||||||
|
|
||||||
tr, err := backend.ModelTranscription(filename, language, opts.Loader, c, opts)
|
tr, err := backend.ModelTranscription(filename, language, ml, c, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1,83 +0,0 @@
|
|||||||
package metrics
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
||||||
"go.opentelemetry.io/otel/attribute"
|
|
||||||
"go.opentelemetry.io/otel/exporters/prometheus"
|
|
||||||
api "go.opentelemetry.io/otel/metric"
|
|
||||||
"go.opentelemetry.io/otel/sdk/metric"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Metrics struct {
|
|
||||||
meter api.Meter
|
|
||||||
apiTimeMetric api.Float64Histogram
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
|
|
||||||
// If it does not return an error, make sure to call shutdown for proper cleanup.
|
|
||||||
func SetupMetrics() (*Metrics, error) {
|
|
||||||
exporter, err := prometheus.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
|
||||||
meter := provider.Meter("github.com/go-skynet/LocalAI")
|
|
||||||
|
|
||||||
apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls"))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Metrics{
|
|
||||||
meter: meter,
|
|
||||||
apiTimeMetric: apiTimeMetric,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func MetricsHandler() fiber.Handler {
|
|
||||||
return adaptor.HTTPHandler(promhttp.Handler())
|
|
||||||
}
|
|
||||||
|
|
||||||
type apiMiddlewareConfig struct {
|
|
||||||
Filter func(c *fiber.Ctx) bool
|
|
||||||
metrics *Metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
func APIMiddleware(metrics *Metrics) fiber.Handler {
|
|
||||||
cfg := apiMiddlewareConfig{
|
|
||||||
metrics: metrics,
|
|
||||||
Filter: func(c *fiber.Ctx) bool {
|
|
||||||
if c.Path() == "/metrics" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
if cfg.Filter != nil && cfg.Filter(c) {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
path := c.Path()
|
|
||||||
method := c.Method()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
err := c.Next()
|
|
||||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
|
||||||
cfg.metrics.ObserveAPICall(method, path, elapsed)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Metrics) ObserveAPICall(method string, path string, duration float64) {
|
|
||||||
opts := api.WithAttributes(
|
|
||||||
attribute.String("method", method),
|
|
||||||
attribute.String("path", path),
|
|
||||||
)
|
|
||||||
m.apiTimeMetric.Record(context.Background(), duration, opts)
|
|
||||||
}
|
|
@ -22,11 +22,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
|||||||
applyModel := func(model *GalleryModel) error {
|
applyModel := func(model *GalleryModel) error {
|
||||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||||
|
|
||||||
var config Config
|
var config InstallableModel
|
||||||
|
|
||||||
if len(model.URL) > 0 {
|
if len(model.URL) > 0 {
|
||||||
var err error
|
var err error
|
||||||
config, err = GetGalleryConfigFromURL(model.URL)
|
config, err = GetInstallableModelFromURL(model.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -36,7 +36,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
config = Config{
|
config = InstallableModel{
|
||||||
ConfigFile: string(reYamlConfig),
|
ConfigFile: string(reYamlConfig),
|
||||||
Description: model.Description,
|
Description: model.Description,
|
||||||
License: model.License,
|
License: model.License,
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
package gallery
|
package gallery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/imdario/mergo"
|
"github.com/imdario/mergo"
|
||||||
@ -41,9 +37,9 @@ prompt_templates:
|
|||||||
content: ""
|
content: ""
|
||||||
|
|
||||||
*/
|
*/
|
||||||
// Config is the model configuration which contains all the model details
|
// InstallableModel is the model configuration which contains all the model details
|
||||||
// This configuration is read from the gallery endpoint and is used to download and install the model
|
// This configuration is read from the gallery endpoint and is used to download and install the model
|
||||||
type Config struct {
|
type InstallableModel struct {
|
||||||
Description string `yaml:"description"`
|
Description string `yaml:"description"`
|
||||||
License string `yaml:"license"`
|
License string `yaml:"license"`
|
||||||
URLs []string `yaml:"urls"`
|
URLs []string `yaml:"urls"`
|
||||||
@ -64,8 +60,8 @@ type PromptTemplate struct {
|
|||||||
Content string `yaml:"content"`
|
Content string `yaml:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGalleryConfigFromURL(url string) (Config, error) {
|
func GetInstallableModelFromURL(url string) (InstallableModel, error) {
|
||||||
var config Config
|
var config InstallableModel
|
||||||
err := utils.GetURI(url, func(url string, d []byte) error {
|
err := utils.GetURI(url, func(url string, d []byte) error {
|
||||||
return yaml.Unmarshal(d, &config)
|
return yaml.Unmarshal(d, &config)
|
||||||
})
|
})
|
||||||
@ -76,7 +72,7 @@ func GetGalleryConfigFromURL(url string) (Config, error) {
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadConfigFile(filePath string) (*Config, error) {
|
func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
|
||||||
// Read the YAML file
|
// Read the YAML file
|
||||||
yamlFile, err := os.ReadFile(filePath)
|
yamlFile, err := os.ReadFile(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -84,7 +80,7 @@ func ReadConfigFile(filePath string) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal YAML data into a Config struct
|
// Unmarshal YAML data into a Config struct
|
||||||
var config Config
|
var config InstallableModel
|
||||||
err = yaml.Unmarshal(yamlFile, &config)
|
err = yaml.Unmarshal(yamlFile, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
|
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
|
||||||
@ -93,7 +89,7 @@ func ReadConfigFile(filePath string) (*Config, error) {
|
|||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
func InstallModel(basePath, nameOverride string, config *InstallableModel, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
||||||
// Create base path if it doesn't exist
|
// Create base path if it doesn't exist
|
||||||
err := os.MkdirAll(basePath, 0755)
|
err := os.MkdirAll(basePath, 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -183,54 +179,3 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type progressWriter struct {
|
|
||||||
fileName string
|
|
||||||
total int64
|
|
||||||
written int64
|
|
||||||
downloadStatus func(string, string, string, float64)
|
|
||||||
hash hash.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
|
||||||
n, err = pw.hash.Write(p)
|
|
||||||
pw.written += int64(n)
|
|
||||||
|
|
||||||
if pw.total > 0 {
|
|
||||||
percentage := float64(pw.written) / float64(pw.total) * 100
|
|
||||||
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
|
||||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
|
||||||
} else {
|
|
||||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatBytes(bytes int64) string {
|
|
||||||
const unit = 1024
|
|
||||||
if bytes < unit {
|
|
||||||
return strconv.FormatInt(bytes, 10) + " B"
|
|
||||||
}
|
|
||||||
div, exp := int64(unit), 0
|
|
||||||
for n := bytes / unit; n >= unit; n /= unit {
|
|
||||||
div *= unit
|
|
||||||
exp++
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
|
||||||
}
|
|
||||||
|
|
||||||
func calculateSHA(filePath string) (string, error) {
|
|
||||||
file, err := os.Open(filePath)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
hash := sha256.New()
|
|
||||||
if _, err := io.Copy(hash, file); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
|
||||||
}
|
|
||||||
|
@ -16,7 +16,7 @@ var _ = Describe("Model test", func() {
|
|||||||
tempdir, err := os.MkdirTemp("", "test")
|
tempdir, err := os.MkdirTemp("", "test")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer os.RemoveAll(tempdir)
|
defer os.RemoveAll(tempdir)
|
||||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||||
@ -87,7 +87,7 @@ var _ = Describe("Model test", func() {
|
|||||||
tempdir, err := os.MkdirTemp("", "test")
|
tempdir, err := os.MkdirTemp("", "test")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer os.RemoveAll(tempdir)
|
defer os.RemoveAll(tempdir)
|
||||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||||
@ -103,7 +103,7 @@ var _ = Describe("Model test", func() {
|
|||||||
tempdir, err := os.MkdirTemp("", "test")
|
tempdir, err := os.MkdirTemp("", "test")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer os.RemoveAll(tempdir)
|
defer os.RemoveAll(tempdir)
|
||||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
|
err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
|
||||||
@ -129,7 +129,7 @@ var _ = Describe("Model test", func() {
|
|||||||
tempdir, err := os.MkdirTemp("", "test")
|
tempdir, err := os.MkdirTemp("", "test")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer os.RemoveAll(tempdir)
|
defer os.RemoveAll(tempdir)
|
||||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||||
|
18
pkg/gallery/op.go
Normal file
18
pkg/gallery/op.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package gallery
|
||||||
|
|
||||||
|
type GalleryOp struct {
|
||||||
|
Req GalleryModel
|
||||||
|
Id string
|
||||||
|
Galleries []Gallery
|
||||||
|
GalleryName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type GalleryOpStatus struct {
|
||||||
|
FileName string `json:"file_name"`
|
||||||
|
Error error `json:"error"`
|
||||||
|
Processed bool `json:"processed"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Progress float64 `json:"progress"`
|
||||||
|
TotalFileSize string `json:"file_size"`
|
||||||
|
DownloadedFileSize string `json:"downloaded_size"`
|
||||||
|
}
|
@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() {
|
|||||||
Context("requests", func() {
|
Context("requests", func() {
|
||||||
It("parses github with a branch", func() {
|
It("parses github with a branch", func() {
|
||||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||||
e, err := GetGalleryConfigFromURL(req.URL)
|
e, err := GetInstallableModelFromURL(req.URL)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(e.Name).To(Equal("gpt4all-j"))
|
Expect(e.Name).To(Equal("gpt4all-j"))
|
||||||
})
|
})
|
||||||
|
@ -6,8 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,8 +53,9 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
|
|||||||
return fmt.Errorf("unimplemented")
|
return fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) {
|
// TODO CHECK THIS
|
||||||
return schema.Result{}, fmt.Errorf("unimplemented")
|
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error) {
|
||||||
|
return schema.WhisperResult{}, fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *Base) TTS(*pb.TTSRequest) error {
|
func (llm *Base) TTS(*pb.TTSRequest) error {
|
||||||
|
@ -7,8 +7,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
)
|
)
|
||||||
@ -223,7 +223,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
|||||||
return client.TTS(ctx, in, opts...)
|
return client.TTS(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
|
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.WhisperResult, error) {
|
||||||
if !c.parallel {
|
if !c.parallel {
|
||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
defer c.opMutex.Unlock()
|
defer c.opMutex.Unlock()
|
||||||
@ -244,14 +244,14 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tresult := &schema.Result{}
|
tresult := &schema.WhisperResult{}
|
||||||
for _, s := range res.Segments {
|
for _, s := range res.Segments {
|
||||||
tks := []int{}
|
tks := []int{}
|
||||||
for _, t := range s.Tokens {
|
for _, t := range s.Tokens {
|
||||||
tks = append(tks, int(t))
|
tks = append(tks, int(t))
|
||||||
}
|
}
|
||||||
tresult.Segments = append(tresult.Segments,
|
tresult.Segments = append(tresult.Segments,
|
||||||
schema.Segment{
|
schema.WhisperSegment{
|
||||||
Text: s.Text,
|
Text: s.Text,
|
||||||
Id: int(s.Id),
|
Id: int(s.Id),
|
||||||
Start: time.Duration(s.Start),
|
Start: time.Duration(s.Start),
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package grpc
|
package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LLM interface {
|
type LLM interface {
|
||||||
@ -15,7 +15,7 @@ type LLM interface {
|
|||||||
Load(*pb.ModelOptions) error
|
Load(*pb.ModelOptions) error
|
||||||
Embeddings(*pb.PredictOptions) ([]float32, error)
|
Embeddings(*pb.PredictOptions) ([]float32, error)
|
||||||
GenerateImage(*pb.GenerateImageRequest) error
|
GenerateImage(*pb.GenerateImageRequest) error
|
||||||
AudioTranscription(*pb.TranscriptRequest) (schema.Result, error)
|
AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error)
|
||||||
TTS(*pb.TTSRequest) error
|
TTS(*pb.TTSRequest) error
|
||||||
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
||||||
Status() (pb.StatusResponse, error)
|
Status() (pb.StatusResponse, error)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.28.1
|
// protoc-gen-go v1.26.0
|
||||||
// protoc v3.6.1
|
// protoc v4.26.0
|
||||||
// source: backend.proto
|
// source: backend.proto
|
||||||
|
|
||||||
package proto
|
package proto
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// - protoc-gen-go-grpc v1.2.0
|
// - protoc-gen-go-grpc v1.3.0
|
||||||
// - protoc v3.6.1
|
// - protoc v4.26.0
|
||||||
// source: backend.proto
|
// source: backend.proto
|
||||||
|
|
||||||
package proto
|
package proto
|
||||||
@ -18,6 +18,19 @@ import (
|
|||||||
// Requires gRPC-Go v1.32.0 or later.
|
// Requires gRPC-Go v1.32.0 or later.
|
||||||
const _ = grpc.SupportPackageIsVersion7
|
const _ = grpc.SupportPackageIsVersion7
|
||||||
|
|
||||||
|
const (
|
||||||
|
Backend_Health_FullMethodName = "/backend.Backend/Health"
|
||||||
|
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
|
||||||
|
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
|
||||||
|
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
|
||||||
|
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
|
||||||
|
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
|
||||||
|
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
|
||||||
|
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
|
||||||
|
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
|
||||||
|
Backend_Status_FullMethodName = "/backend.Backend/Status"
|
||||||
|
)
|
||||||
|
|
||||||
// BackendClient is the client API for Backend service.
|
// BackendClient is the client API for Backend service.
|
||||||
//
|
//
|
||||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
@ -44,7 +57,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
|
|||||||
|
|
||||||
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
|
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
|
||||||
out := new(Reply)
|
out := new(Reply)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -53,7 +66,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
|
|||||||
|
|
||||||
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
|
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
|
||||||
out := new(Reply)
|
out := new(Reply)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -62,7 +75,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
|
|||||||
|
|
||||||
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
|
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
|
||||||
out := new(Result)
|
out := new(Result)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -70,7 +83,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
|
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
|
||||||
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
|
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -103,7 +116,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
|
|||||||
|
|
||||||
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
|
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
|
||||||
out := new(EmbeddingResult)
|
out := new(EmbeddingResult)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -112,7 +125,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
|
|||||||
|
|
||||||
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
|
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
|
||||||
out := new(Result)
|
out := new(Result)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -121,7 +134,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
|
|||||||
|
|
||||||
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
|
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
|
||||||
out := new(TranscriptResult)
|
out := new(TranscriptResult)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -130,7 +143,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
|
|||||||
|
|
||||||
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
|
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
|
||||||
out := new(Result)
|
out := new(Result)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -139,7 +152,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca
|
|||||||
|
|
||||||
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
|
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
|
||||||
out := new(TokenizationResponse)
|
out := new(TokenizationResponse)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -148,7 +161,7 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions,
|
|||||||
|
|
||||||
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
|
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
|
||||||
out := new(StatusResponse)
|
out := new(StatusResponse)
|
||||||
err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...)
|
err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -229,7 +242,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/Health",
|
FullMethod: Backend_Health_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
|
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
|
||||||
@ -247,7 +260,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/Predict",
|
FullMethod: Backend_Predict_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
|
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
|
||||||
@ -265,7 +278,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/LoadModel",
|
FullMethod: Backend_LoadModel_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
|
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
|
||||||
@ -304,7 +317,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/Embedding",
|
FullMethod: Backend_Embedding_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
|
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
|
||||||
@ -322,7 +335,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/GenerateImage",
|
FullMethod: Backend_GenerateImage_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
|
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
|
||||||
@ -340,7 +353,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/AudioTranscription",
|
FullMethod: Backend_AudioTranscription_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
|
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
|
||||||
@ -358,7 +371,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/TTS",
|
FullMethod: Backend_TTS_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
|
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
|
||||||
@ -376,7 +389,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/TokenizeString",
|
FullMethod: Backend_TokenizeString_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
|
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
|
||||||
@ -394,7 +407,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
|
|||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/backend.Backend/Status",
|
FullMethod: Backend_Status_FullMethodName,
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(BackendServer).Status(ctx, req.(*HealthMessage))
|
return srv.(BackendServer).Status(ctx, req.(*HealthMessage))
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/phayes/freeport"
|
"github.com/phayes/freeport"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@ -71,7 +71,7 @@ var AutoLoadBackends []string = []string{
|
|||||||
|
|
||||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||||
// It also loads the model
|
// It also loads the model
|
||||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) {
|
func (ml *ModelLoader) grpcModel(backend string, o *ModelOptions) func(string, string) (ModelAddress, error) {
|
||||||
return func(modelName, modelFile string) (ModelAddress, error) {
|
return func(modelName, modelFile string) (ModelAddress, error) {
|
||||||
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)
|
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
grammar "github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
process "github.com/mudler/go-processmanager"
|
process "github.com/mudler/go-processmanager"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type ModelOptions struct {
|
||||||
backendString string
|
backendString string
|
||||||
model string
|
model string
|
||||||
threads uint32
|
threads uint32
|
||||||
@ -23,14 +23,14 @@ type Options struct {
|
|||||||
parallelRequests bool
|
parallelRequests bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(*Options)
|
type Option func(*ModelOptions)
|
||||||
|
|
||||||
var EnableParallelRequests = func(o *Options) {
|
var EnableParallelRequests = func(o *ModelOptions) {
|
||||||
o.parallelRequests = true
|
o.parallelRequests = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithExternalBackend(name string, uri string) Option {
|
func WithExternalBackend(name string, uri string) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
if o.externalBackends == nil {
|
if o.externalBackends == nil {
|
||||||
o.externalBackends = make(map[string]string)
|
o.externalBackends = make(map[string]string)
|
||||||
}
|
}
|
||||||
@ -38,62 +38,81 @@ func WithExternalBackend(name string, uri string) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Currently, LocalAI isn't ready for backends to be yanked out from under it - so this is a little overcomplicated to allow non-overwriting updates
|
||||||
|
func WithExternalBackends(backends map[string]string, overwrite bool) Option {
|
||||||
|
return func(o *ModelOptions) {
|
||||||
|
if backends == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if o.externalBackends == nil {
|
||||||
|
o.externalBackends = backends
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for name, url := range backends {
|
||||||
|
_, exists := o.externalBackends[name]
|
||||||
|
if !exists || overwrite {
|
||||||
|
o.externalBackends[name] = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithGRPCAttempts(attempts int) Option {
|
func WithGRPCAttempts(attempts int) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.grpcAttempts = attempts
|
o.grpcAttempts = attempts
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithGRPCAttemptsDelay(delay int) Option {
|
func WithGRPCAttemptsDelay(delay int) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.grpcAttemptsDelay = delay
|
o.grpcAttemptsDelay = delay
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendString(backend string) Option {
|
func WithBackendString(backend string) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.backendString = backend
|
o.backendString = backend
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithModel(modelFile string) Option {
|
func WithModel(modelFile string) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.model = modelFile
|
o.model = modelFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
|
func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.gRPCOptions = opts
|
o.gRPCOptions = opts
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithThreads(threads uint32) Option {
|
func WithThreads(threads uint32) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.threads = threads
|
o.threads = threads
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithAssetDir(assetDir string) Option {
|
func WithAssetDir(assetDir string) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.assetDir = assetDir
|
o.assetDir = assetDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContext(ctx context.Context) Option {
|
func WithContext(ctx context.Context) Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.context = ctx
|
o.context = ctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithSingleActiveBackend() Option {
|
func WithSingleActiveBackend() Option {
|
||||||
return func(o *Options) {
|
return func(o *ModelOptions) {
|
||||||
o.singleActiveBackend = true
|
o.singleActiveBackend = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOptions(opts ...Option) *Options {
|
func NewOptions(opts ...Option) *ModelOptions {
|
||||||
o := &Options{
|
o := &ModelOptions{
|
||||||
gRPCOptions: &pb.ModelOptions{},
|
gRPCOptions: &pb.ModelOptions{},
|
||||||
context: context.Background(),
|
context: context.Background(),
|
||||||
grpcAttempts: 20,
|
grpcAttempts: 20,
|
||||||
|
@ -1,16 +1,11 @@
|
|||||||
package api_config
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -152,11 +147,6 @@ type TemplateConfig struct {
|
|||||||
Functions string `yaml:"function"`
|
Functions string `yaml:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigLoader struct {
|
|
||||||
configs map[string]Config
|
|
||||||
sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) SetFunctionCallString(s string) {
|
func (c *Config) SetFunctionCallString(s string) {
|
||||||
c.functionCallString = s
|
c.functionCallString = s
|
||||||
}
|
}
|
||||||
@ -193,11 +183,6 @@ func DefaultConfig(modelFile string) *Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigLoader() *ConfigLoader {
|
|
||||||
return &ConfigLoader{
|
|
||||||
configs: make(map[string]Config),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func ReadConfigFile(file string) ([]*Config, error) {
|
func ReadConfigFile(file string) ([]*Config, error) {
|
||||||
c := &[]*Config{}
|
c := &[]*Config{}
|
||||||
f, err := os.ReadFile(file)
|
f, err := os.ReadFile(file)
|
||||||
@ -211,7 +196,7 @@ func ReadConfigFile(file string) ([]*Config, error) {
|
|||||||
return *c, nil
|
return *c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadConfig(file string) (*Config, error) {
|
func ReadSingleConfigFile(file string) (*Config, error) {
|
||||||
c := &Config{}
|
c := &Config{}
|
||||||
f, err := os.ReadFile(file)
|
f, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -224,136 +209,192 @@ func ReadConfig(file string) (*Config, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfigFile(file string) error {
|
func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) {
|
||||||
cm.Lock()
|
if input.Echo {
|
||||||
defer cm.Unlock()
|
config.Echo = input.Echo
|
||||||
c, err := ReadConfigFile(file)
|
}
|
||||||
if err != nil {
|
if input.TopK != 0 {
|
||||||
return fmt.Errorf("cannot load config file: %w", err)
|
config.TopK = input.TopK
|
||||||
|
}
|
||||||
|
if input.TopP != 0 {
|
||||||
|
config.TopP = input.TopP
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, cc := range c {
|
if input.Backend != "" {
|
||||||
cm.configs[cc.Name] = *cc
|
config.Backend = input.Backend
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfig(file string) error {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
c, err := ReadConfig(file)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot read config file: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.configs[c.Name] = *c
|
if input.ClipSkip != 0 {
|
||||||
return nil
|
config.Diffusers.ClipSkip = input.ClipSkip
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
v, exists := cm.configs[m]
|
|
||||||
return v, exists
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigLoader) GetAllConfigs() []Config {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
var res []Config
|
|
||||||
for _, v := range cm.configs {
|
|
||||||
res = append(res, v)
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigLoader) ListConfigs() []string {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
var res []string
|
|
||||||
for k := range cm.configs {
|
|
||||||
res = append(res, k)
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preload prepare models if they are not local but url or huggingface repositories
|
|
||||||
func (cm *ConfigLoader) Preload(modelPath string) error {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
|
|
||||||
status := func(fileName, current, total string, percent float64) {
|
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msgf("Preloading models from %s", modelPath)
|
if input.ModelBaseName != "" {
|
||||||
|
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||||
|
}
|
||||||
|
|
||||||
for i, config := range cm.configs {
|
if input.NegativePromptScale != 0 {
|
||||||
|
config.NegativePromptScale = input.NegativePromptScale
|
||||||
|
}
|
||||||
|
|
||||||
// Download files and verify their SHA
|
if input.UseFastTokenizer {
|
||||||
for _, file := range config.DownloadFiles {
|
config.UseFastTokenizer = input.UseFastTokenizer
|
||||||
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
}
|
||||||
|
|
||||||
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
|
if input.NegativePrompt != "" {
|
||||||
return err
|
config.NegativePrompt = input.NegativePrompt
|
||||||
}
|
}
|
||||||
// Create file path
|
|
||||||
filePath := filepath.Join(modelPath, file.Filename)
|
|
||||||
|
|
||||||
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
|
if input.RopeFreqBase != 0 {
|
||||||
return err
|
config.RopeFreqBase = input.RopeFreqBase
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.RopeFreqScale != 0 {
|
||||||
|
config.RopeFreqScale = input.RopeFreqScale
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Grammar != "" {
|
||||||
|
config.Grammar = input.Grammar
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Temperature != 0 {
|
||||||
|
config.Temperature = input.Temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Maxtokens != 0 {
|
||||||
|
config.Maxtokens = input.Maxtokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.RepeatPenalty != 0 {
|
||||||
|
config.RepeatPenalty = input.RepeatPenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Keep != 0 {
|
||||||
|
config.Keep = input.Keep
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Batch != 0 {
|
||||||
|
config.Batch = input.Batch
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.F16 {
|
||||||
|
config.F16 = input.F16
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.IgnoreEOS {
|
||||||
|
config.IgnoreEOS = input.IgnoreEOS
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Seed != 0 {
|
||||||
|
config.Seed = input.Seed
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Mirostat != 0 {
|
||||||
|
config.LLMConfig.Mirostat = input.Mirostat
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.MirostatETA != 0 {
|
||||||
|
config.LLMConfig.MirostatETA = input.MirostatETA
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.MirostatTAU != 0 {
|
||||||
|
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.TypicalP != 0 {
|
||||||
|
config.TypicalP = input.TypicalP
|
||||||
|
}
|
||||||
|
|
||||||
|
switch stop := input.Stop.(type) {
|
||||||
|
case string:
|
||||||
|
if stop != "" {
|
||||||
|
config.StopWords = append(config.StopWords, stop)
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
for _, pp := range stop {
|
||||||
|
if s, ok := pp.(string); ok {
|
||||||
|
config.StopWords = append(config.StopWords, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
modelURL := config.PredictionOptions.Model
|
// Decode each request's message content
|
||||||
modelURL = utils.ConvertURL(modelURL)
|
index := 0
|
||||||
|
for i, m := range input.Messages {
|
||||||
|
switch content := m.Content.(type) {
|
||||||
|
case string:
|
||||||
|
input.Messages[i].StringContent = content
|
||||||
|
case []interface{}:
|
||||||
|
dat, _ := json.Marshal(content)
|
||||||
|
c := []Content{}
|
||||||
|
json.Unmarshal(dat, &c)
|
||||||
|
for _, pp := range c {
|
||||||
|
if pp.Type == "text" {
|
||||||
|
input.Messages[i].StringContent = pp.Text
|
||||||
|
} else if pp.Type == "image_url" {
|
||||||
|
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||||
|
base64, err := utils.GetBase64Image(pp.ImageURL.URL)
|
||||||
|
if err == nil {
|
||||||
|
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||||
|
// set a placeholder for each image
|
||||||
|
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
||||||
|
index++
|
||||||
|
} else {
|
||||||
|
fmt.Print("Failed encoding image", err)
|
||||||
|
}
|
||||||
|
|
||||||
if utils.LooksLikeURL(modelURL) {
|
|
||||||
// md5 of model name
|
|
||||||
md5Name := utils.MD5(modelURL)
|
|
||||||
|
|
||||||
// check if file exists
|
|
||||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
|
||||||
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cc := cm.configs[i]
|
|
||||||
c := &cc
|
|
||||||
c.PredictionOptions.Model = md5Name
|
|
||||||
cm.configs[i] = *c
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
// TODO: check that this was merged correctly? I _think_ it is?
|
||||||
|
switch inputs := input.Input.(type) {
|
||||||
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
case string:
|
||||||
cm.Lock()
|
if inputs != "" {
|
||||||
defer cm.Unlock()
|
config.InputStrings = append(config.InputStrings, inputs)
|
||||||
entries, err := os.ReadDir(path)
|
}
|
||||||
if err != nil {
|
case []interface{}:
|
||||||
return err
|
for _, pp := range inputs {
|
||||||
}
|
switch i := pp.(type) {
|
||||||
files := make([]fs.FileInfo, 0, len(entries))
|
case string:
|
||||||
for _, entry := range entries {
|
config.InputStrings = append(config.InputStrings, i)
|
||||||
info, err := entry.Info()
|
case []interface{}:
|
||||||
if err != nil {
|
tokens := []int{}
|
||||||
return err
|
for _, ii := range i {
|
||||||
}
|
tokens = append(tokens, int(ii.(float64)))
|
||||||
files = append(files, info)
|
}
|
||||||
}
|
config.InputToken = append(config.InputToken, tokens)
|
||||||
for _, file := range files {
|
}
|
||||||
// Skip templates, YAML and .keep files
|
}
|
||||||
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
}
|
||||||
continue
|
|
||||||
}
|
// Can be either a string or an object
|
||||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
switch fnc := input.FunctionCall.(type) {
|
||||||
if err == nil {
|
case string:
|
||||||
cm.configs[c.Name] = *c
|
if fnc != "" {
|
||||||
}
|
config.SetFunctionCallString(fnc)
|
||||||
}
|
}
|
||||||
|
case map[string]interface{}:
|
||||||
return nil
|
var name string
|
||||||
|
n, exists := fnc["name"]
|
||||||
|
if exists {
|
||||||
|
nn, e := n.(string)
|
||||||
|
if e {
|
||||||
|
name = nn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.SetFunctionCallNameString(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p := input.Prompt.(type) {
|
||||||
|
case string:
|
||||||
|
config.PromptStrings = append(config.PromptStrings, p)
|
||||||
|
case []interface{}:
|
||||||
|
for _, pp := range p {
|
||||||
|
if s, ok := pp.(string); ok {
|
||||||
|
config.PromptStrings = append(config.PromptStrings, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -1,11 +1,10 @@
|
|||||||
package api_config_test
|
package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
. "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() {
|
|||||||
Context("Test Read configuration functions", func() {
|
Context("Test Read configuration functions", func() {
|
||||||
configFile = os.Getenv("CONFIG_FILE")
|
configFile = os.Getenv("CONFIG_FILE")
|
||||||
It("Test ReadConfigFile", func() {
|
It("Test ReadConfigFile", func() {
|
||||||
config, err := ReadConfigFile(configFile)
|
config, err := schema.ReadConfigFile(configFile)
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(config).ToNot(BeNil())
|
Expect(config).ToNot(BeNil())
|
||||||
// two configs in config.yaml
|
// two configs in config.yaml
|
||||||
@ -28,12 +27,8 @@ var _ = Describe("Test cases for config related functions", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("Test LoadConfigs", func() {
|
It("Test LoadConfigs", func() {
|
||||||
cm := NewConfigLoader()
|
cm := services.NewConfigLoader()
|
||||||
opts := options.NewOptions()
|
err := cm.LoadConfigs(os.Getenv("MODELS_PATH"))
|
||||||
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
|
||||||
options.WithModelLoader(modelLoader)(opts)
|
|
||||||
|
|
||||||
err := cm.LoadConfigs(opts.Loader.ModelPath)
|
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(cm.ListConfigs()).ToNot(BeNil())
|
Expect(cm.ListConfigs()).ToNot(BeNil())
|
||||||
|
|
39
pkg/schema/localai.go
Normal file
39
pkg/schema/localai.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||||
|
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendMonitorRequest struct {
|
||||||
|
Model string `json:"model" yaml:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackendMonitorResponse struct {
|
||||||
|
MemoryInfo *gopsutil.MemoryInfoStat
|
||||||
|
MemoryPercent float32
|
||||||
|
CPUPercent float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type TTSRequest struct {
|
||||||
|
Model string `json:"model" yaml:"model"`
|
||||||
|
Input string `json:"input" yaml:"input"`
|
||||||
|
Backend string `json:"backend" yaml:"backend"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LocalAIMetrics struct {
|
||||||
|
Meter metric.Meter
|
||||||
|
ApiTimeMetric metric.Float64Histogram
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *LocalAIMetrics) ObserveAPICall(method string, path string, duration float64) {
|
||||||
|
opts := metric.WithAttributes(
|
||||||
|
attribute.String("method", method),
|
||||||
|
attribute.String("path", path),
|
||||||
|
)
|
||||||
|
m.ApiTimeMetric.Record(context.Background(), duration, opts)
|
||||||
|
}
|
@ -3,8 +3,6 @@ package schema
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,7 +88,7 @@ type ChatCompletionResponseFormat struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIRequest struct {
|
type OpenAIRequest struct {
|
||||||
config.PredictionOptions
|
PredictionOptions
|
||||||
|
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Cancel context.CancelFunc
|
Cancel context.CancelFunc
|
@ -1,4 +1,4 @@
|
|||||||
package api_config
|
package schema
|
||||||
|
|
||||||
type PredictionOptions struct {
|
type PredictionOptions struct {
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
package options
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -6,16 +6,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Option struct {
|
type StartupOptions struct {
|
||||||
Context context.Context
|
Context context.Context
|
||||||
ConfigFile string
|
ConfigFile string
|
||||||
Loader *model.ModelLoader
|
ModelPath string
|
||||||
UploadLimitMB, Threads, ContextSize int
|
UploadLimitMB, Threads, ContextSize int
|
||||||
F16 bool
|
F16 bool
|
||||||
Debug, DisableMessage bool
|
Debug, DisableMessage bool
|
||||||
@ -26,7 +24,7 @@ type Option struct {
|
|||||||
PreloadModelsFromPath string
|
PreloadModelsFromPath string
|
||||||
CORSAllowOrigins string
|
CORSAllowOrigins string
|
||||||
ApiKeys []string
|
ApiKeys []string
|
||||||
Metrics *metrics.Metrics
|
Metrics *LocalAIMetrics
|
||||||
|
|
||||||
Galleries []gallery.Gallery
|
Galleries []gallery.Gallery
|
||||||
|
|
||||||
@ -47,12 +45,14 @@ type Option struct {
|
|||||||
ModelsURL []string
|
ModelsURL []string
|
||||||
|
|
||||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||||
|
|
||||||
|
LocalAIConfigDir string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppOption func(*Option)
|
type AppOption func(*StartupOptions)
|
||||||
|
|
||||||
func NewOptions(o ...AppOption) *Option {
|
func NewStartupOptions(o ...AppOption) *StartupOptions {
|
||||||
opt := &Option{
|
opt := &StartupOptions{
|
||||||
Context: context.Background(),
|
Context: context.Background(),
|
||||||
UploadLimitMB: 15,
|
UploadLimitMB: 15,
|
||||||
Threads: 1,
|
Threads: 1,
|
||||||
@ -67,57 +67,57 @@ func NewOptions(o ...AppOption) *Option {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithModelsURL(urls ...string) AppOption {
|
func WithModelsURL(urls ...string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.ModelsURL = urls
|
o.ModelsURL = urls
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCors(b bool) AppOption {
|
func WithCors(b bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.CORS = b
|
o.CORS = b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDog = func(o *Option) {
|
var EnableWatchDog = func(o *StartupOptions) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDogIdleCheck = func(o *Option) {
|
var EnableWatchDogIdleCheck = func(o *StartupOptions) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
o.WatchDogIdle = true
|
o.WatchDogIdle = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDogBusyCheck = func(o *Option) {
|
var EnableWatchDogBusyCheck = func(o *StartupOptions) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
o.WatchDogBusy = true
|
o.WatchDogBusy = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.WatchDogBusyTimeout = t
|
o.WatchDogBusyTimeout = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.WatchDogIdleTimeout = t
|
o.WatchDogIdleTimeout = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableSingleBackend = func(o *Option) {
|
var EnableSingleBackend = func(o *StartupOptions) {
|
||||||
o.SingleBackend = true
|
o.SingleBackend = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableParallelBackendRequests = func(o *Option) {
|
var EnableParallelBackendRequests = func(o *StartupOptions) {
|
||||||
o.ParallelBackendRequests = true
|
o.ParallelBackendRequests = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableGalleriesAutoload = func(o *Option) {
|
var EnableGalleriesAutoload = func(o *StartupOptions) {
|
||||||
o.AutoloadGalleries = true
|
o.AutoloadGalleries = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithExternalBackend(name string, uri string) AppOption {
|
func WithExternalBackend(name string, uri string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
if o.ExternalGRPCBackends == nil {
|
if o.ExternalGRPCBackends == nil {
|
||||||
o.ExternalGRPCBackends = make(map[string]string)
|
o.ExternalGRPCBackends = make(map[string]string)
|
||||||
}
|
}
|
||||||
@ -126,25 +126,25 @@ func WithExternalBackend(name string, uri string) AppOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithCorsAllowOrigins(b string) AppOption {
|
func WithCorsAllowOrigins(b string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.CORSAllowOrigins = b
|
o.CORSAllowOrigins = b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendAssetsOutput(out string) AppOption {
|
func WithBackendAssetsOutput(out string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.AssetsDestination = out
|
o.AssetsDestination = out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendAssets(f embed.FS) AppOption {
|
func WithBackendAssets(f embed.FS) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.BackendAssets = f
|
o.BackendAssets = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithStringGalleries(galls string) AppOption {
|
func WithStringGalleries(galls string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
if galls == "" {
|
if galls == "" {
|
||||||
log.Debug().Msgf("no galleries to load")
|
log.Debug().Msgf("no galleries to load")
|
||||||
o.Galleries = []gallery.Gallery{}
|
o.Galleries = []gallery.Gallery{}
|
||||||
@ -159,96 +159,102 @@ func WithStringGalleries(galls string) AppOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.Galleries = append(o.Galleries, galleries...)
|
o.Galleries = append(o.Galleries, galleries...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContext(ctx context.Context) AppOption {
|
func WithContext(ctx context.Context) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.Context = ctx
|
o.Context = ctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithYAMLConfigPreload(configFile string) AppOption {
|
func WithYAMLConfigPreload(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.PreloadModelsFromPath = configFile
|
o.PreloadModelsFromPath = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithJSONStringPreload(configFile string) AppOption {
|
func WithJSONStringPreload(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.PreloadJSONModels = configFile
|
o.PreloadJSONModels = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func WithConfigFile(configFile string) AppOption {
|
func WithConfigFile(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.ConfigFile = configFile
|
o.ConfigFile = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
func WithModelPath(path string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.Loader = loader
|
o.ModelPath = path
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithUploadLimitMB(limit int) AppOption {
|
func WithUploadLimitMB(limit int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.UploadLimitMB = limit
|
o.UploadLimitMB = limit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithThreads(threads int) AppOption {
|
func WithThreads(threads int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.Threads = threads
|
o.Threads = threads
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContextSize(ctxSize int) AppOption {
|
func WithContextSize(ctxSize int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.ContextSize = ctxSize
|
o.ContextSize = ctxSize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithF16(f16 bool) AppOption {
|
func WithF16(f16 bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.F16 = f16
|
o.F16 = f16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithDebug(debug bool) AppOption {
|
func WithDebug(debug bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.Debug = debug
|
o.Debug = debug
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithDisableMessage(disableMessage bool) AppOption {
|
func WithDisableMessage(disableMessage bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.DisableMessage = disableMessage
|
o.DisableMessage = disableMessage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithAudioDir(audioDir string) AppOption {
|
func WithAudioDir(audioDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.AudioDir = audioDir
|
o.AudioDir = audioDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithImageDir(imageDir string) AppOption {
|
func WithImageDir(imageDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.ImageDir = imageDir
|
o.ImageDir = imageDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithApiKeys(apiKeys []string) AppOption {
|
func WithApiKeys(apiKeys []string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.ApiKeys = apiKeys
|
o.ApiKeys = apiKeys
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithMetrics(meter *metrics.Metrics) AppOption {
|
func WithMetrics(metrics *LocalAIMetrics) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *StartupOptions) {
|
||||||
o.Metrics = meter
|
o.Metrics = metrics
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithLocalAIConfigDir(configDir string) AppOption {
|
||||||
|
return func(o *StartupOptions) {
|
||||||
|
o.LocalAIConfigDir = configDir
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -2,7 +2,7 @@ package schema
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
type Segment struct {
|
type WhisperSegment struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Start time.Duration `json:"start"`
|
Start time.Duration `json:"start"`
|
||||||
End time.Duration `json:"end"`
|
End time.Duration `json:"end"`
|
||||||
@ -10,7 +10,7 @@ type Segment struct {
|
|||||||
Tokens []int `json:"tokens"`
|
Tokens []int `json:"tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Result struct {
|
type WhisperResult struct {
|
||||||
Segments []Segment `json:"segments"`
|
Segments []WhisperSegment `json:"segments"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}
|
}
|
81
pkg/utils/file.go
Normal file
81
pkg/utils/file.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateTempFileFromMultipartFile(file *multipart.FileHeader, tempDir string, tempPattern string) (string, error) {
|
||||||
|
|
||||||
|
f, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Create a temporary file in the requested directory:
|
||||||
|
outputFile, err := os.CreateTemp(tempDir, tempPattern)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer outputFile.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(outputFile, f); err != nil {
|
||||||
|
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, outputFile, err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputFile.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateTempFileFromBase64(base64data string, tempDir string, tempPattern string) (string, error) {
|
||||||
|
if len(base64data) == 0 {
|
||||||
|
return "", fmt.Errorf("base64data empty?")
|
||||||
|
}
|
||||||
|
//base 64 decode the file and write it somewhere
|
||||||
|
// that we will cleanup
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(base64data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Create a temporary file in the requested directory:
|
||||||
|
outputFile, err := os.CreateTemp(tempDir, tempPattern)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer outputFile.Close()
|
||||||
|
// write the base64 result
|
||||||
|
writer := bufio.NewWriter(outputFile)
|
||||||
|
_, err = writer.Write(decoded)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return outputFile.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateTempFileFromUrl(url string, tempDir string, tempPattern string) (string, error) {
|
||||||
|
// Get the data
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Create the file
|
||||||
|
out, err := os.CreateTemp(tempDir, tempPattern)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
// Write the body to file
|
||||||
|
_, err = io.Copy(out, resp.Body)
|
||||||
|
return out.Name(), err
|
||||||
|
}
|
@ -3,18 +3,38 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HuggingFacePrefix = "huggingface://"
|
||||||
|
HTTPPrefix = "http://"
|
||||||
|
HTTPSPrefix = "https://"
|
||||||
|
GithubURI = "github:"
|
||||||
|
GithubURI2 = "github://"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getRecognizedURIPrefixes() []string {
|
||||||
|
return []string{
|
||||||
|
HuggingFacePrefix,
|
||||||
|
HTTPPrefix,
|
||||||
|
HTTPSPrefix,
|
||||||
|
GithubURI,
|
||||||
|
GithubURI2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetURI(url string, f func(url string, i []byte) error) error {
|
func GetURI(url string, f func(url string, i []byte) error) error {
|
||||||
url = ConvertURL(url)
|
url = ConvertURL(url)
|
||||||
|
|
||||||
@ -52,20 +72,8 @@ func GetURI(url string, f func(url string, i []byte) error) error {
|
|||||||
return f(url, body)
|
return f(url, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
HuggingFacePrefix = "huggingface://"
|
|
||||||
HTTPPrefix = "http://"
|
|
||||||
HTTPSPrefix = "https://"
|
|
||||||
GithubURI = "github:"
|
|
||||||
GithubURI2 = "github://"
|
|
||||||
)
|
|
||||||
|
|
||||||
func LooksLikeURL(s string) bool {
|
func LooksLikeURL(s string) bool {
|
||||||
return strings.HasPrefix(s, HTTPPrefix) ||
|
return slices.Contains(getRecognizedURIPrefixes(), s)
|
||||||
strings.HasPrefix(s, HTTPSPrefix) ||
|
|
||||||
strings.HasPrefix(s, HuggingFacePrefix) ||
|
|
||||||
strings.HasPrefix(s, GithubURI) ||
|
|
||||||
strings.HasPrefix(s, GithubURI2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertURL(s string) string {
|
func ConvertURL(s string) string {
|
||||||
@ -241,6 +249,37 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||||
|
// encodes it in base64 and returns the base64 string
|
||||||
|
func GetBase64Image(s string) (string, error) {
|
||||||
|
if strings.HasPrefix(s, "http") {
|
||||||
|
// download the image
|
||||||
|
resp, err := http.Get(s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// read the image data into memory
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// encode the image data in base64
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(data)
|
||||||
|
|
||||||
|
// return the base64 string
|
||||||
|
return encoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||||
|
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||||
|
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("not valid string")
|
||||||
|
}
|
||||||
|
|
||||||
type progressWriter struct {
|
type progressWriter struct {
|
||||||
fileName string
|
fileName string
|
||||||
total int64
|
total int64
|
||||||
|
@ -3,16 +3,16 @@ package integration_test
|
|||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Integration Tests involving reflection in liue of code generation", func() {
|
var _ = Describe("Integration Tests involving reflection in liue of code generation", func() {
|
||||||
Context("config.TemplateConfig and model.TemplateType must stay in sync", func() {
|
Context("schema.TemplateConfig and model.TemplateType must stay in sync", func() {
|
||||||
|
|
||||||
ttc := reflect.TypeOf(config.TemplateConfig{})
|
ttc := reflect.TypeOf(schema.TemplateConfig{})
|
||||||
|
|
||||||
It("TemplateConfig and TemplateType should have the same number of valid values", func() {
|
It("TemplateConfig and TemplateType should have the same number of valid values", func() {
|
||||||
const lastValidTemplateType = model.IntegrationTestTemplate - 1
|
const lastValidTemplateType = model.IntegrationTestTemplate - 1
|
||||||
|
Loading…
Reference in New Issue
Block a user