Add support for cerebras (#45)

Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
Ettore Di Giacinto 2023-04-20 19:33:36 +02:00 committed by GitHub
parent d517a54e28
commit 1c4fbaae20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 124 additions and 20 deletions

View File

@ -17,11 +17,12 @@ all: help
## Build: ## Build:
build: prepare ## Build the project build: prepare ## Build the project
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./ C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp $(GOCMD) build -o $(BINARY_NAME) ./
buildgeneric: prepare-generic ## Build the project buildgeneric: prepare-generic ## Build the project
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./ C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp $(GOCMD) build -o $(BINARY_NAME) ./
## GPT4ALL-J
go-gpt4all-j: go-gpt4all-j:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
@ -30,6 +31,9 @@ go-gpt4all-j:
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} + @find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} + @find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
go-gpt4all-j/libgptj.a: go-gpt4all-j go-gpt4all-j/libgptj.a: go-gpt4all-j
$(MAKE) -C go-gpt4all-j libgptj.a $(MAKE) -C go-gpt4all-j libgptj.a
@ -37,6 +41,25 @@ go-gpt4all-j/libgptj.a: go-gpt4all-j
go-gpt4all-j/libgptj.a-generic: go-gpt4all-j go-gpt4all-j/libgptj.a-generic: go-gpt4all-j
$(MAKE) -C go-gpt4all-j generic-libgptj.a $(MAKE) -C go-gpt4all-j generic-libgptj.a
# CEREBRAS GPT
go-gpt2.cpp:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2.cpp
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
@find ./go-gpt2.cpp -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gpt2_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gpt2_replace/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gpt2_replace/g' {} +
go-gpt2.cpp/libgpt2.a: go-gpt2.cpp
$(MAKE) -C go-gpt2.cpp libgpt2.a
go-gpt2.cpp/libgpt2.a-generic: go-gpt2.cpp
$(MAKE) -C go-gpt2.cpp generic-libgpt2.a
go-llama: go-llama:
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
$(MAKE) -C go-llama libbinding.a $(MAKE) -C go-llama libbinding.a
@ -45,17 +68,19 @@ go-llama-generic:
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
$(MAKE) -C go-llama generic-libbinding.a $(MAKE) -C go-llama generic-libbinding.a
prepare: go-llama go-gpt4all-j/libgptj.a replace:
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2.cpp
prepare: go-llama go-gpt4all-j/libgptj.a go-gpt2.cpp/libgpt2.a replace
prepare-generic: go-llama-generic go-gpt4all-j/libgptj.a-generic go-gpt2.cpp/libgpt2.a-generic replace
prepare-generic: go-llama-generic go-gpt4all-j/libgptj.a-generic
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
clean: ## Remove build related file clean: ## Remove build related file
rm -fr ./go-llama rm -fr ./go-llama
rm -rf ./go-gpt4all-j rm -rf ./go-gpt4all-j
rm -rf ./go-gpt2.cpp
rm -rf $(BINARY_NAME) rm -rf $(BINARY_NAME)
## Run: ## Run:

View File

@ -12,13 +12,12 @@ LocalAI is a straightforward, drop-in replacement API compatible with OpenAI for
- OpenAI compatible API - OpenAI compatible API
- Supports multiple-models - Supports multiple-models
- Once loaded the first time, it keep models loaded in memory for faster inference - Once loaded the first time, it keep models loaded in memory for faster inference
- Provides a simple command line interface that allows text generation directly from the terminal
- Support for prompt templates - Support for prompt templates
- Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp). - Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp).
## Model compatibility ## Model compatibility
It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) and also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all). It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) supports also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all) and [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml).
Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`. Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`.
@ -97,8 +96,6 @@ And you'll see:
└───────────────────────────────────────────────────┘ └───────────────────────────────────────────────────┘
``` ```
Note: Models have to end up with `.bin` so can be listed by the `/models` endpoint.
You can control the API server options with command line arguments: You can control the API server options with command line arguments:
``` ```

View File

@ -8,6 +8,7 @@ import (
"sync" "sync"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp" llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -73,6 +74,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var err error var err error
var model *llama.LLama var model *llama.LLama
var gptModel *gptj.GPTJ var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2
input := new(OpenAIRequest) input := new(OpenAIRequest)
// Get input data from the request body // Get input data from the request body
@ -97,7 +99,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
} }
// Try to load the model with both // Try to load the model with both
var llamaerr error var llamaerr, gpt2err, gptjerr error
llamaOpts := []llama.ModelOption{} llamaOpts := []llama.ModelOption{}
if ctx != 0 { if ctx != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(ctx)) llamaOpts = append(llamaOpts, llama.SetContext(ctx))
@ -106,11 +108,15 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
llamaOpts = append(llamaOpts, llama.EnableF16Memory) llamaOpts = append(llamaOpts, llama.EnableF16Memory)
} }
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
if llamaerr != nil { if llamaerr != nil {
gptModel, err = loader.LoadGPTJModel(modelFile) gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
if err != nil { if gptjerr != nil {
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil {
return fmt.Errorf("llama: %s gpt: %s gpt2: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error()) // llama failed first, so we want to catch both errors
}
} }
} }
@ -176,6 +182,30 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var predFunc func() (string, error) var predFunc func() (string, error)
switch { switch {
case gpt2Model != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(temperature),
gpt2.SetTopP(topP),
gpt2.SetTopK(topK),
gpt2.SetTokens(tokens),
gpt2.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
}
return gpt2Model.Predict(
predInput,
predictOptions...,
)
}
case gptModel != nil: case gptModel != nil:
predFunc = func() (string, error) { predFunc = func() (string, error) {
// Generate the prediction using the language model // Generate the prediction using the language model

1
go.mod
View File

@ -13,6 +13,7 @@ require (
require ( require (
github.com/andybalholm/brotli v1.0.4 // indirect github.com/andybalholm/brotli v1.0.4 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d // indirect
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0 // indirect
github.com/klauspost/compress v1.15.9 // indirect github.com/klauspost/compress v1.15.9 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect

4
go.sum
View File

@ -4,6 +4,10 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420164106-516b5871c74d h1:8crcrVuvpRzf6wejPtIFYGmMrSTfW94CYPJZIssT8zo=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420164106-516b5871c74d/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d h1:Jabxk0NI5CLbY7PVODkRp1AQbEovS9gM6jGAOwyy5FI=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E=

View File

@ -12,20 +12,24 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp" llama "github.com/go-skynet/go-llama.cpp"
) )
type ModelLoader struct { type ModelLoader struct {
modelPath string modelPath string
mu sync.Mutex mu sync.Mutex
models map[string]*llama.LLama
gptmodels map[string]*gptj.GPTJ models map[string]*llama.LLama
gptmodels map[string]*gptj.GPTJ
gpt2models map[string]*gpt2.GPT2
promptsTemplates map[string]*template.Template promptsTemplates map[string]*template.Template
} }
func NewModelLoader(modelPath string) *ModelLoader { func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{modelPath: modelPath, gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} return &ModelLoader{modelPath: modelPath, gpt2models: make(map[string]*gpt2.GPT2), gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
} }
func (ml *ModelLoader) ExistsInModelPath(s string) bool { func (ml *ModelLoader) ExistsInModelPath(s string) bool {
@ -98,6 +102,38 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
return nil return nil
} }
func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.gpt2models[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gpt2.New(modelFile)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return nil, err
}
ml.gpt2models[modelName] = model
return model, err
}
func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) { func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
ml.mu.Lock() ml.mu.Lock()
defer ml.mu.Unlock() defer ml.mu.Unlock()
@ -112,6 +148,13 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
return m, nil return m, nil
} }
// TODO: This needs refactoring, it's really bad to have it in here
// Check if we have a GPT2 model loaded instead - if we do we return an error so the API tries with GPT2
if _, ok := ml.gpt2models[modelName]; ok {
log.Debug().Msgf("Model is GPT2: %s", modelName)
return nil, fmt.Errorf("this model is a GPT2 one")
}
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.modelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile) log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
@ -152,6 +195,10 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
log.Debug().Msgf("Model is GPTJ: %s", modelName) log.Debug().Msgf("Model is GPTJ: %s", modelName)
return nil, fmt.Errorf("this model is a GPTJ one") return nil, fmt.Errorf("this model is a GPTJ one")
} }
if _, ok := ml.gpt2models[modelName]; ok {
log.Debug().Msgf("Model is GPT2: %s", modelName)
return nil, fmt.Errorf("this model is a GPT2 one")
}
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.modelPath, modelName)