[Refactor]: Core/API Split (#1506)

Refactors api folder to core, creates firm split between backend code and api frontend.
This commit is contained in:
Dave 2024-01-05 09:34:56 -05:00 committed by GitHub
parent bcf02449b3
commit ab7b4d5ee9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
77 changed files with 3441 additions and 3117 deletions

4
.gitignore vendored
View File

@ -19,8 +19,8 @@ LocalAI
local-ai local-ai
# prevent above rules from omitting the helm chart # prevent above rules from omitting the helm chart
!charts/* !charts/*
# prevent above rules from omitting the api/localai folder # prevent above rules from omitting the core/**/localai folder
!api/localai !core/**/localai
# Ignore models # Ignore models
models/* models/*

View File

@ -88,7 +88,7 @@ ENV NVIDIA_VISIBLE_DEVICES=all
WORKDIR /build WORKDIR /build
COPY . . COPY . .
COPY .git . COPY .git/ .git/
RUN make prepare RUN make prepare
# stablediffusion does not tolerate a newer version of abseil, build it first # stablediffusion does not tolerate a newer version of abseil, build it first

View File

@ -1,302 +0,0 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/localai"
"github.com/go-skynet/LocalAI/api/openai"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
options := options.NewOptions(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
modelPath := options.Loader.ModelPath
if len(options.ModelsURL) > 0 {
for _, url := range options.ModelsURL {
if utils.LooksLikeURL(url) {
// md5 of model name
md5Name := utils.MD5(url)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := utils.DownloadFile(url, filepath.Join(modelPath, md5Name)+".yaml", "", func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
if err != nil {
log.Error().Msgf("error loading model: %s", err.Error())
}
}
}
}
}
cl := config.NewConfigLoader()
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if err := cl.Preload(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error downloading models: %s", err.Error())
}
if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopAllGRPC()
}()
if options.WatchDog {
wd := model.NewWatchDog(
options.Loader,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
options.Loader.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}
return options, cl, nil
}
func App(opts ...options.AppOption) (*fiber.App, error) {
options, cl, err := Startup(opts...)
if err != nil {
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.DisableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
},
})
if options.Debug {
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
// Default middleware config
app.Use(recover.New())
if options.Metrics != nil {
app.Use(metrics.APIMiddleware(options.Metrics))
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(options.ApiKeys) == 0 {
return c.Next()
}
// Check for api_keys.json file
fileContent, err := os.ReadFile("api_keys.json")
if err == nil {
// Parse JSON content from the file
var fileKeys []string
err := json.Unmarshal(fileContent, &fileKeys)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
}
// Add file keys to options.ApiKeys
options.ApiKeys = append(options.ApiKeys, fileKeys...)
}
if len(options.ApiKeys) == 0 {
return c.Next()
}
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range options.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if options.CORS {
var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
}
app.Use(c)
}
// LocalAI API endpoints
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
galleryService.Start(options.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
}
if options.AudioDir != "" {
app.Static("/generated-audio", options.AudioDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
// Experimental Backend Statistics Module
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/metrics", metrics.MetricsHandler())
return app, nil
}

View File

@ -1,61 +0,0 @@
package backend
import (
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModel(c.Model),
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
}),
})
inferenceModel, err := loader.BackendLoader(
opts...,
)
if err != nil {
return nil, err
}
fn := func() error {
_, err := inferenceModel.GenerateImage(
o.Context,
&proto.GenerateImageRequest{
Height: int32(height),
Width: int32(width),
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(c.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: c.Diffusers.EnableParameters,
})
return err
}
return fn, nil
}

View File

@ -1,167 +0,0 @@
package backend
import (
"context"
"os"
"regexp"
"strings"
"sync"
"unicode/utf8"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
}
type TokenUsage struct {
Prompt int
Completion int
}
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
var inferenceModel *grpc.Client
var err error
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
if err != nil {
return nil, err
}
}
}
if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return nil, err
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
opts.Images = images
tokenUsage := TokenUsage{}
// check the per-model feature flag for usage, since tokenCallback may have a cost.
// Defaults to off as for now it is still experimental
if c.FeatureFlag.Enabled("usage") {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}
if tokenCallback != nil {
ss := ""
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
partialRune = append(partialRune, chars...)
for len(partialRune) > 0 {
r, size := utf8.DecodeRune(partialRune)
if r == utf8.RuneError {
// incomplete rune, wait for more bytes
break
}
tokenCallback(string(r), tokenUsage)
ss += string(r)
partialRune = partialRune[size:]
}
})
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return LLMResponse{}, err
}
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
}
return fn, nil
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
func Finetune(config config.Config, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}
for _, c := range config.Cutstrings {
mu.Lock()
reg, ok := cutstrings[c]
if !ok {
cutstrings[c] = regexp.MustCompile(c)
reg = cutstrings[c]
}
mu.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}
for _, c := range config.TrimSpace {
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
}
for _, c := range config.TrimSuffix {
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
}
return prediction
}

View File

@ -1,39 +0,0 @@
package backend
import (
"context"
"fmt"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
})
whisperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if whisperModel == nil {
return nil, fmt.Errorf("could not load whisper model")
}
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
})
}

View File

@ -1,162 +0,0 @@
package localai
import (
"context"
"fmt"
"strings"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type BackendMonitor struct {
configLoader *config.ConfigLoader
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
options: options,
}
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.options.Loader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
input := new(BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", err
}
config, exists := bm.configLoader.GetConfig(input.Model)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = input.Model
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
model := bm.options.Loader.CheckIsLoaded(backendId)
if model == "" {
return fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return c.JSON(proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
})
}
return c.JSON(status)
}
}
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
return bm.options.Loader.ShutdownModel(backendId)
}
}

View File

@ -1,326 +0,0 @@
package localai
import (
"context"
"fmt"
"os"
"slices"
"strings"
"sync"
json "github.com/json-iterator/go"
"gopkg.in/yaml.v3"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type galleryOp struct {
req gallery.GalleryModel
id string
galleries []gallery.Gallery
galleryName string
}
type galleryOpStatus struct {
FileName string `json:"file_name"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}
type galleryApplier struct {
modelPath string
sync.Mutex
C chan galleryOp
statuses map[string]*galleryOpStatus
}
func NewGalleryService(modelPath string) *galleryApplier {
return &galleryApplier{
modelPath: modelPath,
C: make(chan galleryOp),
statuses: make(map[string]*galleryOpStatus),
}
}
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" {
if strings.Contains(op.galleryName, "@") {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
}
} else {
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
err = cm.Preload(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
/// Endpoint Service
type ModelGalleryService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *galleryApplier
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
return ModelGalleryService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.getAllStatus())
}
}
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- galleryOp{
req: input.GalleryModel,
id: uuid.String(),
galleryName: input.ID,
galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

View File

@ -1,32 +0,0 @@
package localai
import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
)
type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Backend string `json:"backend" yaml:"backend"`
}
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, o.Loader, o)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View File

@ -1,399 +0,0 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
emptyMessage := ""
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
return true
})
close(responses)
}
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if config.FunctionsConfig.NoActionFunctionName != "" {
noActionName = config.FunctionsConfig.NoActionFunctionName
}
if config.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
// process functions if we have any defined or if we have a function call string
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
log.Debug().Msgf("Response needs to process functions")
processFunctions = true
noActionGrammar := grammar.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("")
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
}
// functions are not supported in stream mode (yet?)
toStream := input.Stream && !processFunctions
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
MessageIndex: messageIndex,
}
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
if toStream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
if toStream {
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, o.Loader, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: "stop",
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s = utils.EscapeNewLines(s)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)
ss["arguments"] = string(d)
ss["name"] = func_name
// if do nothing, reply with a message
if func_name == noActionName {
log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(d), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = backend.Finetune(*config, predInput, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
return
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
} else {
// otherwise reply with the function call
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
})
}
return
}
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -1,199 +0,0 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
Text: s,
},
},
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
log.Debug().Msgf("Sending goroutine: %s", s)
responses <- resp
return true
})
close(responses)
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("`input`: %+v", input)
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
log.Debug().Msgf("Parameter Config: %+v", config)
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Completion != "" {
templateFile = config.TemplateConfig.Completion
}
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
predInput := config.PromptStrings[0]
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
}
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, o.Loader, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
for ev := range responses {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: "stop",
},
},
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for k, i := range config.PromptStrings {
if templateFile != "" {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}
r, tokenUsage, err := ComputeChoices(
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -1,94 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Edit != "" {
templateFile = config.TemplateConfig.Edit
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for _, i := range config.InputStrings {
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -1,78 +0,0 @@
package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
items := []schema.Item{}
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -1,239 +0,0 @@
package openai
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func downloadFile(url string) (string, error) {
// Get the data
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Create the file
out, err := os.CreateTemp("", "image")
if err != nil {
return "", err
}
defer out.Close()
// Write the body to file
_, err = io.Copy(out, resp.Body)
return out.Name(), err
}
// https://platform.openai.com/docs/api-reference/images/create
/*
*
curl http://localhost:8080/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cute baby sea otter",
"n": 1,
"size": "512x512"
}'
*
*/
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, o, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
if m == "" {
m = model.StableDiffusionBackend
}
log.Debug().Msgf("Loading model: %+v", m)
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
src := ""
if input.File != "" {
fileData := []byte{}
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
out, err := downloadFile(input.File)
if err != nil {
return fmt.Errorf("failed downloading file:%w", err)
}
defer os.RemoveAll(out)
fileData, err = os.ReadFile(out)
if err != nil {
return fmt.Errorf("failed reading file:%w", err)
}
} else {
// base 64 decode the file and write it somewhere
// that we will cleanup
fileData, err = base64.StdEncoding.DecodeString(input.File)
if err != nil {
return err
}
}
// Create a temporary file
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
if err != nil {
return err
}
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(fileData)
if err != nil {
outputFile.Close()
return err
}
outputFile.Close()
src = outputFile.Name()
defer os.RemoveAll(src)
}
log.Debug().Msgf("Parameter Config: %+v", config)
switch config.Backend {
case "stablediffusion":
config.Backend = model.StableDiffusionBackend
case "tinydream":
config.Backend = model.TinyDreamBackend
case "":
config.Backend = model.StableDiffusionBackend
}
sizeParts := strings.Split(input.Size, "x")
if len(sizeParts) != 2 {
return fmt.Errorf("Invalid value for 'size'")
}
width, err := strconv.Atoi(sizeParts[0])
if err != nil {
return fmt.Errorf("Invalid value for 'size'")
}
height, err := strconv.Atoi(sizeParts[1])
if err != nil {
return fmt.Errorf("Invalid value for 'size'")
}
b64JSON := false
if input.ResponseFormat.Type == "b64_json" {
b64JSON = true
}
// src and clip_skip
var result []schema.Item
for _, i := range config.PromptStrings {
n := input.N
if input.N == 0 {
n = 1
}
for j := 0; j < n; j++ {
prompts := strings.Split(i, "|")
positive_prompt := prompts[0]
negative_prompt := ""
if len(prompts) > 1 {
negative_prompt = prompts[1]
}
mode := 0
step := config.Step
if step == 0 {
step = 15
}
if input.Mode != 0 {
mode = input.Mode
}
if input.Step != 0 {
step = input.Step
}
tempDir := ""
if !b64JSON {
tempDir = o.ImageDir
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
if err != nil {
return err
}
outputFile.Close()
output := outputFile.Name() + ".png"
// Rename the temporary file
err = os.Rename(outputFile.Name(), output)
if err != nil {
return err
}
baseURL := c.BaseURL()
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o)
if err != nil {
return err
}
if err := fn(); err != nil {
return err
}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
data, err := os.ReadFile(output)
if err != nil {
return err
}
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-images/" + base
}
result = append(result, *item)
}
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Data: result,
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -1,55 +0,0 @@
package openai
import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *config.Config,
o *options.Option,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
n := req.N // number of completions to return
result := []schema.Choice{}
if n == 0 {
n = 1
}
images := []string{}
for _, m := range req.Messages {
images = append(images, m.StringImages...)
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}
tokenUsage := backend.TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, backend.TokenUsage{}, err
}
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
return result, tokenUsage, err
}

View File

@ -1,336 +0,0 @@
package openai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
config "github.com/go-skynet/LocalAI/api/config"
options "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *schema.OpenAIRequest, error) {
loader := o.Loader
input := new(schema.OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
modelFile := input.Model
if c.Params("model") != "" {
modelFile = c.Params("model")
}
received, _ := json.Marshal(input)
log.Debug().Msgf("Request received: %s", string(received))
// Set model from bearer token, if available
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelFile == "" && !bearerExists && randomModel {
models, _ := loader.ListModels()
if len(models) > 0 {
modelFile = models[0]
log.Debug().Msgf("No model specified, using: %s", modelFile)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", nil, fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelFile = bearer
}
return modelFile, input, nil
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func getBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}
func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != 0 {
config.TopK = input.TopK
}
if input.TopP != 0 {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != 0 {
config.Temperature = input.Temperature
}
if input.Maxtokens != 0 {
config.Maxtokens = input.Maxtokens
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := getBase64Image(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.F16 {
config.F16 = input.F16
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != 0 {
config.Seed = input.Seed
}
if input.Mirostat != 0 {
config.LLMConfig.Mirostat = input.Mirostat
}
if input.MirostatETA != 0 {
config.LLMConfig.MirostatETA = input.MirostatETA
}
if input.MirostatTAU != 0 {
config.LLMConfig.MirostatTAU = input.MirostatTAU
}
if input.TypicalP != 0 {
config.TypicalP = input.TypicalP
}
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
}
func readConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
var cfg *config.Config
defaults := func() {
cfg = config.DefaultConfig(modelFile)
cfg.ContextSize = ctx
cfg.Threads = threads
cfg.F16 = f16
cfg.Debug = debug
}
cfgExisting, exists := cm.GetConfig(modelFile)
if !exists {
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cm.GetConfig(modelFile)
if exists {
cfg = &cfgExisting
} else {
defaults()
}
} else {
defaults()
}
} else {
cfg = &cfgExisting
}
// Set the parameters for the language model prediction
updateConfig(cfg, input)
// Don't allow 0 as setting
if cfg.Threads == 0 {
if threads != 0 {
cfg.Threads = threads
} else {
cfg.Threads = 4
}
}
// Enforce debug flag if passed from CLI
if debug {
cfg.Debug = true
}
return cfg, input, nil
}

View File

@ -1,71 +0,0 @@
package openai
import (
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, o, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
return err
}
f, err := file.Open()
if err != nil {
return err
}
defer f.Close()
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return err
}
defer os.RemoveAll(dir)
dst := filepath.Join(dir, path.Base(file.Filename))
dstFile, err := os.Create(dst)
if err != nil {
return err
}
if _, err := io.Copy(dstFile, f); err != nil {
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
return err
}
log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
if err != nil {
return err
}
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(tr)
}
}

View File

@ -8,7 +8,7 @@ import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav" "github.com/go-audio/wav"
"github.com/go-skynet/LocalAI/api/schema" "github.com/go-skynet/LocalAI/pkg/schema"
) )
func sh(c string) (string, error) { func sh(c string) (string, error) {
@ -29,8 +29,8 @@ func audioToWav(src, dst string) error {
return nil return nil
} }
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) { func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.WhisperResult, error) {
res := schema.Result{} res := schema.WhisperResult{}
dir, err := os.MkdirTemp("", "whisper") dir, err := os.MkdirTemp("", "whisper")
if err != nil { if err != nil {
@ -90,7 +90,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
tokens = append(tokens, t.Id) tokens = append(tokens, t.Id)
} }
segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens} segment := schema.WhisperSegment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
res.Segments = append(res.Segments, segment) res.Segments = append(res.Segments, segment)
res.Text += s.Text res.Text += s.Text

View File

@ -4,9 +4,9 @@ package main
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import ( import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/pkg/grpc/base" "github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/schema"
) )
type Whisper struct { type Whisper struct {
@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) { func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.WhisperResult, error) {
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
} }

0
config/.keep Normal file
View File

View File

@ -2,14 +2,17 @@ package backend
import ( import (
"fmt" "fmt"
"time"
config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
) )
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() ([]float32, error), error) {
if !c.Embeddings { if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
} }
@ -27,6 +30,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile), model.WithModel(modelFile),
model.WithContext(o.Context), model.WithContext(o.Context),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
}) })
if c.Backend == "" { if c.Backend == "" {
@ -90,3 +94,51 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
return embeds, nil return embeds, nil
}, nil }, nil
} }
func EmbeddingOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
items := []schema.Item{}
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := ModelEmbedding("", s, ml, *config, startupOptions)
if err != nil {
return nil, err
}
embeddings, err := embedFn()
if err != nil {
return nil, err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := ModelEmbedding(s, []int{}, ml, *config, startupOptions)
if err != nil {
return nil, err
}
embeddings, err := embedFn()
if err != nil {
return nil, err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
return &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}, nil
}

210
core/backend/image.go Normal file
View File

@ -0,0 +1,210 @@
package backend
import (
"encoding/base64"
"fmt"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() error, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModel(c.Model),
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
}),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
})
inferenceModel, err := loader.BackendLoader(
opts...,
)
if err != nil {
return nil, err
}
fn := func() error {
_, err := inferenceModel.GenerateImage(
o.Context,
&proto.GenerateImageRequest{
Height: int32(height),
Width: int32(width),
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(c.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: c.Diffusers.EnableParameters,
})
return err
}
return fn, nil
}
func ImageGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
id := uuid.New().String()
created := int(time.Now().Unix())
if modelName == "" {
modelName = model.StableDiffusionBackend
}
log.Debug().Msgf("Loading model: %+v", modelName)
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, fmt.Errorf("failed reading parameters from request: %w", err)
}
src := ""
if input.File != "" {
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
src, err = utils.CreateTempFileFromUrl(input.File, "", "image-src")
if err != nil {
return nil, fmt.Errorf("failed downloading file:%w", err)
}
} else {
src, err = utils.CreateTempFileFromBase64(input.File, "", "base64-image-src")
if err != nil {
return nil, fmt.Errorf("error creating temporary image source file: %w", err)
}
}
}
log.Debug().Msgf("Parameter Config: %+v", config)
switch config.Backend {
case "stablediffusion":
config.Backend = model.StableDiffusionBackend
case "tinydream":
config.Backend = model.TinyDreamBackend
case "":
config.Backend = model.StableDiffusionBackend
}
sizeParts := strings.Split(input.Size, "x")
if len(sizeParts) != 2 {
return nil, fmt.Errorf("invalid value for 'size'")
}
width, err := strconv.Atoi(sizeParts[0])
if err != nil {
return nil, fmt.Errorf("invalid value for 'size'")
}
height, err := strconv.Atoi(sizeParts[1])
if err != nil {
return nil, fmt.Errorf("invalid value for 'size'")
}
b64JSON := false
if input.ResponseFormat.Type == "b64_json" {
b64JSON = true
}
// src and clip_skip
var result []schema.Item
for _, i := range config.PromptStrings {
n := input.N
if input.N == 0 {
n = 1
}
for j := 0; j < n; j++ {
prompts := strings.Split(i, "|")
positive_prompt := prompts[0]
negative_prompt := ""
if len(prompts) > 1 {
negative_prompt = prompts[1]
}
mode := 0
step := config.Step
if step == 0 {
step = 15
}
if input.Mode != 0 {
mode = input.Mode
}
if input.Step != 0 {
step = input.Step
}
tempDir := ""
if !b64JSON {
tempDir = startupOptions.ImageDir
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
if err != nil {
return nil, err
}
outputFile.Close()
output := outputFile.Name() + ".png"
// Rename the temporary file
err = os.Rename(outputFile.Name(), output)
if err != nil {
return nil, err
}
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, startupOptions)
if err != nil {
return nil, err
}
if err := fn(); err != nil {
return nil, err
}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
data, err := os.ReadFile(output)
if err != nil {
return nil, err
}
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = path.Join(startupOptions.ImageDir, base)
}
result = append(result, *item)
}
}
return &schema.OpenAIResponse{
ID: id,
Created: created,
Data: result,
}, nil
}

861
core/backend/llm.go Normal file
View File

@ -0,0 +1,861 @@
package backend
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
////////// TYPES //////////////
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
}
// TODO: Test removing this and using the variant in pkg/schema someday?
type TokenUsage struct {
Prompt int
Completion int
}
type TemplateConfigBindingFn func(*schema.Config) *string
// type LLMStreamProcessor func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse)
/////// CONSTS ///////////
const DEFAULT_NO_ACTION_NAME = "answer"
const DEFAULT_NO_ACTION_DESCRIPTION = "use this action to answer without performing any action"
////// INFERENCE /////////
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
var inferenceModel *grpc.Client
var err error
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
if err != nil {
return nil, err
}
}
}
if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return nil, err
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
opts.Images = images
tokenUsage := TokenUsage{}
// check the per-model feature flag for usage, since tokenCallback may have a cost.
// Defaults to off as for now it is still experimental
if c.FeatureFlag.Enabled("usage") {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}
if tokenCallback != nil {
ss := ""
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
partialRune = append(partialRune, chars...)
for len(partialRune) > 0 {
r, size := utf8.DecodeRune(partialRune)
if r == utf8.RuneError {
// incomplete rune, wait for more bytes
break
}
tokenCallback(string(r), tokenUsage)
ss += string(r)
partialRune = partialRune[size:]
}
})
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return LLMResponse{}, err
}
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
}
return fn, nil
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
func Finetune(config schema.Config, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}
for _, c := range config.Cutstrings {
mu.Lock()
reg, ok := cutstrings[c]
if !ok {
cutstrings[c] = regexp.MustCompile(c)
reg = cutstrings[c]
}
mu.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}
for _, c := range config.TrimSpace {
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
}
for _, c := range config.TrimSuffix {
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
}
return prediction
}
////// CONFIG AND REQUEST HANDLING ///////////////
func ReadConfigFromFileAndCombineWithOpenAIRequest(modelFile string, input *schema.OpenAIRequest, cm *services.ConfigLoader, startupOptions *schema.StartupOptions) (*schema.Config, *schema.OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(startupOptions.ModelPath, modelFile+".yaml")
var cfg *schema.Config
defaults := func() {
cfg = schema.DefaultConfig(modelFile)
cfg.ContextSize = startupOptions.ContextSize
cfg.Threads = startupOptions.Threads
cfg.F16 = startupOptions.F16
cfg.Debug = startupOptions.Debug
}
cfgExisting, exists := cm.GetConfig(modelFile)
if !exists {
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cm.GetConfig(modelFile)
if exists {
cfg = &cfgExisting
} else {
defaults()
}
} else {
defaults()
}
} else {
cfg = &cfgExisting
}
// Set the parameters for the language model prediction
schema.UpdateConfigFromOpenAIRequest(cfg, input)
// Don't allow 0 as setting
if cfg.Threads == 0 {
if startupOptions.Threads != 0 {
cfg.Threads = startupOptions.Threads
} else {
cfg.Threads = 4
}
}
// Enforce debug flag if passed from CLI
if startupOptions.Debug {
cfg.Debug = true
}
return cfg, input, nil
}
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *schema.Config,
o *schema.StartupOptions,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, TokenUsage) bool) ([]schema.Choice, TokenUsage, error) {
n := req.N // number of completions to return
result := []schema.Choice{}
if n == 0 {
n = 1
}
images := []string{}
for _, m := range req.Messages {
images = append(images, m.StringImages...)
}
// get the model function to call for the result
predFunc, err := ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
if err != nil {
return result, TokenUsage{}, err
}
tokenUsage := TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, TokenUsage{}, err
}
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
return result, tokenUsage, err
}
// TODO: No functions???? Commonize with prepareChatGenerationOpenAIRequest below?
func prepareGenerationOpenAIRequest(bindingFn TemplateConfigBindingFn, modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, error) {
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
log.Debug().Msgf("Parameter Config: %+v", config)
configTemplate := bindingFn(config)
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if (*configTemplate == "") && (ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model))) {
*configTemplate = config.Model
}
if *configTemplate == "" {
return nil, fmt.Errorf(("failed to find templateConfig"))
}
return config, nil
}
////////// SPECIFIC REQUESTS //////////////
// TODO: For round one of the refactor, give each of the three primary text endpoints their own function?
// SEMITODO: During a merge, edit/completion were semi-combined - but remain nominally split
// Can cleanup into a common form later if possible easier if they are all here for now
// If they remain different, extract each of these named segments to a seperate file
func prepareChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, string, bool, error) {
// IMPORTANT DEFS
funcs := grammar.Functions{}
// The Basic Begining
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, "", false, fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
// Special Input/Config Handling
// Allow the user to set custom actions via config file
// to be "embedded" in each model - but if they are missing, use defaults.
if config.FunctionsConfig.NoActionFunctionName == "" {
config.FunctionsConfig.NoActionFunctionName = DEFAULT_NO_ACTION_NAME
}
if config.FunctionsConfig.NoActionDescriptionName == "" {
config.FunctionsConfig.NoActionDescriptionName = DEFAULT_NO_ACTION_DESCRIPTION
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
processFunctions := len(input.Functions) > 0 && config.ShouldUseFunctions()
if processFunctions {
log.Debug().Msgf("Response needs to process functions")
noActionGrammar := grammar.Function{
Name: config.FunctionsConfig.NoActionFunctionName,
Description: config.FunctionsConfig.NoActionDescriptionName,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("")
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
}
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
MessageIndex: messageIndex,
}
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
return config, predInput, processFunctions, nil
}
func EditGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
id := uuid.New().String()
created := int(time.Now().Unix())
binding := func(config *schema.Config) *string {
return &config.TemplateConfig.Edit
}
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
var result []schema.Choice
totalTokenUsage := TokenUsage{}
for _, i := range config.InputStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, config.TemplateConfig.Edit, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, tokenUsage, err := ComputeChoices(input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {
return nil, err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
return &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}, nil
}
func ChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
// DEFS
id := uuid.New().String()
created := int(time.Now().Unix())
// Prepare
config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s = utils.EscapeNewLines(s)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)
ss["arguments"] = string(d)
ss["name"] = func_name
// if do nothing, reply with a message
if func_name == config.FunctionsConfig.NoActionFunctionName {
log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(d), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = Finetune(*config, predInput, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
return
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := ModelInference(input.Context, predInput, images, ml, *config, startupOptions, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
fineTunedResponse := Finetune(*config, predInput, prediction.Response)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
} else {
// otherwise reply with the function call
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
})
}
return
}
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return nil, err
}
return &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}, nil
}
func CompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
// Prepare
id := uuid.New().String()
created := int(time.Now().Unix())
binding := func(config *schema.Config) *string {
return &config.TemplateConfig.Completion
}
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
var result []schema.Choice
totalTokenUsage := TokenUsage{}
for k, i := range config.PromptStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, tokenUsage, err := ComputeChoices(
input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return nil, err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
return &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}, nil
}
func StreamingChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) {
// DEFS
emptyMessage := ""
id := uuid.New().String()
created := int(time.Now().Unix())
// Prepare
config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
if processFunctions {
// TODO: unused variable means I did something wrong. investigate once stable
log.Debug().Msgf("StreamingChatGenerationOpenAIRequest with processFunctions=true for %s?", config.Name)
}
processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
return true
})
close(responses)
}
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to create response channel")
responses := make(chan schema.OpenAIResponse)
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to start processor goroutine")
go processor(predInput, input, config, ml, responses)
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: DONE! successfully returning to caller!")
return responses, nil
}
func StreamingCompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) {
// DEFS
id := uuid.New().String()
created := int(time.Now().Unix())
binding := func(config *schema.Config) *string {
return &config.TemplateConfig.Completion
}
// Prepare
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
if err != nil {
return nil, err
}
processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
Text: s,
},
},
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
log.Debug().Msgf("Sending goroutine: %s", s)
responses <- resp
return true
})
close(responses)
}
if len(config.PromptStrings) > 1 {
return nil, errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
predInput := config.PromptStrings[0]
//A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to create response channel")
responses := make(chan schema.OpenAIResponse)
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to start processor goroutine")
go processor(predInput, input, config, ml, responses)
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: DONE! successfully returning to caller!")
return responses, nil
}

View File

@ -5,13 +5,11 @@ import (
"path/filepath" "path/filepath"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
) )
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option { func modelOpts(c schema.Config, o *schema.StartupOptions, opts []model.Option) []model.Option {
if o.SingleBackend { if o.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend()) opts = append(opts, model.WithSingleActiveBackend())
} }
@ -35,7 +33,7 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.
return opts return opts
} }
func gRPCModelOpts(c config.Config) *pb.ModelOptions { func gRPCModelOpts(c schema.Config) *pb.ModelOptions {
b := 512 b := 512
if c.Batch != 0 { if c.Batch != 0 {
b = c.Batch b = c.Batch
@ -82,7 +80,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
} }
} }
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { func gRPCPredictOpts(c schema.Config, modelPath string) *pb.PredictOptions {
promptCachePath := "" promptCachePath := ""
if c.PromptCachePath != "" { if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath) p := filepath.Join(modelPath, c.PromptCachePath)

View File

@ -0,0 +1,52 @@
package backend
import (
"context"
"fmt"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
)
func ModelTranscription(audio, language string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (*schema.WhisperResult, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
})
whisperModel, err := loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if whisperModel == nil {
return nil, fmt.Errorf("could not load whisper model")
}
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
})
}
func TranscriptionOpenAIRequest(modelName string, input *schema.OpenAIRequest, audioFilePath string, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.WhisperResult, error) {
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
if err != nil {
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
}
tr, err := ModelTranscription(audioFilePath, input.Language, ml, *config, startupOptions)
if err != nil {
return nil, err
}
return tr, nil
}

View File

@ -6,10 +6,9 @@ import (
"os" "os"
"path/filepath" "path/filepath"
api_config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
) )
@ -29,18 +28,19 @@ func generateUniqueFileName(dir, baseName, ext string) string {
} }
} }
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) { func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *schema.StartupOptions) (string, *proto.Result, error) {
bb := backend bb := backend
if bb == "" { if bb == "" {
bb = model.PiperBackend bb = model.PiperBackend
} }
opts := modelOpts(api_config.Config{}, o, []model.Option{ opts := modelOpts(schema.Config{}, o, []model.Option{
model.WithBackendString(bb), model.WithBackendString(bb),
model.WithModel(modelFile), model.WithModel(modelFile),
model.WithContext(o.Context), model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(o.AssetsDestination),
model.WithExternalBackends(o.ExternalGRPCBackends, false),
}) })
piperModel, err := o.Loader.BackendLoader(opts...) piperModel, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -60,8 +60,8 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt
modelPath := "" modelPath := ""
if modelFile != "" { if modelFile != "" {
if bb != model.TransformersMusicGen { if bb != model.TransformersMusicGen {
modelPath = filepath.Join(o.Loader.ModelPath, modelFile) modelPath = filepath.Join(o.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { if err := utils.VerifyPath(modelPath, o.ModelPath); err != nil {
return "", nil, err return "", nil, err
} }
} else { } else {

169
core/http/api.go Normal file
View File

@ -0,0 +1,169 @@
package http
import (
"errors"
"strings"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
)
func App(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*fiber.App, error) {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.DisableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
},
})
if options.Debug {
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
// Default middleware config
app.Use(recover.New())
if options.Metrics != nil {
app.Use(localai.MetricsAPIMiddleware(options.Metrics))
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(options.ApiKeys) == 0 {
return c.Next()
}
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range options.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if options.CORS {
var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
}
app.Use(c)
}
// LocalAI API endpoints
galleryService := services.NewGalleryApplier(options.ModelPath)
galleryService.Start(options.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
modelGalleryService := localai.CreateModelGalleryEndpointService(options.Galleries, options.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, options))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, options))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, options))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, options))
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, options))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, options))
if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
}
if options.AudioDir != "" {
app.Static("/generated-audio", options.AudioDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
app.Get("/metrics", localai.MetricsHandler())
backendMonitor := services.NewBackendMonitor(cl, ml, options)
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
// model listing
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
return app, nil
}

View File

@ -1,4 +1,4 @@
package api_test package http_test
import ( import (
"bytes" "bytes"
@ -13,11 +13,12 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
. "github.com/go-skynet/LocalAI/api" server "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
@ -118,16 +119,15 @@ var backendAssets embed.FS
var _ = Describe("API test", func() { var _ = Describe("API test", func() {
var app *fiber.App var app *fiber.App
var modelLoader *model.ModelLoader
var client *openai.Client var client *openai.Client
var client2 *openaigo.Client var client2 *openaigo.Client
var c context.Context var c context.Context
var cancel context.CancelFunc var cancel context.CancelFunc
var tmpdir string var tmpdir string
commonOpts := []options.AppOption{ commonOpts := []schema.AppOption{
options.WithDebug(true), schema.WithDebug(true),
options.WithDisableMessage(true), schema.WithDisableMessage(true),
} }
Context("API with ephemeral models", func() { Context("API with ephemeral models", func() {
@ -136,7 +136,6 @@ var _ = Describe("API test", func() {
tmpdir, err = os.MkdirTemp("", "") tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
g := []gallery.GalleryModel{ g := []gallery.GalleryModel{
@ -163,15 +162,20 @@ var _ = Describe("API test", func() {
}, },
} }
metricsService, err := metrics.SetupMetrics() metricsService, err := services.SetupMetrics()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = App( cl, ml, options, err := startup.Startup(
append(commonOpts, append(commonOpts,
options.WithMetrics(metricsService), schema.WithMetrics(metricsService),
options.WithContext(c), schema.WithContext(c),
options.WithGalleries(galleries), schema.WithGalleries(galleries),
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...) schema.WithModelPath(tmpdir),
schema.WithBackendAssets(backendAssets),
schema.WithBackendAssetsOutput(tmpdir))...)
Expect(err).ToNot(HaveOccurred())
app, err = server.App(cl, ml, options)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
@ -475,7 +479,6 @@ var _ = Describe("API test", func() {
tmpdir, err = os.MkdirTemp("", "") tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
galleries := []gallery.Gallery{ galleries := []gallery.Gallery{
@ -485,21 +488,22 @@ var _ = Describe("API test", func() {
}, },
} }
metricsService, err := metrics.SetupMetrics() metricsService, err := services.SetupMetrics()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = App( cl, ml, options, err := startup.Startup(
append(commonOpts, append(commonOpts,
options.WithContext(c), schema.WithContext(c),
options.WithMetrics(metricsService), schema.WithMetrics(metricsService),
options.WithAudioDir(tmpdir), schema.WithAudioDir(tmpdir),
options.WithImageDir(tmpdir), schema.WithImageDir(tmpdir),
options.WithGalleries(galleries), schema.WithGalleries(galleries),
options.WithModelLoader(modelLoader), schema.WithModelPath(tmpdir),
options.WithBackendAssets(backendAssets), schema.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(tmpdir))..., schema.WithBackendAssetsOutput(tmpdir))...,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = server.App(cl, ml, options)
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -590,20 +594,21 @@ var _ = Describe("API test", func() {
Context("API query", func() { Context("API query", func() {
BeforeEach(func() { BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
metricsService, err := metrics.SetupMetrics() metricsService, err := services.SetupMetrics()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = App( cl, ml, options, err := startup.Startup(
append(commonOpts, append(commonOpts,
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), schema.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
options.WithContext(c), schema.WithContext(c),
options.WithModelLoader(modelLoader), schema.WithModelPath(os.Getenv("MODELS_PATH")),
options.WithMetrics(metricsService), schema.WithMetrics(metricsService),
)...) )...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = server.App(cl, ml, options)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -802,20 +807,21 @@ var _ = Describe("API test", func() {
Context("Config file", func() { Context("Config file", func() {
BeforeEach(func() { BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
metricsService, err := metrics.SetupMetrics() metricsService, err := services.SetupMetrics()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = App( cl, ml, options, err := startup.Startup(
append(commonOpts, append(commonOpts,
options.WithContext(c), schema.WithContext(c),
options.WithMetrics(metricsService), schema.WithMetrics(metricsService),
options.WithModelLoader(modelLoader), schema.WithModelPath(os.Getenv("MODELS_PATH")),
options.WithConfigFile(os.Getenv("CONFIG_FILE")))..., schema.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = server.App(cl, ml, options)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")

View File

@ -1,4 +1,4 @@
package api_test package http_test
import ( import (
"testing" "testing"

View File

@ -0,0 +1,34 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
)
func BackendMonitorEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(resp)
}
}
func BackendShutdownEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}

View File

@ -0,0 +1,148 @@
package localai
import (
"encoding/json"
"fmt"
"slices"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
/// Endpoint Service
type ModelGalleryEndpointService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *services.GalleryApplier
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryApplier) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.GetAllStatus())
}
}
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- gallery.GalleryOp{
Req: input.GalleryModel,
Id: uuid.String(),
GalleryName: input.ID,
Galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

View File

@ -0,0 +1,42 @@
package localai
import (
"time"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func MetricsHandler() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metrics *schema.LocalAIMetrics
}
func MetricsAPIMiddleware(metrics *schema.LocalAIMetrics) fiber.Handler {
cfg := apiMiddlewareConfig{
metrics: metrics,
Filter: func(c *fiber.Ctx) bool {
return c.Path() == "/metrics"
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metrics.ObserveAPICall(method, path, elapsed)
return err
}
}

View File

@ -0,0 +1,25 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
)
func TTSEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, ml, so)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View File

@ -0,0 +1,97 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
func ChatEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) func(c *fiber.Ctx) error {
emptyMessage := ""
return func(c *fiber.Ctx) error {
modelName, input, err := readInput(c, startupOptions, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// The scary comment I feel like I forgot about along the way:
//
// functions are not supported in stream mode (yet?)
//
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
responses, err := backend.StreamingChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
if err != nil {
return fmt.Errorf("failed establishing streaming chat request :%w", err)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
id := ""
created := 0
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
id = ev.ID
created = ev.Created // Similarly, grab the ID and created from any / the last response so we can use it for the stop
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: "stop",
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
//////////////////////////////////////////
resp, err := backend.ChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
if err != nil {
return fmt.Errorf("error generating chat request: +%w", err)
}
respData, _ := json.Marshal(resp) // TODO this is only used for the debug log and costs performance. monitor this?
log.Debug().Msgf("Response: %s", respData)
return c.JSON(resp)
}
}

View File

@ -0,0 +1,91 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
return func(c *fiber.Ctx) error {
modelName, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("`input`: %+v", input)
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
responses, err := backend.StreamingCompletionGenerationOpenAIRequest(modelName, input, cl, ml, so)
if err != nil {
return fmt.Errorf("failed establishing streaming completion request :%w", err)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
for ev := range responses {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: "stop",
},
},
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
///////////
resp, err := backend.CompletionGenerationOpenAIRequest(modelName, input, cl, ml, so)
if err != nil {
return fmt.Errorf("error generating completion request: +%w", err)
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -0,0 +1,34 @@
package openai
import (
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func EditEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
resp, err := backend.EditGenerationOpenAIRequest(modelFile, input, cl, ml, so)
if err != nil {
return err
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -0,0 +1,35 @@
package openai
import (
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
resp, err := backend.EmbeddingOpenAIRequest(modelFile, input, cl, ml, so)
if err != nil {
return err
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -0,0 +1,48 @@
package openai
import (
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/images/create
/*
*
curl http://localhost:8080/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cute baby sea otter",
"n": 1,
"size": "512x512"
}'
*
*/
func ImageEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
resp, err := backend.ImageGenerationOpenAIRequest(modelName, input, cl, ml, so)
if err != nil {
return fmt.Errorf("error generating image request: +%w", err)
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@ -3,21 +3,21 @@ package openai
import ( import (
"regexp" "regexp"
config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/api/schema" "github.com/go-skynet/LocalAI/pkg/model"
model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
models, err := loader.ListModels() models, err := ml.ListModels()
if err != nil { if err != nil {
return err return err
} }
var mm map[string]interface{} = map[string]interface{}{} var mm map[string]interface{} = map[string]interface{}{}
dataModels := []schema.OpenAIModel{} openAIModels := []schema.OpenAIModel{}
var filterFn func(name string) bool var filterFn func(name string) bool
filter := c.Query("filter") filter := c.Query("filter")
@ -40,13 +40,13 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
excludeConfigured := c.QueryBool("excludeConfigured", true) excludeConfigured := c.QueryBool("excludeConfigured", true)
// Start with the known configurations // Start with the known configurations
for _, c := range cm.GetAllConfigs() { for _, c := range cl.GetAllConfigs() {
if excludeConfigured { if excludeConfigured {
mm[c.Model] = nil mm[c.Model] = nil
} }
if filterFn(c.Name) { if filterFn(c.Name) {
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) openAIModels = append(openAIModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
} }
} }
@ -54,7 +54,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
for _, m := range models { for _, m := range models {
// And only adds them if they shouldn't be skipped. // And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) { if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) openAIModels = append(openAIModels, schema.OpenAIModel{ID: m, Object: "model"})
} }
} }
@ -63,7 +63,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
Data []schema.OpenAIModel `json:"data"` Data []schema.OpenAIModel `json:"data"`
}{ }{
Object: "list", Object: "list",
Data: dataModels, Data: openAIModels,
}) })
} }
} }

View File

@ -0,0 +1,57 @@
package openai
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readInput(c *fiber.Ctx, o *schema.StartupOptions, ml *model.ModelLoader, randomModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
modelFile := input.Model
if c.Params("model") != "" {
modelFile = c.Params("model")
}
received, _ := json.Marshal(input)
log.Debug().Msgf("Request received: %s", string(received))
// Set model from bearer token, if available
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && ml.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelFile == "" && !bearerExists && randomModel {
models, _ := ml.ListModels()
if len(models) > 0 {
modelFile = models[0]
log.Debug().Msgf("No model specified, using: %s", modelFile)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", nil, fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelFile = bearer
}
return modelFile, input, nil
}

View File

@ -0,0 +1,49 @@
package openai
import (
"fmt"
"net/http"
"os"
"path"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName, input, err := readInput(c, so, ml, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
return err
}
dst, err := utils.CreateTempFileFromMultipartFile(file, "", "transcription") // 3rd param formerly whisper
if err != nil {
return err
}
log.Debug().Msgf("Audio file copied to: %+v", dst)
defer os.RemoveAll(path.Dir(dst))
tr, err := backend.TranscriptionOpenAIRequest(modelName, input, dst, cl, ml, so)
if err != nil {
return fmt.Errorf("error generating transcription request: +%w", err)
}
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(tr)
}
}

24
core/mqtt/manager.go Normal file
View File

@ -0,0 +1,24 @@
package mqtt
import (
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
)
// PLACEHOLDER DURING PART 1 OF THE REFACTOR
type MQTTManager struct {
configLoader *services.ConfigLoader
modelLoader *model.ModelLoader
startupOptions *schema.StartupOptions
}
func NewMQTTManager(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*MQTTManager, error) {
return &MQTTManager{
configLoader: cl,
modelLoader: ml,
startupOptions: options,
}, nil
}

View File

@ -0,0 +1,138 @@
package services
import (
"context"
"fmt"
"strings"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitor struct {
configLoader *ConfigLoader
modelLoader *model.ModelLoader
options *schema.StartupOptions // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *ConfigLoader, modelLoader *model.ModelLoader, options *schema.StartupOptions) *BackendMonitor {
return &BackendMonitor{
configLoader: configLoader,
modelLoader: modelLoader,
options: options,
}
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.modelLoader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &schema.BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bm.configLoader.GetConfig(modelName)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = modelName
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
if err != nil {
return nil, err
}
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
if modelAddr == "" {
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return &proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
}, nil
}
return status, nil
}
func (bm BackendMonitor) ShutdownModel(modelName string) error {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
if err != nil {
return err
}
return bm.modelLoader.ShutdownModel(backendId)
}

157
core/services/config.go Normal file
View File

@ -0,0 +1,157 @@
package services
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type ConfigLoader struct {
configs map[string]schema.Config
sync.Mutex
}
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
configs: make(map[string]schema.Config),
}
}
// TODO: check this is correct post-merge
func (cm *ConfigLoader) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := schema.ReadSingleConfigFile(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cm.configs[c.Name] = *c
return nil
}
func (cm *ConfigLoader) GetConfig(m string) (schema.Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
return v, exists
}
func (cm *ConfigLoader) GetAllConfigs() []schema.Config {
cm.Lock()
defer cm.Unlock()
var res []schema.Config
for _, v := range cm.configs {
res = append(res, v)
}
return res
}
func (cm *ConfigLoader) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
var res []string
for k := range cm.configs {
res = append(res, k)
}
return res
}
func (cm *ConfigLoader) LoadConfigs(path string) error {
cm.Lock()
defer cm.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return err
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
return err
}
files = append(files, info)
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
continue
}
c, err := schema.ReadSingleConfigFile(filepath.Join(path, file.Name()))
if err == nil {
cm.configs[c.Name] = *c
}
}
return nil
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}
log.Info().Msgf("Preloading models from %s", modelPath)
for _, config := range cm.configs {
// Download files and verify their SHA
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}
modelURL := config.PredictionOptions.Model
modelURL = utils.ConvertURL(modelURL)
if utils.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
}
}
}
return nil
}
func (cl *ConfigLoader) LoadConfigFile(file string) error {
cl.Lock()
defer cl.Unlock()
c, err := schema.ReadConfigFile(file)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cl.configs[cc.Name] = *cc
}
return nil
}

160
core/services/gallery.go Normal file
View File

@ -0,0 +1,160 @@
package services
import (
"context"
"encoding/json"
"os"
"strings"
"sync"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v2"
)
type GalleryApplier struct {
modelPath string
sync.Mutex
C chan gallery.GalleryOp
statuses map[string]*gallery.GalleryOpStatus
}
func NewGalleryApplier(modelPath string) *GalleryApplier {
return &GalleryApplier{
modelPath: modelPath,
C: make(chan gallery.GalleryOp),
statuses: make(map[string]*gallery.GalleryOpStatus),
}
}
func (g *GalleryApplier) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *GalleryApplier) GetStatus(s string) *gallery.GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *GalleryApplier) GetAllStatus() map[string]*gallery.GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *GalleryApplier) Start(c context.Context, cm *ConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryName != "" {
if strings.Contains(op.GalleryName, "@") {
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
}
} else {
err = PrepareModel(g.modelPath, op.Req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func PrepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetInstallableModelFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func processRequests(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = PrepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}

29
core/services/metrics.go Normal file
View File

@ -0,0 +1,29 @@
package services
import (
"github.com/go-skynet/LocalAI/pkg/schema"
"go.opentelemetry.io/otel/exporters/prometheus"
api "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/sdk/metric"
)
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
// If it does not return an error, make sure to call shutdown for proper cleanup.
func SetupMetrics() (*schema.LocalAIMetrics, error) {
exporter, err := prometheus.New()
if err != nil {
return nil, err
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
meter := provider.Meter("github.com/go-skynet/LocalAI")
apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls"))
if err != nil {
return nil, err
}
return &schema.LocalAIMetrics{
Meter: meter,
ApiTimeMetric: apiTimeMetric,
}, nil
}

View File

@ -0,0 +1,100 @@
package startup
import (
"encoding/json"
"fmt"
"os"
"path"
"github.com/fsnotify/fsnotify"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
)
type WatchConfigDirectoryCloser func() error
func ReadApiKeysJson(configDir string, options *schema.StartupOptions) error {
fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json"))
if err == nil {
// Parse JSON content from the file
var fileKeys []string
err := json.Unmarshal(fileContent, &fileKeys)
if err == nil {
options.ApiKeys = append(options.ApiKeys, fileKeys...)
return nil
}
return err
}
return err
}
func ReadExternalBackendsJson(configDir string, options *schema.StartupOptions) error {
fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json"))
if err != nil {
return err
}
// Parse JSON content from the file
var fileBackends map[string]string
err = json.Unmarshal(fileContent, &fileBackends)
if err != nil {
return err
}
err = mergo.Merge(&options.ExternalGRPCBackends, fileBackends)
if err != nil {
return err
}
return nil
}
var CONFIG_FILE_UPDATES = map[string]func(configDir string, options *schema.StartupOptions) error{
"api_keys.json": ReadApiKeysJson,
"external_backends.json": ReadExternalBackendsJson,
}
func WatchConfigDirectory(configDir string, options *schema.StartupOptions) (WatchConfigDirectoryCloser, error) {
if len(configDir) == 0 {
return nil, fmt.Errorf("configDir blank")
}
configWatcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err)
}
ret := func() error {
configWatcher.Close()
return nil
}
// Start listening for events.
go func() {
for {
select {
case event, ok := <-configWatcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) {
for targetName, watchFn := range CONFIG_FILE_UPDATES {
if event.Name == targetName {
err := watchFn(configDir, options)
log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err)
}
}
}
case _, ok := <-configWatcher.Errors:
if !ok {
return
}
log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err)
}
}
}()
// Add a path.
err = configWatcher.Add(configDir)
if err != nil {
return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err)
}
return ret, nil
}

93
core/startup/startup.go Normal file
View File

@ -0,0 +1,93 @@
package startup
import (
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func Startup(opts ...schema.AppOption) (*services.ConfigLoader, *model.ModelLoader, *schema.StartupOptions, error) {
options := schema.NewStartupOptions(opts...)
ml := model.NewModelLoader(options.ModelPath)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
cl := services.NewConfigLoader()
if err := cl.LoadConfigs(options.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if err := cl.Preload(options.ModelPath); err != nil {
log.Error().Msgf("error downloading models: %s", err.Error())
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, nil, err
}
}
if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
ml.StopAllGRPC()
}()
if options.WatchDog {
wd := model.NewWatchDog(
ml,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
ml.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}
return cl, ml, options, nil
}

View File

@ -17,6 +17,53 @@ This section will collect how-to, notes and development documentation
We use conventional commits and semantic versioning. Please follow the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) specification when writing commit messages. We use conventional commits and semantic versioning. Please follow the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) specification when writing commit messages.
## LocalAI Project Structure
**LocalAI is made of multiple components, developed in multiple repositories:**
The core repository, containing the primary `local-ai` server code, gRPC stubs, this documentation website, and docker container building resources are all located at [mudler/LocalAI](https://github.com/mudler/LocalAI).
As LocalAI is designed to make use of multiple, independent model galleries, those are maintained seperately. The following public model galleries are available for use:
* [go-skynet/model-gallery](https://github.com/go-skynet/model-gallery) - The original gallery, the `golang` huggingface scraper ran into limits and was largely retired, so this now holds handmade yaml configs
* [dave-gray101/model-gallery](https://github.com/dave-gray101/model-gallery) - An automated gallery designed to track HuggingFace uploads and produce best-effort automatically generated configurations for LocalAI. It is designed to produce one LocalAI gallery per repository on HuggingFace.
### Directory Structure of this Repo
The core repository is broken up into the following primary chunks:
* `/backend`: gRPC protobuf specification and gRPC backends. Subfolders for each language.
* **`/core`**: golang sourcecode for the core LocalAI application. Broken down below.
* `/docs`: localai.io website that you are reading now
* `/examples`: example code integrating LocalAI to other projects and/or developer samples and tools
* `/internal`: **here be dragons**. Don't touch this, it's used for automatic versioning.
* `/models`: _No code here!_ This is where models are installed!
* **`/pkg`**: golang sourcecode that is intended to be reusable or at least widely imported across LocalAI. Broken down below
* `/prompt-templates`: _No code here!_ This is where **example** prompt templates were historically stored. Somewhat obsolete these days, model-galleries tend to replace manually creating these?
* `/tests`: Does what it says on the tin. Please write tests and put them here when you do.
The `core` folder is broken down further:
* **`/core/backend`**: code that interacts with a gRPC backend to perform AI tasks.
* `/core/http`: code specifically related to the REST server
* `/core/http/endpoints`: Has two subdirectories, `openai` and `localai` for binding the respective endpoints to the correct backend or service.
* `/core/mqtt`: core specifically related to the MQTT server. Stub for now. Coming soon!
* **`/core/services`**: code implementing functionality performed by `local-ai` itself, rather than delegated to a backend.
* `/core/startup`: code related specifically to application startup of `local-ai`. Potentially to be refactored to become a part of `/core/services` at a later date, or not.
The `pkg` folder is broken down further:
* `/pkg/assets`: Currently contains a single function related to extracting files from archives. Potentially to be refactored to become a part of `/core/utils` at a later date?
* **`/pkg/datamodel`**: Contains the data types and definitions used by the LocalAI project. Imported widely!
* `/pkg/gallery`: Code related to interacting with a `model-gallery`
* `/pkg/grammar`: Code related to BNF / functions for LLM
* `/pkg/grpc`: base classes and interfaces for gRPC backends to implement
* `/pkg/langchain`: langchain related code in golang
* **`/pkg/model`**: Code related to loading and initializing a model and creating the appropriate gRPC backend.
* `/pkg/stablediffusion`: Code related to stablediffusion in golang.
* `/pkg/utils`: Every real programmer knows what they are going to find in here... it's our junk drawer of utility functions.
## Creating a gRPC backend ## Creating a gRPC backend
LocalAI backends are `gRPC` servers. LocalAI backends are `gRPC` servers.

View File

@ -20,7 +20,7 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
Returns an `audio/wav` file. Returns an `audio/wav` file.
#### Setup #### Text-To-Speech Setup
LocalAI supports [bark]({{%relref "model-compatibility/bark" %}}) , `piper` and `vall-e-x`: LocalAI supports [bark]({{%relref "model-compatibility/bark" %}}) , `piper` and `vall-e-x`:
@ -52,6 +52,8 @@ Note:
- The model name is case sensitive. - The model name is case sensitive.
- LocalAI must be compiled with the `GO_TAGS=tts` flag. - LocalAI must be compiled with the `GO_TAGS=tts` flag.
#### Music
LocalAI also has experimental support for `transformers-musicgen` for the generation of short musical compositions. Currently, this is implemented via the same requests used for text to speech: LocalAI also has experimental support for `transformers-musicgen` for the generation of short musical compositions. Currently, this is implemented via the same requests used for text to speech:
``` ```
@ -62,7 +64,8 @@ curl --request POST \
"backend": "transformers-musicgen", "backend": "transformers-musicgen",
"model": "facebook/musicgen-medium", "model": "facebook/musicgen-medium",
"input": "Cello Rave" "input": "Cello Rave"
}' | aplay``` }' | aplay
```
Future versions of LocalAI will expose additional control over audio generation beyond the text prompt. Future versions of LocalAI will expose additional control over audio generation beyond the text prompt.

4
go.mod
View File

@ -5,6 +5,7 @@ go 1.21
require ( require (
github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
github.com/fsnotify/fsnotify v1.7.0
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
github.com/go-audio/wav v1.1.0 github.com/go-audio/wav v1.1.0
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
@ -15,7 +16,6 @@ require (
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
github.com/hpcloud/tail v1.0.0 github.com/hpcloud/tail v1.0.0
github.com/imdario/mergo v0.3.16 github.com/imdario/mergo v0.3.16
github.com/json-iterator/go v1.1.12
github.com/mholt/archiver/v3 v3.5.1 github.com/mholt/archiver/v3 v3.5.1
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
@ -63,8 +63,6 @@ require (
github.com/klauspost/pgzip v1.2.5 // indirect github.com/klauspost/pgzip v1.2.5 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/nwaples/rardecode v1.1.0 // indirect github.com/nwaples/rardecode v1.1.0 // indirect
github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pierrec/lz4/v4 v4.1.2 // indirect
github.com/pkoukk/tiktoken-go v0.1.2 // indirect github.com/pkoukk/tiktoken-go v0.1.2 // indirect

11
go.sum
View File

@ -24,8 +24,9 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
@ -74,7 +75,6 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
@ -88,8 +88,6 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw=
github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
@ -119,11 +117,6 @@ github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Cl
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4= github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI= github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI=

120
main.go
View File

@ -12,14 +12,14 @@ import (
"syscall" "syscall"
"time" "time"
api "github.com/go-skynet/LocalAI/api" "github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/api/backend" "github.com/go-skynet/LocalAI/core/http"
config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/schema"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
progressbar "github.com/schollz/progressbar/v3" progressbar "github.com/schollz/progressbar/v3"
@ -190,6 +190,12 @@ func main() {
EnvVars: []string{"PRELOAD_BACKEND_ONLY"}, EnvVars: []string{"PRELOAD_BACKEND_ONLY"},
Value: false, Value: false,
}, },
&cli.StringFlag{
Name: "localai-config-dir",
Usage: "Directory to use for the configuration files of LocalAI itself. This is NOT where model files should be placed.",
EnvVars: []string{"LOCALAI_CONFIG_DIR"},
Value: "./config",
},
}, },
Description: ` Description: `
LocalAI is a drop-in replacement OpenAI API which runs inference locally. LocalAI is a drop-in replacement OpenAI API which runs inference locally.
@ -208,54 +214,54 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
UsageText: `local-ai [options]`, UsageText: `local-ai [options]`,
Copyright: "Ettore Di Giacinto", Copyright: "Ettore Di Giacinto",
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
opts := []options.AppOption{ opts := []schema.AppOption{
options.WithConfigFile(ctx.String("config-file")), schema.WithConfigFile(ctx.String("config-file")),
options.WithJSONStringPreload(ctx.String("preload-models")), schema.WithJSONStringPreload(ctx.String("preload-models")),
options.WithYAMLConfigPreload(ctx.String("preload-models-config")), schema.WithYAMLConfigPreload(ctx.String("preload-models-config")),
options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), schema.WithModelPath(ctx.String("models-path")),
options.WithContextSize(ctx.Int("context-size")), schema.WithContextSize(ctx.Int("context-size")),
options.WithDebug(ctx.Bool("debug")), schema.WithDebug(ctx.Bool("debug")),
options.WithImageDir(ctx.String("image-path")), schema.WithImageDir(ctx.String("image-path")),
options.WithAudioDir(ctx.String("audio-path")), schema.WithAudioDir(ctx.String("audio-path")),
options.WithF16(ctx.Bool("f16")), schema.WithF16(ctx.Bool("f16")),
options.WithStringGalleries(ctx.String("galleries")), schema.WithStringGalleries(ctx.String("galleries")),
options.WithDisableMessage(false), schema.WithDisableMessage(false),
options.WithCors(ctx.Bool("cors")), schema.WithCors(ctx.Bool("cors")),
options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), schema.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
options.WithThreads(ctx.Int("threads")), schema.WithThreads(ctx.Int("threads")),
options.WithBackendAssets(backendAssets), schema.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), schema.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
options.WithUploadLimitMB(ctx.Int("upload-limit")), schema.WithUploadLimitMB(ctx.Int("upload-limit")),
options.WithApiKeys(ctx.StringSlice("api-keys")), schema.WithApiKeys(ctx.StringSlice("api-keys")),
options.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...), schema.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...),
} }
idleWatchDog := ctx.Bool("enable-watchdog-idle") idleWatchDog := ctx.Bool("enable-watchdog-idle")
busyWatchDog := ctx.Bool("enable-watchdog-busy") busyWatchDog := ctx.Bool("enable-watchdog-busy")
if idleWatchDog || busyWatchDog { if idleWatchDog || busyWatchDog {
opts = append(opts, options.EnableWatchDog) opts = append(opts, schema.EnableWatchDog)
if idleWatchDog { if idleWatchDog {
opts = append(opts, options.EnableWatchDogIdleCheck) opts = append(opts, schema.EnableWatchDogIdleCheck)
dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout")) dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout"))
if err != nil { if err != nil {
return err return err
} }
opts = append(opts, options.SetWatchDogIdleTimeout(dur)) opts = append(opts, schema.SetWatchDogIdleTimeout(dur))
} }
if busyWatchDog { if busyWatchDog {
opts = append(opts, options.EnableWatchDogBusyCheck) opts = append(opts, schema.EnableWatchDogBusyCheck)
dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout")) dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout"))
if err != nil { if err != nil {
return err return err
} }
opts = append(opts, options.SetWatchDogBusyTimeout(dur)) opts = append(opts, schema.SetWatchDogBusyTimeout(dur))
} }
} }
if ctx.Bool("parallel-requests") { if ctx.Bool("parallel-requests") {
opts = append(opts, options.EnableParallelBackendRequests) opts = append(opts, schema.EnableParallelBackendRequests)
} }
if ctx.Bool("single-active-backend") { if ctx.Bool("single-active-backend") {
opts = append(opts, options.EnableSingleBackend) opts = append(opts, schema.EnableSingleBackend)
} }
externalgRPC := ctx.StringSlice("external-grpc-backends") externalgRPC := ctx.StringSlice("external-grpc-backends")
@ -263,30 +269,42 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
for _, v := range externalgRPC { for _, v := range externalgRPC {
backend := v[:strings.IndexByte(v, ':')] backend := v[:strings.IndexByte(v, ':')]
uri := v[strings.IndexByte(v, ':')+1:] uri := v[strings.IndexByte(v, ':')+1:]
opts = append(opts, options.WithExternalBackend(backend, uri)) opts = append(opts, schema.WithExternalBackend(backend, uri))
} }
if ctx.Bool("autoload-galleries") { if ctx.Bool("autoload-galleries") {
opts = append(opts, options.EnableGalleriesAutoload) opts = append(opts, schema.EnableGalleriesAutoload)
} }
if ctx.Bool("preload-backend-only") { if ctx.Bool("preload-backend-only") {
_, _, err := api.Startup(opts...) _, _, _, err := startup.Startup(opts...)
return err return err
} }
metrics, err := metrics.SetupMetrics() metrics, err := services.SetupMetrics()
if err != nil { if err != nil {
return err return err
} }
opts = append(opts, options.WithMetrics(metrics)) opts = append(opts, schema.WithMetrics(metrics))
app, err := api.App(opts...) cl, ml, options, err := startup.Startup(opts...)
if err != nil {
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
closeConfigWatcherFn, err := startup.WatchConfigDirectory(ctx.String("localai-config-dir"), options)
defer closeConfigWatcherFn()
if err != nil {
return fmt.Errorf("failed while watching configuration directory %s", ctx.String("localai-config-dir"))
}
appHTTP, err := http.App(cl, ml, options)
if err != nil { if err != nil {
return err return err
} }
return app.Listen(ctx.String("address")) return appHTTP.Listen(ctx.String("address"))
}, },
Commands: []*cli.Command{ Commands: []*cli.Command{
{ {
@ -384,16 +402,18 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
text := strings.Join(ctx.Args().Slice(), " ") text := strings.Join(ctx.Args().Slice(), " ")
opts := &options.Option{ opts := &schema.StartupOptions{
Loader: model.NewModelLoader(ctx.String("models-path")), ModelPath: ctx.String("models-path"),
Context: context.Background(), Context: context.Background(),
AudioDir: outputDir, AudioDir: outputDir,
AssetsDestination: ctx.String("backend-assets-path"), AssetsDestination: ctx.String("backend-assets-path"),
} }
defer opts.Loader.StopAllGRPC() loader := model.NewModelLoader(opts.ModelPath)
filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, opts.Loader, opts) defer loader.StopAllGRPC()
filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, loader, opts)
if err != nil { if err != nil {
return err return err
} }
@ -446,13 +466,15 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
language := ctx.String("language") language := ctx.String("language")
threads := ctx.Int("threads") threads := ctx.Int("threads")
opts := &options.Option{ opts := &schema.StartupOptions{
Loader: model.NewModelLoader(ctx.String("models-path")), ModelPath: ctx.String("models-path"),
Context: context.Background(), Context: context.Background(),
AssetsDestination: ctx.String("backend-assets-path"), AssetsDestination: ctx.String("backend-assets-path"),
} }
cl := config.NewConfigLoader() ml := model.NewModelLoader(opts.ModelPath)
cl := services.NewConfigLoader()
if err := cl.LoadConfigs(ctx.String("models-path")); err != nil { if err := cl.LoadConfigs(ctx.String("models-path")); err != nil {
return err return err
} }
@ -464,9 +486,9 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
c.Threads = threads c.Threads = threads
defer opts.Loader.StopAllGRPC() defer ml.StopAllGRPC()
tr, err := backend.ModelTranscription(filename, language, opts.Loader, c, opts) tr, err := backend.ModelTranscription(filename, language, ml, c, opts)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,83 +0,0 @@
package metrics
import (
"context"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/prometheus"
api "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/sdk/metric"
)
type Metrics struct {
meter api.Meter
apiTimeMetric api.Float64Histogram
}
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
// If it does not return an error, make sure to call shutdown for proper cleanup.
func SetupMetrics() (*Metrics, error) {
exporter, err := prometheus.New()
if err != nil {
return nil, err
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
meter := provider.Meter("github.com/go-skynet/LocalAI")
apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls"))
if err != nil {
return nil, err
}
return &Metrics{
meter: meter,
apiTimeMetric: apiTimeMetric,
}, nil
}
func MetricsHandler() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metrics *Metrics
}
func APIMiddleware(metrics *Metrics) fiber.Handler {
cfg := apiMiddlewareConfig{
metrics: metrics,
Filter: func(c *fiber.Ctx) bool {
if c.Path() == "/metrics" {
return true
}
return false
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metrics.ObserveAPICall(method, path, elapsed)
return err
}
}
func (m *Metrics) ObserveAPICall(method string, path string, duration float64) {
opts := api.WithAttributes(
attribute.String("method", method),
attribute.String("path", path),
)
m.apiTimeMetric.Record(context.Background(), duration, opts)
}

View File

@ -22,11 +22,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
applyModel := func(model *GalleryModel) error { applyModel := func(model *GalleryModel) error {
name = strings.ReplaceAll(name, string(os.PathSeparator), "__") name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
var config Config var config InstallableModel
if len(model.URL) > 0 { if len(model.URL) > 0 {
var err error var err error
config, err = GetGalleryConfigFromURL(model.URL) config, err = GetInstallableModelFromURL(model.URL)
if err != nil { if err != nil {
return err return err
} }
@ -36,7 +36,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
if err != nil { if err != nil {
return err return err
} }
config = Config{ config = InstallableModel{
ConfigFile: string(reYamlConfig), ConfigFile: string(reYamlConfig),
Description: model.Description, Description: model.Description,
License: model.License, License: model.License,

View File

@ -1,13 +1,9 @@
package gallery package gallery
import ( import (
"crypto/sha256"
"fmt" "fmt"
"hash"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo" "github.com/imdario/mergo"
@ -41,9 +37,9 @@ prompt_templates:
content: "" content: ""
*/ */
// Config is the model configuration which contains all the model details // InstallableModel is the model configuration which contains all the model details
// This configuration is read from the gallery endpoint and is used to download and install the model // This configuration is read from the gallery endpoint and is used to download and install the model
type Config struct { type InstallableModel struct {
Description string `yaml:"description"` Description string `yaml:"description"`
License string `yaml:"license"` License string `yaml:"license"`
URLs []string `yaml:"urls"` URLs []string `yaml:"urls"`
@ -64,8 +60,8 @@ type PromptTemplate struct {
Content string `yaml:"content"` Content string `yaml:"content"`
} }
func GetGalleryConfigFromURL(url string) (Config, error) { func GetInstallableModelFromURL(url string) (InstallableModel, error) {
var config Config var config InstallableModel
err := utils.GetURI(url, func(url string, d []byte) error { err := utils.GetURI(url, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config) return yaml.Unmarshal(d, &config)
}) })
@ -76,7 +72,7 @@ func GetGalleryConfigFromURL(url string) (Config, error) {
return config, nil return config, nil
} }
func ReadConfigFile(filePath string) (*Config, error) { func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
// Read the YAML file // Read the YAML file
yamlFile, err := os.ReadFile(filePath) yamlFile, err := os.ReadFile(filePath)
if err != nil { if err != nil {
@ -84,7 +80,7 @@ func ReadConfigFile(filePath string) (*Config, error) {
} }
// Unmarshal YAML data into a Config struct // Unmarshal YAML data into a Config struct
var config Config var config InstallableModel
err = yaml.Unmarshal(yamlFile, &config) err = yaml.Unmarshal(yamlFile, &config)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
@ -93,7 +89,7 @@ func ReadConfigFile(filePath string) (*Config, error) {
return &config, nil return &config, nil
} }
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { func InstallModel(basePath, nameOverride string, config *InstallableModel, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist // Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755) err := os.MkdirAll(basePath, 0755)
if err != nil { if err != nil {
@ -183,54 +179,3 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
return nil return nil
} }
type progressWriter struct {
fileName string
total int64
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.hash.Write(p)
pw.written += int64(n)
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
} else {
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
}
return
}
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + " B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
func calculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

View File

@ -16,7 +16,7 @@ var _ = Describe("Model test", func() {
tempdir, err := os.MkdirTemp("", "test") tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir) defer os.RemoveAll(tempdir)
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
@ -87,7 +87,7 @@ var _ = Describe("Model test", func() {
tempdir, err := os.MkdirTemp("", "test") tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir) defer os.RemoveAll(tempdir)
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
@ -103,7 +103,7 @@ var _ = Describe("Model test", func() {
tempdir, err := os.MkdirTemp("", "test") tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir) defer os.RemoveAll(tempdir)
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
@ -129,7 +129,7 @@ var _ = Describe("Model test", func() {
tempdir, err := os.MkdirTemp("", "test") tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir) defer os.RemoveAll(tempdir)
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})

18
pkg/gallery/op.go Normal file
View File

@ -0,0 +1,18 @@
package gallery
type GalleryOp struct {
Req GalleryModel
Id string
Galleries []Gallery
GalleryName string
}
type GalleryOpStatus struct {
FileName string `json:"file_name"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}

View File

@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() {
Context("requests", func() { Context("requests", func() {
It("parses github with a branch", func() { It("parses github with a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
e, err := GetGalleryConfigFromURL(req.URL) e, err := GetInstallableModelFromURL(req.URL)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(e.Name).To(Equal("gpt4all-j")) Expect(e.Name).To(Equal("gpt4all-j"))
}) })

View File

@ -6,8 +6,8 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/go-skynet/LocalAI/api/schema"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/schema"
gopsutil "github.com/shirou/gopsutil/v3/process" gopsutil "github.com/shirou/gopsutil/v3/process"
) )
@ -53,8 +53,9 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented") return fmt.Errorf("unimplemented")
} }
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { // TODO CHECK THIS
return schema.Result{}, fmt.Errorf("unimplemented") func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error) {
return schema.WhisperResult{}, fmt.Errorf("unimplemented")
} }
func (llm *Base) TTS(*pb.TTSRequest) error { func (llm *Base) TTS(*pb.TTSRequest) error {

View File

@ -7,8 +7,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-skynet/LocalAI/api/schema"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/schema"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
) )
@ -223,7 +223,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
return client.TTS(ctx, in, opts...) return client.TTS(ctx, in, opts...)
} }
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.WhisperResult, error) {
if !c.parallel { if !c.parallel {
c.opMutex.Lock() c.opMutex.Lock()
defer c.opMutex.Unlock() defer c.opMutex.Unlock()
@ -244,14 +244,14 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
if err != nil { if err != nil {
return nil, err return nil, err
} }
tresult := &schema.Result{} tresult := &schema.WhisperResult{}
for _, s := range res.Segments { for _, s := range res.Segments {
tks := []int{} tks := []int{}
for _, t := range s.Tokens { for _, t := range s.Tokens {
tks = append(tks, int(t)) tks = append(tks, int(t))
} }
tresult.Segments = append(tresult.Segments, tresult.Segments = append(tresult.Segments,
schema.Segment{ schema.WhisperSegment{
Text: s.Text, Text: s.Text,
Id: int(s.Id), Id: int(s.Id),
Start: time.Duration(s.Start), Start: time.Duration(s.Start),

View File

@ -1,8 +1,8 @@
package grpc package grpc
import ( import (
"github.com/go-skynet/LocalAI/api/schema"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/schema"
) )
type LLM interface { type LLM interface {
@ -15,7 +15,7 @@ type LLM interface {
Load(*pb.ModelOptions) error Load(*pb.ModelOptions) error
Embeddings(*pb.PredictOptions) ([]float32, error) Embeddings(*pb.PredictOptions) ([]float32, error)
GenerateImage(*pb.GenerateImageRequest) error GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error)
TTS(*pb.TTSRequest) error TTS(*pb.TTSRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error) Status() (pb.StatusResponse, error)

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.28.1 // protoc-gen-go v1.26.0
// protoc v3.6.1 // protoc v4.26.0
// source: backend.proto // source: backend.proto
package proto package proto

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.2.0 // - protoc-gen-go-grpc v1.3.0
// - protoc v3.6.1 // - protoc v4.26.0
// source: backend.proto // source: backend.proto
package proto package proto
@ -18,6 +18,19 @@ import (
// Requires gRPC-Go v1.32.0 or later. // Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7 const _ = grpc.SupportPackageIsVersion7
const (
Backend_Health_FullMethodName = "/backend.Backend/Health"
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
Backend_Status_FullMethodName = "/backend.Backend/Status"
)
// BackendClient is the client API for Backend service. // BackendClient is the client API for Backend service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
@ -44,7 +57,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply) out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -53,7 +66,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply) out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,7 +75,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +83,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
} }
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -103,7 +116,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
out := new(EmbeddingResult) out := new(EmbeddingResult)
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,7 +125,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,7 +134,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
out := new(TranscriptResult) out := new(TranscriptResult)
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -130,7 +143,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -139,7 +152,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) { func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
out := new(TokenizationResponse) out := new(TokenizationResponse)
err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...) err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -148,7 +161,7 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions,
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) { func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
out := new(StatusResponse) out := new(StatusResponse)
err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...) err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -229,7 +242,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/Health", FullMethod: Backend_Health_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
@ -247,7 +260,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/Predict", FullMethod: Backend_Predict_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
@ -265,7 +278,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/LoadModel", FullMethod: Backend_LoadModel_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
@ -304,7 +317,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/Embedding", FullMethod: Backend_Embedding_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
@ -322,7 +335,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/GenerateImage", FullMethod: Backend_GenerateImage_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
@ -340,7 +353,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/AudioTranscription", FullMethod: Backend_AudioTranscription_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
@ -358,7 +371,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/TTS", FullMethod: Backend_TTS_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
@ -376,7 +389,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/TokenizeString", FullMethod: Backend_TokenizeString_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions)) return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
@ -394,7 +407,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/Status", FullMethod: Backend_Status_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Status(ctx, req.(*HealthMessage)) return srv.(BackendServer).Status(ctx, req.(*HealthMessage))

View File

@ -8,7 +8,7 @@ import (
"strings" "strings"
"time" "time"
grpc "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/phayes/freeport" "github.com/phayes/freeport"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -71,7 +71,7 @@ var AutoLoadBackends []string = []string{
// starts the grpcModelProcess for the backend, and returns a grpc client // starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model // It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) { func (ml *ModelLoader) grpcModel(backend string, o *ModelOptions) func(string, string) (ModelAddress, error) {
return func(modelName, modelFile string) (ModelAddress, error) { return func(modelName, modelFile string) (ModelAddress, error) {
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o) log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)

View File

@ -10,7 +10,7 @@ import (
"sync" "sync"
"text/template" "text/template"
grammar "github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
process "github.com/mudler/go-processmanager" process "github.com/mudler/go-processmanager"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"

View File

@ -6,7 +6,7 @@ import (
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
) )
type Options struct { type ModelOptions struct {
backendString string backendString string
model string model string
threads uint32 threads uint32
@ -23,14 +23,14 @@ type Options struct {
parallelRequests bool parallelRequests bool
} }
type Option func(*Options) type Option func(*ModelOptions)
var EnableParallelRequests = func(o *Options) { var EnableParallelRequests = func(o *ModelOptions) {
o.parallelRequests = true o.parallelRequests = true
} }
func WithExternalBackend(name string, uri string) Option { func WithExternalBackend(name string, uri string) Option {
return func(o *Options) { return func(o *ModelOptions) {
if o.externalBackends == nil { if o.externalBackends == nil {
o.externalBackends = make(map[string]string) o.externalBackends = make(map[string]string)
} }
@ -38,62 +38,81 @@ func WithExternalBackend(name string, uri string) Option {
} }
} }
// Currently, LocalAI isn't ready for backends to be yanked out from under it - so this is a little overcomplicated to allow non-overwriting updates
func WithExternalBackends(backends map[string]string, overwrite bool) Option {
return func(o *ModelOptions) {
if backends == nil {
return
}
if o.externalBackends == nil {
o.externalBackends = backends
return
}
for name, url := range backends {
_, exists := o.externalBackends[name]
if !exists || overwrite {
o.externalBackends[name] = url
}
}
}
}
func WithGRPCAttempts(attempts int) Option { func WithGRPCAttempts(attempts int) Option {
return func(o *Options) { return func(o *ModelOptions) {
o.grpcAttempts = attempts o.grpcAttempts = attempts
} }
} }
func WithGRPCAttemptsDelay(delay int) Option { func WithGRPCAttemptsDelay(delay int) Option {
return func(o *Options) { return func(o *ModelOptions) {
o.grpcAttemptsDelay = delay o.grpcAttemptsDelay = delay
} }
} }
func WithBackendString(backend string) Option { func WithBackendString(backend string) Option {
return func(o *Options) { return func(o *ModelOptions) {
o.backendString = backend o.backendString = backend
} }
} }
func WithModel(modelFile string) Option { func WithModel(modelFile string) Option {
return func(o *Options) { return func(o *ModelOptions) {
o.model = modelFile o.model = modelFile
} }
} }
func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option { func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
return func(o *Options) { return func(o *ModelOptions) {
o.gRPCOptions = opts o.gRPCOptions = opts
} }
} }
func WithThreads(threads uint32) Option { func WithThreads(threads uint32) Option {
return func(o *Options) { return func(o *ModelOptions) {
o.threads = threads o.threads = threads
} }
} }
func WithAssetDir(assetDir string) Option { func WithAssetDir(assetDir string) Option {
return func(o *Options) { return func(o *ModelOptions) {
o.assetDir = assetDir o.assetDir = assetDir
} }
} }
func WithContext(ctx context.Context) Option { func WithContext(ctx context.Context) Option {
return func(o *Options) { return func(o *ModelOptions) {
o.context = ctx o.context = ctx
} }
} }
func WithSingleActiveBackend() Option { func WithSingleActiveBackend() Option {
return func(o *Options) { return func(o *ModelOptions) {
o.singleActiveBackend = true o.singleActiveBackend = true
} }
} }
func NewOptions(opts ...Option) *Options { func NewOptions(opts ...Option) *ModelOptions {
o := &Options{ o := &ModelOptions{
gRPCOptions: &pb.ModelOptions{}, gRPCOptions: &pb.ModelOptions{},
context: context.Background(), context: context.Background(),
grpcAttempts: 20, grpcAttempts: 20,

View File

@ -1,16 +1,11 @@
package api_config package schema
import ( import (
"errors" "encoding/json"
"fmt" "fmt"
"io/fs"
"os" "os"
"path/filepath"
"strings"
"sync"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -152,11 +147,6 @@ type TemplateConfig struct {
Functions string `yaml:"function"` Functions string `yaml:"function"`
} }
type ConfigLoader struct {
configs map[string]Config
sync.Mutex
}
func (c *Config) SetFunctionCallString(s string) { func (c *Config) SetFunctionCallString(s string) {
c.functionCallString = s c.functionCallString = s
} }
@ -193,11 +183,6 @@ func DefaultConfig(modelFile string) *Config {
} }
} }
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
configs: make(map[string]Config),
}
}
func ReadConfigFile(file string) ([]*Config, error) { func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{} c := &[]*Config{}
f, err := os.ReadFile(file) f, err := os.ReadFile(file)
@ -211,7 +196,7 @@ func ReadConfigFile(file string) ([]*Config, error) {
return *c, nil return *c, nil
} }
func ReadConfig(file string) (*Config, error) { func ReadSingleConfigFile(file string) (*Config, error) {
c := &Config{} c := &Config{}
f, err := os.ReadFile(file) f, err := os.ReadFile(file)
if err != nil { if err != nil {
@ -224,136 +209,192 @@ func ReadConfig(file string) (*Config, error) {
return c, nil return c, nil
} }
func (cm *ConfigLoader) LoadConfigFile(file string) error { func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) {
cm.Lock() if input.Echo {
defer cm.Unlock() config.Echo = input.Echo
c, err := ReadConfigFile(file) }
if err != nil { if input.TopK != 0 {
return fmt.Errorf("cannot load config file: %w", err) config.TopK = input.TopK
}
if input.TopP != 0 {
config.TopP = input.TopP
} }
for _, cc := range c { if input.Backend != "" {
cm.configs[cc.Name] = *cc config.Backend = input.Backend
}
return nil
}
func (cm *ConfigLoader) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfig(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
} }
cm.configs[c.Name] = *c if input.ClipSkip != 0 {
return nil config.Diffusers.ClipSkip = input.ClipSkip
}
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
return v, exists
}
func (cm *ConfigLoader) GetAllConfigs() []Config {
cm.Lock()
defer cm.Unlock()
var res []Config
for _, v := range cm.configs {
res = append(res, v)
}
return res
}
func (cm *ConfigLoader) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
var res []string
for k := range cm.configs {
res = append(res, k)
}
return res
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
} }
log.Info().Msgf("Preloading models from %s", modelPath) if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
for i, config := range cm.configs { if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
// Download files and verify their SHA if input.UseFastTokenizer {
for _, file := range config.DownloadFiles { config.UseFastTokenizer = input.UseFastTokenizer
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) }
if err := utils.VerifyPath(file.Filename, modelPath); err != nil { if input.NegativePrompt != "" {
return err config.NegativePrompt = input.NegativePrompt
} }
// Create file path
filePath := filepath.Join(modelPath, file.Filename)
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { if input.RopeFreqBase != 0 {
return err config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != 0 {
config.Temperature = input.Temperature
}
if input.Maxtokens != 0 {
config.Maxtokens = input.Maxtokens
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.F16 {
config.F16 = input.F16
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != 0 {
config.Seed = input.Seed
}
if input.Mirostat != 0 {
config.LLMConfig.Mirostat = input.Mirostat
}
if input.MirostatETA != 0 {
config.LLMConfig.MirostatETA = input.MirostatETA
}
if input.MirostatTAU != 0 {
config.LLMConfig.MirostatTAU = input.MirostatTAU
}
if input.TypicalP != 0 {
config.TypicalP = input.TypicalP
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
} }
} }
}
modelURL := config.PredictionOptions.Model // Decode each request's message content
modelURL = utils.ConvertURL(modelURL) index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := utils.GetBase64Image(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
if utils.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
} }
} }
cc := cm.configs[i]
c := &cc
c.PredictionOptions.Model = md5Name
cm.configs[i] = *c
} }
} }
return nil
} // TODO: check that this was merged correctly? I _think_ it is?
switch inputs := input.Input.(type) {
func (cm *ConfigLoader) LoadConfigs(path string) error { case string:
cm.Lock() if inputs != "" {
defer cm.Unlock() config.InputStrings = append(config.InputStrings, inputs)
entries, err := os.ReadDir(path) }
if err != nil { case []interface{}:
return err for _, pp := range inputs {
} switch i := pp.(type) {
files := make([]fs.FileInfo, 0, len(entries)) case string:
for _, entry := range entries { config.InputStrings = append(config.InputStrings, i)
info, err := entry.Info() case []interface{}:
if err != nil { tokens := []int{}
return err for _, ii := range i {
} tokens = append(tokens, int(ii.(float64)))
files = append(files, info) }
} config.InputToken = append(config.InputToken, tokens)
for _, file := range files { }
// Skip templates, YAML and .keep files }
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { }
continue
} // Can be either a string or an object
c, err := ReadConfig(filepath.Join(path, file.Name())) switch fnc := input.FunctionCall.(type) {
if err == nil { case string:
cm.configs[c.Name] = *c if fnc != "" {
} config.SetFunctionCallString(fnc)
} }
case map[string]interface{}:
return nil var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
} }

View File

@ -1,11 +1,10 @@
package api_config_test package schema_test
import ( import (
"os" "os"
. "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/schema"
"github.com/go-skynet/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() {
Context("Test Read configuration functions", func() { Context("Test Read configuration functions", func() {
configFile = os.Getenv("CONFIG_FILE") configFile = os.Getenv("CONFIG_FILE")
It("Test ReadConfigFile", func() { It("Test ReadConfigFile", func() {
config, err := ReadConfigFile(configFile) config, err := schema.ReadConfigFile(configFile)
Expect(err).To(BeNil()) Expect(err).To(BeNil())
Expect(config).ToNot(BeNil()) Expect(config).ToNot(BeNil())
// two configs in config.yaml // two configs in config.yaml
@ -28,12 +27,8 @@ var _ = Describe("Test cases for config related functions", func() {
}) })
It("Test LoadConfigs", func() { It("Test LoadConfigs", func() {
cm := NewConfigLoader() cm := services.NewConfigLoader()
opts := options.NewOptions() err := cm.LoadConfigs(os.Getenv("MODELS_PATH"))
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
options.WithModelLoader(modelLoader)(opts)
err := cm.LoadConfigs(opts.Loader.ModelPath)
Expect(err).To(BeNil()) Expect(err).To(BeNil())
Expect(cm.ListConfigs()).ToNot(BeNil()) Expect(cm.ListConfigs()).ToNot(BeNil())

39
pkg/schema/localai.go Normal file
View File

@ -0,0 +1,39 @@
package schema
import (
"context"
gopsutil "github.com/shirou/gopsutil/v3/process"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Backend string `json:"backend" yaml:"backend"`
}
type LocalAIMetrics struct {
Meter metric.Meter
ApiTimeMetric metric.Float64Histogram
}
func (m *LocalAIMetrics) ObserveAPICall(method string, path string, duration float64) {
opts := metric.WithAttributes(
attribute.String("method", method),
attribute.String("path", path),
)
m.ApiTimeMetric.Record(context.Background(), duration, opts)
}

View File

@ -3,8 +3,6 @@ package schema
import ( import (
"context" "context"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grammar"
) )
@ -90,7 +88,7 @@ type ChatCompletionResponseFormat struct {
} }
type OpenAIRequest struct { type OpenAIRequest struct {
config.PredictionOptions PredictionOptions
Context context.Context Context context.Context
Cancel context.CancelFunc Cancel context.CancelFunc

View File

@ -1,4 +1,4 @@
package api_config package schema
type PredictionOptions struct { type PredictionOptions struct {

View File

@ -1,4 +1,4 @@
package options package schema
import ( import (
"context" "context"
@ -6,16 +6,14 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Option struct { type StartupOptions struct {
Context context.Context Context context.Context
ConfigFile string ConfigFile string
Loader *model.ModelLoader ModelPath string
UploadLimitMB, Threads, ContextSize int UploadLimitMB, Threads, ContextSize int
F16 bool F16 bool
Debug, DisableMessage bool Debug, DisableMessage bool
@ -26,7 +24,7 @@ type Option struct {
PreloadModelsFromPath string PreloadModelsFromPath string
CORSAllowOrigins string CORSAllowOrigins string
ApiKeys []string ApiKeys []string
Metrics *metrics.Metrics Metrics *LocalAIMetrics
Galleries []gallery.Gallery Galleries []gallery.Gallery
@ -47,12 +45,14 @@ type Option struct {
ModelsURL []string ModelsURL []string
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
LocalAIConfigDir string
} }
type AppOption func(*Option) type AppOption func(*StartupOptions)
func NewOptions(o ...AppOption) *Option { func NewStartupOptions(o ...AppOption) *StartupOptions {
opt := &Option{ opt := &StartupOptions{
Context: context.Background(), Context: context.Background(),
UploadLimitMB: 15, UploadLimitMB: 15,
Threads: 1, Threads: 1,
@ -67,57 +67,57 @@ func NewOptions(o ...AppOption) *Option {
} }
func WithModelsURL(urls ...string) AppOption { func WithModelsURL(urls ...string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.ModelsURL = urls o.ModelsURL = urls
} }
} }
func WithCors(b bool) AppOption { func WithCors(b bool) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.CORS = b o.CORS = b
} }
} }
var EnableWatchDog = func(o *Option) { var EnableWatchDog = func(o *StartupOptions) {
o.WatchDog = true o.WatchDog = true
} }
var EnableWatchDogIdleCheck = func(o *Option) { var EnableWatchDogIdleCheck = func(o *StartupOptions) {
o.WatchDog = true o.WatchDog = true
o.WatchDogIdle = true o.WatchDogIdle = true
} }
var EnableWatchDogBusyCheck = func(o *Option) { var EnableWatchDogBusyCheck = func(o *StartupOptions) {
o.WatchDog = true o.WatchDog = true
o.WatchDogBusy = true o.WatchDogBusy = true
} }
func SetWatchDogBusyTimeout(t time.Duration) AppOption { func SetWatchDogBusyTimeout(t time.Duration) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.WatchDogBusyTimeout = t o.WatchDogBusyTimeout = t
} }
} }
func SetWatchDogIdleTimeout(t time.Duration) AppOption { func SetWatchDogIdleTimeout(t time.Duration) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.WatchDogIdleTimeout = t o.WatchDogIdleTimeout = t
} }
} }
var EnableSingleBackend = func(o *Option) { var EnableSingleBackend = func(o *StartupOptions) {
o.SingleBackend = true o.SingleBackend = true
} }
var EnableParallelBackendRequests = func(o *Option) { var EnableParallelBackendRequests = func(o *StartupOptions) {
o.ParallelBackendRequests = true o.ParallelBackendRequests = true
} }
var EnableGalleriesAutoload = func(o *Option) { var EnableGalleriesAutoload = func(o *StartupOptions) {
o.AutoloadGalleries = true o.AutoloadGalleries = true
} }
func WithExternalBackend(name string, uri string) AppOption { func WithExternalBackend(name string, uri string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
if o.ExternalGRPCBackends == nil { if o.ExternalGRPCBackends == nil {
o.ExternalGRPCBackends = make(map[string]string) o.ExternalGRPCBackends = make(map[string]string)
} }
@ -126,25 +126,25 @@ func WithExternalBackend(name string, uri string) AppOption {
} }
func WithCorsAllowOrigins(b string) AppOption { func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.CORSAllowOrigins = b o.CORSAllowOrigins = b
} }
} }
func WithBackendAssetsOutput(out string) AppOption { func WithBackendAssetsOutput(out string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.AssetsDestination = out o.AssetsDestination = out
} }
} }
func WithBackendAssets(f embed.FS) AppOption { func WithBackendAssets(f embed.FS) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.BackendAssets = f o.BackendAssets = f
} }
} }
func WithStringGalleries(galls string) AppOption { func WithStringGalleries(galls string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
if galls == "" { if galls == "" {
log.Debug().Msgf("no galleries to load") log.Debug().Msgf("no galleries to load")
o.Galleries = []gallery.Gallery{} o.Galleries = []gallery.Gallery{}
@ -159,96 +159,102 @@ func WithStringGalleries(galls string) AppOption {
} }
func WithGalleries(galleries []gallery.Gallery) AppOption { func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.Galleries = append(o.Galleries, galleries...) o.Galleries = append(o.Galleries, galleries...)
} }
} }
func WithContext(ctx context.Context) AppOption { func WithContext(ctx context.Context) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.Context = ctx o.Context = ctx
} }
} }
func WithYAMLConfigPreload(configFile string) AppOption { func WithYAMLConfigPreload(configFile string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.PreloadModelsFromPath = configFile o.PreloadModelsFromPath = configFile
} }
} }
func WithJSONStringPreload(configFile string) AppOption { func WithJSONStringPreload(configFile string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.PreloadJSONModels = configFile o.PreloadJSONModels = configFile
} }
} }
func WithConfigFile(configFile string) AppOption { func WithConfigFile(configFile string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.ConfigFile = configFile o.ConfigFile = configFile
} }
} }
func WithModelLoader(loader *model.ModelLoader) AppOption { func WithModelPath(path string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.Loader = loader o.ModelPath = path
} }
} }
func WithUploadLimitMB(limit int) AppOption { func WithUploadLimitMB(limit int) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.UploadLimitMB = limit o.UploadLimitMB = limit
} }
} }
func WithThreads(threads int) AppOption { func WithThreads(threads int) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.Threads = threads o.Threads = threads
} }
} }
func WithContextSize(ctxSize int) AppOption { func WithContextSize(ctxSize int) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.ContextSize = ctxSize o.ContextSize = ctxSize
} }
} }
func WithF16(f16 bool) AppOption { func WithF16(f16 bool) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.F16 = f16 o.F16 = f16
} }
} }
func WithDebug(debug bool) AppOption { func WithDebug(debug bool) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.Debug = debug o.Debug = debug
} }
} }
func WithDisableMessage(disableMessage bool) AppOption { func WithDisableMessage(disableMessage bool) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.DisableMessage = disableMessage o.DisableMessage = disableMessage
} }
} }
func WithAudioDir(audioDir string) AppOption { func WithAudioDir(audioDir string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.AudioDir = audioDir o.AudioDir = audioDir
} }
} }
func WithImageDir(imageDir string) AppOption { func WithImageDir(imageDir string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.ImageDir = imageDir o.ImageDir = imageDir
} }
} }
func WithApiKeys(apiKeys []string) AppOption { func WithApiKeys(apiKeys []string) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.ApiKeys = apiKeys o.ApiKeys = apiKeys
} }
} }
func WithMetrics(meter *metrics.Metrics) AppOption { func WithMetrics(metrics *LocalAIMetrics) AppOption {
return func(o *Option) { return func(o *StartupOptions) {
o.Metrics = meter o.Metrics = metrics
}
}
func WithLocalAIConfigDir(configDir string) AppOption {
return func(o *StartupOptions) {
o.LocalAIConfigDir = configDir
} }
} }

View File

@ -2,7 +2,7 @@ package schema
import "time" import "time"
type Segment struct { type WhisperSegment struct {
Id int `json:"id"` Id int `json:"id"`
Start time.Duration `json:"start"` Start time.Duration `json:"start"`
End time.Duration `json:"end"` End time.Duration `json:"end"`
@ -10,7 +10,7 @@ type Segment struct {
Tokens []int `json:"tokens"` Tokens []int `json:"tokens"`
} }
type Result struct { type WhisperResult struct {
Segments []Segment `json:"segments"` Segments []WhisperSegment `json:"segments"`
Text string `json:"text"` Text string `json:"text"`
} }

81
pkg/utils/file.go Normal file
View File

@ -0,0 +1,81 @@
package utils
import (
"bufio"
"encoding/base64"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"github.com/rs/zerolog/log"
)
func CreateTempFileFromMultipartFile(file *multipart.FileHeader, tempDir string, tempPattern string) (string, error) {
f, err := file.Open()
if err != nil {
return "", err
}
defer f.Close()
// Create a temporary file in the requested directory:
outputFile, err := os.CreateTemp(tempDir, tempPattern)
if err != nil {
return "", err
}
defer outputFile.Close()
if _, err := io.Copy(outputFile, f); err != nil {
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, outputFile, err)
return "", err
}
return outputFile.Name(), nil
}
func CreateTempFileFromBase64(base64data string, tempDir string, tempPattern string) (string, error) {
if len(base64data) == 0 {
return "", fmt.Errorf("base64data empty?")
}
//base 64 decode the file and write it somewhere
// that we will cleanup
decoded, err := base64.StdEncoding.DecodeString(base64data)
if err != nil {
return "", err
}
// Create a temporary file in the requested directory:
outputFile, err := os.CreateTemp(tempDir, tempPattern)
if err != nil {
return "", err
}
defer outputFile.Close()
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(decoded)
if err != nil {
return "", err
}
return outputFile.Name(), nil
}
func CreateTempFileFromUrl(url string, tempDir string, tempPattern string) (string, error) {
// Get the data
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Create the file
out, err := os.CreateTemp(tempDir, tempPattern)
if err != nil {
return "", err
}
defer out.Close()
// Write the body to file
_, err = io.Copy(out, resp.Body)
return out.Name(), err
}

View File

@ -3,18 +3,38 @@ package utils
import ( import (
"crypto/md5" "crypto/md5"
"crypto/sha256" "crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"hash" "hash"
"io" "io"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"slices"
"strconv" "strconv"
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
const (
HuggingFacePrefix = "huggingface://"
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
)
func getRecognizedURIPrefixes() []string {
return []string{
HuggingFacePrefix,
HTTPPrefix,
HTTPSPrefix,
GithubURI,
GithubURI2,
}
}
func GetURI(url string, f func(url string, i []byte) error) error { func GetURI(url string, f func(url string, i []byte) error) error {
url = ConvertURL(url) url = ConvertURL(url)
@ -52,20 +72,8 @@ func GetURI(url string, f func(url string, i []byte) error) error {
return f(url, body) return f(url, body)
} }
const (
HuggingFacePrefix = "huggingface://"
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
)
func LooksLikeURL(s string) bool { func LooksLikeURL(s string) bool {
return strings.HasPrefix(s, HTTPPrefix) || return slices.Contains(getRecognizedURIPrefixes(), s)
strings.HasPrefix(s, HTTPSPrefix) ||
strings.HasPrefix(s, HuggingFacePrefix) ||
strings.HasPrefix(s, GithubURI) ||
strings.HasPrefix(s, GithubURI2)
} }
func ConvertURL(s string) string { func ConvertURL(s string) string {
@ -241,6 +249,37 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
return nil return nil
} }
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func GetBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}
type progressWriter struct { type progressWriter struct {
fileName string fileName string
total int64 total int64

View File

@ -3,16 +3,16 @@ package integration_test
import ( import (
"reflect" "reflect"
config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/model"
model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/schema"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("Integration Tests involving reflection in liue of code generation", func() { var _ = Describe("Integration Tests involving reflection in liue of code generation", func() {
Context("config.TemplateConfig and model.TemplateType must stay in sync", func() { Context("schema.TemplateConfig and model.TemplateType must stay in sync", func() {
ttc := reflect.TypeOf(config.TemplateConfig{}) ttc := reflect.TypeOf(schema.TemplateConfig{})
It("TemplateConfig and TemplateType should have the same number of valid values", func() { It("TemplateConfig and TemplateType should have the same number of valid values", func() {
const lastValidTemplateType = model.IntegrationTestTemplate - 1 const lastValidTemplateType = model.IntegrationTestTemplate - 1