mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-06 02:28:15 +00:00
feat: add rwkv support (#158)
Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
parent
1ae7150810
commit
751b7eca62
53
Makefile
53
Makefile
@ -9,14 +9,18 @@ GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109
|
|||||||
# renovate: datasource=git-refs packageNameTemplate=https://github.com/go-skynet/go-gpt2.cpp currentValueTemplate=master depNameTemplate=go-gpt2.cpp
|
# renovate: datasource=git-refs packageNameTemplate=https://github.com/go-skynet/go-gpt2.cpp currentValueTemplate=master depNameTemplate=go-gpt2.cpp
|
||||||
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa
|
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa
|
||||||
|
|
||||||
|
# here until https://github.com/donomii/go-rwkv.cpp/pull/1 is merged
|
||||||
|
RWKV_REPO?=https://github.com/mudler/go-rwkv.cpp
|
||||||
|
RWKV_VERSION?=6ba15255b03016b5ecce36529b500d21815399a7
|
||||||
|
|
||||||
GREEN := $(shell tput -Txterm setaf 2)
|
GREEN := $(shell tput -Txterm setaf 2)
|
||||||
YELLOW := $(shell tput -Txterm setaf 3)
|
YELLOW := $(shell tput -Txterm setaf 3)
|
||||||
WHITE := $(shell tput -Txterm setaf 7)
|
WHITE := $(shell tput -Txterm setaf 7)
|
||||||
CYAN := $(shell tput -Txterm setaf 6)
|
CYAN := $(shell tput -Txterm setaf 6)
|
||||||
RESET := $(shell tput -Txterm sgr0)
|
RESET := $(shell tput -Txterm sgr0)
|
||||||
|
|
||||||
C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2
|
C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv
|
||||||
LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2
|
LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv
|
||||||
|
|
||||||
# Use this if you want to set the default behavior
|
# Use this if you want to set the default behavior
|
||||||
ifndef BUILD_TYPE
|
ifndef BUILD_TYPE
|
||||||
@ -33,16 +37,6 @@ endif
|
|||||||
|
|
||||||
all: help
|
all: help
|
||||||
|
|
||||||
## Build:
|
|
||||||
|
|
||||||
build: prepare ## Build the project
|
|
||||||
$(info ${GREEN}I local-ai build info:${RESET})
|
|
||||||
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
|
||||||
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./
|
|
||||||
|
|
||||||
generic-build: ## Build the project using generic
|
|
||||||
BUILD_TYPE="generic" $(MAKE) build
|
|
||||||
|
|
||||||
## GPT4ALL-J
|
## GPT4ALL-J
|
||||||
go-gpt4all-j:
|
go-gpt4all-j:
|
||||||
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
|
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
|
||||||
@ -57,10 +51,18 @@ go-gpt4all-j:
|
|||||||
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
|
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
|
||||||
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
|
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
|
||||||
|
|
||||||
|
## RWKV
|
||||||
|
go-rwkv:
|
||||||
|
git clone --recurse-submodules $(RWKV_REPO) go-rwkv
|
||||||
|
cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
|
go-rwkv/librwkv.a: go-rwkv
|
||||||
|
cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. && cp ggml/src/libggml.a ..
|
||||||
|
|
||||||
go-gpt4all-j/libgptj.a: go-gpt4all-j
|
go-gpt4all-j/libgptj.a: go-gpt4all-j
|
||||||
$(MAKE) -C go-gpt4all-j $(GENERIC_PREFIX)libgptj.a
|
$(MAKE) -C go-gpt4all-j $(GENERIC_PREFIX)libgptj.a
|
||||||
|
|
||||||
# CEREBRAS GPT
|
## CEREBRAS GPT
|
||||||
go-gpt2:
|
go-gpt2:
|
||||||
git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2
|
git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2
|
||||||
cd go-gpt2 && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1
|
cd go-gpt2 && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
@ -75,7 +77,6 @@ go-gpt2:
|
|||||||
go-gpt2/libgpt2.a: go-gpt2
|
go-gpt2/libgpt2.a: go-gpt2
|
||||||
$(MAKE) -C go-gpt2 $(GENERIC_PREFIX)libgpt2.a
|
$(MAKE) -C go-gpt2 $(GENERIC_PREFIX)libgpt2.a
|
||||||
|
|
||||||
|
|
||||||
go-llama:
|
go-llama:
|
||||||
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
|
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
|
||||||
|
|
||||||
@ -86,26 +87,40 @@ replace:
|
|||||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
|
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
|
||||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
|
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
|
||||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2
|
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2
|
||||||
|
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv
|
||||||
|
|
||||||
prepare-sources: go-llama go-gpt2 go-gpt4all-j
|
prepare-sources: go-llama go-gpt2 go-gpt4all-j go-rwkv
|
||||||
$(GOCMD) mod download
|
$(GOCMD) mod download
|
||||||
|
|
||||||
rebuild:
|
## GENERIC
|
||||||
|
rebuild: ## Rebuilds the project
|
||||||
$(MAKE) -C go-llama clean
|
$(MAKE) -C go-llama clean
|
||||||
$(MAKE) -C go-gpt4all-j clean
|
$(MAKE) -C go-gpt4all-j clean
|
||||||
$(MAKE) -C go-gpt2 clean
|
$(MAKE) -C go-gpt2 clean
|
||||||
|
$(MAKE) -C go-rwkv clean
|
||||||
$(MAKE) build
|
$(MAKE) build
|
||||||
|
|
||||||
prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a replace
|
prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a go-rwkv/librwkv.a replace ## Prepares for building
|
||||||
|
|
||||||
clean: ## Remove build related file
|
clean: ## Remove build related file
|
||||||
rm -fr ./go-llama
|
rm -fr ./go-llama
|
||||||
rm -rf ./go-gpt4all-j
|
rm -rf ./go-gpt4all-j
|
||||||
rm -rf ./go-gpt2
|
rm -rf ./go-gpt2
|
||||||
|
rm -rf ./go-rwkv
|
||||||
rm -rf $(BINARY_NAME)
|
rm -rf $(BINARY_NAME)
|
||||||
|
|
||||||
## Run:
|
## Build:
|
||||||
run: prepare
|
|
||||||
|
build: prepare ## Build the project
|
||||||
|
$(info ${GREEN}I local-ai build info:${RESET})
|
||||||
|
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
||||||
|
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./
|
||||||
|
|
||||||
|
generic-build: ## Build the project using generic
|
||||||
|
BUILD_TYPE="generic" $(MAKE) build
|
||||||
|
|
||||||
|
## Run
|
||||||
|
run: prepare ## run local-ai
|
||||||
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./main.go
|
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./main.go
|
||||||
|
|
||||||
test-models/testmodel:
|
test-models/testmodel:
|
||||||
|
10
README.md
10
README.md
@ -15,7 +15,7 @@
|
|||||||
- Supports multiple-models
|
- Supports multiple-models
|
||||||
- Once loaded the first time, it keep models loaded in memory for faster inference
|
- Once loaded the first time, it keep models loaded in memory for faster inference
|
||||||
- Support for prompt templates
|
- Support for prompt templates
|
||||||
- Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp).
|
- Doesn't shell-out, but uses C bindings for a faster inference and better performance.
|
||||||
|
|
||||||
LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud).
|
LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud).
|
||||||
|
|
||||||
@ -39,6 +39,7 @@ Tested with:
|
|||||||
- [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin)
|
- [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin)
|
||||||
- Koala
|
- Koala
|
||||||
- [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml)
|
- [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml)
|
||||||
|
- [RWKV](https://github.com/BlinkDL/RWKV-LM) with [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp)
|
||||||
|
|
||||||
It should also be compatible with StableLM and GPTNeoX ggml models (untested)
|
It should also be compatible with StableLM and GPTNeoX ggml models (untested)
|
||||||
|
|
||||||
@ -506,6 +507,13 @@ LocalAI is a community-driven project. It was initially created by [mudler](http
|
|||||||
|
|
||||||
MIT
|
MIT
|
||||||
|
|
||||||
|
## Golang bindings used
|
||||||
|
|
||||||
|
- [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
|
||||||
|
- [go-skynet/go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp)
|
||||||
|
- [go-skynet/go-gpt2.cpp](https://github.com/go-skynet/go-gpt2.cpp)
|
||||||
|
- [donomii/go-rwkv.cpp](https://github.com/donomii/go-rwkv.cpp)
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
|
|
||||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
- [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||||
|
@ -79,7 +79,7 @@ var _ = Describe("API test", func() {
|
|||||||
It("returns errors", func() {
|
It("returns errors", func() {
|
||||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
|
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 4 errors occurred:"))
|
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 5 errors occurred:"))
|
||||||
})
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/donomii/go-rwkv.cpp"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||||
@ -13,6 +14,8 @@ import (
|
|||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const tokenizerSuffix = ".tokenizer.json"
|
||||||
|
|
||||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
var mutexMap sync.Mutex
|
var mutexMap sync.Mutex
|
||||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
||||||
@ -20,7 +23,7 @@ var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
|||||||
var loadedModels map[string]interface{} = map[string]interface{}{}
|
var loadedModels map[string]interface{} = map[string]interface{}{}
|
||||||
var muModels sync.Mutex
|
var muModels sync.Mutex
|
||||||
|
|
||||||
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
|
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
|
||||||
switch strings.ToLower(backendString) {
|
switch strings.ToLower(backendString) {
|
||||||
case "llama":
|
case "llama":
|
||||||
return loader.LoadLLaMAModel(modelFile, llamaOpts...)
|
return loader.LoadLLaMAModel(modelFile, llamaOpts...)
|
||||||
@ -30,12 +33,14 @@ func backendLoader(backendString string, loader *model.ModelLoader, modelFile st
|
|||||||
return loader.LoadGPT2Model(modelFile)
|
return loader.LoadGPT2Model(modelFile)
|
||||||
case "gptj":
|
case "gptj":
|
||||||
return loader.LoadGPTJModel(modelFile)
|
return loader.LoadGPTJModel(modelFile)
|
||||||
|
case "rwkv":
|
||||||
|
return loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("backend unsupported: %s", backendString)
|
return nil, fmt.Errorf("backend unsupported: %s", backendString)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
|
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
|
||||||
updateModels := func(model interface{}) {
|
updateModels := func(model interface{}) {
|
||||||
muModels.Lock()
|
muModels.Lock()
|
||||||
defer muModels.Unlock()
|
defer muModels.Unlock()
|
||||||
@ -82,6 +87,14 @@ func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama
|
|||||||
err = multierror.Append(err, modelerr)
|
err = multierror.Append(err, modelerr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model, modelerr = loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
|
||||||
|
if modelerr == nil {
|
||||||
|
updateModels(model)
|
||||||
|
return model, nil
|
||||||
|
} else {
|
||||||
|
err = multierror.Append(err, modelerr)
|
||||||
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
|
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,9 +114,9 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
|||||||
var inferenceModel interface{}
|
var inferenceModel interface{}
|
||||||
var err error
|
var err error
|
||||||
if c.Backend == "" {
|
if c.Backend == "" {
|
||||||
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts)
|
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts, uint32(c.Threads))
|
||||||
} else {
|
} else {
|
||||||
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts)
|
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts, uint32(c.Threads))
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -112,6 +125,20 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
|||||||
var fn func() (string, error)
|
var fn func() (string, error)
|
||||||
|
|
||||||
switch model := inferenceModel.(type) {
|
switch model := inferenceModel.(type) {
|
||||||
|
case *rwkv.RwkvState:
|
||||||
|
supportStreams = true
|
||||||
|
|
||||||
|
fn = func() (string, error) {
|
||||||
|
//model.ProcessInput("You are a chatbot that is very good at chatting. blah blah blah")
|
||||||
|
stopWord := "\n"
|
||||||
|
if len(c.StopWords) > 0 {
|
||||||
|
stopWord = c.StopWords[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
case *gpt2.StableLM:
|
case *gpt2.StableLM:
|
||||||
fn = func() (string, error) {
|
fn = func() (string, error) {
|
||||||
// Generate the prediction using the language model
|
// Generate the prediction using the language model
|
||||||
|
1
go.mod
1
go.mod
@ -23,6 +23,7 @@ require (
|
|||||||
github.com/StackExchange/wmi v1.2.1 // indirect
|
github.com/StackExchange/wmi v1.2.1 // indirect
|
||||||
github.com/andybalholm/brotli v1.0.5 // indirect
|
github.com/andybalholm/brotli v1.0.5 // indirect
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
|
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
|
||||||
|
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d // indirect
|
||||||
github.com/ghodss/yaml v1.0.0 // indirect
|
github.com/ghodss/yaml v1.0.0 // indirect
|
||||||
github.com/go-logr/logr v1.2.3 // indirect
|
github.com/go-logr/logr v1.2.3 // indirect
|
||||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -12,6 +12,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d h1:lSHwlYf1H4WAWYgf7rjEVTGen1qmigUq2Egpu8mnQiY=
|
||||||
|
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d/go.mod h1:H6QBF7/Tz6DAEBDXQged4H1BvsmqY/K5FG9wQRGa01g=
|
||||||
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
|
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
|
||||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||||
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
rwkv "github.com/donomii/go-rwkv.cpp"
|
||||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||||
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
|
||||||
llama "github.com/go-skynet/go-llama.cpp"
|
llama "github.com/go-skynet/go-llama.cpp"
|
||||||
@ -25,8 +26,8 @@ type ModelLoader struct {
|
|||||||
gptmodels map[string]*gptj.GPTJ
|
gptmodels map[string]*gptj.GPTJ
|
||||||
gpt2models map[string]*gpt2.GPT2
|
gpt2models map[string]*gpt2.GPT2
|
||||||
gptstablelmmodels map[string]*gpt2.StableLM
|
gptstablelmmodels map[string]*gpt2.StableLM
|
||||||
|
rwkv map[string]*rwkv.RwkvState
|
||||||
promptsTemplates map[string]*template.Template
|
promptsTemplates map[string]*template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewModelLoader(modelPath string) *ModelLoader {
|
func NewModelLoader(modelPath string) *ModelLoader {
|
||||||
@ -36,6 +37,7 @@ func NewModelLoader(modelPath string) *ModelLoader {
|
|||||||
gptmodels: make(map[string]*gptj.GPTJ),
|
gptmodels: make(map[string]*gptj.GPTJ),
|
||||||
gptstablelmmodels: make(map[string]*gpt2.StableLM),
|
gptstablelmmodels: make(map[string]*gpt2.StableLM),
|
||||||
models: make(map[string]*llama.LLama),
|
models: make(map[string]*llama.LLama),
|
||||||
|
rwkv: make(map[string]*rwkv.RwkvState),
|
||||||
promptsTemplates: make(map[string]*template.Template),
|
promptsTemplates: make(map[string]*template.Template),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -218,6 +220,36 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
|
|||||||
return model, err
|
return model, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) LoadRWKV(modelName, tokenFile string, threads uint32) (*rwkv.RwkvState, error) {
|
||||||
|
ml.mu.Lock()
|
||||||
|
defer ml.mu.Unlock()
|
||||||
|
|
||||||
|
log.Debug().Msgf("Loading model name: %s", modelName)
|
||||||
|
|
||||||
|
// Check if we already have a loaded model
|
||||||
|
if !ml.ExistsInModelPath(modelName) {
|
||||||
|
return nil, fmt.Errorf("model does not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m, ok := ml.rwkv[modelName]; ok {
|
||||||
|
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the model and keep it in memory for later use
|
||||||
|
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||||
|
tokenPath := filepath.Join(ml.ModelPath, tokenFile)
|
||||||
|
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
|
||||||
|
|
||||||
|
model := rwkv.LoadFiles(modelFile, tokenPath, threads)
|
||||||
|
if model == nil {
|
||||||
|
return nil, fmt.Errorf("could not load model")
|
||||||
|
}
|
||||||
|
|
||||||
|
ml.rwkv[modelName] = model
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
|
func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
|
||||||
ml.mu.Lock()
|
ml.mu.Lock()
|
||||||
defer ml.mu.Unlock()
|
defer ml.mu.Unlock()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user