diff --git a/README.md b/README.md index a3197603..28eb7777 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,32 @@ Example of starting the API with `docker`: docker run -ti --rm quay.io/go-skynet/llama-cli:latest api ``` +### Golang client API + +The `llama-cli` codebase has also a small client in go that can be used alongside with the api: + +```golang +package main + +import ( + "fmt" + + client "github.com/go-skynet/llama-cli/client" +) + +func main() { + + cli := client.NewClient("http://ip:30007") + + out, err := cli.Predict("What's an alpaca?") + if err != nil { + panic(err) + } + + fmt.Println(out) +} +``` + ### Kubernetes You can run the API directly in Kubernetes: diff --git a/client/client.go b/client/client.go new file mode 100644 index 00000000..785a46d9 --- /dev/null +++ b/client/client.go @@ -0,0 +1,75 @@ +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" +) + +type Prediction struct { + Prediction string `json:"prediction"` +} + +type Client struct { + baseURL string + client *http.Client + endpoint string +} + +func NewClient(baseURL string) *Client { + return &Client{ + baseURL: baseURL, + client: &http.Client{}, + endpoint: "/predict", + } +} + +type InputData struct { + Text string `json:"text"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Tokens int `json:"tokens,omitempty"` +} + +func (c *Client) Predict(text string, opts ...InputOption) (string, error) { + input := NewInputData(opts...) + input.Text = text + + // encode input data to JSON format + inputBytes, err := json.Marshal(input) + if err != nil { + return "", err + } + + // create HTTP request + url := c.baseURL + c.endpoint + req, err := http.NewRequest("POST", url, bytes.NewBuffer(inputBytes)) + if err != nil { + return "", err + } + + // set request headers + req.Header.Set("Content-Type", "application/json") + + // send request and get response + resp, err := c.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("request failed with status %d", resp.StatusCode) + } + + // decode response body to Prediction struct + var prediction Prediction + err = json.NewDecoder(resp.Body).Decode(&prediction) + if err != nil { + return "", err + } + + return prediction.Prediction, nil +} diff --git a/client/options.go b/client/options.go new file mode 100644 index 00000000..6635763e --- /dev/null +++ b/client/options.go @@ -0,0 +1,51 @@ +package client + +import "net/http" + +type ClientOption func(c *Client) + +func WithHTTPClient(httpClient *http.Client) ClientOption { + return func(c *Client) { + c.client = httpClient + } +} + +func WithEndpoint(endpoint string) ClientOption { + return func(c *Client) { + c.endpoint = endpoint + } +} + +type InputOption func(d *InputData) + +func NewInputData(opts ...InputOption) *InputData { + data := &InputData{} + for _, opt := range opts { + opt(data) + } + return data +} + +func WithTopP(topP float64) InputOption { + return func(d *InputData) { + d.TopP = topP + } +} + +func WithTopK(topK int) InputOption { + return func(d *InputData) { + d.TopK = topK + } +} + +func WithTemperature(temperature float64) InputOption { + return func(d *InputData) { + d.Temperature = temperature + } +} + +func WithTokens(tokens int) InputOption { + return func(d *InputData) { + d.Tokens = tokens + } +}