mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat: add LangChainGo Huggingface backend (#446)
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
7282668da1
commit
3ba07a5928
@ -1,3 +1,5 @@
|
|||||||
|
.git
|
||||||
|
.idea
|
||||||
models
|
models
|
||||||
examples/chatbot-ui/models
|
examples/chatbot-ui/models
|
||||||
examples/rwkv/models
|
examples/rwkv/models
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -24,3 +24,4 @@ release/
|
|||||||
|
|
||||||
# just in case
|
# just in case
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
.idea
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/donomii/go-rwkv.cpp"
|
"github.com/donomii/go-rwkv.cpp"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/langchain"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
||||||
"github.com/go-skynet/bloomz.cpp"
|
"github.com/go-skynet/bloomz.cpp"
|
||||||
@ -494,6 +495,23 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
|||||||
model.SetTokenCallback(nil)
|
model.SetTokenCallback(nil)
|
||||||
return str, er
|
return str, er
|
||||||
}
|
}
|
||||||
|
case *langchain.HuggingFace:
|
||||||
|
fn = func() (string, error) {
|
||||||
|
|
||||||
|
// Generate the prediction using the language model
|
||||||
|
predictOptions := []langchain.PredictOption{
|
||||||
|
langchain.SetModel(c.Model),
|
||||||
|
langchain.SetMaxTokens(c.Maxtokens),
|
||||||
|
langchain.SetTemperature(c.Temperature),
|
||||||
|
langchain.SetStopWords(c.StopWords),
|
||||||
|
}
|
||||||
|
|
||||||
|
pred, er := model.PredictHuggingFace(s, predictOptions...)
|
||||||
|
if er != nil {
|
||||||
|
return "", er
|
||||||
|
}
|
||||||
|
return pred.Completion, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return func() (string, error) {
|
return func() (string, error) {
|
||||||
|
68
examples/langchain-huggingface/README.md
Normal file
68
examples/langchain-huggingface/README.md
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Data query example
|
||||||
|
|
||||||
|
Example of integration with HuggingFace Inference API with help of [langchaingo](https://github.com/tmc/langchaingo).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Download the LocalAI and start the API:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone LocalAI
|
||||||
|
git clone https://github.com/go-skynet/LocalAI
|
||||||
|
|
||||||
|
cd LocalAI/examples/langchain-huggingface
|
||||||
|
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
Node: Ensure you've set `HUGGINGFACEHUB_API_TOKEN` environment variable, you can generate it
|
||||||
|
on [Settings / Access Tokens](https://huggingface.co/settings/tokens) page of HuggingFace site.
|
||||||
|
|
||||||
|
This is an example `.env` file for LocalAI:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
MODELS_PATH=/models
|
||||||
|
CONTEXT_SIZE=512
|
||||||
|
HUGGINGFACEHUB_API_TOKEN=hg_123456
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using remote models
|
||||||
|
|
||||||
|
Now you can use any remote models available via HuggingFace API, for example let's enable using of
|
||||||
|
[gpt2](https://huggingface.co/gpt2) model in `gpt-3.5-turbo.yaml` config:
|
||||||
|
|
||||||
|
```yml
|
||||||
|
name: gpt-3.5-turbo
|
||||||
|
parameters:
|
||||||
|
model: gpt2
|
||||||
|
top_k: 80
|
||||||
|
temperature: 0.2
|
||||||
|
top_p: 0.7
|
||||||
|
context_size: 1024
|
||||||
|
backend: "langchain-huggingface"
|
||||||
|
stopwords:
|
||||||
|
- "HUMAN:"
|
||||||
|
- "GPT:"
|
||||||
|
roles:
|
||||||
|
user: " "
|
||||||
|
system: " "
|
||||||
|
template:
|
||||||
|
completion: completion
|
||||||
|
chat: gpt4all
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is you can see in field `parameters.model` equal `gpt2` and `backend` equal `langchain-huggingface`.
|
||||||
|
|
||||||
|
## How to use
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Now API is accessible at localhost:8080
|
||||||
|
curl http://localhost:8080/v1/models
|
||||||
|
# {"object":"list","data":[{"id":"gpt-3.5-turbo","object":"model"}]}
|
||||||
|
|
||||||
|
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"prompt": "A long time ago in a galaxy far, far away",
|
||||||
|
"temperature": 0.7
|
||||||
|
}'
|
||||||
|
```
|
15
examples/langchain-huggingface/docker-compose.yml
Normal file
15
examples/langchain-huggingface/docker-compose.yml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
version: '3.6'
|
||||||
|
|
||||||
|
services:
|
||||||
|
api:
|
||||||
|
image: quay.io/go-skynet/local-ai:latest
|
||||||
|
build:
|
||||||
|
context: ../../
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
ports:
|
||||||
|
- 8080:8080
|
||||||
|
env_file:
|
||||||
|
- ../../.env
|
||||||
|
volumes:
|
||||||
|
- ./models:/models:cached
|
||||||
|
command: ["/usr/bin/local-ai"]
|
1
examples/langchain-huggingface/models/completion.tmpl
Normal file
1
examples/langchain-huggingface/models/completion.tmpl
Normal file
@ -0,0 +1 @@
|
|||||||
|
{{.Input}}
|
17
examples/langchain-huggingface/models/gpt-3.5-turbo.yaml
Normal file
17
examples/langchain-huggingface/models/gpt-3.5-turbo.yaml
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
name: gpt-3.5-turbo
|
||||||
|
parameters:
|
||||||
|
model: gpt2
|
||||||
|
top_k: 80
|
||||||
|
temperature: 0.2
|
||||||
|
top_p: 0.7
|
||||||
|
context_size: 1024
|
||||||
|
backend: "langchain-huggingface"
|
||||||
|
stopwords:
|
||||||
|
- "HUMAN:"
|
||||||
|
- "GPT:"
|
||||||
|
roles:
|
||||||
|
user: " "
|
||||||
|
system: " "
|
||||||
|
template:
|
||||||
|
completion: completion
|
||||||
|
chat: gpt4all
|
4
examples/langchain-huggingface/models/gpt4all.tmpl
Normal file
4
examples/langchain-huggingface/models/gpt4all.tmpl
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
|
||||||
|
### Prompt:
|
||||||
|
{{.Input}}
|
||||||
|
### Response:
|
1
go.mod
1
go.mod
@ -59,6 +59,7 @@ require (
|
|||||||
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 // indirect
|
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 // indirect
|
||||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
|
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
|
||||||
github.com/tinylib/msgp v1.1.8 // indirect
|
github.com/tinylib/msgp v1.1.8 // indirect
|
||||||
|
github.com/tmc/langchaingo v0.0.0-20230530193922-fb062652f841 // indirect
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -192,6 +192,8 @@ github.com/swaggo/swag v1.16.1/go.mod h1:9/LMvHycG3NFHfR6LwvikHv5iFvmPADQ359cKik
|
|||||||
github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw=
|
github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw=
|
||||||
github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0=
|
github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0=
|
||||||
github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw=
|
github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw=
|
||||||
|
github.com/tmc/langchaingo v0.0.0-20230530193922-fb062652f841 h1:IVlfKPZzq3W1G+CkhZgN5VjmHnAeB3YqEvxyNPPCZXY=
|
||||||
|
github.com/tmc/langchaingo v0.0.0-20230530193922-fb062652f841/go.mod h1:6l1WoyqVDwkv7cFlY3gfcTv8yVowVyuutKv8PGlQCWI=
|
||||||
github.com/urfave/cli/v2 v2.25.3 h1:VJkt6wvEBOoSjPFQvOkv6iWIrsJyCrKGtCtxXWwmGeY=
|
github.com/urfave/cli/v2 v2.25.3 h1:VJkt6wvEBOoSjPFQvOkv6iWIrsJyCrKGtCtxXWwmGeY=
|
||||||
github.com/urfave/cli/v2 v2.25.3/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
|
github.com/urfave/cli/v2 v2.25.3/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
|
||||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
|
47
pkg/langchain/huggingface.go
Normal file
47
pkg/langchain/huggingface.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package langchain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/tmc/langchaingo/llms"
|
||||||
|
"github.com/tmc/langchaingo/llms/huggingface"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HuggingFace struct {
|
||||||
|
modelPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHuggingFace(repoId string) (*HuggingFace, error) {
|
||||||
|
return &HuggingFace{
|
||||||
|
modelPath: repoId,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HuggingFace) PredictHuggingFace(text string, opts ...PredictOption) (*Predict, error) {
|
||||||
|
po := NewPredictOptions(opts...)
|
||||||
|
|
||||||
|
// Init client
|
||||||
|
llm, err := huggingface.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert from LocalAI to LangChainGo format of options
|
||||||
|
co := []llms.CallOption{
|
||||||
|
llms.WithModel(po.Model),
|
||||||
|
llms.WithMaxTokens(po.MaxTokens),
|
||||||
|
llms.WithTemperature(po.Temperature),
|
||||||
|
llms.WithStopWords(po.StopWords),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call Inference API
|
||||||
|
ctx := context.Background()
|
||||||
|
completion, err := llm.Call(ctx, text, co...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Predict{
|
||||||
|
Completion: completion,
|
||||||
|
}, nil
|
||||||
|
}
|
57
pkg/langchain/langchain.go
Normal file
57
pkg/langchain/langchain.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package langchain
|
||||||
|
|
||||||
|
type PredictOptions struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
// MaxTokens is the maximum number of tokens to generate.
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
// Temperature is the temperature for sampling, between 0 and 1.
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
// StopWords is a list of words to stop on.
|
||||||
|
StopWords []string `json:"stop_words"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PredictOption func(p *PredictOptions)
|
||||||
|
|
||||||
|
var DefaultOptions = PredictOptions{
|
||||||
|
Model: "gpt2",
|
||||||
|
MaxTokens: 200,
|
||||||
|
Temperature: 0.96,
|
||||||
|
StopWords: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
type Predict struct {
|
||||||
|
Completion string
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetModel(model string) PredictOption {
|
||||||
|
return func(o *PredictOptions) {
|
||||||
|
o.Model = model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetTemperature(temperature float64) PredictOption {
|
||||||
|
return func(o *PredictOptions) {
|
||||||
|
o.Temperature = temperature
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetMaxTokens(maxTokens int) PredictOption {
|
||||||
|
return func(o *PredictOptions) {
|
||||||
|
o.MaxTokens = maxTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetStopWords(stopWords []string) PredictOption {
|
||||||
|
return func(o *PredictOptions) {
|
||||||
|
o.StopWords = stopWords
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPredictOptions Create a new PredictOptions object with the given options.
|
||||||
|
func NewPredictOptions(opts ...PredictOption) PredictOptions {
|
||||||
|
p := DefaultOptions
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&p)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
rwkv "github.com/donomii/go-rwkv.cpp"
|
rwkv "github.com/donomii/go-rwkv.cpp"
|
||||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/langchain"
|
||||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
||||||
bloomz "github.com/go-skynet/bloomz.cpp"
|
bloomz "github.com/go-skynet/bloomz.cpp"
|
||||||
bert "github.com/go-skynet/go-bert.cpp"
|
bert "github.com/go-skynet/go-bert.cpp"
|
||||||
@ -36,6 +37,7 @@ const (
|
|||||||
RwkvBackend = "rwkv"
|
RwkvBackend = "rwkv"
|
||||||
WhisperBackend = "whisper"
|
WhisperBackend = "whisper"
|
||||||
StableDiffusionBackend = "stablediffusion"
|
StableDiffusionBackend = "stablediffusion"
|
||||||
|
LCHuggingFaceBackend = "langchain-huggingface"
|
||||||
)
|
)
|
||||||
|
|
||||||
var backends []string = []string{
|
var backends []string = []string{
|
||||||
@ -100,6 +102,10 @@ var whisperModel = func(modelFile string) (interface{}, error) {
|
|||||||
return whisper.New(modelFile)
|
return whisper.New(modelFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lcHuggingFace = func(repoId string) (interface{}, error) {
|
||||||
|
return langchain.NewHuggingFace(repoId)
|
||||||
|
}
|
||||||
|
|
||||||
func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
|
func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
|
||||||
return func(s string) (interface{}, error) {
|
return func(s string) (interface{}, error) {
|
||||||
return llama.New(s, opts...)
|
return llama.New(s, opts...)
|
||||||
@ -159,6 +165,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
|
|||||||
return ml.LoadModel(modelFile, rwkvLM(filepath.Join(ml.ModelPath, modelFile+tokenizerSuffix), threads))
|
return ml.LoadModel(modelFile, rwkvLM(filepath.Join(ml.ModelPath, modelFile+tokenizerSuffix), threads))
|
||||||
case WhisperBackend:
|
case WhisperBackend:
|
||||||
return ml.LoadModel(modelFile, whisperModel)
|
return ml.LoadModel(modelFile, whisperModel)
|
||||||
|
case LCHuggingFaceBackend:
|
||||||
|
return ml.LoadModel(modelFile, lcHuggingFace)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("backend unsupported: %s", backendString)
|
return nil, fmt.Errorf("backend unsupported: %s", backendString)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user