mirror of
https://github.com/mudler/LocalAI.git
synced 2025-02-20 09:26:15 +00:00
feat: drop default model and llama-specific API (#26)
Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
parent
1370b4482f
commit
63601fabd1
29
README.md
29
README.md
@ -27,6 +27,7 @@ docker compose up -d --build
|
||||
|
||||
# Now API is accessible at localhost:8080
|
||||
curl http://localhost:8080/v1/models
|
||||
|
||||
# {"object":"list","data":[{"id":"your-model.bin","object":"model"}]}
|
||||
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "your-model.bin",
|
||||
@ -88,7 +89,7 @@ llama-cli --model <model_path> --instruction <instruction> [--input <input>] [--
|
||||
| template | TEMPLATE | | A file containing a template for output formatting (optional). |
|
||||
| instruction | INSTRUCTION | | Input prompt text or instruction. "-" for STDIN. |
|
||||
| input | INPUT | - | Path to text or "-" for STDIN. |
|
||||
| model | MODEL_PATH | | The path to the pre-trained GPT-based model. |
|
||||
| model | MODEL | | The path to the pre-trained GPT-based model. |
|
||||
| tokens | TOKENS | 128 | The maximum number of tokens to generate. |
|
||||
| threads | THREADS | NumCPU() | The number of threads to use for text generation. |
|
||||
| temperature | TEMPERATURE | 0.95 | Sampling temperature for model output. ( values between `0.1` and `1.0` ) |
|
||||
@ -216,32 +217,6 @@ python 828bddec6162a023114ce19146cb2b82/gistfile1.txt models tokenizer.model
|
||||
# There will be a new model with the ".tmp" extension, you have to use that one!
|
||||
```
|
||||
|
||||
### Golang client API
|
||||
|
||||
The `llama-cli` codebase has also a small client in go that can be used alongside with the api:
|
||||
|
||||
```golang
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
client "github.com/go-skynet/llama-cli/client"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
cli := client.NewClient("http://ip:port")
|
||||
|
||||
out, err := cli.Predict("What's an alpaca?")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(out)
|
||||
}
|
||||
```
|
||||
|
||||
### Windows compatibility
|
||||
|
||||
It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/llama-cli/issues/2
|
||||
|
77
api/api.go
77
api/api.go
@ -4,7 +4,6 @@ import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@ -70,7 +69,7 @@ type OpenAIRequest struct {
|
||||
var indexHTML embed.FS
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
var err error
|
||||
var model *llama.LLama
|
||||
@ -82,10 +81,7 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
|
||||
}
|
||||
|
||||
if input.Model == "" {
|
||||
if defaultModel == nil {
|
||||
return fmt.Errorf("no default model loaded, and no model specified")
|
||||
}
|
||||
model = defaultModel
|
||||
return fmt.Errorf("no model specified")
|
||||
} else {
|
||||
model, err = loader.LoadModel(input.Model)
|
||||
if err != nil {
|
||||
@ -204,7 +200,7 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
|
||||
}
|
||||
}
|
||||
|
||||
func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr string, threads int) error {
|
||||
func Start(loader *model.ModelLoader, listenAddr string, threads int) error {
|
||||
app := fiber.New()
|
||||
|
||||
// Default middleware config
|
||||
@ -217,8 +213,8 @@ func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr stri
|
||||
var mumutex = &sync.Mutex{}
|
||||
|
||||
// openAI compatible API endpoint
|
||||
app.Post("/v1/chat/completions", openAIEndpoint(true, defaultModel, loader, threads, mutex, mumutex, mu))
|
||||
app.Post("/v1/completions", openAIEndpoint(false, defaultModel, loader, threads, mutex, mumutex, mu))
|
||||
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, mutex, mumutex, mu))
|
||||
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, mutex, mumutex, mu))
|
||||
app.Get("/v1/models", func(c *fiber.Ctx) error {
|
||||
models, err := loader.ListModels()
|
||||
if err != nil {
|
||||
@ -243,69 +239,6 @@ func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr stri
|
||||
NotFoundFile: "index.html",
|
||||
}))
|
||||
|
||||
/*
|
||||
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
|
||||
"text": "What is an alpaca?",
|
||||
"topP": 0.8,
|
||||
"topK": 50,
|
||||
"temperature": 0.7,
|
||||
"tokens": 100
|
||||
}'
|
||||
*/
|
||||
// Endpoint to generate the prediction
|
||||
app.Post("/predict", func(c *fiber.Ctx) error {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
// Get input data from the request body
|
||||
input := new(struct {
|
||||
Text string `json:"text"`
|
||||
})
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate the prediction using the language model
|
||||
prediction, err := defaultModel.Predict(
|
||||
input.Text,
|
||||
llama.SetTemperature(temperature),
|
||||
llama.SetTopP(topP),
|
||||
llama.SetTopK(topK),
|
||||
llama.SetTokens(tokens),
|
||||
llama.SetThreads(threads),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(struct {
|
||||
Prediction string `json:"prediction"`
|
||||
}{
|
||||
Prediction: prediction,
|
||||
})
|
||||
})
|
||||
|
||||
// Start the server
|
||||
app.Listen(listenAddr)
|
||||
return nil
|
||||
|
@ -1,75 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Prediction struct {
|
||||
Prediction string `json:"prediction"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
endpoint string
|
||||
}
|
||||
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
client: &http.Client{},
|
||||
endpoint: "/predict",
|
||||
}
|
||||
}
|
||||
|
||||
type InputData struct {
|
||||
Text string `json:"text"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Tokens int `json:"tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Client) Predict(text string, opts ...InputOption) (string, error) {
|
||||
input := NewInputData(opts...)
|
||||
input.Text = text
|
||||
|
||||
// encode input data to JSON format
|
||||
inputBytes, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// create HTTP request
|
||||
url := c.baseURL + c.endpoint
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(inputBytes))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// set request headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// send request and get response
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("request failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// decode response body to Prediction struct
|
||||
var prediction Prediction
|
||||
err = json.NewDecoder(resp.Body).Decode(&prediction)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return prediction.Prediction, nil
|
||||
}
|
@ -1,51 +0,0 @@
|
||||
package client
|
||||
|
||||
import "net/http"
|
||||
|
||||
type ClientOption func(c *Client)
|
||||
|
||||
func WithHTTPClient(httpClient *http.Client) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.client = httpClient
|
||||
}
|
||||
}
|
||||
|
||||
func WithEndpoint(endpoint string) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.endpoint = endpoint
|
||||
}
|
||||
}
|
||||
|
||||
type InputOption func(d *InputData)
|
||||
|
||||
func NewInputData(opts ...InputOption) *InputData {
|
||||
data := &InputData{}
|
||||
for _, opt := range opts {
|
||||
opt(data)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func WithTopP(topP float64) InputOption {
|
||||
return func(d *InputData) {
|
||||
d.TopP = topP
|
||||
}
|
||||
}
|
||||
|
||||
func WithTopK(topK int) InputOption {
|
||||
return func(d *InputData) {
|
||||
d.TopK = topK
|
||||
}
|
||||
}
|
||||
|
||||
func WithTemperature(temperature float64) InputOption {
|
||||
return func(d *InputData) {
|
||||
d.Temperature = temperature
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokens(tokens int) InputOption {
|
||||
return func(d *InputData) {
|
||||
d.Tokens = tokens
|
||||
}
|
||||
}
|
20
main.go
20
main.go
@ -57,7 +57,7 @@ func templateString(t string, in interface{}) (string, error) {
|
||||
var modelFlags = []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "model",
|
||||
EnvVars: []string{"MODEL_PATH"},
|
||||
EnvVars: []string{"MODEL"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "tokens",
|
||||
@ -134,10 +134,6 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
||||
Name: "models-path",
|
||||
EnvVars: []string{"MODELS_PATH"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "default-model",
|
||||
EnvVars: []string{"DEFAULT_MODEL"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "address",
|
||||
EnvVars: []string{"ADDRESS"},
|
||||
@ -150,19 +146,7 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
||||
},
|
||||
},
|
||||
Action: func(ctx *cli.Context) error {
|
||||
|
||||
var defaultModel *llama.LLama
|
||||
defModel := ctx.String("default-model")
|
||||
if defModel != "" {
|
||||
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
|
||||
var err error
|
||||
defaultModel, err = llama.New(ctx.String("default-model"), opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return api.Start(defaultModel, model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
|
||||
return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
|
||||
},
|
||||
},
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user