mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-24 06:46:39 +00:00
feat: bert.cpp token embeddings (#241)
This commit is contained in:
parent
b4241d0a0d
commit
2488c445b6
3
Makefile
3
Makefile
@ -10,7 +10,7 @@ GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307
|
|||||||
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
||||||
RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47
|
RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47
|
||||||
WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993
|
WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993
|
||||||
BERT_VERSION?=ec771ec715576ac050263bb7bb74bfd616a5ba13
|
BERT_VERSION?=ac22f8f74aec5e31bc46242c17e7d511f127856b
|
||||||
BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1
|
BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1
|
||||||
|
|
||||||
|
|
||||||
@ -182,6 +182,7 @@ test-models/testmodel:
|
|||||||
mkdir test-dir
|
mkdir test-dir
|
||||||
wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel
|
wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel
|
||||||
wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||||
|
wget https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-q4_0.bin -O test-models/bert
|
||||||
wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
|
wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
|
||||||
cp tests/fixtures/* test-models
|
cp tests/fixtures/* test-models
|
||||||
|
|
||||||
|
@ -47,8 +47,7 @@ var _ = Describe("API test", func() {
|
|||||||
It("returns the models list", func() {
|
It("returns the models list", func() {
|
||||||
models, err := client.ListModels(context.TODO())
|
models, err := client.ListModels(context.TODO())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(len(models.Models)).To(Equal(5))
|
Expect(len(models.Models)).To(Equal(7))
|
||||||
Expect(models.Models[0].ID).To(Equal("testmodel"))
|
|
||||||
})
|
})
|
||||||
It("can generate completions", func() {
|
It("can generate completions", func() {
|
||||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
|
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
|
||||||
@ -97,6 +96,33 @@ var _ = Describe("API test", func() {
|
|||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
|
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("calculate embeddings", func() {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
Skip("test supported only on linux")
|
||||||
|
}
|
||||||
|
resp, err := client.CreateEmbeddings(
|
||||||
|
context.Background(),
|
||||||
|
openai.EmbeddingRequest{
|
||||||
|
Model: openai.AdaEmbeddingV2,
|
||||||
|
Input: []string{"sun", "cat"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
|
||||||
|
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
|
||||||
|
|
||||||
|
sunEmbedding := resp.Data[0].Embedding
|
||||||
|
resp2, err := client.CreateEmbeddings(
|
||||||
|
context.Background(),
|
||||||
|
openai.EmbeddingRequest{
|
||||||
|
Model: openai.AdaEmbeddingV2,
|
||||||
|
Input: []string{"sun"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Config file", func() {
|
Context("Config file", func() {
|
||||||
@ -123,8 +149,7 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
models, err := client.ListModels(context.TODO())
|
models, err := client.ListModels(context.TODO())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(len(models.Models)).To(Equal(7))
|
Expect(len(models.Models)).To(Equal(9))
|
||||||
Expect(models.Models[0].ID).To(Equal("testmodel"))
|
|
||||||
})
|
})
|
||||||
It("can generate chat completions from config file", func() {
|
It("can generate chat completions from config file", func() {
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
|
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
|
||||||
|
@ -68,7 +68,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
|
|||||||
case *bert.Bert:
|
case *bert.Bert:
|
||||||
fn = func() ([]float32, error) {
|
fn = func() ([]float32, error) {
|
||||||
if len(tokens) > 0 {
|
if len(tokens) > 0 {
|
||||||
return nil, fmt.Errorf("embeddings endpoint for this model supports only string")
|
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
|
||||||
}
|
}
|
||||||
return model.Embeddings(s, bert.SetThreads(c.Threads))
|
return model.Embeddings(s, bert.SetThreads(c.Threads))
|
||||||
}
|
}
|
||||||
|
6
tests/fixtures/embeddings.yaml
vendored
Normal file
6
tests/fixtures/embeddings.yaml
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
name: text-embedding-ada-002
|
||||||
|
parameters:
|
||||||
|
model: bert
|
||||||
|
threads: 14
|
||||||
|
backend: bert-embeddings
|
||||||
|
embeddings: true
|
Loading…
Reference in New Issue
Block a user