2023-07-15 01:19:43 +02:00
|
|
|
package localai
|
2023-05-18 15:59:03 +02:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
2023-05-27 09:26:33 +02:00
|
|
|
"os"
|
2023-09-02 03:00:44 -04:00
|
|
|
"slices"
|
2023-07-20 22:10:12 +02:00
|
|
|
"strings"
|
2023-05-18 15:59:03 +02:00
|
|
|
"sync"
|
|
|
|
|
2023-06-24 08:18:17 +02:00
|
|
|
json "github.com/json-iterator/go"
|
2023-07-31 21:13:16 +02:00
|
|
|
"gopkg.in/yaml.v3"
|
2023-06-24 08:18:17 +02:00
|
|
|
|
2023-07-15 01:19:43 +02:00
|
|
|
config "github.com/go-skynet/LocalAI/api/config"
|
2023-05-18 15:59:03 +02:00
|
|
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
2023-07-20 22:10:12 +02:00
|
|
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
|
|
|
2023-05-18 15:59:03 +02:00
|
|
|
"github.com/gofiber/fiber/v2"
|
|
|
|
"github.com/google/uuid"
|
2023-06-08 21:33:18 +02:00
|
|
|
"github.com/rs/zerolog/log"
|
2023-05-18 15:59:03 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
type galleryOp struct {
|
2023-06-24 08:18:17 +02:00
|
|
|
req gallery.GalleryModel
|
|
|
|
id string
|
|
|
|
galleries []gallery.Gallery
|
|
|
|
galleryName string
|
2023-05-18 15:59:03 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
type galleryOpStatus struct {
|
2023-08-31 16:03:03 +01:00
|
|
|
FileName string `json:"file_name"`
|
2023-06-08 21:33:18 +02:00
|
|
|
Error error `json:"error"`
|
|
|
|
Processed bool `json:"processed"`
|
|
|
|
Message string `json:"message"`
|
|
|
|
Progress float64 `json:"progress"`
|
|
|
|
TotalFileSize string `json:"file_size"`
|
|
|
|
DownloadedFileSize string `json:"downloaded_size"`
|
2023-05-18 15:59:03 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
type galleryApplier struct {
|
|
|
|
modelPath string
|
|
|
|
sync.Mutex
|
|
|
|
C chan galleryOp
|
|
|
|
statuses map[string]*galleryOpStatus
|
|
|
|
}
|
|
|
|
|
2023-07-15 01:19:43 +02:00
|
|
|
func NewGalleryService(modelPath string) *galleryApplier {
|
2023-05-18 15:59:03 +02:00
|
|
|
return &galleryApplier{
|
|
|
|
modelPath: modelPath,
|
|
|
|
C: make(chan galleryOp),
|
|
|
|
statuses: make(map[string]*galleryOpStatus),
|
|
|
|
}
|
|
|
|
}
|
2023-05-27 09:26:33 +02:00
|
|
|
|
2023-07-15 01:19:43 +02:00
|
|
|
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
2023-06-24 08:18:17 +02:00
|
|
|
|
2023-06-26 12:25:38 +02:00
|
|
|
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
2023-05-27 09:26:33 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
config.Files = append(config.Files, req.AdditionalFiles...)
|
|
|
|
|
2023-06-24 08:18:17 +02:00
|
|
|
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
2023-05-27 09:26:33 +02:00
|
|
|
}
|
|
|
|
|
2023-06-24 08:18:17 +02:00
|
|
|
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
|
2023-05-18 15:59:03 +02:00
|
|
|
g.Lock()
|
|
|
|
defer g.Unlock()
|
|
|
|
g.statuses[s] = op
|
|
|
|
}
|
|
|
|
|
2023-06-24 08:18:17 +02:00
|
|
|
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
|
2023-05-18 15:59:03 +02:00
|
|
|
g.Lock()
|
|
|
|
defer g.Unlock()
|
|
|
|
|
|
|
|
return g.statuses[s]
|
|
|
|
}
|
|
|
|
|
2023-08-31 16:03:03 +01:00
|
|
|
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
|
|
|
|
g.Lock()
|
|
|
|
defer g.Unlock()
|
|
|
|
|
|
|
|
return g.statuses
|
|
|
|
}
|
|
|
|
|
2023-07-15 01:19:43 +02:00
|
|
|
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
2023-05-18 15:59:03 +02:00
|
|
|
go func() {
|
|
|
|
for {
|
|
|
|
select {
|
|
|
|
case <-c.Done():
|
|
|
|
return
|
|
|
|
case op := <-g.C:
|
2023-07-20 22:10:12 +02:00
|
|
|
utils.ResetDownloadTimers()
|
|
|
|
|
2023-06-24 08:18:17 +02:00
|
|
|
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
2023-05-18 15:59:03 +02:00
|
|
|
|
2023-06-24 08:18:17 +02:00
|
|
|
// updates the status with an error
|
2023-05-18 15:59:03 +02:00
|
|
|
updateError := func(e error) {
|
2023-06-24 08:18:17 +02:00
|
|
|
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
2023-05-18 15:59:03 +02:00
|
|
|
}
|
2023-05-20 09:06:30 +02:00
|
|
|
|
2023-06-24 08:18:17 +02:00
|
|
|
// displayDownload displays the download progress
|
|
|
|
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
2023-08-31 16:03:03 +01:00
|
|
|
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
2023-07-20 22:10:12 +02:00
|
|
|
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
2023-06-24 08:18:17 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
var err error
|
|
|
|
// if the request contains a gallery name, we apply the gallery from the gallery list
|
|
|
|
if op.galleryName != "" {
|
2023-07-20 22:10:12 +02:00
|
|
|
if strings.Contains(op.galleryName, "@") {
|
|
|
|
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
|
|
|
} else {
|
|
|
|
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
|
|
|
}
|
2023-06-24 08:18:17 +02:00
|
|
|
} else {
|
|
|
|
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err != nil {
|
2023-05-18 15:59:03 +02:00
|
|
|
updateError(err)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2023-06-24 08:18:17 +02:00
|
|
|
// Reload models
|
|
|
|
err = cm.LoadConfigs(g.modelPath)
|
|
|
|
if err != nil {
|
|
|
|
updateError(err)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
2023-05-27 09:26:33 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|
2023-05-18 15:59:03 +02:00
|
|
|
|
2023-06-26 12:25:38 +02:00
|
|
|
type galleryModel struct {
|
2023-08-01 19:09:32 +02:00
|
|
|
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
|
|
|
ID string `json:"id"`
|
2023-06-26 12:25:38 +02:00
|
|
|
}
|
|
|
|
|
2023-07-31 21:13:16 +02:00
|
|
|
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
|
|
|
var err error
|
|
|
|
for _, r := range requests {
|
|
|
|
utils.ResetDownloadTimers()
|
|
|
|
if r.ID == "" {
|
|
|
|
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
|
|
|
} else {
|
|
|
|
if strings.Contains(r.ID, "@") {
|
|
|
|
err = gallery.InstallModelFromGallery(
|
|
|
|
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
|
|
|
} else {
|
|
|
|
err = gallery.InstallModelFromGalleryByName(
|
|
|
|
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2023-07-15 01:19:43 +02:00
|
|
|
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
2023-05-27 09:26:33 +02:00
|
|
|
dat, err := os.ReadFile(s)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2023-07-31 21:13:16 +02:00
|
|
|
var requests []galleryModel
|
|
|
|
|
|
|
|
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return processRequests(modelPath, s, cm, galleries, requests)
|
2023-05-27 09:26:33 +02:00
|
|
|
}
|
2023-06-08 21:33:18 +02:00
|
|
|
|
2023-07-15 01:19:43 +02:00
|
|
|
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
2023-06-26 12:25:38 +02:00
|
|
|
var requests []galleryModel
|
2023-05-27 09:26:33 +02:00
|
|
|
err := json.Unmarshal([]byte(s), &requests)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2023-05-18 15:59:03 +02:00
|
|
|
|
2023-07-31 21:13:16 +02:00
|
|
|
return processRequests(modelPath, s, cm, galleries, requests)
|
2023-05-18 15:59:03 +02:00
|
|
|
}
|
|
|
|
|
2023-09-02 03:00:44 -04:00
|
|
|
/// Endpoint Service
|
2023-07-15 01:19:43 +02:00
|
|
|
|
2023-09-02 03:00:44 -04:00
|
|
|
type ModelGalleryService struct {
|
|
|
|
galleries []gallery.Gallery
|
|
|
|
modelPath string
|
|
|
|
galleryApplier *galleryApplier
|
|
|
|
}
|
|
|
|
|
|
|
|
type GalleryModel struct {
|
|
|
|
ID string `json:"id"`
|
|
|
|
gallery.GalleryModel
|
|
|
|
}
|
|
|
|
|
|
|
|
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
|
|
|
|
return ModelGalleryService{
|
|
|
|
galleries: galleries,
|
|
|
|
modelPath: modelPath,
|
|
|
|
galleryApplier: galleryApplier,
|
|
|
|
}
|
|
|
|
}
|
2023-05-18 15:59:03 +02:00
|
|
|
|
2023-09-02 03:00:44 -04:00
|
|
|
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
|
|
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
|
2023-05-18 15:59:03 +02:00
|
|
|
if status == nil {
|
|
|
|
return fmt.Errorf("could not find any status for ID")
|
|
|
|
}
|
|
|
|
return c.JSON(status)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-02 03:00:44 -04:00
|
|
|
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
2023-08-31 16:03:03 +01:00
|
|
|
return func(c *fiber.Ctx) error {
|
2023-09-02 03:00:44 -04:00
|
|
|
return c.JSON(mgs.galleryApplier.getAllStatus())
|
2023-08-31 16:03:03 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-02 03:00:44 -04:00
|
|
|
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
2023-05-18 15:59:03 +02:00
|
|
|
return func(c *fiber.Ctx) error {
|
2023-06-24 08:18:17 +02:00
|
|
|
input := new(GalleryModel)
|
2023-05-18 15:59:03 +02:00
|
|
|
// Get input data from the request body
|
|
|
|
if err := c.BodyParser(input); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
uuid, err := uuid.NewUUID()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2023-09-02 03:00:44 -04:00
|
|
|
mgs.galleryApplier.C <- galleryOp{
|
2023-06-24 08:18:17 +02:00
|
|
|
req: input.GalleryModel,
|
|
|
|
id: uuid.String(),
|
|
|
|
galleryName: input.ID,
|
2023-09-02 03:00:44 -04:00
|
|
|
galleries: mgs.galleries,
|
2023-05-18 15:59:03 +02:00
|
|
|
}
|
|
|
|
return c.JSON(struct {
|
2023-05-20 17:03:53 +02:00
|
|
|
ID string `json:"uuid"`
|
2023-05-18 15:59:03 +02:00
|
|
|
StatusURL string `json:"status"`
|
|
|
|
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
|
|
|
}
|
|
|
|
}
|
2023-06-24 08:18:17 +02:00
|
|
|
|
2023-09-02 03:00:44 -04:00
|
|
|
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
2023-06-24 08:18:17 +02:00
|
|
|
return func(c *fiber.Ctx) error {
|
2023-09-02 03:00:44 -04:00
|
|
|
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
2023-06-24 08:18:17 +02:00
|
|
|
|
2023-09-02 03:00:44 -04:00
|
|
|
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
2023-06-24 08:18:17 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
log.Debug().Msgf("Models found from galleries: %+v", models)
|
|
|
|
for _, m := range models {
|
|
|
|
log.Debug().Msgf("Model found from galleries: %+v", m)
|
|
|
|
}
|
|
|
|
dat, err := json.Marshal(models)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return c.Send(dat)
|
|
|
|
}
|
|
|
|
}
|
2023-09-02 03:00:44 -04:00
|
|
|
|
|
|
|
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
|
|
|
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
|
|
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
|
|
|
dat, err := json.Marshal(mgs.galleries)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return c.Send(dat)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
|
|
input := new(gallery.Gallery)
|
|
|
|
// Get input data from the request body
|
|
|
|
if err := c.BodyParser(input); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
|
|
return gallery.Name == input.Name
|
|
|
|
}) {
|
|
|
|
return fmt.Errorf("%s already exists", input.Name)
|
|
|
|
}
|
|
|
|
dat, err := json.Marshal(mgs.galleries)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
|
|
|
mgs.galleries = append(mgs.galleries, *input)
|
|
|
|
return c.Send(dat)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
|
|
input := new(gallery.Gallery)
|
|
|
|
// Get input data from the request body
|
|
|
|
if err := c.BodyParser(input); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
|
|
return gallery.Name == input.Name
|
|
|
|
}) {
|
|
|
|
return fmt.Errorf("%s is not currently registered", input.Name)
|
|
|
|
}
|
|
|
|
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
|
|
return gallery.Name == input.Name
|
|
|
|
})
|
|
|
|
return c.Send(nil)
|
|
|
|
}
|
|
|
|
}
|