From f43aeeb4a1faa3f6c074c56f5766d53364062235 Mon Sep 17 00:00:00 2001 From: mudler Date: Sun, 9 Apr 2023 12:30:55 +0200 Subject: [PATCH] Add both API endpoints (completion, chat) --- api.go | 194 ++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 145 insertions(+), 49 deletions(-) diff --git a/api.go b/api.go index 98a6d3f8..ce37f340 100644 --- a/api.go +++ b/api.go @@ -16,55 +16,38 @@ import ( ) type OpenAIResponse struct { - Created int `json:"created"` - Object string `json:"chat.completion"` - ID string `json:"id"` - Model string `json:"model"` - Choices []Choice `json:"choices"` + Created int `json:"created,omitempty"` + Object string `json:"chat.completion,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Choices []Choice `json:"choices,omitempty"` } type Choice struct { - Index int `json:"index"` - FinishReason string `json:"finish_reason"` - Message Message `json:"message"` + Index int `json:"index,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message Message `json:"message,omitempty"` + Text string `json:"text,omitempty"` } type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` } //go:embed index.html var indexHTML embed.FS -func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, threads int) error { - app := fiber.New() - - // Default middleware config - app.Use(recover.New()) - app.Use(cors.New()) - - app.Use("/", filesystem.New(filesystem.Config{ - Root: http.FS(indexHTML), - NotFoundFile: "index.html", - })) - - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - var mutex = &sync.Mutex{} - mu := map[string]*sync.Mutex{} - var mumutex = &sync.Mutex{} - - // openAI compatible API endpoint - app.Post("/v1/chat/completions", func(c *fiber.Ctx) error { +func completionEndpoint(defaultModel *llama.LLama, loader *ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { var err error var model *llama.LLama // Get input data from the request body input := new(struct { - Messages []Message `json:"messages"` - Model string `json:"model"` - Prompt string `json:"prompt"` + Model string `json:"model"` + Prompt string `json:"prompt"` }) if err := c.BodyParser(input); err != nil { return err @@ -84,19 +67,114 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 if input.Model != "" { - mumutex.Lock() - l, ok := mu[input.Model] + mutexMap.Lock() + l, ok := mutexes[input.Model] if !ok { m := &sync.Mutex{} - mu[input.Model] = m + mutexes[input.Model] = m l = m } - mumutex.Unlock() + mutexMap.Unlock() l.Lock() defer l.Unlock() } else { - mutex.Lock() - defer mutex.Unlock() + defaultMutex.Lock() + defer defaultMutex.Unlock() + } + + // Set the parameters for the language model prediction + topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9 + if err != nil { + return err + } + + topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40 + if err != nil { + return err + } + + temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5 + if err != nil { + return err + } + + tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128 + if err != nil { + return err + } + + predInput := input.Prompt + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(input.Model, struct { + Input string + }{Input: input.Prompt}) + if err == nil { + predInput = templatedInput + } + + // Generate the prediction using the language model + prediction, err := model.Predict( + predInput, + llama.SetTemperature(temperature), + llama.SetTopP(topP), + llama.SetTopK(topK), + llama.SetTokens(tokens), + llama.SetThreads(threads), + ) + if err != nil { + return err + } + + // Return the prediction in the response body + return c.JSON(OpenAIResponse{ + Model: input.Model, + Choices: []Choice{{Text: prediction}}, + }) + } +} + +func chatEndpoint(defaultModel *llama.LLama, loader *ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + var err error + var model *llama.LLama + + // Get input data from the request body + input := new(struct { + Messages []Message `json:"messages"` + Model string `json:"model"` + }) + if err := c.BodyParser(input); err != nil { + return err + } + + if input.Model == "" { + if defaultModel == nil { + return fmt.Errorf("no default model loaded, and no model specified") + } + model = defaultModel + } else { + model, err = loader.LoadModel(input.Model) + if err != nil { + return err + } + } + + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + if input.Model != "" { + mutexMap.Lock() + l, ok := mutexes[input.Model] + if !ok { + m := &sync.Mutex{} + mutexes[input.Model] = m + l = m + } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() + } else { + defaultMutex.Lock() + defer defaultMutex.Unlock() } // Set the parameters for the language model prediction @@ -127,16 +205,12 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre predInput := strings.Join(mess, "\n") - if input.Prompt == "" { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(input.Model, struct { - Input string - }{Input: predInput}) - if err == nil { - predInput = templatedInput - } - } else { - predInput = input.Prompt + predInput + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(input.Model, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput } // Generate the prediction using the language model @@ -157,7 +231,29 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre Model: input.Model, Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}}, }) - }) + } +} + +func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, threads int) error { + app := fiber.New() + + // Default middleware config + app.Use(recover.New()) + app.Use(cors.New()) + + app.Use("/", filesystem.New(filesystem.Config{ + Root: http.FS(indexHTML), + NotFoundFile: "index.html", + })) + + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + var mutex = &sync.Mutex{} + mu := map[string]*sync.Mutex{} + var mumutex = &sync.Mutex{} + + // openAI compatible API endpoint + app.Post("/v1/chat/completions", chatEndpoint(defaultModel, loader, threads, mutex, mumutex, mu)) + app.Post("/v1/completions", completionEndpoint(defaultModel, loader, threads, mutex, mumutex, mu)) /* curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{