From 84946e92756af9e465d39a8657de263c82a2683b Mon Sep 17 00:00:00 2001
From: Ettore Di Giacinto <mudler@users.noreply.github.com>
Date: Thu, 8 Jun 2023 21:33:18 +0200
Subject: [PATCH] feat: display download progress when installing models (#543)

---
 api/gallery.go             | 54 +++++++++++++++++++++++++------
 pkg/gallery/models.go      | 66 +++++++++++++++++++++++++++++---------
 pkg/gallery/models_test.go |  8 ++---
 3 files changed, 99 insertions(+), 29 deletions(-)

diff --git a/api/gallery.go b/api/gallery.go
index b5b74b0d..a9a87220 100644
--- a/api/gallery.go
+++ b/api/gallery.go
@@ -10,10 +10,12 @@ import (
 	"os"
 	"strings"
 	"sync"
+	"time"
 
 	"github.com/go-skynet/LocalAI/pkg/gallery"
 	"github.com/gofiber/fiber/v2"
 	"github.com/google/uuid"
+	"github.com/rs/zerolog/log"
 	"gopkg.in/yaml.v3"
 )
 
@@ -23,9 +25,12 @@ type galleryOp struct {
 }
 
 type galleryOpStatus struct {
-	Error     error  `json:"error"`
-	Processed bool   `json:"processed"`
-	Message   string `json:"message"`
+	Error              error   `json:"error"`
+	Processed          bool    `json:"processed"`
+	Message            string  `json:"message"`
+	Progress           float64 `json:"progress"`
+	TotalFileSize      string  `json:"file_size"`
+	DownloadedFileSize string  `json:"downloaded_size"`
 }
 
 type galleryApplier struct {
@@ -43,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier {
 	}
 }
 
-func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger) error {
+func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
 	url, err := req.DecodeURL()
 	if err != nil {
 		return err
@@ -71,7 +76,7 @@ func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerg
 
 	config.Files = append(config.Files, req.AdditionalFiles...)
 
-	if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides); err != nil {
+	if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides, downloadStatus); err != nil {
 		return err
 	}
 
@@ -99,23 +104,51 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
 			case <-c.Done():
 				return
 			case op := <-g.C:
-				g.updatestatus(op.id, &galleryOpStatus{Message: "processing"})
+				g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
 
 				updateError := func(e error) {
 					g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
 				}
 
-				if err := applyGallery(g.modelPath, op.req, cm); err != nil {
+				if err := applyGallery(g.modelPath, op.req, cm, func(fileName string, current string, total string, percentage float64) {
+					g.updatestatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
+					displayDownload(fileName, current, total, percentage)
+				}); err != nil {
 					updateError(err)
 					continue
 				}
 
-				g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"})
+				g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
 			}
 		}
 	}()
 }
 
+var lastProgress time.Time = time.Now()
+var startTime time.Time = time.Now()
+
+func displayDownload(fileName string, current string, total string, percentage float64) {
+	currentTime := time.Now()
+
+	if currentTime.Sub(lastProgress) >= 5*time.Second {
+
+		lastProgress = currentTime
+
+		// calculate ETA based on percentage and elapsed time
+		var eta time.Duration
+		if percentage > 0 {
+			elapsed := currentTime.Sub(startTime)
+			eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
+		}
+
+		if total != "" {
+			log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
+		} else {
+			log.Debug().Msgf("Downloading: %s", current)
+		}
+	}
+}
+
 func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
 	dat, err := os.ReadFile(s)
 	if err != nil {
@@ -128,13 +161,14 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error {
 	}
 
 	for _, r := range requests {
-		if err := applyGallery(modelPath, r, cm); err != nil {
+		if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
 			return err
 		}
 	}
 
 	return nil
 }
+
 func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
 	var requests []ApplyGalleryModelRequest
 	err := json.Unmarshal([]byte(s), &requests)
@@ -143,7 +177,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error {
 	}
 
 	for _, r := range requests {
-		if err := applyGallery(modelPath, r, cm); err != nil {
+		if err := applyGallery(modelPath, r, cm, displayDownload); err != nil {
 			return err
 		}
 	}
diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go
index f4f86ae7..14a7d6ac 100644
--- a/pkg/gallery/models.go
+++ b/pkg/gallery/models.go
@@ -3,10 +3,12 @@ package gallery
 import (
 	"crypto/sha256"
 	"fmt"
+	"hash"
 	"io"
 	"net/http"
 	"os"
 	"path/filepath"
+	"strconv"
 
 	"github.com/imdario/mergo"
 	"github.com/rs/zerolog/log"
@@ -93,7 +95,7 @@ func verifyPath(path, basePath string) error {
 	return inTrustedRoot(c, basePath)
 }
 
-func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error {
+func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
 	// Create base path if it doesn't exist
 	err := os.MkdirAll(basePath, 0755)
 	if err != nil {
@@ -168,27 +170,25 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
 		}
 		defer outFile.Close()
 
+		progress := &progressWriter{
+			fileName:       file.Filename,
+			total:          resp.ContentLength,
+			hash:           sha256.New(),
+			downloadStatus: downloadStatus,
+		}
+		_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
+		if err != nil {
+			return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
+		}
+
 		if file.SHA256 != "" {
-			log.Debug().Msgf("Download and verifying %q", file.Filename)
-
-			// Write file content and calculate SHA
-			hash := sha256.New()
-			_, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body)
-			if err != nil {
-				return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
-			}
-
 			// Verify SHA
-			calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil))
+			calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
 			if calculatedSHA != file.SHA256 {
 				return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
 			}
 		} else {
 			log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
-			_, err = io.Copy(outFile, resp.Body)
-			if err != nil {
-				return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
-			}
 		}
 
 		log.Debug().Msgf("File %q downloaded and verified", file.Filename)
@@ -255,6 +255,42 @@ func Apply(basePath, nameOverride string, config *Config, configOverrides map[st
 	return nil
 }
 
+type progressWriter struct {
+	fileName       string
+	total          int64
+	written        int64
+	downloadStatus func(string, string, string, float64)
+	hash           hash.Hash
+}
+
+func (pw *progressWriter) Write(p []byte) (n int, err error) {
+	n, err = pw.hash.Write(p)
+	pw.written += int64(n)
+
+	if pw.total > 0 {
+		percentage := float64(pw.written) / float64(pw.total) * 100
+		//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
+		pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
+	} else {
+		pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
+	}
+
+	return
+}
+
+func formatBytes(bytes int64) string {
+	const unit = 1024
+	if bytes < unit {
+		return strconv.FormatInt(bytes, 10) + " B"
+	}
+	div, exp := int64(unit), 0
+	for n := bytes / unit; n >= unit; n /= unit {
+		div *= unit
+		exp++
+	}
+	return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
+}
+
 func calculateSHA(filePath string) (string, error) {
 	file, err := os.Open(filePath)
 	if err != nil {
diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go
index f0e580e9..343bf6ab 100644
--- a/pkg/gallery/models_test.go
+++ b/pkg/gallery/models_test.go
@@ -19,7 +19,7 @@ var _ = Describe("Model test", func() {
 			c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
 			Expect(err).ToNot(HaveOccurred())
 
-			err = Apply(tempdir, "", c, map[string]interface{}{})
+			err = Apply(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
 			Expect(err).ToNot(HaveOccurred())
 
 			for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@@ -45,7 +45,7 @@ var _ = Describe("Model test", func() {
 			c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
 			Expect(err).ToNot(HaveOccurred())
 
-			err = Apply(tempdir, "foo", c, map[string]interface{}{})
+			err = Apply(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
 			Expect(err).ToNot(HaveOccurred())
 
 			for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -61,7 +61,7 @@ var _ = Describe("Model test", func() {
 			c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
 			Expect(err).ToNot(HaveOccurred())
 
-			err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"})
+			err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
 			Expect(err).ToNot(HaveOccurred())
 
 			for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -87,7 +87,7 @@ var _ = Describe("Model test", func() {
 			c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
 			Expect(err).ToNot(HaveOccurred())
 
-			err = Apply(tempdir, "../../../foo", c, map[string]interface{}{})
+			err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
 			Expect(err).To(HaveOccurred())
 		})
 	})