feat: auto load into memory on startup (#3627)

Signed-off-by: Sertac Ozercan <sozercan@gmail.com>
This commit is contained in:
Sertaç Özercan 2024-09-22 01:03:30 -07:00 committed by GitHub
parent 1f43678d53
commit ee21b00a8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 259 additions and 213 deletions

View File

@ -12,7 +12,7 @@ import (
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model
grpcOpts := gRPCModelOpts(backendConfig)
grpcOpts := GRPCModelOpts(backendConfig)
var inferenceModel interface{}
var err error

View File

@ -12,7 +12,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
if *threads == 0 && appConfig.Threads != 0 {
threads = &appConfig.Threads
}
gRPCOpts := gRPCModelOpts(backendConfig)
gRPCOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),

View File

@ -37,7 +37,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
}
grpcOpts := gRPCModelOpts(c)
grpcOpts := GRPCModelOpts(c)
var inferenceModel grpc.Backend
var err error

View File

@ -44,7 +44,7 @@ func getSeed(c config.BackendConfig) int32 {
return seed
}
func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch

View File

@ -15,7 +15,7 @@ func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *mod
return nil, fmt.Errorf("backend is required")
}
grpcOpts := gRPCModelOpts(backendConfig)
grpcOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),

View File

@ -29,7 +29,7 @@ func SoundGeneration(
return "", nil, fmt.Errorf("backend is a required parameter")
}
grpcOpts := gRPCModelOpts(backendConfig)
grpcOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(backend),
model.WithModel(modelFile),

View File

@ -28,7 +28,7 @@ func ModelTTS(
bb = model.PiperBackend
}
grpcOpts := gRPCModelOpts(backendConfig)
grpcOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),

View File

@ -69,6 +69,7 @@ type RunCMD struct {
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
}
func (r *RunCMD) Run(ctx *cliContext.Context) error {
@ -104,6 +105,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithDisableApiKeyRequirementForHttpGet(r.DisableApiKeyRequirementForHttpGet),
config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
config.WithLoadToMemory(r.LoadToMemory),
}
token := ""

View File

@ -41,6 +41,7 @@ type ApplicationConfig struct {
DisableApiKeyRequirementForHttpGet bool
HttpGetExemptedEndpoints []*regexp.Regexp
DisableGalleryEndpoint bool
LoadToMemory []string
ModelLibraryURL string
@ -331,6 +332,12 @@ func WithOpaqueErrors(opaque bool) AppOption {
}
}
func WithLoadToMemory(models []string) AppOption {
return func(o *ApplicationConfig) {
o.LoadToMemory = models
}
}
func WithSubtleKeyComparison(subtle bool) AppOption {
return func(o *ApplicationConfig) {
o.UseSubtleKeyComparison = subtle

View File

@ -5,6 +5,7 @@ import (
"os"
"github.com/mudler/LocalAI/core"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
@ -144,6 +145,42 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
}()
}
if options.LoadToMemory != nil {
for _, m := range options.LoadToMemory {
cfg, err := cl.LoadBackendConfigFileByName(m, options.ModelPath,
config.LoadOptionDebug(options.Debug),
config.LoadOptionThreads(options.Threads),
config.LoadOptionContextSize(options.ContextSize),
config.LoadOptionF16(options.F16),
config.ModelPath(options.ModelPath),
)
if err != nil {
return nil, nil, nil, err
}
log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model)
grpcOpts := backend.GRPCModelOpts(*cfg)
o := []model.Option{
model.WithModel(cfg.Model),
model.WithAssetDir(options.AssetsDestination),
model.WithThreads(uint32(options.Threads)),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
}
var backendErr error
if cfg.Backend != "" {
o = append(o, model.WithBackendString(cfg.Backend))
_, backendErr = ml.BackendLoader(o...)
} else {
_, backendErr = ml.GreedyLoader(o...)
}
if backendErr != nil {
return nil, nil, nil, err
}
}
}
// Watch the configuration directory
startWatcher(options)