Enhancements (#34)

Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
Ettore Di Giacinto 2023-04-19 17:10:29 +02:00 committed by GitHub
parent a9a875ee2b
commit 7fec26f5d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 226 additions and 78 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
# go-llama build artifacts # go-llama build artifacts
go-llama go-llama
go-gpt4all-j
# llama-cli build binary # llama-cli build binary
llama-cli llama-cli

View File

@ -2,16 +2,10 @@ ARG GO_VERSION=1.20
ARG DEBIAN_VERSION=11 ARG DEBIAN_VERSION=11
FROM golang:$GO_VERSION as builder FROM golang:$GO_VERSION as builder
WORKDIR /build WORKDIR /build
ARG GO_LLAMA_CPP_TAG=llama.cpp-4ad7313 RUN apt-get update && apt-get install -y cmake
RUN git clone -b $GO_LLAMA_CPP_TAG --recurse-submodules https://github.com/go-skynet/go-llama.cpp
RUN cd go-llama.cpp && make libbinding.a
COPY go.mod ./
COPY go.sum ./
RUN go mod download
RUN apt-get update
COPY . . COPY . .
RUN go mod edit -replace github.com/go-skynet/go-llama.cpp=/build/go-llama.cpp ARG BUILD_TYPE=
RUN C_INCLUDE_PATH=/build/go-llama.cpp LIBRARY_PATH=/build/go-llama.cpp go build -o llama-cli ./ RUN make build${BUILD_TYPE}
FROM debian:$DEBIAN_VERSION FROM debian:$DEBIAN_VERSION
COPY --from=builder /build/llama-cli /usr/bin/llama-cli COPY --from=builder /build/llama-cli /usr/bin/llama-cli

View File

@ -2,7 +2,7 @@ GOCMD=go
GOTEST=$(GOCMD) test GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet GOVET=$(GOCMD) vet
BINARY_NAME=llama-cli BINARY_NAME=llama-cli
GOLLAMA_VERSION?=llama.cpp-4ad7313 GOLLAMA_VERSION?=llama.cpp-5ecff35
GREEN := $(shell tput -Txterm setaf 2) GREEN := $(shell tput -Txterm setaf 2)
YELLOW := $(shell tput -Txterm setaf 3) YELLOW := $(shell tput -Txterm setaf 3)
@ -17,23 +17,50 @@ all: help
## Build: ## Build:
build: prepare ## Build the project build: prepare ## Build the project
$(GOCMD) build -o $(BINARY_NAME) ./ C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./
buildgeneric: prepare-generic ## Build the project
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./
go-gpt4all-j:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
@find ./go-gpt4all-j -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
go-gpt4all-j/libgptj.a: go-gpt4all-j
$(MAKE) -C go-gpt4all-j libgptj.a
go-gpt4all-j/libgptj.a-generic: go-gpt4all-j
$(MAKE) -C go-gpt4all-j generic-libgptj.a
go-llama: go-llama:
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
prepare: go-llama
$(MAKE) -C go-llama libbinding.a $(MAKE) -C go-llama libbinding.a
go-llama-generic:
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
$(MAKE) -C go-llama generic-libbinding.a
prepare: go-llama go-gpt4all-j/libgptj.a
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
prepare-generic: go-llama-generic go-gpt4all-j/libgptj.a-generic
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
clean: ## Remove build related file clean: ## Remove build related file
$(MAKE) -C go-llama clean
rm -fr ./go-llama rm -fr ./go-llama
rm -f $(BINARY_NAME) rm -rf ./go-gpt4all-j
rm -rf $(BINARY_NAME)
## Run: ## Run:
run: prepare run: prepare
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp LIBRARY_PATH=$(shell pwd)/go-llama.cpp $(GOCMD) run ./ api $(GOCMD) run ./ api
## Test: ## Test:
test: ## Run the tests of the project test: ## Run the tests of the project

View File

@ -1,15 +1,24 @@
## :camel: llama-cli ## :camel: llama-cli
llama-cli is a straightforward golang CLI interface and API compatible with OpenAI for [llama.cpp](https://github.com/ggerganov/llama.cpp), it supports multiple-models and also provides a simple command line interface that allows text generation using a GPT-based model like llama directly from the terminal. llama-cli is a straightforward, drop-in replacement API compatible with OpenAI for local CPU inferencing, based on [llama.cpp](https://github.com/ggerganov/llama.cpp), [gpt4all](https://github.com/nomic-ai/gpt4all) and [ggml](https://github.com/ggerganov/ggml), including support GPT4ALL-J which is Apache 2.0 Licensed and can be used for commercial purposes.
It is compatible with the models supported by `llama.cpp`. You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`. - OpenAI compatible API
- Supports multiple-models
- Once loaded the first time, it keep models loaded in memory for faster inference
- Provides a simple command line interface that allows text generation directly from the terminal
- Support for prompt templates
- Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp).
`llama-cli` doesn't shell-out, it uses https://github.com/go-skynet/go-llama.cpp, which is a golang binding of [llama.cpp](https://github.com/ggerganov/llama.cpp). ## Model compatibility
It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) and also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all).
Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`.
## Usage ## Usage
You can use `docker-compose`: The easiest way to run llama-cli is by using `docker-compose`:
```bash ```bash
@ -27,15 +36,13 @@ docker compose up -d --build
# Now API is accessible at localhost:8080 # Now API is accessible at localhost:8080
curl http://localhost:8080/v1/models curl http://localhost:8080/v1/models
# {"object":"list","data":[{"id":"your-model.bin","object":"model"}]} # {"object":"list","data":[{"id":"your-model.bin","object":"model"}]}
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{ curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
"model": "your-model.bin", "model": "your-model.bin",
"prompt": "A long time ago in a galaxy far, far away", "prompt": "A long time ago in a galaxy far, far away",
"temperature": 0.7 "temperature": 0.7
}' }'
``` ```
Note: The API doesn't inject a default prompt for talking to the model, while the CLI does. You have to use a prompt similar to what's described in the standford-alpaca docs: https://github.com/tatsu-lab/stanford_alpaca#data-release. Note: The API doesn't inject a default prompt for talking to the model, while the CLI does. You have to use a prompt similar to what's described in the standford-alpaca docs: https://github.com/tatsu-lab/stanford_alpaca#data-release.

View File

@ -7,6 +7,7 @@ import (
model "github.com/go-skynet/llama-cli/pkg/model" model "github.com/go-skynet/llama-cli/pkg/model"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp" llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/cors"
@ -60,13 +61,16 @@ type OpenAIRequest struct {
Batch int `json:"batch"` Batch int `json:"batch"`
F16 bool `json:"f16kv"` F16 bool `json:"f16kv"`
IgnoreEOS bool `json:"ignore_eos"` IgnoreEOS bool `json:"ignore_eos"`
Seed int `json:"seed"`
} }
// https://platform.openai.com/docs/api-reference/completions // https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
var err error var err error
var model *llama.LLama var model *llama.LLama
var gptModel *gptj.GPTJ
input := new(OpenAIRequest) input := new(OpenAIRequest)
// Get input data from the request body // Get input data from the request body
@ -77,9 +81,22 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
if input.Model == "" { if input.Model == "" {
return fmt.Errorf("no model specified") return fmt.Errorf("no model specified")
} else { } else {
model, err = loader.LoadModel(input.Model) // Try to load the model with both
if err != nil { var llamaerr error
return err llamaOpts := []llama.ModelOption{}
if ctx != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
}
if f16 {
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
}
model, llamaerr = loader.LoadLLaMAModel(input.Model, llamaOpts...)
if llamaerr != nil {
gptModel, err = loader.LoadGPTJModel(input.Model)
if err != nil {
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors
}
} }
} }
@ -146,32 +163,70 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
n = 1 n = 1
} }
var predFunc func() (string, error)
switch {
case gptModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gptj.PredictOption{
gptj.SetTemperature(temperature),
gptj.SetTopP(topP),
gptj.SetTopK(topK),
gptj.SetTokens(tokens),
gptj.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gptj.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gptj.SetSeed(input.Seed))
}
return gptModel.Predict(
predInput,
predictOptions...,
)
}
case model != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []llama.PredictOption{
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
}
if input.F16 {
predictOptions = append(predictOptions, llama.EnableF16KV)
}
if input.IgnoreEOS {
predictOptions = append(predictOptions, llama.IgnoreEOS)
}
if input.Seed != 0 {
predictOptions = append(predictOptions, llama.SetSeed(input.Seed))
}
return model.Predict(
predInput,
predictOptions...,
)
}
}
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
// Generate the prediction using the language model var prediction string
predictOptions := []llama.PredictOption{
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
}
if input.Batch != 0 { prediction, err := predFunc()
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
}
if input.F16 {
predictOptions = append(predictOptions, llama.EnableF16KV)
}
if input.IgnoreEOS {
predictOptions = append(predictOptions, llama.IgnoreEOS)
}
prediction, err := model.Predict(
predInput,
predictOptions...,
)
if err != nil { if err != nil {
return err return err
} }
@ -179,6 +234,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
if input.Echo { if input.Echo {
prediction = predInput + prediction prediction = predInput + prediction
} }
if chat { if chat {
result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}}) result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}})
} else { } else {
@ -194,7 +250,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
} }
} }
func Start(loader *model.ModelLoader, listenAddr string, threads int) error { func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error {
app := fiber.New() app := fiber.New()
// Default middleware config // Default middleware config
@ -207,8 +263,8 @@ func Start(loader *model.ModelLoader, listenAddr string, threads int) error {
var mumutex = &sync.Mutex{} var mumutex = &sync.Mutex{}
// openAI compatible API endpoint // openAI compatible API endpoint
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, mutex, mumutex, mu)) app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mutex, mumutex, mu))
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, mutex, mumutex, mu)) app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mutex, mumutex, mu))
app.Get("/v1/models", func(c *fiber.Ctx) error { app.Get("/v1/models", func(c *fiber.Ctx) error {
models, err := loader.ListModels() models, err := loader.ListModels()
if err != nil { if err != nil {

View File

@ -16,6 +16,8 @@ services:
build: build:
context: . context: .
dockerfile: Dockerfile dockerfile: Dockerfile
# args:
# BUILD_TYPE: generic # Uncomment to build CPU generic code that works on most HW
ports: ports:
- 8080:8080 - 8080:8080
environment: environment:

1
go.mod
View File

@ -11,6 +11,7 @@ require (
require ( require (
github.com/andybalholm/brotli v1.0.4 // indirect github.com/andybalholm/brotli v1.0.4 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 // indirect
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0 // indirect
github.com/klauspost/compress v1.15.9 // indirect github.com/klauspost/compress v1.15.9 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect

2
go.sum
View File

@ -3,6 +3,8 @@ github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHG
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E=
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=

14
main.go
View File

@ -34,11 +34,6 @@ var nonEmptyInput string = `Below is an instruction that describes a task, paire
### Response: ### Response:
` `
func llamaFromOptions(ctx *cli.Context) (*llama.LLama, error) {
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
return llama.New(ctx.String("model"), opts...)
}
func templateString(t string, in interface{}) (string, error) { func templateString(t string, in interface{}) (string, error) {
// Parse the template // Parse the template
tmpl, err := template.New("prompt").Parse(t) tmpl, err := template.New("prompt").Parse(t)
@ -125,6 +120,10 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
Name: "api", Name: "api",
Flags: []cli.Flag{ Flags: []cli.Flag{
&cli.BoolFlag{
Name: "f16",
EnvVars: []string{"F16"},
},
&cli.IntFlag{ &cli.IntFlag{
Name: "threads", Name: "threads",
EnvVars: []string{"THREADS"}, EnvVars: []string{"THREADS"},
@ -146,7 +145,7 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
}, },
}, },
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads")) return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"))
}, },
}, },
}, },
@ -201,7 +200,8 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
os.Exit(1) os.Exit(1)
} }
l, err := llamaFromOptions(ctx) opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
l, err := llama.New(ctx.String("model"), opts...)
if err != nil { if err != nil {
fmt.Println("Loading the model failed:", err.Error()) fmt.Println("Loading the model failed:", err.Error())
os.Exit(1) os.Exit(1)

View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"text/template" "text/template"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp" llama "github.com/go-skynet/go-llama.cpp"
) )
@ -17,11 +18,12 @@ type ModelLoader struct {
modelPath string modelPath string
mu sync.Mutex mu sync.Mutex
models map[string]*llama.LLama models map[string]*llama.LLama
gptmodels map[string]*gptj.GPTJ
promptsTemplates map[string]*template.Template promptsTemplates map[string]*template.Template
} }
func NewModelLoader(modelPath string) *ModelLoader { func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} return &ModelLoader{modelPath: modelPath, gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
} }
func (ml *ModelLoader) ListModels() ([]string, error) { func (ml *ModelLoader) ListModels() ([]string, error) {
@ -62,16 +64,81 @@ func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string,
return buf.String(), nil return buf.String(), nil
} }
func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) { func (ml *ModelLoader) loadTemplate(modelName, modelFile string) error {
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile)
// Check if the model path exists
if _, err := os.Stat(modelTemplateFile); err != nil {
return nil
}
dat, err := os.ReadFile(modelTemplateFile)
if err != nil {
return err
}
// Parse the template
tmpl, err := template.New("prompt").Parse(string(dat))
if err != nil {
return err
}
ml.promptsTemplates[modelName] = tmpl
return nil
}
func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
ml.mu.Lock() ml.mu.Lock()
defer ml.mu.Unlock() defer ml.mu.Unlock()
// Check if we already have a loaded model // Check if we already have a loaded model
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.modelPath, modelName)
if m, ok := ml.gptmodels[modelFile]; ok {
return m, nil
}
// Check if the model path exists
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
// try to find a s.bin
modelBin := fmt.Sprintf("%s.bin", modelFile)
if _, err := os.Stat(modelBin); os.IsNotExist(err) {
return nil, err
} else {
modelName = fmt.Sprintf("%s.bin", modelName)
modelFile = modelBin
}
}
// Load the model and keep it in memory for later use
model, err := gptj.New(modelFile)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
if err := ml.loadTemplate(modelName, modelFile); err != nil {
return nil, err
}
ml.gptmodels[modelFile] = model
return model, err
}
func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
modelFile := filepath.Join(ml.modelPath, modelName)
if m, ok := ml.models[modelFile]; ok { if m, ok := ml.models[modelFile]; ok {
return m, nil return m, nil
} }
// TODO: This needs refactoring, it's really bad to have it in here
// Check if we have a GPTJ model loaded instead
if _, ok := ml.gptmodels[modelFile]; ok {
return nil, fmt.Errorf("this model is a GPTJ one")
}
// Check if the model path exists // Check if the model path exists
if _, err := os.Stat(modelFile); os.IsNotExist(err) { if _, err := os.Stat(modelFile); os.IsNotExist(err) {
@ -92,21 +159,8 @@ func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*
} }
// If there is a prompt template, load it // If there is a prompt template, load it
if err := ml.loadTemplate(modelName, modelFile); err != nil {
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile) return nil, err
// Check if the model path exists
if _, err := os.Stat(modelTemplateFile); err == nil {
dat, err := os.ReadFile(modelTemplateFile)
if err != nil {
return nil, err
}
// Parse the template
tmpl, err := template.New("prompt").Parse(string(dat))
if err != nil {
return nil, err
}
ml.promptsTemplates[modelName] = tmpl
} }
ml.models[modelFile] = model ml.models[modelFile] = model

View File

@ -0,0 +1,4 @@
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt:
{{.Input}}
### Response: