mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(model-list): be consistent, skip known files from listing (#2760)
fix(model-list): be consistent, skip known files from listing This changeset does two things: - Removes the dependency of listing models from the OpenAI schema. - Tries to reduce confusion between ListModels() in model loader and in the service - now there is only one ListModels which is in services and does not depend anymore on the OpenAI schema - The OpenAI-schema functions were moved nearby the OpenAI specific endpoints that needs the schema - Drops the ListModel Service structure as there was no real need for it. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
28c6daf916
commit
59ef426fbf
@ -28,7 +28,6 @@ type Application struct {
|
||||
// LocalAI System Services
|
||||
BackendMonitorService *services.BackendMonitorService
|
||||
GalleryService *services.GalleryService
|
||||
ListModelsService *services.ListModelsService
|
||||
LocalAIMetricsService *services.LocalAIMetricsService
|
||||
// OpenAIService *services.OpenAIService
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@ -13,7 +15,7 @@ import (
|
||||
// If no model is specified, it will take the first available
|
||||
// Takes a model string as input which should be the one received from the user request.
|
||||
// It returns the model name resolved from the context and an error if any.
|
||||
func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
|
||||
func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
|
||||
if ctx.Params("model") != "" {
|
||||
modelInput = ctx.Params("model")
|
||||
}
|
||||
@ -24,7 +26,7 @@ func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput stri
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelInput == "" && !bearerExists && firstModel {
|
||||
models, _ := loader.ListModels()
|
||||
models, _ := services.ListModels(cl, loader, "", true)
|
||||
if len(models) > 0 {
|
||||
modelInput = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelInput)
|
||||
|
@ -28,7 +28,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
|
||||
if err != nil {
|
||||
modelFile = input.ModelID
|
||||
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||
|
@ -28,7 +28,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
|
@ -29,7 +29,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
@ -12,7 +13,7 @@ import (
|
||||
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
cl *config.BackendConfigLoader, ml *model.ModelLoader, modelStatus func() (map[string]string, map[string]string)) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
models, _ := ml.ListModels()
|
||||
models, _ := services.ListModels(cl, ml, "", true)
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
|
||||
galleryConfigs := map[string]*gallery.Config{}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -79,7 +80,7 @@ func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
if !modelExists(ml, request.Model) {
|
||||
if !modelExists(cl, ml, request.Model) {
|
||||
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model)
|
||||
return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found")
|
||||
}
|
||||
@ -213,9 +214,9 @@ func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant {
|
||||
return filteredAssistants
|
||||
}
|
||||
|
||||
func modelExists(ml *model.ModelLoader, modelName string) (found bool) {
|
||||
func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string) (found bool) {
|
||||
found = false
|
||||
models, err := ml.ListModels()
|
||||
models, err := services.ListModels(cl, ml, "", true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readRequest(c, ml, startupOptions, true)
|
||||
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readRequest(c, ml, appConfig, true)
|
||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ import (
|
||||
|
||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readRequest(c, ml, appConfig, true)
|
||||
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ import (
|
||||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readRequest(c, ml, appConfig, true)
|
||||
model, input, err := readRequest(c, cl, ml, appConfig, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func downloadFile(url string) (string, error) {
|
||||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readRequest(c, ml, appConfig, false)
|
||||
m, input, err := readRequest(c, cl, ml, appConfig, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
@ -2,15 +2,17 @@ package openai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.Query("filter")
|
||||
@ -18,7 +20,7 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er
|
||||
// By default, exclude any loose files that are already referenced by a configuration file.
|
||||
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
||||
|
||||
dataModels, err := lms.ListModels(filter, excludeConfigured)
|
||||
dataModels, err := modelList(bcl, ml, filter, excludeConfigured)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -28,3 +30,20 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func modelList(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
|
||||
|
||||
models, err := services.ListModels(bcl, ml, filter, excludeConfigured)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataModels := []schema.OpenAIModel{}
|
||||
|
||||
// Then iterate through the loose files:
|
||||
for _, m := range models {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
|
||||
return dataModels, nil
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
||||
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
||||
input := new(schema.OpenAIRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
@ -31,7 +31,7 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
|
||||
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
|
||||
|
||||
return modelFile, input, err
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ import (
|
||||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readRequest(c, ml, appConfig, false)
|
||||
m, input, err := readRequest(c, cl, ml, appConfig, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
@ -81,8 +80,7 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
||||
app.Static("/generated-audio", appConfig.AudioDir)
|
||||
}
|
||||
|
||||
// models
|
||||
tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance.
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(tmpLMS))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(tmpLMS))
|
||||
// List models
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||
}
|
||||
|
@ -27,7 +27,6 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *services.GalleryService,
|
||||
auth func(*fiber.Ctx) error) {
|
||||
tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance.
|
||||
|
||||
// keeps the state of models that are being installed from the UI
|
||||
var processingModels = xsync.NewSyncedMap[string, string]()
|
||||
@ -270,7 +269,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
|
||||
// Show the Chat page
|
||||
app.Get("/chat/:model", auth, func(c *fiber.Ctx) error {
|
||||
backendConfigs, _ := tmpLMS.ListModels("", true)
|
||||
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Chat with " + c.Params("model"),
|
||||
@ -285,7 +284,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/talk/", auth, func(c *fiber.Ctx) error {
|
||||
backendConfigs, _ := tmpLMS.ListModels("", true)
|
||||
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
||||
|
||||
if len(backendConfigs) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
@ -295,7 +294,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Talk",
|
||||
"ModelsConfig": backendConfigs,
|
||||
"Model": backendConfigs[0].ID,
|
||||
"Model": backendConfigs[0],
|
||||
"IsP2PEnabled": p2p.IsP2PEnabled(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
@ -306,7 +305,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
|
||||
app.Get("/chat/", auth, func(c *fiber.Ctx) error {
|
||||
|
||||
backendConfigs, _ := tmpLMS.ListModels("", true)
|
||||
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
||||
|
||||
if len(backendConfigs) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
@ -314,9 +313,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Chat with " + backendConfigs[0].ID,
|
||||
"Title": "LocalAI - Chat with " + backendConfigs[0],
|
||||
"ModelsConfig": backendConfigs,
|
||||
"Model": backendConfigs[0].ID,
|
||||
"Model": backendConfigs[0],
|
||||
"Version": internal.PrintableVersion(),
|
||||
"IsP2PEnabled": p2p.IsP2PEnabled(),
|
||||
}
|
||||
|
@ -100,10 +100,10 @@ SOFTWARE.
|
||||
<option value="" disabled class="text-gray-400" >Select a model</option>
|
||||
{{ $model:=.Model}}
|
||||
{{ range .ModelsConfig }}
|
||||
{{ if eq .ID $model }}
|
||||
<option value="/chat/{{.ID}}" selected class="bg-gray-700 text-white">{{.ID}}</option>
|
||||
{{ if eq . $model }}
|
||||
<option value="/chat/{{.}}" selected class="bg-gray-700 text-white">{{.}}</option>
|
||||
{{ else }}
|
||||
<option value="/chat/{{.ID}}" class="bg-gray-700 text-white">{{.ID}}</option>
|
||||
<option value="/chat/{{.}}" class="bg-gray-700 text-white">{{.}}</option>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
</select>
|
||||
|
@ -62,7 +62,7 @@
|
||||
<option value="" disabled class="text-gray-400" >Select a model</option>
|
||||
|
||||
{{ range .ModelsConfig }}
|
||||
<option value="{{.ID}}" class="bg-gray-700 text-white">{{.ID}}</option>
|
||||
<option value="{{.}}" class="bg-gray-700 text-white">{{.}}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
</div>
|
||||
@ -76,7 +76,7 @@
|
||||
<option value="" disabled class="text-gray-400" >Select a model</option>
|
||||
|
||||
{{ range .ModelsConfig }}
|
||||
<option value="{{.ID}}" class="bg-gray-700 text-white">{{.ID}}</option>
|
||||
<option value="{{.}}" class="bg-gray-700 text-white">{{.}}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
</div>
|
||||
@ -89,7 +89,7 @@
|
||||
>
|
||||
<option value="" disabled class="text-gray-400" >Select a model</option>
|
||||
{{ range .ModelsConfig }}
|
||||
<option value="{{.ID}}" class="bg-gray-700 text-white">{{.ID}}</option>
|
||||
<option value="{{.}}" class="bg-gray-700 text-white">{{.}}</option>
|
||||
{{ end }}
|
||||
</select>
|
||||
</div>
|
||||
|
@ -4,34 +4,19 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
type ListModelsService struct {
|
||||
bcl *config.BackendConfigLoader
|
||||
ml *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter string, excludeConfigured bool) ([]string, error) {
|
||||
|
||||
func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService {
|
||||
return &ListModelsService{
|
||||
bcl: bcl,
|
||||
ml: ml,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
|
||||
|
||||
models, err := lms.ml.ListModels()
|
||||
models, err := ml.ListFilesInModelPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mm map[string]interface{} = map[string]interface{}{}
|
||||
|
||||
dataModels := []schema.OpenAIModel{}
|
||||
dataModels := []string{}
|
||||
|
||||
var filterFn func(name string) bool
|
||||
|
||||
@ -50,13 +35,13 @@ func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool)
|
||||
}
|
||||
|
||||
// Start with the known configurations
|
||||
for _, c := range lms.bcl.GetAllBackendConfigs() {
|
||||
for _, c := range bcl.GetAllBackendConfigs() {
|
||||
if excludeConfigured {
|
||||
mm[c.Model] = nil
|
||||
}
|
||||
|
||||
if filterFn(c.Name) {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
|
||||
dataModels = append(dataModels, c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,7 +49,7 @@ func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool)
|
||||
for _, m := range models {
|
||||
// And only adds them if they shouldn't be skipped.
|
||||
if _, exists := mm[m]; !exists && filterFn(m) {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
dataModels = append(dataModels, m)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -195,7 +195,6 @@ func createApplication(appConfig *config.ApplicationConfig) *core.Application {
|
||||
|
||||
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
app.GalleryService = services.NewGalleryService(app.ApplicationConfig)
|
||||
app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)
|
||||
|
||||
app.LocalAIMetricsService, err = services.NewLocalAIMetricsService()
|
||||
|
@ -30,7 +30,6 @@ type PromptTemplateData struct {
|
||||
MessageIndex int
|
||||
}
|
||||
|
||||
// TODO: Ask mudler about FunctionCall stuff being useful at the message level?
|
||||
type ChatMessageTemplateData struct {
|
||||
SystemPrompt string
|
||||
Role string
|
||||
@ -87,22 +86,47 @@ func (ml *ModelLoader) ExistsInModelPath(s string) bool {
|
||||
return utils.ExistsInPath(ml.ModelPath, s)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) ListModels() ([]string, error) {
|
||||
var knownFilesToSkip []string = []string{
|
||||
"MODEL_CARD",
|
||||
"README",
|
||||
"README.md",
|
||||
}
|
||||
|
||||
var knownModelsNameSuffixToSkip []string = []string{
|
||||
".tmpl",
|
||||
".keep",
|
||||
".yaml",
|
||||
".yml",
|
||||
".json",
|
||||
".DS_Store",
|
||||
".",
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) ListFilesInModelPath() ([]string, error) {
|
||||
files, err := os.ReadDir(ml.ModelPath)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
models := []string{}
|
||||
FILE:
|
||||
for _, file := range files {
|
||||
// Skip templates, YAML, .keep, .json, and .DS_Store files - TODO: as this list grows, is there a more efficient method?
|
||||
if strings.HasSuffix(file.Name(), ".tmpl") ||
|
||||
strings.HasSuffix(file.Name(), ".keep") ||
|
||||
strings.HasSuffix(file.Name(), ".yaml") ||
|
||||
strings.HasSuffix(file.Name(), ".yml") ||
|
||||
strings.HasSuffix(file.Name(), ".json") ||
|
||||
strings.HasSuffix(file.Name(), ".DS_Store") ||
|
||||
strings.HasPrefix(file.Name(), ".") {
|
||||
|
||||
for _, skip := range knownFilesToSkip {
|
||||
if strings.EqualFold(file.Name(), skip) {
|
||||
continue FILE
|
||||
}
|
||||
}
|
||||
|
||||
// Skip templates, YAML, .keep, .json, and .DS_Store files
|
||||
for _, skip := range knownModelsNameSuffixToSkip {
|
||||
if strings.HasSuffix(file.Name(), skip) {
|
||||
continue FILE
|
||||
}
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user