mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat: enhance API, expose more parameters (#24)
Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
parent
c37175271f
commit
b062f3142b
82
api/api.go
82
api/api.go
@ -26,10 +26,10 @@ type OpenAIResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
Index int `json:"index,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
FinishReason string `json:"finish_reason,omitempty"`
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
Message Message `json:"message,omitempty"`
|
Message *Message `json:"message,omitempty"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@ -47,20 +47,29 @@ type OpenAIRequest struct {
|
|||||||
|
|
||||||
// Prompt is read only by completion API calls
|
// Prompt is read only by completion API calls
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
||||||
// Messages is read only by chat/completion API calls
|
// Messages is read only by chat/completion API calls
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
|
|
||||||
|
Echo bool `json:"echo"`
|
||||||
// Common options between all the API calls
|
// Common options between all the API calls
|
||||||
TopP float64 `json:"top_p"`
|
TopP float64 `json:"top_p"`
|
||||||
TopK int `json:"top_k"`
|
TopK int `json:"top_k"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
Maxtokens int `json:"max_tokens"`
|
Maxtokens int `json:"max_tokens"`
|
||||||
|
|
||||||
|
N int `json:"n"`
|
||||||
|
|
||||||
|
// Custom parameters - not present in the OpenAI API
|
||||||
|
Batch int `json:"batch"`
|
||||||
|
F16 bool `json:"f16kv"`
|
||||||
|
IgnoreEOS bool `json:"ignore_eos"`
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:embed index.html
|
//go:embed index.html
|
||||||
var indexHTML embed.FS
|
var indexHTML embed.FS
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
var err error
|
var err error
|
||||||
@ -139,31 +148,58 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
|
|||||||
predInput = templatedInput
|
predInput = templatedInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the prediction using the language model
|
result := []Choice{}
|
||||||
prediction, err := model.Predict(
|
|
||||||
predInput,
|
n := input.N
|
||||||
llama.SetTemperature(temperature),
|
|
||||||
llama.SetTopP(topP),
|
if input.N == 0 {
|
||||||
llama.SetTopK(topK),
|
n = 1
|
||||||
llama.SetTokens(tokens),
|
|
||||||
llama.SetThreads(threads),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chat {
|
for i := 0; i < n; i++ {
|
||||||
// Return the chat prediction in the response body
|
// Generate the prediction using the language model
|
||||||
return c.JSON(OpenAIResponse{
|
predictOptions := []llama.PredictOption{
|
||||||
Model: input.Model,
|
llama.SetTemperature(temperature),
|
||||||
Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}},
|
llama.SetTopP(topP),
|
||||||
})
|
llama.SetTopK(topK),
|
||||||
|
llama.SetTokens(tokens),
|
||||||
|
llama.SetThreads(threads),
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Batch != 0 {
|
||||||
|
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.F16 {
|
||||||
|
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.IgnoreEOS {
|
||||||
|
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction, err := model.Predict(
|
||||||
|
predInput,
|
||||||
|
predictOptions...,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Echo {
|
||||||
|
prediction = predInput + prediction
|
||||||
|
}
|
||||||
|
if chat {
|
||||||
|
result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}})
|
||||||
|
} else {
|
||||||
|
result = append(result, Choice{Text: prediction})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the prediction in the response body
|
// Return the prediction in the response body
|
||||||
return c.JSON(OpenAIResponse{
|
return c.JSON(OpenAIResponse{
|
||||||
Model: input.Model,
|
Model: input.Model,
|
||||||
Choices: []Choice{{Text: prediction}},
|
Choices: result,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user