feat(gallery): support model deletion (#2173)

* feat(gallery): op now supports deletion of models

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Wire things with WebUI(WIP)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* minor improvements

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-04-28 23:42:46 +02:00 committed by GitHub
parent a24cd4fda0
commit e8d44447ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 294 additions and 36 deletions

View File

@ -184,6 +184,36 @@ func (c *BackendConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0 return len(c.functionCallNameString) > 0
} }
// MMProjFileName returns the filename of the MMProj file
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) MMProjFileName() string {
modelURL := downloader.ConvertURL(c.MMProj)
if downloader.LooksLikeURL(modelURL) {
return utils.MD5(modelURL)
}
return c.MMProj
}
func (c *BackendConfig) IsMMProjURL() bool {
return downloader.LooksLikeURL(downloader.ConvertURL(c.MMProj))
}
func (c *BackendConfig) IsModelURL() bool {
return downloader.LooksLikeURL(downloader.ConvertURL(c.Model))
}
// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) ModelFileName() string {
modelURL := downloader.ConvertURL(c.Model)
if downloader.LooksLikeURL(modelURL) {
return utils.MD5(modelURL)
}
return c.Model
}
func (c *BackendConfig) FunctionToCall() string { func (c *BackendConfig) FunctionToCall() string {
if c.functionCallNameString != "" && if c.functionCallNameString != "" &&
c.functionCallNameString != "none" && c.functionCallNameString != "auto" { c.functionCallNameString != "none" && c.functionCallNameString != "auto" {
@ -532,16 +562,13 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
} }
} }
modelURL := config.PredictionOptions.Model // If the model is an URL, expand it, and download the file
modelURL = downloader.ConvertURL(modelURL) if config.IsModelURL() {
modelFileName := config.ModelFileName()
if downloader.LooksLikeURL(modelURL) { modelURL := downloader.ConvertURL(config.Model)
// md5 of model name
md5Name := utils.MD5(modelURL)
// check if file exists // check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", 0, 0, status) err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil { if err != nil {
return err return err
} }
@ -549,9 +576,27 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
cc := cl.configs[i] cc := cl.configs[i]
c := &cc c := &cc
c.PredictionOptions.Model = md5Name c.PredictionOptions.Model = modelFileName
cl.configs[i] = *c cl.configs[i] = *c
} }
if config.IsMMProjURL() {
modelFileName := config.MMProjFileName()
modelURL := downloader.ConvertURL(config.MMProj)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
}
cc := cl.configs[i]
c := &cc
c.MMProj = modelFileName
cl.configs[i] = *c
}
if cl.configs[i].Name != "" { if cl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name)) glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
} }
@ -586,7 +631,8 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...C
} }
for _, file := range files { for _, file := range files {
// Skip templates, YAML and .keep files // Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") ||
strings.HasPrefix(file.Name(), ".") {
continue continue
} }
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)

View File

@ -13,7 +13,7 @@ const (
NoImage = "https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg" NoImage = "https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg"
) )
func DoneProgress(uid string) string { func DoneProgress(uid, text string) string {
return elem.Div( return elem.Div(
attrs.Props{}, attrs.Props{},
elem.H3( elem.H3(
@ -23,7 +23,7 @@ func DoneProgress(uid string) string {
"tabindex": "-1", "tabindex": "-1",
"autofocus": "", "autofocus": "",
}, },
elem.Text("Installation completed"), elem.Text(text),
), ),
).Render() ).Render()
} }
@ -60,7 +60,7 @@ func ProgressBar(progress string) string {
).Render() ).Render()
} }
func StartProgressBar(uid, progress string) string { func StartProgressBar(uid, progress, text string) string {
if progress == "" { if progress == "" {
progress = "0" progress = "0"
} }
@ -77,7 +77,7 @@ func StartProgressBar(uid, progress string) string {
"tabindex": "-1", "tabindex": "-1",
"autofocus": "", "autofocus": "",
}, },
elem.Text("Installing"), elem.Text(text),
// This is a simple example of how to use the HTMLX library to create a progress bar that updates every 600ms. // This is a simple example of how to use the HTMLX library to create a progress bar that updates every 600ms.
elem.Div(attrs.Props{ elem.Div(attrs.Props{
"hx-get": "/browse/job/progress/" + uid, "hx-get": "/browse/job/progress/" + uid,
@ -106,14 +106,33 @@ func cardSpan(text, icon string) elem.Node {
func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[string, string]) string { func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[string, string]) string {
//StartProgressBar(uid, "0") //StartProgressBar(uid, "0")
modelsElements := []elem.Node{} modelsElements := []elem.Node{}
span := func(s string) elem.Node { // span := func(s string) elem.Node {
return elem.Span( // return elem.Span(
// attrs.Props{
// "class": "float-right inline-block bg-green-500 text-white py-1 px-3 rounded-full text-xs",
// },
// elem.Text(s),
// )
// }
deleteButton := func(m *gallery.GalleryModel) elem.Node {
return elem.Button(
attrs.Props{ attrs.Props{
"class": "float-right inline-block bg-green-500 text-white py-1 px-3 rounded-full text-xs", "data-twe-ripple-init": "",
"data-twe-ripple-color": "light",
"class": "float-right inline-block rounded bg-red-800 px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-red-accent-300 hover:shadow-red-2 focus:bg-red-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-red-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/delete/model/" + m.Name,
}, },
elem.Text(s), elem.I(
attrs.Props{
"class": "fa-solid fa-cancel pr-2",
},
),
elem.Text("Delete"),
) )
} }
installButton := func(m *gallery.GalleryModel) elem.Node { installButton := func(m *gallery.GalleryModel) elem.Node {
return elem.Button( return elem.Button(
attrs.Props{ attrs.Props{
@ -202,10 +221,14 @@ func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[stri
elem.If( elem.If(
currentlyInstalling, currentlyInstalling,
elem.Node( // If currently installing, show progress bar elem.Node( // If currently installing, show progress bar
elem.Raw(StartProgressBar(installing.Get(galleryID), "0")), elem.Raw(StartProgressBar(installing.Get(galleryID), "0", "Installing")),
), // Otherwise, show install button (if not installed) or display "Installed" ), // Otherwise, show install button (if not installed) or display "Installed"
elem.If(m.Installed, elem.If(m.Installed,
span("Installed"), //elem.Node(elem.Div(
// attrs.Props{},
// span("Installed"), deleteButton(m),
// )),
deleteButton(m),
installButton(m), installButton(m),
), ),
), ),

View File

@ -74,6 +74,27 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe
} }
} }
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
mgs.galleryApplier.C <- gallery.GalleryOp{
Delete: true,
GalleryName: modelName,
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries) log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)

View File

@ -23,6 +23,8 @@ func RegisterLocalAIRoutes(app *fiber.App,
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint()) app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())

View File

@ -66,6 +66,12 @@ func RegisterUIRoutes(app *fiber.App,
return c.SendString(elements.ListModels(filteredModels, installingModels)) return c.SendString(elements.ListModels(filteredModels, installingModels))
}) })
/*
Install routes
*/
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service // This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/ // https://htmx.org/examples/progress-bar/
app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error { app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error {
@ -89,7 +95,33 @@ func RegisterUIRoutes(app *fiber.App,
galleryService.C <- op galleryService.C <- op
}() }()
return c.SendString(elements.StartProgressBar(uid, "0")) return c.SendString(elements.StartProgressBar(uid, "0", "Installation"))
})
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/delete/model/:id", auth, func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
id, err := uuid.NewUUID()
if err != nil {
return err
}
uid := id.String()
installingModels.Set(galleryID, uid)
op := gallery.GalleryOp{
Id: uid,
Delete: true,
GalleryName: galleryID,
}
go func() {
galleryService.C <- op
}()
return c.SendString(elements.StartProgressBar(uid, "0", "Deletion"))
}) })
// Display the job current progress status // Display the job current progress status
@ -118,12 +150,20 @@ func RegisterUIRoutes(app *fiber.App,
// this route is hit when the job is done, and we display the // this route is hit when the job is done, and we display the
// final state (for now just displays "Installation completed") // final state (for now just displays "Installation completed")
app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error { app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error {
status := galleryService.GetStatus(c.Params("uid"))
for _, k := range installingModels.Keys() { for _, k := range installingModels.Keys() {
if installingModels.Get(k) == c.Params("uid") { if installingModels.Get(k) == c.Params("uid") {
installingModels.Delete(k) installingModels.Delete(k)
} }
} }
return c.SendString(elements.DoneProgress(c.Params("uid"))) displayText := "Installation completed"
if status.Deletion {
displayText = "Deletion completed"
}
return c.SendString(elements.DoneProgress(c.Params("uid"), displayText))
}) })
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"os" "os"
"path/filepath"
"strings" "strings"
"sync" "sync"
@ -84,18 +85,47 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
} }
var err error var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryName != "" { // delete a model
if strings.Contains(op.GalleryName, "@") { if op.Delete {
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) modelConfig := &config.BackendConfig{}
} else { // Galleryname is the name of the model in this case
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) dat, err := os.ReadFile(filepath.Join(g.modelPath, op.GalleryName+".yaml"))
if err != nil {
updateError(err)
continue
} }
} else if op.ConfigURL != "" { err = yaml.Unmarshal(dat, modelConfig)
startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) if err != nil {
err = cl.Preload(g.modelPath) updateError(err)
continue
}
files := []string{}
// Remove the model from the config
if modelConfig.Model != "" {
files = append(files, modelConfig.ModelFileName())
}
if modelConfig.MMProj != "" {
files = append(files, modelConfig.MMProjFileName())
}
err = gallery.DeleteModelFromSystem(g.modelPath, op.GalleryName, files)
} else { } else {
err = prepareModel(g.modelPath, op.Req, cl, progressCallback) // if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryName != "" {
if strings.Contains(op.GalleryName, "@") {
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
}
} else if op.ConfigURL != "" {
startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL)
err = cl.Preload(g.modelPath)
} else {
err = prepareModel(g.modelPath, op.Req, cl, progressCallback)
}
} }
if err != nil { if err != nil {
@ -116,7 +146,12 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
continue continue
} }
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100}) g.UpdateStatus(op.Id,
&gallery.GalleryOpStatus{
Deletion: op.Delete,
Processed: true,
Message: "completed",
Progress: 100})
} }
} }
}() }()

View File

@ -1,6 +1,7 @@
package gallery package gallery
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -184,3 +185,48 @@ func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error)
} }
return models, nil return models, nil
} }
func DeleteModelFromSystem(basePath string, name string, additionalFiles []string) error {
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
configFile := filepath.Join(basePath, fmt.Sprintf("%s.yaml", name))
galleryFile := filepath.Join(basePath, galleryFileName(name))
var err error
// Delete all the files associated to the model
// read the model config
galleryconfig, err := ReadConfigFile(galleryFile)
if err != nil {
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
}
// Remove additional files
if galleryconfig != nil {
for _, f := range galleryconfig.Files {
fullPath := filepath.Join(basePath, f.Filename)
log.Debug().Msgf("Removing file %s", fullPath)
if e := os.Remove(fullPath); e != nil {
err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", f.Filename, e))
}
}
}
for _, f := range additionalFiles {
fullPath := filepath.Join(filepath.Join(basePath, f))
log.Debug().Msgf("Removing additional file %s", fullPath)
if e := os.Remove(fullPath); e != nil {
err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", f, e))
}
}
log.Debug().Msgf("Removing model config file %s", configFile)
// Delete the model config file
if e := os.Remove(configFile); e != nil {
err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", configFile, e))
}
return err
}

View File

@ -1,6 +1,7 @@
package gallery_test package gallery_test
import ( import (
"os"
"testing" "testing"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
@ -11,3 +12,9 @@ func TestGallery(t *testing.T) {
RegisterFailHandler(Fail) RegisterFailHandler(Fail)
RunSpecs(t, "Gallery test suite") RunSpecs(t, "Gallery test suite")
} }
var _ = BeforeSuite(func() {
if os.Getenv("FIXTURES") == "" {
Fail("FIXTURES env var not set")
}
})

View File

@ -178,5 +178,20 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
log.Debug().Msgf("Written config file %s", configFilePath) log.Debug().Msgf("Written config file %s", configFilePath)
} }
return nil // Save the model gallery file for further reference
modelFile := filepath.Join(basePath, galleryFileName(name))
data, err := yaml.Marshal(config)
if err != nil {
return err
}
log.Debug().Msgf("Written gallery file %s", modelFile)
return os.WriteFile(modelFile, data, 0600)
//return nil
}
func galleryFileName(name string) string {
return "._gallery_" + name + ".yaml"
} }

View File

@ -1,6 +1,7 @@
package gallery_test package gallery_test
import ( import (
"errors"
"os" "os"
"path/filepath" "path/filepath"
@ -11,6 +12,7 @@ import (
) )
var _ = Describe("Model test", func() { var _ = Describe("Model test", func() {
Context("Downloading", func() { Context("Downloading", func() {
It("applies model correctly", func() { It("applies model correctly", func() {
tempdir, err := os.MkdirTemp("", "test") tempdir, err := os.MkdirTemp("", "test")
@ -80,6 +82,19 @@ var _ = Describe("Model test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1)) Expect(len(models)).To(Equal(1))
Expect(models[0].Installed).To(BeTrue()) Expect(models[0].Installed).To(BeTrue())
// delete
err = DeleteModelFromSystem(tempdir, "bert", []string{})
Expect(err).ToNot(HaveOccurred())
models, err = AvailableGalleryModels(galleries, tempdir)
Expect(err).ToNot(HaveOccurred())
Expect(len(models)).To(Equal(1))
Expect(models[0].Installed).To(BeFalse())
_, err = os.Stat(filepath.Join(tempdir, "bert.yaml"))
Expect(err).To(HaveOccurred())
Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
}) })
It("renames model correctly", func() { It("renames model correctly", func() {

View File

@ -4,12 +4,14 @@ type GalleryOp struct {
Id string Id string
GalleryName string GalleryName string
ConfigURL string ConfigURL string
Delete bool
Req GalleryModel Req GalleryModel
Galleries []Gallery Galleries []Gallery
} }
type GalleryOpStatus struct { type GalleryOpStatus struct {
Deletion bool `json:"deletion"` // Deletion is true if the operation is a deletion
FileName string `json:"file_name"` FileName string `json:"file_name"`
Error error `json:"error"` Error error `json:"error"`
Processed bool `json:"processed"` Processed bool `json:"processed"`

View File

@ -96,7 +96,13 @@ func (ml *ModelLoader) ListModels() ([]string, error) {
models := []string{} models := []string{}
for _, file := range files { for _, file := range files {
// Skip templates, YAML, .keep, .json, and .DS_Store files - TODO: as this list grows, is there a more efficient method? // Skip templates, YAML, .keep, .json, and .DS_Store files - TODO: as this list grows, is there a more efficient method?
if strings.HasSuffix(file.Name(), ".tmpl") || strings.HasSuffix(file.Name(), ".keep") || strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") || strings.HasSuffix(file.Name(), ".json") || strings.HasSuffix(file.Name(), ".DS_Store") { if strings.HasSuffix(file.Name(), ".tmpl") ||
strings.HasSuffix(file.Name(), ".keep") ||
strings.HasSuffix(file.Name(), ".yaml") ||
strings.HasSuffix(file.Name(), ".yml") ||
strings.HasSuffix(file.Name(), ".json") ||
strings.HasSuffix(file.Name(), ".DS_Store") ||
strings.HasPrefix(file.Name(), ".") {
continue continue
} }