feat(model-list): be consistent, skip known files from listing (#2760)

fix(model-list): be consistent, skip known files from listing

This changeset does two things:

- Removes the dependency of listing models from the OpenAI schema.
- Tries to reduce confusion between ListModels() in model loader and in
  the service - now there is only one ListModels which is in services
and does not depend anymore on the OpenAI schema
- The OpenAI-schema functions were moved nearby the OpenAI specific
  endpoints that needs the schema
- Drops the ListModel Service structure as there was no real need for
  it.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-07-10 15:28:39 +02:00 committed by GitHub
parent 28c6daf916
commit 59ef426fbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 97 additions and 70 deletions

View File

@ -28,7 +28,6 @@ type Application struct {
// LocalAI System Services
BackendMonitorService *services.BackendMonitorService
GalleryService *services.GalleryService
ListModelsService *services.ListModelsService
LocalAIMetricsService *services.LocalAIMetricsService
// OpenAIService *services.OpenAIService
}

View File

@ -5,6 +5,8 @@ import (
"strings"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@ -13,7 +15,7 @@ import (
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
@ -24,7 +26,7 @@ func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput stri
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := loader.ListModels()
models, _ := services.ListModels(cl, loader, "", true)
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)

View File

@ -28,7 +28,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)

View File

@ -28,7 +28,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)

View File

@ -29,7 +29,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)

View File

@ -5,6 +5,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
)
@ -12,7 +13,7 @@ import (
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.BackendConfigLoader, ml *model.ModelLoader, modelStatus func() (map[string]string, map[string]string)) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, _ := ml.ListModels()
models, _ := services.ListModels(cl, ml, "", true)
backendConfigs := cl.GetAllBackendConfigs()
galleryConfigs := map[string]*gallery.Config{}

View File

@ -11,6 +11,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
@ -79,7 +80,7 @@ func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
}
if !modelExists(ml, request.Model) {
if !modelExists(cl, ml, request.Model) {
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model)
return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found")
}
@ -213,9 +214,9 @@ func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant {
return filteredAssistants
}
func modelExists(ml *model.ModelLoader, modelName string) (found bool) {
func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string) (found bool) {
found = false
models, err := ml.ListModels()
models, err := services.ListModels(cl, ml, "", true)
if err != nil {
return
}

View File

@ -159,7 +159,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, startupOptions, true)
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -57,7 +57,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, appConfig, true)
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -18,7 +18,7 @@ import (
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, appConfig, true)
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -23,7 +23,7 @@ import (
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, ml, appConfig, true)
model, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -66,7 +66,7 @@ func downloadFile(url string) (string, error) {
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, ml, appConfig, false)
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -2,15 +2,17 @@ package openai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
model "github.com/mudler/LocalAI/pkg/model"
)
// ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models
// @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get]
func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error {
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
// If blank, no filter is applied.
filter := c.Query("filter")
@ -18,7 +20,7 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er
// By default, exclude any loose files that are already referenced by a configuration file.
excludeConfigured := c.QueryBool("excludeConfigured", true)
dataModels, err := lms.ListModels(filter, excludeConfigured)
dataModels, err := modelList(bcl, ml, filter, excludeConfigured)
if err != nil {
return err
}
@ -28,3 +30,20 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er
})
}
}
func modelList(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
models, err := services.ListModels(bcl, ml, filter, excludeConfigured)
if err != nil {
return nil, err
}
dataModels := []schema.OpenAIModel{}
// Then iterate through the loose files:
for _, m := range models {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
return dataModels, nil
}

View File

@ -15,7 +15,7 @@ import (
"github.com/rs/zerolog/log"
)
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// Get input data from the request body
@ -31,7 +31,7 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
return modelFile, input, err
}

View File

@ -25,7 +25,7 @@ import (
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, ml, appConfig, false)
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -5,7 +5,6 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
)
@ -81,8 +80,7 @@ func RegisterOpenAIRoutes(app *fiber.App,
app.Static("/generated-audio", appConfig.AudioDir)
}
// models
tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance.
app.Get("/v1/models", auth, openai.ListModelsEndpoint(tmpLMS))
app.Get("/models", auth, openai.ListModelsEndpoint(tmpLMS))
// List models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
}

View File

@ -27,7 +27,6 @@ func RegisterUIRoutes(app *fiber.App,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService,
auth func(*fiber.Ctx) error) {
tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance.
// keeps the state of models that are being installed from the UI
var processingModels = xsync.NewSyncedMap[string, string]()
@ -270,7 +269,7 @@ func RegisterUIRoutes(app *fiber.App,
// Show the Chat page
app.Get("/chat/:model", auth, func(c *fiber.Ctx) error {
backendConfigs, _ := tmpLMS.ListModels("", true)
backendConfigs, _ := services.ListModels(cl, ml, "", true)
summary := fiber.Map{
"Title": "LocalAI - Chat with " + c.Params("model"),
@ -285,7 +284,7 @@ func RegisterUIRoutes(app *fiber.App,
})
app.Get("/talk/", auth, func(c *fiber.Ctx) error {
backendConfigs, _ := tmpLMS.ListModels("", true)
backendConfigs, _ := services.ListModels(cl, ml, "", true)
if len(backendConfigs) == 0 {
// If no model is available redirect to the index which suggests how to install models
@ -295,7 +294,7 @@ func RegisterUIRoutes(app *fiber.App,
summary := fiber.Map{
"Title": "LocalAI - Talk",
"ModelsConfig": backendConfigs,
"Model": backendConfigs[0].ID,
"Model": backendConfigs[0],
"IsP2PEnabled": p2p.IsP2PEnabled(),
"Version": internal.PrintableVersion(),
}
@ -306,7 +305,7 @@ func RegisterUIRoutes(app *fiber.App,
app.Get("/chat/", auth, func(c *fiber.Ctx) error {
backendConfigs, _ := tmpLMS.ListModels("", true)
backendConfigs, _ := services.ListModels(cl, ml, "", true)
if len(backendConfigs) == 0 {
// If no model is available redirect to the index which suggests how to install models
@ -314,9 +313,9 @@ func RegisterUIRoutes(app *fiber.App,
}
summary := fiber.Map{
"Title": "LocalAI - Chat with " + backendConfigs[0].ID,
"Title": "LocalAI - Chat with " + backendConfigs[0],
"ModelsConfig": backendConfigs,
"Model": backendConfigs[0].ID,
"Model": backendConfigs[0],
"Version": internal.PrintableVersion(),
"IsP2PEnabled": p2p.IsP2PEnabled(),
}

View File

@ -100,10 +100,10 @@ SOFTWARE.
<option value="" disabled class="text-gray-400" >Select a model</option>
{{ $model:=.Model}}
{{ range .ModelsConfig }}
{{ if eq .ID $model }}
<option value="/chat/{{.ID}}" selected class="bg-gray-700 text-white">{{.ID}}</option>
{{ if eq . $model }}
<option value="/chat/{{.}}" selected class="bg-gray-700 text-white">{{.}}</option>
{{ else }}
<option value="/chat/{{.ID}}" class="bg-gray-700 text-white">{{.ID}}</option>
<option value="/chat/{{.}}" class="bg-gray-700 text-white">{{.}}</option>
{{ end }}
{{ end }}
</select>

View File

@ -62,7 +62,7 @@
<option value="" disabled class="text-gray-400" >Select a model</option>
{{ range .ModelsConfig }}
<option value="{{.ID}}" class="bg-gray-700 text-white">{{.ID}}</option>
<option value="{{.}}" class="bg-gray-700 text-white">{{.}}</option>
{{ end }}
</select>
</div>
@ -76,7 +76,7 @@
<option value="" disabled class="text-gray-400" >Select a model</option>
{{ range .ModelsConfig }}
<option value="{{.ID}}" class="bg-gray-700 text-white">{{.ID}}</option>
<option value="{{.}}" class="bg-gray-700 text-white">{{.}}</option>
{{ end }}
</select>
</div>
@ -89,7 +89,7 @@
>
<option value="" disabled class="text-gray-400" >Select a model</option>
{{ range .ModelsConfig }}
<option value="{{.ID}}" class="bg-gray-700 text-white">{{.ID}}</option>
<option value="{{.}}" class="bg-gray-700 text-white">{{.}}</option>
{{ end }}
</select>
</div>

View File

@ -4,34 +4,19 @@ import (
"regexp"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
)
type ListModelsService struct {
bcl *config.BackendConfigLoader
ml *model.ModelLoader
appConfig *config.ApplicationConfig
}
func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter string, excludeConfigured bool) ([]string, error) {
func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService {
return &ListModelsService{
bcl: bcl,
ml: ml,
appConfig: appConfig,
}
}
func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
models, err := lms.ml.ListModels()
models, err := ml.ListFilesInModelPath()
if err != nil {
return nil, err
}
var mm map[string]interface{} = map[string]interface{}{}
dataModels := []schema.OpenAIModel{}
dataModels := []string{}
var filterFn func(name string) bool
@ -50,13 +35,13 @@ func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool)
}
// Start with the known configurations
for _, c := range lms.bcl.GetAllBackendConfigs() {
for _, c := range bcl.GetAllBackendConfigs() {
if excludeConfigured {
mm[c.Model] = nil
}
if filterFn(c.Name) {
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
dataModels = append(dataModels, c.Name)
}
}
@ -64,7 +49,7 @@ func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool)
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
dataModels = append(dataModels, m)
}
}

View File

@ -195,7 +195,6 @@ func createApplication(appConfig *config.ApplicationConfig) *core.Application {
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.GalleryService = services.NewGalleryService(app.ApplicationConfig)
app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)
app.LocalAIMetricsService, err = services.NewLocalAIMetricsService()

View File

@ -30,7 +30,6 @@ type PromptTemplateData struct {
MessageIndex int
}
// TODO: Ask mudler about FunctionCall stuff being useful at the message level?
type ChatMessageTemplateData struct {
SystemPrompt string
Role string
@ -87,22 +86,47 @@ func (ml *ModelLoader) ExistsInModelPath(s string) bool {
return utils.ExistsInPath(ml.ModelPath, s)
}
func (ml *ModelLoader) ListModels() ([]string, error) {
var knownFilesToSkip []string = []string{
"MODEL_CARD",
"README",
"README.md",
}
var knownModelsNameSuffixToSkip []string = []string{
".tmpl",
".keep",
".yaml",
".yml",
".json",
".DS_Store",
".",
}
func (ml *ModelLoader) ListFilesInModelPath() ([]string, error) {
files, err := os.ReadDir(ml.ModelPath)
if err != nil {
return []string{}, err
}
models := []string{}
FILE:
for _, file := range files {
// Skip templates, YAML, .keep, .json, and .DS_Store files - TODO: as this list grows, is there a more efficient method?
if strings.HasSuffix(file.Name(), ".tmpl") ||
strings.HasSuffix(file.Name(), ".keep") ||
strings.HasSuffix(file.Name(), ".yaml") ||
strings.HasSuffix(file.Name(), ".yml") ||
strings.HasSuffix(file.Name(), ".json") ||
strings.HasSuffix(file.Name(), ".DS_Store") ||
strings.HasPrefix(file.Name(), ".") {
for _, skip := range knownFilesToSkip {
if strings.EqualFold(file.Name(), skip) {
continue FILE
}
}
// Skip templates, YAML, .keep, .json, and .DS_Store files
for _, skip := range knownModelsNameSuffixToSkip {
if strings.HasSuffix(file.Name(), skip) {
continue FILE
}
}
// Skip directories
if file.IsDir() {
continue
}