mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 12:26:26 +00:00
parent
a9a875ee2b
commit
7fec26f5d3
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,5 +1,6 @@
|
||||
# go-llama build artifacts
|
||||
go-llama
|
||||
go-gpt4all-j
|
||||
|
||||
# llama-cli build binary
|
||||
llama-cli
|
||||
|
12
Dockerfile
12
Dockerfile
@ -2,16 +2,10 @@ ARG GO_VERSION=1.20
|
||||
ARG DEBIAN_VERSION=11
|
||||
FROM golang:$GO_VERSION as builder
|
||||
WORKDIR /build
|
||||
ARG GO_LLAMA_CPP_TAG=llama.cpp-4ad7313
|
||||
RUN git clone -b $GO_LLAMA_CPP_TAG --recurse-submodules https://github.com/go-skynet/go-llama.cpp
|
||||
RUN cd go-llama.cpp && make libbinding.a
|
||||
COPY go.mod ./
|
||||
COPY go.sum ./
|
||||
RUN go mod download
|
||||
RUN apt-get update
|
||||
RUN apt-get update && apt-get install -y cmake
|
||||
COPY . .
|
||||
RUN go mod edit -replace github.com/go-skynet/go-llama.cpp=/build/go-llama.cpp
|
||||
RUN C_INCLUDE_PATH=/build/go-llama.cpp LIBRARY_PATH=/build/go-llama.cpp go build -o llama-cli ./
|
||||
ARG BUILD_TYPE=
|
||||
RUN make build${BUILD_TYPE}
|
||||
|
||||
FROM debian:$DEBIAN_VERSION
|
||||
COPY --from=builder /build/llama-cli /usr/bin/llama-cli
|
||||
|
41
Makefile
41
Makefile
@ -2,7 +2,7 @@ GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
GOVET=$(GOCMD) vet
|
||||
BINARY_NAME=llama-cli
|
||||
GOLLAMA_VERSION?=llama.cpp-4ad7313
|
||||
GOLLAMA_VERSION?=llama.cpp-5ecff35
|
||||
|
||||
GREEN := $(shell tput -Txterm setaf 2)
|
||||
YELLOW := $(shell tput -Txterm setaf 3)
|
||||
@ -17,23 +17,50 @@ all: help
|
||||
## Build:
|
||||
|
||||
build: prepare ## Build the project
|
||||
$(GOCMD) build -o $(BINARY_NAME) ./
|
||||
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./
|
||||
|
||||
buildgeneric: prepare-generic ## Build the project
|
||||
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./
|
||||
|
||||
go-gpt4all-j:
|
||||
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
|
||||
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
|
||||
@find ./go-gpt4all-j -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
|
||||
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
|
||||
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
|
||||
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
|
||||
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
|
||||
|
||||
go-gpt4all-j/libgptj.a: go-gpt4all-j
|
||||
$(MAKE) -C go-gpt4all-j libgptj.a
|
||||
|
||||
go-gpt4all-j/libgptj.a-generic: go-gpt4all-j
|
||||
$(MAKE) -C go-gpt4all-j generic-libgptj.a
|
||||
|
||||
go-llama:
|
||||
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
|
||||
|
||||
prepare: go-llama
|
||||
$(MAKE) -C go-llama libbinding.a
|
||||
|
||||
go-llama-generic:
|
||||
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
|
||||
$(MAKE) -C go-llama generic-libbinding.a
|
||||
|
||||
prepare: go-llama go-gpt4all-j/libgptj.a
|
||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
|
||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
|
||||
|
||||
prepare-generic: go-llama-generic go-gpt4all-j/libgptj.a-generic
|
||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
|
||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
|
||||
|
||||
clean: ## Remove build related file
|
||||
$(MAKE) -C go-llama clean
|
||||
rm -fr ./go-llama
|
||||
rm -f $(BINARY_NAME)
|
||||
rm -rf ./go-gpt4all-j
|
||||
rm -rf $(BINARY_NAME)
|
||||
|
||||
## Run:
|
||||
run: prepare
|
||||
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp LIBRARY_PATH=$(shell pwd)/go-llama.cpp $(GOCMD) run ./ api
|
||||
$(GOCMD) run ./ api
|
||||
|
||||
## Test:
|
||||
test: ## Run the tests of the project
|
||||
|
21
README.md
21
README.md
@ -1,15 +1,24 @@
|
||||
## :camel: llama-cli
|
||||
|
||||
|
||||
llama-cli is a straightforward golang CLI interface and API compatible with OpenAI for [llama.cpp](https://github.com/ggerganov/llama.cpp), it supports multiple-models and also provides a simple command line interface that allows text generation using a GPT-based model like llama directly from the terminal.
|
||||
llama-cli is a straightforward, drop-in replacement API compatible with OpenAI for local CPU inferencing, based on [llama.cpp](https://github.com/ggerganov/llama.cpp), [gpt4all](https://github.com/nomic-ai/gpt4all) and [ggml](https://github.com/ggerganov/ggml), including support GPT4ALL-J which is Apache 2.0 Licensed and can be used for commercial purposes.
|
||||
|
||||
It is compatible with the models supported by `llama.cpp`. You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`.
|
||||
- OpenAI compatible API
|
||||
- Supports multiple-models
|
||||
- Once loaded the first time, it keep models loaded in memory for faster inference
|
||||
- Provides a simple command line interface that allows text generation directly from the terminal
|
||||
- Support for prompt templates
|
||||
- Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp).
|
||||
|
||||
`llama-cli` doesn't shell-out, it uses https://github.com/go-skynet/go-llama.cpp, which is a golang binding of [llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
## Model compatibility
|
||||
|
||||
It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) and also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all).
|
||||
|
||||
Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`.
|
||||
|
||||
## Usage
|
||||
|
||||
You can use `docker-compose`:
|
||||
The easiest way to run llama-cli is by using `docker-compose`:
|
||||
|
||||
```bash
|
||||
|
||||
@ -27,15 +36,13 @@ docker compose up -d --build
|
||||
|
||||
# Now API is accessible at localhost:8080
|
||||
curl http://localhost:8080/v1/models
|
||||
|
||||
# {"object":"list","data":[{"id":"your-model.bin","object":"model"}]}
|
||||
|
||||
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "your-model.bin",
|
||||
"prompt": "A long time ago in a galaxy far, far away",
|
||||
"temperature": 0.7
|
||||
}'
|
||||
|
||||
|
||||
```
|
||||
|
||||
Note: The API doesn't inject a default prompt for talking to the model, while the CLI does. You have to use a prompt similar to what's described in the standford-alpaca docs: https://github.com/tatsu-lab/stanford_alpaca#data-release.
|
||||
|
118
api/api.go
118
api/api.go
@ -7,6 +7,7 @@ import (
|
||||
|
||||
model "github.com/go-skynet/llama-cli/pkg/model"
|
||||
|
||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
@ -60,13 +61,16 @@ type OpenAIRequest struct {
|
||||
Batch int `json:"batch"`
|
||||
F16 bool `json:"f16kv"`
|
||||
IgnoreEOS bool `json:"ignore_eos"`
|
||||
|
||||
Seed int `json:"seed"`
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
var err error
|
||||
var model *llama.LLama
|
||||
var gptModel *gptj.GPTJ
|
||||
|
||||
input := new(OpenAIRequest)
|
||||
// Get input data from the request body
|
||||
@ -77,9 +81,22 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
|
||||
if input.Model == "" {
|
||||
return fmt.Errorf("no model specified")
|
||||
} else {
|
||||
model, err = loader.LoadModel(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
// Try to load the model with both
|
||||
var llamaerr error
|
||||
llamaOpts := []llama.ModelOption{}
|
||||
if ctx != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
|
||||
}
|
||||
if f16 {
|
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
|
||||
model, llamaerr = loader.LoadLLaMAModel(input.Model, llamaOpts...)
|
||||
if llamaerr != nil {
|
||||
gptModel, err = loader.LoadGPTJModel(input.Model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -146,32 +163,70 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
|
||||
n = 1
|
||||
}
|
||||
|
||||
var predFunc func() (string, error)
|
||||
switch {
|
||||
case gptModel != nil:
|
||||
predFunc = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gptj.PredictOption{
|
||||
gptj.SetTemperature(temperature),
|
||||
gptj.SetTopP(topP),
|
||||
gptj.SetTopK(topK),
|
||||
gptj.SetTokens(tokens),
|
||||
gptj.SetThreads(threads),
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gptj.SetBatch(input.Batch))
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
predictOptions = append(predictOptions, gptj.SetSeed(input.Seed))
|
||||
}
|
||||
|
||||
return gptModel.Predict(
|
||||
predInput,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case model != nil:
|
||||
predFunc = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(temperature),
|
||||
llama.SetTopP(topP),
|
||||
llama.SetTopK(topK),
|
||||
llama.SetTokens(tokens),
|
||||
llama.SetThreads(threads),
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetSeed(input.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
predInput,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(temperature),
|
||||
llama.SetTopP(topP),
|
||||
llama.SetTopK(topK),
|
||||
llama.SetTokens(tokens),
|
||||
llama.SetThreads(threads),
|
||||
}
|
||||
var prediction string
|
||||
|
||||
if input.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
prediction, err := model.Predict(
|
||||
predInput,
|
||||
predictOptions...,
|
||||
)
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -179,6 +234,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
|
||||
if input.Echo {
|
||||
prediction = predInput + prediction
|
||||
}
|
||||
|
||||
if chat {
|
||||
result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}})
|
||||
} else {
|
||||
@ -194,7 +250,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMu
|
||||
}
|
||||
}
|
||||
|
||||
func Start(loader *model.ModelLoader, listenAddr string, threads int) error {
|
||||
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error {
|
||||
app := fiber.New()
|
||||
|
||||
// Default middleware config
|
||||
@ -207,8 +263,8 @@ func Start(loader *model.ModelLoader, listenAddr string, threads int) error {
|
||||
var mumutex = &sync.Mutex{}
|
||||
|
||||
// openAI compatible API endpoint
|
||||
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, mutex, mumutex, mu))
|
||||
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, mutex, mumutex, mu))
|
||||
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mutex, mumutex, mu))
|
||||
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mutex, mumutex, mu))
|
||||
app.Get("/v1/models", func(c *fiber.Ctx) error {
|
||||
models, err := loader.ListModels()
|
||||
if err != nil {
|
||||
|
@ -16,6 +16,8 @@ services:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
# args:
|
||||
# BUILD_TYPE: generic # Uncomment to build CPU generic code that works on most HW
|
||||
ports:
|
||||
- 8080:8080
|
||||
environment:
|
||||
|
1
go.mod
1
go.mod
@ -11,6 +11,7 @@ require (
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.4 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
|
||||
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/klauspost/compress v1.15.9 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -3,6 +3,8 @@ github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHG
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
||||
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c=
|
||||
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
|
14
main.go
14
main.go
@ -34,11 +34,6 @@ var nonEmptyInput string = `Below is an instruction that describes a task, paire
|
||||
### Response:
|
||||
`
|
||||
|
||||
func llamaFromOptions(ctx *cli.Context) (*llama.LLama, error) {
|
||||
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
|
||||
return llama.New(ctx.String("model"), opts...)
|
||||
}
|
||||
|
||||
func templateString(t string, in interface{}) (string, error) {
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Parse(t)
|
||||
@ -125,6 +120,10 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
||||
|
||||
Name: "api",
|
||||
Flags: []cli.Flag{
|
||||
&cli.BoolFlag{
|
||||
Name: "f16",
|
||||
EnvVars: []string{"F16"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "threads",
|
||||
EnvVars: []string{"THREADS"},
|
||||
@ -146,7 +145,7 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
||||
},
|
||||
},
|
||||
Action: func(ctx *cli.Context) error {
|
||||
return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
|
||||
return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"))
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -201,7 +200,8 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
l, err := llamaFromOptions(ctx)
|
||||
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
|
||||
l, err := llama.New(ctx.String("model"), opts...)
|
||||
if err != nil {
|
||||
fmt.Println("Loading the model failed:", err.Error())
|
||||
os.Exit(1)
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
)
|
||||
|
||||
@ -17,11 +18,12 @@ type ModelLoader struct {
|
||||
modelPath string
|
||||
mu sync.Mutex
|
||||
models map[string]*llama.LLama
|
||||
gptmodels map[string]*gptj.GPTJ
|
||||
promptsTemplates map[string]*template.Template
|
||||
}
|
||||
|
||||
func NewModelLoader(modelPath string) *ModelLoader {
|
||||
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
|
||||
return &ModelLoader{modelPath: modelPath, gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) ListModels() ([]string, error) {
|
||||
@ -62,16 +64,81 @@ func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string,
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
|
||||
func (ml *ModelLoader) loadTemplate(modelName, modelFile string) error {
|
||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile)
|
||||
|
||||
// Check if the model path exists
|
||||
if _, err := os.Stat(modelTemplateFile); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
dat, err := os.ReadFile(modelTemplateFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Parse(string(dat))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ml.promptsTemplates[modelName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
|
||||
// Check if we already have a loaded model
|
||||
modelFile := filepath.Join(ml.modelPath, modelName)
|
||||
|
||||
if m, ok := ml.gptmodels[modelFile]; ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Check if the model path exists
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
// try to find a s.bin
|
||||
modelBin := fmt.Sprintf("%s.bin", modelFile)
|
||||
if _, err := os.Stat(modelBin); os.IsNotExist(err) {
|
||||
return nil, err
|
||||
} else {
|
||||
modelName = fmt.Sprintf("%s.bin", modelName)
|
||||
modelFile = modelBin
|
||||
}
|
||||
}
|
||||
|
||||
// Load the model and keep it in memory for later use
|
||||
model, err := gptj.New(modelFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If there is a prompt template, load it
|
||||
if err := ml.loadTemplate(modelName, modelFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ml.gptmodels[modelFile] = model
|
||||
return model, err
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
|
||||
// Check if we already have a loaded model
|
||||
modelFile := filepath.Join(ml.modelPath, modelName)
|
||||
if m, ok := ml.models[modelFile]; ok {
|
||||
return m, nil
|
||||
}
|
||||
// TODO: This needs refactoring, it's really bad to have it in here
|
||||
// Check if we have a GPTJ model loaded instead
|
||||
if _, ok := ml.gptmodels[modelFile]; ok {
|
||||
return nil, fmt.Errorf("this model is a GPTJ one")
|
||||
}
|
||||
|
||||
// Check if the model path exists
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
@ -92,21 +159,8 @@ func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*
|
||||
}
|
||||
|
||||
// If there is a prompt template, load it
|
||||
|
||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile)
|
||||
// Check if the model path exists
|
||||
if _, err := os.Stat(modelTemplateFile); err == nil {
|
||||
dat, err := os.ReadFile(modelTemplateFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Parse(string(dat))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ml.promptsTemplates[modelName] = tmpl
|
||||
if err := ml.loadTemplate(modelName, modelFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ml.models[modelFile] = model
|
||||
|
4
prompt-templates/ggml-gpt4all-j.tmpl
Normal file
4
prompt-templates/ggml-gpt4all-j.tmpl
Normal file
@ -0,0 +1,4 @@
|
||||
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
|
||||
### Prompt:
|
||||
{{.Input}}
|
||||
### Response:
|
Loading…
Reference in New Issue
Block a user