mirror of
https://github.com/mudler/LocalAI.git
synced 2025-06-06 01:01:35 +00:00
feat: drop default model and llama-specific API (#26)
Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
parent
1370b4482f
commit
63601fabd1
29
README.md
29
README.md
@ -27,6 +27,7 @@ docker compose up -d --build
|
|||||||
|
|
||||||
# Now API is accessible at localhost:8080
|
# Now API is accessible at localhost:8080
|
||||||
curl http://localhost:8080/v1/models
|
curl http://localhost:8080/v1/models
|
||||||
|
|
||||||
# {"object":"list","data":[{"id":"your-model.bin","object":"model"}]}
|
# {"object":"list","data":[{"id":"your-model.bin","object":"model"}]}
|
||||||
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
|
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
|
||||||
"model": "your-model.bin",
|
"model": "your-model.bin",
|
||||||
@ -88,7 +89,7 @@ llama-cli --model <model_path> --instruction <instruction> [--input <input>] [--
|
|||||||
| template | TEMPLATE | | A file containing a template for output formatting (optional). |
|
| template | TEMPLATE | | A file containing a template for output formatting (optional). |
|
||||||
| instruction | INSTRUCTION | | Input prompt text or instruction. "-" for STDIN. |
|
| instruction | INSTRUCTION | | Input prompt text or instruction. "-" for STDIN. |
|
||||||
| input | INPUT | - | Path to text or "-" for STDIN. |
|
| input | INPUT | - | Path to text or "-" for STDIN. |
|
||||||
| model | MODEL_PATH | | The path to the pre-trained GPT-based model. |
|
| model | MODEL | | The path to the pre-trained GPT-based model. |
|
||||||
| tokens | TOKENS | 128 | The maximum number of tokens to generate. |
|
| tokens | TOKENS | 128 | The maximum number of tokens to generate. |
|
||||||
| threads | THREADS | NumCPU() | The number of threads to use for text generation. |
|
| threads | THREADS | NumCPU() | The number of threads to use for text generation. |
|
||||||
| temperature | TEMPERATURE | 0.95 | Sampling temperature for model output. ( values between `0.1` and `1.0` ) |
|
| temperature | TEMPERATURE | 0.95 | Sampling temperature for model output. ( values between `0.1` and `1.0` ) |
|
||||||
@ -216,32 +217,6 @@ python 828bddec6162a023114ce19146cb2b82/gistfile1.txt models tokenizer.model
|
|||||||
# There will be a new model with the ".tmp" extension, you have to use that one!
|
# There will be a new model with the ".tmp" extension, you have to use that one!
|
||||||
```
|
```
|
||||||
|
|
||||||
### Golang client API
|
|
||||||
|
|
||||||
The `llama-cli` codebase has also a small client in go that can be used alongside with the api:
|
|
||||||
|
|
||||||
```golang
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
client "github.com/go-skynet/llama-cli/client"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
cli := client.NewClient("http://ip:port")
|
|
||||||
|
|
||||||
out, err := cli.Predict("What's an alpaca?")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println(out)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Windows compatibility
|
### Windows compatibility
|
||||||
|
|
||||||
It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/llama-cli/issues/2
|
It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/llama-cli/issues/2
|
||||||
|
77
api/api.go
77
api/api.go
@ -4,7 +4,6 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -70,7 +69,7 @@ type OpenAIRequest struct {
|
|||||||
var indexHTML embed.FS
|
var indexHTML embed.FS
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/completions
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
var err error
|
var err error
|
||||||
var model *llama.LLama
|
var model *llama.LLama
|
||||||
@ -82,10 +81,7 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
|
|||||||
}
|
}
|
||||||
|
|
||||||
if input.Model == "" {
|
if input.Model == "" {
|
||||||
if defaultModel == nil {
|
return fmt.Errorf("no model specified")
|
||||||
return fmt.Errorf("no default model loaded, and no model specified")
|
|
||||||
}
|
|
||||||
model = defaultModel
|
|
||||||
} else {
|
} else {
|
||||||
model, err = loader.LoadModel(input.Model)
|
model, err = loader.LoadModel(input.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -204,7 +200,7 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr string, threads int) error {
|
func Start(loader *model.ModelLoader, listenAddr string, threads int) error {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
|
||||||
// Default middleware config
|
// Default middleware config
|
||||||
@ -217,8 +213,8 @@ func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr stri
|
|||||||
var mumutex = &sync.Mutex{}
|
var mumutex = &sync.Mutex{}
|
||||||
|
|
||||||
// openAI compatible API endpoint
|
// openAI compatible API endpoint
|
||||||
app.Post("/v1/chat/completions", openAIEndpoint(true, defaultModel, loader, threads, mutex, mumutex, mu))
|
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, mutex, mumutex, mu))
|
||||||
app.Post("/v1/completions", openAIEndpoint(false, defaultModel, loader, threads, mutex, mumutex, mu))
|
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, mutex, mumutex, mu))
|
||||||
app.Get("/v1/models", func(c *fiber.Ctx) error {
|
app.Get("/v1/models", func(c *fiber.Ctx) error {
|
||||||
models, err := loader.ListModels()
|
models, err := loader.ListModels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -243,69 +239,6 @@ func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr stri
|
|||||||
NotFoundFile: "index.html",
|
NotFoundFile: "index.html",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
/*
|
|
||||||
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
|
|
||||||
"text": "What is an alpaca?",
|
|
||||||
"topP": 0.8,
|
|
||||||
"topK": 50,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"tokens": 100
|
|
||||||
}'
|
|
||||||
*/
|
|
||||||
// Endpoint to generate the prediction
|
|
||||||
app.Post("/predict", func(c *fiber.Ctx) error {
|
|
||||||
mutex.Lock()
|
|
||||||
defer mutex.Unlock()
|
|
||||||
// Get input data from the request body
|
|
||||||
input := new(struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
})
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
|
||||||
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the prediction using the language model
|
|
||||||
prediction, err := defaultModel.Predict(
|
|
||||||
input.Text,
|
|
||||||
llama.SetTemperature(temperature),
|
|
||||||
llama.SetTopP(topP),
|
|
||||||
llama.SetTopK(topK),
|
|
||||||
llama.SetTokens(tokens),
|
|
||||||
llama.SetThreads(threads),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(struct {
|
|
||||||
Prediction string `json:"prediction"`
|
|
||||||
}{
|
|
||||||
Prediction: prediction,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// Start the server
|
// Start the server
|
||||||
app.Listen(listenAddr)
|
app.Listen(listenAddr)
|
||||||
return nil
|
return nil
|
||||||
|
@ -1,75 +0,0 @@
|
|||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Prediction struct {
|
|
||||||
Prediction string `json:"prediction"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
baseURL string
|
|
||||||
client *http.Client
|
|
||||||
endpoint string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClient(baseURL string) *Client {
|
|
||||||
return &Client{
|
|
||||||
baseURL: baseURL,
|
|
||||||
client: &http.Client{},
|
|
||||||
endpoint: "/predict",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type InputData struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK int `json:"topK,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
Tokens int `json:"tokens,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) Predict(text string, opts ...InputOption) (string, error) {
|
|
||||||
input := NewInputData(opts...)
|
|
||||||
input.Text = text
|
|
||||||
|
|
||||||
// encode input data to JSON format
|
|
||||||
inputBytes, err := json.Marshal(input)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// create HTTP request
|
|
||||||
url := c.baseURL + c.endpoint
|
|
||||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(inputBytes))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set request headers
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
// send request and get response
|
|
||||||
resp, err := c.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return "", fmt.Errorf("request failed with status %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decode response body to Prediction struct
|
|
||||||
var prediction Prediction
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&prediction)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return prediction.Prediction, nil
|
|
||||||
}
|
|
@ -1,51 +0,0 @@
|
|||||||
package client
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
type ClientOption func(c *Client)
|
|
||||||
|
|
||||||
func WithHTTPClient(httpClient *http.Client) ClientOption {
|
|
||||||
return func(c *Client) {
|
|
||||||
c.client = httpClient
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithEndpoint(endpoint string) ClientOption {
|
|
||||||
return func(c *Client) {
|
|
||||||
c.endpoint = endpoint
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type InputOption func(d *InputData)
|
|
||||||
|
|
||||||
func NewInputData(opts ...InputOption) *InputData {
|
|
||||||
data := &InputData{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(data)
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithTopP(topP float64) InputOption {
|
|
||||||
return func(d *InputData) {
|
|
||||||
d.TopP = topP
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithTopK(topK int) InputOption {
|
|
||||||
return func(d *InputData) {
|
|
||||||
d.TopK = topK
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithTemperature(temperature float64) InputOption {
|
|
||||||
return func(d *InputData) {
|
|
||||||
d.Temperature = temperature
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithTokens(tokens int) InputOption {
|
|
||||||
return func(d *InputData) {
|
|
||||||
d.Tokens = tokens
|
|
||||||
}
|
|
||||||
}
|
|
20
main.go
20
main.go
@ -57,7 +57,7 @@ func templateString(t string, in interface{}) (string, error) {
|
|||||||
var modelFlags = []cli.Flag{
|
var modelFlags = []cli.Flag{
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
Name: "model",
|
Name: "model",
|
||||||
EnvVars: []string{"MODEL_PATH"},
|
EnvVars: []string{"MODEL"},
|
||||||
},
|
},
|
||||||
&cli.IntFlag{
|
&cli.IntFlag{
|
||||||
Name: "tokens",
|
Name: "tokens",
|
||||||
@ -134,10 +134,6 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
|||||||
Name: "models-path",
|
Name: "models-path",
|
||||||
EnvVars: []string{"MODELS_PATH"},
|
EnvVars: []string{"MODELS_PATH"},
|
||||||
},
|
},
|
||||||
&cli.StringFlag{
|
|
||||||
Name: "default-model",
|
|
||||||
EnvVars: []string{"DEFAULT_MODEL"},
|
|
||||||
},
|
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
Name: "address",
|
Name: "address",
|
||||||
EnvVars: []string{"ADDRESS"},
|
EnvVars: []string{"ADDRESS"},
|
||||||
@ -150,19 +146,7 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Action: func(ctx *cli.Context) error {
|
Action: func(ctx *cli.Context) error {
|
||||||
|
return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
|
||||||
var defaultModel *llama.LLama
|
|
||||||
defModel := ctx.String("default-model")
|
|
||||||
if defModel != "" {
|
|
||||||
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
|
|
||||||
var err error
|
|
||||||
defaultModel, err = llama.New(ctx.String("default-model"), opts...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return api.Start(defaultModel, model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user