diff --git a/Makefile b/Makefile index b6471f3e..4be5604b 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307 RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 -BERT_VERSION?=ec771ec715576ac050263bb7bb74bfd616a5ba13 +BERT_VERSION?=ac22f8f74aec5e31bc46242c17e7d511f127856b BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1 @@ -182,6 +182,7 @@ test-models/testmodel: mkdir test-dir wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en + wget https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-q4_0.bin -O test-models/bert wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav cp tests/fixtures/* test-models diff --git a/api/api_test.go b/api/api_test.go index e4433112..def4e206 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -47,8 +47,7 @@ var _ = Describe("API test", func() { It("returns the models list", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(5)) - Expect(models.Models[0].ID).To(Equal("testmodel")) + Expect(len(models.Models)).To(Equal(7)) }) It("can generate completions", func() { resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"}) @@ -97,6 +96,33 @@ var _ = Describe("API test", func() { Expect(err).ToNot(HaveOccurred()) Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting")) }) + + It("calculate embeddings", func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + resp, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"sun", "cat"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) + Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) + + sunEmbedding := resp.Data[0].Embedding + resp2, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"sun"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) + }) }) Context("Config file", func() { @@ -123,8 +149,7 @@ var _ = Describe("API test", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(7)) - Expect(models.Models[0].ID).To(Equal("testmodel")) + Expect(len(models.Models)).To(Equal(9)) }) It("can generate chat completions from config file", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) diff --git a/api/prediction.go b/api/prediction.go index f31ffd57..3dfb45fd 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -68,7 +68,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) case *bert.Bert: fn = func() ([]float32, error) { if len(tokens) > 0 { - return nil, fmt.Errorf("embeddings endpoint for this model supports only string") + return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) } return model.Embeddings(s, bert.SetThreads(c.Threads)) } diff --git a/tests/fixtures/embeddings.yaml b/tests/fixtures/embeddings.yaml new file mode 100644 index 00000000..b90ca75a --- /dev/null +++ b/tests/fixtures/embeddings.yaml @@ -0,0 +1,6 @@ +name: text-embedding-ada-002 +parameters: + model: bert +threads: 14 +backend: bert-embeddings +embeddings: true