mirror of
https://github.com/mudler/LocalAI.git
synced 2025-02-11 13:15:20 +00:00
feat: display download progress when installing models (#543)
This commit is contained in:
parent
c9bbba4872
commit
84946e9275
@ -10,10 +10,12 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,6 +28,9 @@ type galleryOpStatus struct {
|
|||||||
Error error `json:"error"`
|
Error error `json:"error"`
|
||||||
Processed bool `json:"processed"`
|
Processed bool `json:"processed"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
Progress float64 `json:"progress"`
|
||||||
|
TotalFileSize string `json:"file_size"`
|
||||||
|
DownloadedFileSize string `json:"downloaded_size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type galleryApplier struct {
|
type galleryApplier struct {
|
||||||
@ -43,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger) error {
|
func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
||||||
url, err := req.DecodeURL()
|
url, err := req.DecodeURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -71,7 +76,7 @@ func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerg
|
|||||||
|
|
||||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||||
|
|
||||||
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides); err != nil {
|
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,23 +104,51 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
|
|||||||
case <-c.Done():
|
case <-c.Done():
|
||||||
return
|
return
|
||||||
case op := <-g.C:
|
case op := <-g.C:
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Message: "processing"})
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
||||||
|
|
||||||
updateError := func(e error) {
|
updateError := func(e error) {
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
|
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := applyGallery(g.modelPath, op.req, cm); err != nil {
|
if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
|
||||||
|
g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||||
|
displayDownload(fileName, current, total, percentage)
|
||||||
|
}); err != nil {
|
||||||
updateError(err)
|
updateError(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"})
|
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lastProgress time.Time = time.Now()
|
||||||
|
var startTime time.Time = time.Now()
|
||||||
|
|
||||||
|
func displayDownload(fileName string, current string, total string, percentage float64) {
|
||||||
|
currentTime := time.Now()
|
||||||
|
|
||||||
|
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
||||||
|
|
||||||
|
lastProgress = currentTime
|
||||||
|
|
||||||
|
// calculate ETA based on percentage and elapsed time
|
||||||
|
var eta time.Duration
|
||||||
|
if percentage > 0 {
|
||||||
|
elapsed := currentTime.Sub(startTime)
|
||||||
|
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
||||||
|
}
|
||||||
|
|
||||||
|
if total != "" {
|
||||||
|
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("Downloading: %s", current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
|
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
|
||||||
dat, err := os.ReadFile(s)
|
dat, err := os.ReadFile(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -128,13 +161,14 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range requests {
|
for _, r := range requests {
|
||||||
if err := applyGallery(modelPath, r, cm); err != nil {
|
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
|
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
|
||||||
var requests []ApplyGalleryModelRequest
|
var requests []ApplyGalleryModelRequest
|
||||||
err := json.Unmarshal([]byte(s), &requests)
|
err := json.Unmarshal([]byte(s), &requests)
|
||||||
@ -143,7 +177,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range requests {
|
for _, r := range requests {
|
||||||
if err := applyGallery(modelPath, r, cm); err != nil {
|
if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,12 @@ package gallery
|
|||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/imdario/mergo"
|
"github.com/imdario/mergo"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@ -93,7 +95,7 @@ func verifyPath(path, basePath string) error {
|
|||||||
return inTrustedRoot(c, basePath)
|
return inTrustedRoot(c, basePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error {
|
func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
||||||
// Create base path if it doesn't exist
|
// Create base path if it doesn't exist
|
||||||
err := os.MkdirAll(basePath, 0755)
|
err := os.MkdirAll(basePath, 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -168,27 +170,25 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
|
|||||||
}
|
}
|
||||||
defer outFile.Close()
|
defer outFile.Close()
|
||||||
|
|
||||||
if file.SHA256 != "" {
|
progress := &progressWriter{
|
||||||
log.Debug().Msgf("Download and verifying %q", file.Filename)
|
fileName: file.Filename,
|
||||||
|
total: resp.ContentLength,
|
||||||
// Write file content and calculate SHA
|
hash: sha256.New(),
|
||||||
hash := sha256.New()
|
downloadStatus: downloadStatus,
|
||||||
_, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body)
|
}
|
||||||
|
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
|
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if file.SHA256 != "" {
|
||||||
// Verify SHA
|
// Verify SHA
|
||||||
calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil))
|
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
|
||||||
if calculatedSHA != file.SHA256 {
|
if calculatedSHA != file.SHA256 {
|
||||||
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
|
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
|
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
|
||||||
_, err = io.Copy(outFile, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("File %q downloaded and verified", file.Filename)
|
log.Debug().Msgf("File %q downloaded and verified", file.Filename)
|
||||||
@ -255,6 +255,42 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type progressWriter struct {
|
||||||
|
fileName string
|
||||||
|
total int64
|
||||||
|
written int64
|
||||||
|
downloadStatus func(string, string, string, float64)
|
||||||
|
hash hash.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
||||||
|
n, err = pw.hash.Write(p)
|
||||||
|
pw.written += int64(n)
|
||||||
|
|
||||||
|
if pw.total > 0 {
|
||||||
|
percentage := float64(pw.written) / float64(pw.total) * 100
|
||||||
|
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||||
|
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||||
|
} else {
|
||||||
|
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatBytes(bytes int64) string {
|
||||||
|
const unit = 1024
|
||||||
|
if bytes < unit {
|
||||||
|
return strconv.FormatInt(bytes, 10) + " B"
|
||||||
|
}
|
||||||
|
div, exp := int64(unit), 0
|
||||||
|
for n := bytes / unit; n >= unit; n /= unit {
|
||||||
|
div *= unit
|
||||||
|
exp++
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||||
|
}
|
||||||
|
|
||||||
func calculateSHA(filePath string) (string, error) {
|
func calculateSHA(filePath string) (string, error) {
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -19,7 +19,7 @@ var _ = Describe("Model test", func() {
|
|||||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
err = Apply(tempdir, "", c, map[string]interface{}{})
|
err = Apply(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||||
@ -45,7 +45,7 @@ var _ = Describe("Model test", func() {
|
|||||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
err = Apply(tempdir, "foo", c, map[string]interface{}{})
|
err = Apply(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||||
@ -61,7 +61,7 @@ var _ = Describe("Model test", func() {
|
|||||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"})
|
err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||||
@ -87,7 +87,7 @@ var _ = Describe("Model test", func() {
|
|||||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
err = Apply(tempdir, "../../../foo", c, map[string]interface{}{})
|
err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user