mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-23 14:32:25 +00:00
268 lines
6.2 KiB
Go
268 lines
6.2 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog/log"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
type galleryOp struct {
|
|
req ApplyGalleryModelRequest
|
|
id string
|
|
}
|
|
|
|
type galleryOpStatus struct {
|
|
Error error `json:"error"`
|
|
Processed bool `json:"processed"`
|
|
Message string `json:"message"`
|
|
Progress float64 `json:"progress"`
|
|
TotalFileSize string `json:"file_size"`
|
|
DownloadedFileSize string `json:"downloaded_size"`
|
|
}
|
|
|
|
type galleryApplier struct {
|
|
modelPath string
|
|
sync.Mutex
|
|
C chan galleryOp
|
|
statuses map[string]*galleryOpStatus
|
|
}
|
|
|
|
func newGalleryApplier(modelPath string) *galleryApplier {
|
|
return &galleryApplier{
|
|
modelPath: modelPath,
|
|
C: make(chan galleryOp),
|
|
statuses: make(map[string]*galleryOpStatus),
|
|
}
|
|
}
|
|
|
|
func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
|
url, err := req.DecodeURL()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Send a GET request to the URL
|
|
response, err := http.Get(url)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
// Read the response body
|
|
body, err := ioutil.ReadAll(response.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Unmarshal YAML data into a Config struct
|
|
var config gallery.Config
|
|
err = yaml.Unmarshal(body, &config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
config.Files = append(config.Files, req.AdditionalFiles...)
|
|
|
|
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Reload models
|
|
return cm.LoadConfigs(modelPath)
|
|
}
|
|
|
|
func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) {
|
|
g.Lock()
|
|
defer g.Unlock()
|
|
g.statuses[s] = op
|
|
}
|
|
|
|
func (g *galleryApplier) getstatus(s string) *galleryOpStatus {
|
|
g.Lock()
|
|
defer g.Unlock()
|
|
|
|
return g.statuses[s]
|
|
}
|
|
|
|
func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-c.Done():
|
|
return
|
|
case op := <-g.C:
|
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
|
|
|
updateError := func(e error) {
|
|
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
|
|
}
|
|
|
|
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
|
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
|
displayDownload(fileName, current, total, percentage)
|
|
}); err != nil {
|
|
updateError(err)
|
|
continue
|
|
}
|
|
|
|
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
var lastProgress time.Time = time.Now()
|
|
var startTime time.Time = time.Now()
|
|
|
|
func displayDownload(fileName string, current string, total string, percentage float64) {
|
|
currentTime := time.Now()
|
|
|
|
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
|
|
|
lastProgress = currentTime
|
|
|
|
// calculate ETA based on percentage and elapsed time
|
|
var eta time.Duration
|
|
if percentage > 0 {
|
|
elapsed := currentTime.Sub(startTime)
|
|
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
|
}
|
|
|
|
if total != "" {
|
|
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
|
} else {
|
|
log.Debug().Msgf("Downloading: %s", current)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
|
|
dat, err := os.ReadFile(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var requests []ApplyGalleryModelRequest
|
|
err = json.Unmarshal(dat, &requests)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, r := range requests {
|
|
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
|
|
var requests []ApplyGalleryModelRequest
|
|
err := json.Unmarshal([]byte(s), &requests)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, r := range requests {
|
|
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// endpoints
|
|
|
|
type ApplyGalleryModelRequest struct {
|
|
URL string `json:"url"`
|
|
Name string `json:"name"`
|
|
Overrides map[string]interface{} `json:"overrides"`
|
|
AdditionalFiles []gallery.File `json:"files"`
|
|
}
|
|
|
|
const (
|
|
githubURI = "github:"
|
|
)
|
|
|
|
func (request ApplyGalleryModelRequest) DecodeURL() (string, error) {
|
|
input := request.URL
|
|
var rawURL string
|
|
|
|
if strings.HasPrefix(input, githubURI) {
|
|
parts := strings.Split(input, ":")
|
|
repoParts := strings.Split(parts[1], "@")
|
|
branch := "main"
|
|
|
|
if len(repoParts) > 1 {
|
|
branch = repoParts[1]
|
|
}
|
|
|
|
repoPath := strings.Split(repoParts[0], "/")
|
|
org := repoPath[0]
|
|
project := repoPath[1]
|
|
projectPath := strings.Join(repoPath[2:], "/")
|
|
|
|
rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
|
|
} else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
|
|
// Handle regular URLs
|
|
u, err := url.Parse(input)
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid URL: %w", err)
|
|
}
|
|
rawURL = u.String()
|
|
} else {
|
|
return "", fmt.Errorf("invalid URL format")
|
|
}
|
|
|
|
return rawURL, nil
|
|
}
|
|
|
|
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
|
|
return func(c *fiber.Ctx) error {
|
|
|
|
status := g.getstatus(c.Params("uuid"))
|
|
if status == nil {
|
|
return fmt.Errorf("could not find any status for ID")
|
|
}
|
|
|
|
return c.JSON(status)
|
|
}
|
|
}
|
|
|
|
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error {
|
|
return func(c *fiber.Ctx) error {
|
|
input := new(ApplyGalleryModelRequest)
|
|
// Get input data from the request body
|
|
if err := c.BodyParser(input); err != nil {
|
|
return err
|
|
}
|
|
|
|
uuid, err := uuid.NewUUID()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
g <- galleryOp{
|
|
req: *input,
|
|
id: uuid.String(),
|
|
}
|
|
return c.JSON(struct {
|
|
ID string `json:"uuid"`
|
|
StatusURL string `json:"status"`
|
|
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
|
}
|
|
}
|