mirror of
https://github.com/mudler/LocalAI.git
synced 2025-02-06 19:20:16 +00:00
Make it compatible with openAI api, support multiple models
Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
parent
b33d015b8c
commit
12eee097b7
112
api.go
112
api.go
@ -2,24 +2,128 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
llama "github.com/go-skynet/go-llama.cpp"
|
llama "github.com/go-skynet/go-llama.cpp"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OpenAIResponse struct {
|
||||||
|
Created int `json:"created"`
|
||||||
|
Object string `json:"chat.completion"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []Choice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Choice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
Message Message `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
//go:embed index.html
|
//go:embed index.html
|
||||||
var indexHTML embed.FS
|
var indexHTML embed.FS
|
||||||
|
|
||||||
func api(l *llama.LLama, listenAddr string, threads int) error {
|
func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, threads int) error {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
|
||||||
|
// Default middleware config
|
||||||
|
app.Use(recover.New())
|
||||||
|
app.Use(cors.New())
|
||||||
|
|
||||||
app.Use("/", filesystem.New(filesystem.Config{
|
app.Use("/", filesystem.New(filesystem.Config{
|
||||||
Root: http.FS(indexHTML),
|
Root: http.FS(indexHTML),
|
||||||
NotFoundFile: "index.html",
|
NotFoundFile: "index.html",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
var mutex = &sync.Mutex{}
|
||||||
|
|
||||||
|
// openAI compatible API endpoint
|
||||||
|
app.Post("/v1/chat/completions", func(c *fiber.Ctx) error {
|
||||||
|
var err error
|
||||||
|
var model *llama.LLama
|
||||||
|
|
||||||
|
// Get input data from the request body
|
||||||
|
input := new(struct {
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
})
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Model == "" {
|
||||||
|
if defaultModel == nil {
|
||||||
|
return fmt.Errorf("no default model loaded, and no model specified")
|
||||||
|
}
|
||||||
|
model = defaultModel
|
||||||
|
} else {
|
||||||
|
model, err = loader.LoadModel(input.Model)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the parameters for the language model prediction
|
||||||
|
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
mess := []string{}
|
||||||
|
for _, i := range input.Messages {
|
||||||
|
mess = append(mess, i.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Received", input, input.Model)
|
||||||
|
// Generate the prediction using the language model
|
||||||
|
prediction, err := model.Predict(
|
||||||
|
strings.Join(mess, "\n"),
|
||||||
|
llama.SetTemperature(temperature),
|
||||||
|
llama.SetTopP(topP),
|
||||||
|
llama.SetTopK(topK),
|
||||||
|
llama.SetTokens(tokens),
|
||||||
|
llama.SetThreads(threads),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(OpenAIResponse{
|
||||||
|
Model: input.Model,
|
||||||
|
Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
/*
|
/*
|
||||||
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
|
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
|
||||||
"text": "What is an alpaca?",
|
"text": "What is an alpaca?",
|
||||||
@ -29,8 +133,6 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
|
|||||||
"tokens": 100
|
"tokens": 100
|
||||||
}'
|
}'
|
||||||
*/
|
*/
|
||||||
var mutex = &sync.Mutex{}
|
|
||||||
|
|
||||||
// Endpoint to generate the prediction
|
// Endpoint to generate the prediction
|
||||||
app.Post("/predict", func(c *fiber.Ctx) error {
|
app.Post("/predict", func(c *fiber.Ctx) error {
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
@ -65,7 +167,7 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate the prediction using the language model
|
// Generate the prediction using the language model
|
||||||
prediction, err := l.Predict(
|
prediction, err := defaultModel.Predict(
|
||||||
input.Text,
|
input.Text,
|
||||||
llama.SetTemperature(temperature),
|
llama.SetTemperature(temperature),
|
||||||
llama.SetTopP(topP),
|
llama.SetTopP(topP),
|
||||||
@ -86,6 +188,6 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Start the server
|
// Start the server
|
||||||
app.Listen(":8080")
|
app.Listen(listenAddr)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
22
main.go
22
main.go
@ -146,8 +146,12 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
|||||||
Value: runtime.NumCPU(),
|
Value: runtime.NumCPU(),
|
||||||
},
|
},
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
Name: "model",
|
Name: "models-path",
|
||||||
EnvVars: []string{"MODEL_PATH"},
|
EnvVars: []string{"MODELS_PATH"},
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "default-model",
|
||||||
|
EnvVars: []string{"default-model"},
|
||||||
},
|
},
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
Name: "address",
|
Name: "address",
|
||||||
@ -161,13 +165,19 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Action: func(ctx *cli.Context) error {
|
Action: func(ctx *cli.Context) error {
|
||||||
l, err := llamaFromOptions(ctx)
|
|
||||||
|
var defaultModel *llama.LLama
|
||||||
|
defModel := ctx.String("default-model")
|
||||||
|
if defModel != "" {
|
||||||
|
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
|
||||||
|
var err error
|
||||||
|
defaultModel, err = llama.New(ctx.String("default-model"), opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Loading the model failed:", err.Error())
|
return err
|
||||||
os.Exit(1)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return api(l, ctx.String("address"), ctx.Int("threads"))
|
return api(defaultModel, NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
52
model_loader.go
Normal file
52
model_loader.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
llama "github.com/go-skynet/go-llama.cpp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelLoader struct {
|
||||||
|
modelPath string
|
||||||
|
mu sync.Mutex
|
||||||
|
models map[string]*llama.LLama
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewModelLoader(modelPath string) *ModelLoader {
|
||||||
|
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LLama, error) {
|
||||||
|
ml.mu.Lock()
|
||||||
|
defer ml.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if we already have a loaded model
|
||||||
|
modelFile := filepath.Join(ml.modelPath, s)
|
||||||
|
|
||||||
|
if m, ok := ml.models[modelFile]; ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the model path exists
|
||||||
|
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||||
|
// try to find a s.bin
|
||||||
|
modelBin := fmt.Sprintf("%s.bin", modelFile)
|
||||||
|
if _, err := os.Stat(modelBin); os.IsNotExist(err) {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
modelFile = modelBin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the model and keep it in memory for later use
|
||||||
|
model, err := llama.New(modelFile, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ml.models[modelFile] = model
|
||||||
|
return model, err
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user