feat(stores): Vector store backend (#1795)

Add simple vector store backend

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe 2024-03-22 20:14:04 +00:00 committed by GitHub
parent 4b1ee0c170
commit 643d85d2cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 3250 additions and 441 deletions

31
.editorconfig Normal file
View File

@ -0,0 +1,31 @@
root = true
[*]
indent_style = space
indent_size = 2
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.go]
indent_style = tab
[Makefile]
indent_style = tab
[*.proto]
indent_size = 2
[*.py]
indent_size = 4
[*.js]
indent_size = 2
[*.yaml]
indent_size = 2
[*.md]
trim_trailing_whitespace = false

View File

@ -159,6 +159,7 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-ggml
ALL_GRPC_BACKENDS+=backend-assets/grpc/gpt4all ALL_GRPC_BACKENDS+=backend-assets/grpc/gpt4all
ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC) ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC) GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
@ -333,7 +334,7 @@ prepare-test: grpcs
test: prepare test-models/testmodel.ggml grpcs test: prepare test-models/testmodel.ggml grpcs
@echo 'Running tests' @echo 'Running tests'
export GO_TAGS="tts stablediffusion" export GO_TAGS="tts stablediffusion debug"
$(MAKE) prepare-test $(MAKE) prepare-test
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS) $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
@ -387,6 +388,11 @@ test-stablediffusion: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r $(TEST_PATHS) $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r $(TEST_PATHS)
test-stores: backend-assets/grpc/local-store
mkdir -p tests/integration/backend-assets/grpc
cp -f backend-assets/grpc/local-store tests/integration/backend-assets/grpc/
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts 1 -v -r tests/integration
test-container: test-container:
docker build --target requirements -t local-ai-test-container . docker build --target requirements -t local-ai-test-container .
docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container
@ -536,6 +542,9 @@ backend-assets/grpc/whisper: sources/whisper.cpp sources/whisper.cpp/libwhisper.
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH=$(CURDIR)/sources/whisper.cpp LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \ CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH=$(CURDIR)/sources/whisper.cpp LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/
backend-assets/grpc/local-store: backend-assets/grpc
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/local-store ./backend/go/stores/
grpcs: prepare $(GRPC_BACKENDS) grpcs: prepare $(GRPC_BACKENDS)
DOCKER_IMAGE?=local-ai DOCKER_IMAGE?=local-ai

View File

@ -18,6 +18,48 @@ service Backend {
rpc TTS(TTSRequest) returns (Result) {} rpc TTS(TTSRequest) returns (Result) {}
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {} rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
rpc Status(HealthMessage) returns (StatusResponse) {} rpc Status(HealthMessage) returns (StatusResponse) {}
rpc StoresSet(StoresSetOptions) returns (Result) {}
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
}
message StoresKey {
repeated float Floats = 1;
}
message StoresValue {
bytes Bytes = 1;
}
message StoresSetOptions {
repeated StoresKey Keys = 1;
repeated StoresValue Values = 2;
}
message StoresDeleteOptions {
repeated StoresKey Keys = 1;
}
message StoresGetOptions {
repeated StoresKey Keys = 1;
}
message StoresGetResult {
repeated StoresKey Keys = 1;
repeated StoresValue Values = 2;
}
message StoresFindOptions {
StoresKey Key = 1;
int32 TopK = 2;
}
message StoresFindResult {
repeated StoresKey Keys = 1;
repeated StoresValue Values = 2;
repeated float Similarities = 3;
} }
message HealthMessage {} message HealthMessage {}

View File

@ -0,0 +1,14 @@
//go:build debug
// +build debug
package main
import (
"github.com/rs/zerolog/log"
)
func assert(cond bool, msg string) {
if !cond {
log.Fatal().Stack().Msg(msg)
}
}

26
backend/go/stores/main.go Normal file
View File

@ -0,0 +1,26 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each store
import (
"flag"
"os"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
flag.Parse()
if err := grpc.StartServer(*addr, NewStore()); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,7 @@
//go:build !debug
// +build !debug
package main
func assert(cond bool, msg string) {
}

507
backend/go/stores/store.go Normal file
View File

@ -0,0 +1,507 @@
package main
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"container/heap"
"fmt"
"math"
"slices"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/rs/zerolog/log"
)
type Store struct {
base.SingleThread
// The sorted keys
keys [][]float32
// The sorted values
values [][]byte
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
// TODO: Should we normalize incoming keys if they are not instead?
keysAreNormalized bool
// The first key decides the length of the keys
keyLen int
}
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
type Pair struct {
Key []float32
Value []byte
}
func NewStore() *Store {
return &Store{
keys: make([][]float32, 0),
values: make([][]byte, 0),
keysAreNormalized: true,
keyLen: -1,
}
}
func compareSlices(k1, k2 []float32) int {
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
return slices.Compare(k1, k2)
}
func hasKey(unsortedSlice [][]float32, target []float32) bool {
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
return compareSlices(k, target) == 0
})
}
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
return compareSlices(k, t)
})
}
func isSortedPairs(kvs []Pair) bool {
for i := 1; i < len(kvs); i++ {
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
return false
}
}
return true
}
func isSortedKeys(keys [][]float32) bool {
for i := 1; i < len(keys); i++ {
if compareSlices(keys[i-1], keys[i]) > 0 {
return false
}
}
return true
}
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
ks := make([][]float32, len(keys))
for i, k := range keys {
ks[i] = k.Floats
}
slices.SortFunc(ks, compareSlices)
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
assert(isSortedKeys(ks), "keys are not sorted")
return ks
}
func (s *Store) Load(opts *pb.ModelOptions) error {
return nil
}
// Sort the incoming kvs and merge them with the existing sorted kvs
func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
if len(opts.Keys) == 0 {
return fmt.Errorf("no keys to add")
}
if len(opts.Keys) != len(opts.Values) {
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
}
if s.keyLen == -1 {
s.keyLen = len(opts.Keys[0].Floats)
} else {
if len(opts.Keys[0].Floats) != s.keyLen {
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
}
}
kvs := make([]Pair, len(opts.Keys))
for i, k := range opts.Keys {
if s.keysAreNormalized && !isNormalized(k.Floats) {
s.keysAreNormalized = false
var sample []float32
if len(s.keys) > 5 {
sample = k.Floats[:5]
} else {
sample = k.Floats
}
log.Debug().Msgf("Key is not normalized: %v", sample)
}
kvs[i] = Pair{
Key: k.Floats,
Value: opts.Values[i].Bytes,
}
}
slices.SortFunc(kvs, func(a, b Pair) int {
return compareSlices(a.Key, b.Key)
})
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
assert(isSortedPairs(kvs), "keys are not sorted")
l := len(kvs) + len(s.keys)
merge_ks := make([][]float32, 0, l)
merge_vs := make([][]byte, 0, l)
i, j := 0, 0
for {
if i+j >= l {
break
}
if i >= len(kvs) {
merge_ks = append(merge_ks, s.keys[j])
merge_vs = append(merge_vs, s.values[j])
j++
continue
}
if j >= len(s.keys) {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
continue
}
c := compareSlices(kvs[i].Key, s.keys[j])
if c < 0 {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
} else if c > 0 {
merge_ks = append(merge_ks, s.keys[j])
merge_vs = append(merge_vs, s.values[j])
j++
} else {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
j++
}
}
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
s.keys = merge_ks
s.values = merge_vs
return nil
}
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
if len(opts.Keys) == 0 {
return fmt.Errorf("no keys to delete")
}
if len(opts.Keys) == 0 {
return fmt.Errorf("no keys to add")
}
if s.keyLen == -1 {
s.keyLen = len(opts.Keys[0].Floats)
} else {
if len(opts.Keys[0].Floats) != s.keyLen {
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
}
}
ks := sortIntoKeySlicese(opts.Keys)
l := len(s.keys) - len(ks)
merge_ks := make([][]float32, 0, l)
merge_vs := make([][]byte, 0, l)
tail_ks := s.keys
tail_vs := s.values
for _, k := range ks {
j, found := findInSortedSlice(tail_ks, k)
if found {
merge_ks = append(merge_ks, tail_ks[:j]...)
merge_vs = append(merge_vs, tail_vs[:j]...)
tail_ks = tail_ks[j+1:]
tail_vs = tail_vs[j+1:]
} else {
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
}
log.Debug().Msgf("Delete: found = %v, t = %d, j = %d, len(merge_ks) = %d, len(merge_vs) = %d", found, len(tail_ks), j, len(merge_ks), len(merge_vs))
}
merge_ks = append(merge_ks, tail_ks...)
merge_vs = append(merge_vs, tail_vs...)
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
s.keys = merge_ks
s.values = merge_vs
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
assert(isSortedKeys(s.keys), "keys are not sorted")
assert(func() bool {
for _, k := range ks {
if _, found := findInSortedSlice(s.keys, k); found {
return false
}
}
return true
}(), "Keys to delete still present")
if len(s.keys) != l {
log.Debug().Msgf("Delete: Some keys not found: len(s.keys) = %d, l = %d", len(s.keys), l)
}
return nil
}
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
ks := sortIntoKeySlicese(opts.Keys)
if len(s.keys) == 0 {
log.Debug().Msgf("Get: No keys in store")
}
if s.keyLen == -1 {
s.keyLen = len(opts.Keys[0].Floats)
} else {
if len(opts.Keys[0].Floats) != s.keyLen {
return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
}
}
tail_k := s.keys
tail_v := s.values
for i, k := range ks {
j, found := findInSortedSlice(tail_k, k)
if found {
pbKeys = append(pbKeys, &pb.StoresKey{
Floats: k,
})
pbValues = append(pbValues, &pb.StoresValue{
Bytes: tail_v[j],
})
tail_k = tail_k[j+1:]
tail_v = tail_v[j+1:]
} else {
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
}
}
if len(pbKeys) != len(opts.Keys) {
log.Debug().Msgf("Get: Some keys not found: len(pbKeys) = %d, len(opts.Keys) = %d, len(s.Keys) = %d", len(pbKeys), len(opts.Keys), len(s.keys))
}
return pb.StoresGetResult{
Keys: pbKeys,
Values: pbValues,
}, nil
}
func isNormalized(k []float32) bool {
var sum float32
for _, v := range k {
sum += v
}
return sum == 1.0
}
// TODO: This we could replace with handwritten SIMD code
func normalizedCosineSimilarity(k1, k2 []float32) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot float32
for i := 0; i < len(k1); i++ {
dot += k1[i] * k2[i]
}
assert(dot >= -1 && dot <= 1, fmt.Sprintf("dot = %f", dot))
// 2.0 * (1.0 - dot) would be the Euclidean distance
return dot
}
type PriorityItem struct {
Similarity float32
Key []float32
Value []byte
}
type PriorityQueue []*PriorityItem
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
// Inverted because the most similar should be at the top
return pq[i].Similarity < pq[j].Similarity
}
func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
}
func (pq *PriorityQueue) Push(x any) {
item := x.(*PriorityItem)
*pq = append(*pq, item)
}
func (pq *PriorityQueue) Pop() any {
old := *pq
n := len(old)
item := old[n-1]
*pq = old[0 : n-1]
return item
}
func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
tk := opts.Key.Floats
top_ks := make(PriorityQueue, 0, int(opts.TopK))
heap.Init(&top_ks)
for i, k := range s.keys {
sim := normalizedCosineSimilarity(tk, k)
heap.Push(&top_ks, &PriorityItem{
Similarity: sim,
Key: k,
Value: s.values[i],
})
if top_ks.Len() > int(opts.TopK) {
heap.Pop(&top_ks)
}
}
similarities := make([]float32, top_ks.Len())
pbKeys := make([]*pb.StoresKey, top_ks.Len())
pbValues := make([]*pb.StoresValue, top_ks.Len())
for i := top_ks.Len() - 1; i >= 0; i-- {
item := heap.Pop(&top_ks).(*PriorityItem)
similarities[i] = item.Similarity
pbKeys[i] = &pb.StoresKey{
Floats: item.Key,
}
pbValues[i] = &pb.StoresValue{
Bytes: item.Value,
}
}
return pb.StoresFindResult{
Keys: pbKeys,
Values: pbValues,
Similarities: similarities,
}, nil
}
func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
var dot, mag2 float64
for i := 0; i < len(k1); i++ {
dot += float64(k1[i] * k2[i])
mag2 += float64(k2[i] * k2[i])
}
sim := float32(dot / (mag1 * math.Sqrt(mag2)))
assert(sim >= -1 && sim <= 1, fmt.Sprintf("sim = %f", sim))
return sim
}
func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
tk := opts.Key.Floats
top_ks := make(PriorityQueue, 0, int(opts.TopK))
heap.Init(&top_ks)
var mag1 float64
for _, v := range tk {
mag1 += float64(v * v)
}
mag1 = math.Sqrt(mag1)
for i, k := range s.keys {
dist := cosineSimilarity(tk, k, mag1)
heap.Push(&top_ks, &PriorityItem{
Similarity: dist,
Key: k,
Value: s.values[i],
})
if top_ks.Len() > int(opts.TopK) {
heap.Pop(&top_ks)
}
}
similarities := make([]float32, top_ks.Len())
pbKeys := make([]*pb.StoresKey, top_ks.Len())
pbValues := make([]*pb.StoresValue, top_ks.Len())
for i := top_ks.Len() - 1; i >= 0; i-- {
item := heap.Pop(&top_ks).(*PriorityItem)
similarities[i] = item.Similarity
pbKeys[i] = &pb.StoresKey{
Floats: item.Key,
}
pbValues[i] = &pb.StoresValue{
Bytes: item.Value,
}
}
return pb.StoresFindResult{
Keys: pbKeys,
Values: pbValues,
Similarities: similarities,
}, nil
}
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
tk := opts.Key.Floats
if len(tk) != s.keyLen {
return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
}
if opts.TopK < 1 {
return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
}
if s.keyLen == -1 {
s.keyLen = len(opts.Key.Floats)
} else {
if len(opts.Key.Floats) != s.keyLen {
return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
}
}
if s.keysAreNormalized && isNormalized(tk) {
return s.StoresFindNormalized(opts)
} else {
if s.keysAreNormalized {
var sample []float32
if len(s.keys) > 5 {
sample = tk[:5]
} else {
sample = tk
}
log.Debug().Msgf("Trying to compare non-normalized key with normalized keys: %v", sample)
}
return s.StoresFindFallback(opts)
}
}

23
core/backend/stores.go Normal file
View File

@ -0,0 +1,23 @@
package backend
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/model"
)
func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) (grpc.Backend, error) {
if storeName == "" {
storeName = "default"
}
sc := []model.Option{
model.WithBackendString(model.LocalStoreBackend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(storeName),
}
return sl.BackendLoader(sc...)
}

View File

@ -172,6 +172,13 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
// Elevenlabs // Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig)) app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
// openAI compatible API endpoint // openAI compatible API endpoint
// chat // chat

View File

@ -15,6 +15,7 @@ import (
"github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
. "github.com/go-skynet/LocalAI/core/http" . "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/startup" "github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/downloader"
@ -122,6 +123,75 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return return
} }
func postRequestJSON[B any](url string, bodyJson *B) error {
payload, err := json.Marshal(bodyJson)
if err != nil {
return err
}
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}
return nil
}
func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error {
payload, err := json.Marshal(reqJson)
if err != nil {
return err
}
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}
return json.Unmarshal(body, respJson)
}
//go:embed backend-assets/* //go:embed backend-assets/*
var backendAssets embed.FS var backendAssets embed.FS
@ -836,6 +906,78 @@ var _ = Describe("API test", func() {
Expect(tokens).ToNot(Or(Equal(1), Equal(0))) Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
}) })
}) })
// See tests/integration/stores_test
Context("Stores", Label("stores"), func() {
It("sets, gets, finds and deletes entries", func() {
ks := [][]float32{
{0.1, 0.2, 0.3},
{0.4, 0.5, 0.6},
{0.7, 0.8, 0.9},
}
vs := []string{
"test1",
"test2",
"test3",
}
setBody := schema.StoresSet{
Keys: ks,
Values: vs,
}
url := "http://127.0.0.1:9090/stores/"
err := postRequestJSON(url+"set", &setBody)
Expect(err).ToNot(HaveOccurred())
getBody := schema.StoresGet{
Keys: ks,
}
var getRespBody schema.StoresGetResponse
err = postRequestResponseJSON(url+"get", &getBody, &getRespBody)
Expect(err).ToNot(HaveOccurred())
Expect(len(getRespBody.Keys)).To(Equal(len(ks)))
for i, v := range getRespBody.Keys {
if v[0] == 0.1 {
Expect(getRespBody.Values[i]).To(Equal("test1"))
} else if v[0] == 0.4 {
Expect(getRespBody.Values[i]).To(Equal("test2"))
} else {
Expect(getRespBody.Values[i]).To(Equal("test3"))
}
}
deleteBody := schema.StoresDelete{
Keys: [][]float32{
{0.1, 0.2, 0.3},
},
}
err = postRequestJSON(url+"delete", &deleteBody)
Expect(err).ToNot(HaveOccurred())
findBody := schema.StoresFind{
Key: []float32{0.1, 0.3, 0.7},
Topk: 10,
}
var findRespBody schema.StoresFindResponse
err = postRequestResponseJSON(url+"find", &findBody, &findRespBody)
Expect(err).ToNot(HaveOccurred())
Expect(len(findRespBody.Keys)).To(Equal(2))
for i, v := range findRespBody.Keys {
if v[0] == 0.4 {
Expect(findRespBody.Values[i]).To(Equal("test2"))
} else {
Expect(findRespBody.Values[i]).To(Equal("test3"))
}
Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1))
Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1))
}
})
})
}) })
Context("Config file", func() { Context("Config file", func() {

View File

@ -0,0 +1,121 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/store"
"github.com/gofiber/fiber/v2"
)
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresSet)
if err := c.BodyParser(input); err != nil {
return err
}
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}
vals := make([][]byte, len(input.Values))
for i, v := range input.Values {
vals[i] = []byte(v)
}
err = store.SetCols(c.Context(), sb, input.Keys, vals)
if err != nil {
return err
}
return c.Send(nil)
}
}
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresDelete)
if err := c.BodyParser(input); err != nil {
return err
}
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
return err
}
return c.Send(nil)
}
}
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresGet)
if err := c.BodyParser(input); err != nil {
return err
}
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
if err != nil {
return err
}
res := schema.StoresGetResponse{
Keys: keys,
Values: make([]string, len(vals)),
}
for i, v := range vals {
res.Values[i] = string(v)
}
return c.JSON(res)
}
}
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresFind)
if err := c.BodyParser(input); err != nil {
return err
}
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
if err != nil {
return err
}
res := schema.StoresFindResponse{
Keys: keys,
Values: make([]string, len(vals)),
Similarities: similarities,
}
for i, v := range vals {
res.Values[i] = string(v)
}
return c.JSON(res)
}
}

View File

@ -20,3 +20,40 @@ type TTSRequest struct {
Voice string `json:"voice" yaml:"voice"` Voice string `json:"voice" yaml:"voice"`
Backend string `json:"backend" yaml:"backend"` Backend string `json:"backend" yaml:"backend"`
} }
type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
}
type StoresDelete struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`
Keys [][]float32 `json:"keys"`
}
type StoresGet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`
Keys [][]float32 `json:"keys" yaml:"keys"`
}
type StoresGetResponse struct {
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
}
type StoresFind struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`
Key []float32 `json:"key" yaml:"key"`
Topk int `json:"topk" yaml:"topk"`
}
type StoresFindResponse struct {
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
Similarities []float32 `json:"similarities" yaml:"similarities"`
}

View File

@ -0,0 +1,97 @@
+++
disableToc = false
title = "💾 Stores"
weight = 18
url = '/stores'
+++
Stores are an experimental feature to help with querying data using similarity search. It is
a low level API that consists of only `get`, `set`, `delete` and `find`.
For example if you have an embedding of some text and want to find text with similar embeddings.
You can create embeddings for chunks of all your text then compare them against the embedding of the text you
are searching on.
An embedding here meaning a vector of numbers that represent some information about the text. The
embeddings are created from an A.I. model such as BERT or a more traditional method such as word
frequency.
Previously you would have to integrate with an external vector database or library directly.
With the stores feature you can now do it through the LocalAI API.
Note however that doing a similarity search on embeddings is just one way to do retrieval. A higher level
API can take this into account, so this may not be the best place to start.
## API overview
There is an internal gRPC API and an external facing HTTP JSON API. We'll just discuss the external HTTP API,
however the HTTP API mirrors the gRPC API. Consult `pkg/store/client` for internal usage.
Everything is in columnar format meaning that instead of getting an array of objects with a key and a value each.
You instead get two separate arrays of keys and values.
Keys are arrays of floating point numbers with a maximum width of 32bits. Values are strings (in gRPC they are bytes).
The key vectors must all be the same length and it's best for search performance if they are normalized. When
addings keys it will be detected if they are not normalized and what length they are.
All endpoints accept a `store` field which specifies which store to operate on. Presently they are created
on the fly and there is only one store backend so no configuration is required.
## Set
To set some keys you can do
```
curl -X POST http://localhost:8080/stores/set \
-H "Content-Type: application/json" \
-d '{"keys": [[0.1, 0.2], [0.3, 0.4]], "values": ["foo", "bar"]}'
```
Setting the same keys again will update their values.
On success 200 OK is returned with no body.
## Get
To get some keys you can do
```
curl -X POST http://localhost:8080/stores/get \
-H "Content-Type: application/json" \
-d '{"keys": [[0.1, 0.2]]}'
```
Both the keys and values are returned, e.g: `{"keys":[[0.1,0.2]],"values":["foo"]}`
The order of the keys is not preserved! If a key does not exist then nothing is returned.
## Delete
To delete keys and values you can do
```
curl -X POST http://localhost:8080/stores/delete \
-H "Content-Type: application/json" \
-d '{"keys": [[0.1, 0.2]]}'
```
If a key doesn't exist then it is ignored.
On success 200 OK is returned with no body.
## Find
To do a similarity search you can do
```
curl -X POST http://localhost:8080/stores/find
-H "Content-Type: application/json" \
-d '{"topk": 2, "key": [0.2, 0.1]}'
```
`topk` limits the number of results returned. The result value is the same as `get`,
except that it also includes an array of `similarities`. Where `1.0` is the maximum similarity.
They are returned in the order of most similar to least.

View File

@ -73,6 +73,7 @@ Note that this started just as a fun weekend project by [mudler](https://github.
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/) - ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/) - 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
- 🆕 [Vision API](https://localai.io/features/gpt-vision/) - 🆕 [Vision API](https://localai.io/features/gpt-vision/)
- 💾 [Stores](https://localai.io/features/stores)
## Contribute and help ## Contribute and help

View File

@ -0,0 +1,15 @@
This demonstrates the vector store backend in its simplest form.
You can add tasks and then search/sort them using the TUI.
To build and run do
```bash
$ go get .
$ go run .
```
A seperate LocaAI instance is required of course. For e.g.
```bash
$ docker run -e DEBUG=true --rm -it -p 8080:8080 <LocalAI-image> bert-cpp
```

View File

@ -0,0 +1,18 @@
module semantic-todo
go 1.21.6
require (
github.com/gdamore/tcell/v2 v2.7.1
github.com/rivo/tview v0.0.0-20240307173318-e804876934a1
)
require (
github.com/gdamore/encoding v1.0.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/term v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
)

View File

@ -0,0 +1,50 @@
github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko=
github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg=
github.com/gdamore/tcell/v2 v2.7.1 h1:TiCcmpWHiAU7F0rA2I3S2Y4mmLmO9KHxJ7E1QhYzQbc=
github.com/gdamore/tcell/v2 v2.7.1/go.mod h1:dSXtXTSK0VsW1biw65DZLZ2NKr7j0qP/0J7ONmsraWg=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/rivo/tview v0.0.0-20240307173318-e804876934a1 h1:bWLHTRekAy497pE7+nXSuzXwwFHI0XauRzz6roUvY+s=
github.com/rivo/tview v0.0.0-20240307173318-e804876934a1/go.mod h1:02iFIz7K/A9jGCvrizLPvoqr4cEIx7q54RH5Qudkrss=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.3/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -0,0 +1,352 @@
package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
)
const (
localAI string = "http://localhost:8080"
rootStatus string = "[::b]<space>[::-]: Add Task [::b]/[::-]: Search Task [::b]<C-c>[::-]: Exit"
inputStatus string = "Press [::b]<enter>[::-] to submit the task, [::b]<esc>[::-] to cancel"
)
type Task struct {
Description string
Similarity float32
}
type AppState int
const (
StateRoot AppState = iota
StateInput
StateSearch
)
type App struct {
state AppState
tasks []Task
app *tview.Application
flex *tview.Flex
table *tview.Table
}
func NewApp() *App {
return &App{
state: StateRoot,
tasks: []Task{
{Description: "Take the dog for a walk (after I get a dog)"},
{Description: "Go to the toilet"},
{Description: "Allow TODOs to be marked completed or removed"},
},
}
}
func getEmbeddings(description string) ([]float32, error) {
// Define the request payload
payload := map[string]interface{}{
"model": "bert-cpp-minilm-v6",
"input": description,
}
// Marshal the payload into JSON
jsonPayload, err := json.Marshal(payload)
if err != nil {
return nil, err
}
// Make the HTTP request to the local OpenAI embeddings API
resp, err := http.Post(localAI+"/embeddings", "application/json", bytes.NewBuffer(jsonPayload))
if err != nil {
return nil, err
}
defer resp.Body.Close()
// Check if the request was successful
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("request to embeddings API failed with status code: %d", resp.StatusCode)
}
// Parse the response body
var result struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
// Return the embedding
if len(result.Data) > 0 {
return result.Data[0].Embedding, nil
}
return nil, errors.New("no embedding received from API")
}
type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
}
func postTasksToExternalService(tasks []Task) error {
keys := make([][]float32, 0, len(tasks))
// Get the embeddings for the task description
for _, task := range tasks {
embedding, err := getEmbeddings(task.Description)
if err != nil {
return err
}
keys = append(keys, embedding)
}
values := make([]string, 0, len(tasks))
for _, task := range tasks {
values = append(values, task.Description)
}
// Construct the StoresSet object
storesSet := StoresSet{
Store: "tasks_store", // Assuming you have a specific store name
Keys: keys,
Values: values,
}
// Marshal the StoresSet object into JSON
jsonData, err := json.Marshal(storesSet)
if err != nil {
return err
}
// Make the HTTP POST request to the external service
resp, err := http.Post(localAI+"/stores/set", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return err
}
defer resp.Body.Close()
// Check if the request was successful
if resp.StatusCode != http.StatusOK {
// read resp body into string
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
return fmt.Errorf("store request failed with status code: %d: %s", resp.StatusCode, body)
}
return nil
}
type StoresFind struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`
Key []float32 `json:"key" yaml:"key"`
Topk int `json:"topk" yaml:"topk"`
}
type StoresFindResponse struct {
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
Similarities []float32 `json:"similarities" yaml:"similarities"`
}
func findSimilarTexts(inputText string, topk int) (StoresFindResponse, error) {
// Initialize an empty response object
response := StoresFindResponse{}
// Get the embedding for the input text
embedding, err := getEmbeddings(inputText)
if err != nil {
return response, err
}
// Construct the StoresFind object
storesFind := StoresFind{
Store: "tasks_store", // Assuming you have a specific store name
Key: embedding,
Topk: topk,
}
// Marshal the StoresFind object into JSON
jsonData, err := json.Marshal(storesFind)
if err != nil {
return response, err
}
// Make the HTTP POST request to the external service's /stores/find endpoint
resp, err := http.Post(localAI+"/stores/find", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return response, err
}
defer resp.Body.Close()
// Check if the request was successful
if resp.StatusCode != http.StatusOK {
return response, fmt.Errorf("request to /stores/find failed with status code: %d", resp.StatusCode)
}
// Parse the response body to retrieve similar texts and similarities
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return response, err
}
return response, nil
}
func (app *App) updateUI() {
// Clear the flex layout
app.flex.Clear()
app.flex.SetDirection(tview.FlexColumn)
app.flex.AddItem(nil, 0, 1, false)
midCol := tview.NewFlex()
midCol.SetDirection(tview.FlexRow)
midCol.AddItem(nil, 0, 1, false)
// Create a new table.
app.table.Clear()
app.table.SetBorders(true)
// Set table headers
app.table.SetCell(0, 0, tview.NewTableCell("Description").SetAlign(tview.AlignLeft).SetExpansion(1).SetAttributes(tcell.AttrBold))
app.table.SetCell(0, 1, tview.NewTableCell("Similarity").SetAlign(tview.AlignCenter).SetExpansion(0).SetAttributes(tcell.AttrBold))
// Add the tasks to the table.
for i, task := range app.tasks {
row := i + 1
app.table.SetCell(row, 0, tview.NewTableCell(task.Description))
app.table.SetCell(row, 1, tview.NewTableCell(fmt.Sprintf("%.2f", task.Similarity)))
}
if app.state == StateInput {
inputField := tview.NewInputField()
inputField.
SetLabel("New Task: ").
SetFieldWidth(0).
SetDoneFunc(func(key tcell.Key) {
if key == tcell.KeyEnter {
task := Task{Description: inputField.GetText()}
app.tasks = append(app.tasks, task)
app.state = StateRoot
postTasksToExternalService([]Task{task})
}
app.updateUI()
})
midCol.AddItem(inputField, 3, 2, true)
app.app.SetFocus(inputField)
} else if app.state == StateSearch {
searchField := tview.NewInputField()
searchField.SetLabel("Search: ").
SetFieldWidth(0).
SetDoneFunc(func(key tcell.Key) {
if key == tcell.KeyEnter {
similar, err := findSimilarTexts(searchField.GetText(), 100)
if err != nil {
panic(err)
}
app.tasks = make([]Task, len(similar.Keys))
for i, v := range similar.Values {
app.tasks[i] = Task{Description: v, Similarity: similar.Similarities[i]}
}
}
app.updateUI()
})
midCol.AddItem(searchField, 3, 2, true)
app.app.SetFocus(searchField)
} else {
midCol.AddItem(nil, 3, 1, false)
}
midCol.AddItem(app.table, 0, 2, true)
// Add the status bar to the flex layout
statusBar := tview.NewTextView().
SetText(rootStatus).
SetDynamicColors(true).
SetTextAlign(tview.AlignCenter)
if app.state == StateInput {
statusBar.SetText(inputStatus)
}
midCol.AddItem(statusBar, 1, 1, false)
midCol.AddItem(nil, 0, 1, false)
app.flex.AddItem(midCol, 0, 10, true)
app.flex.AddItem(nil, 0, 1, false)
// Set the flex as the root element
app.app.SetRoot(app.flex, true)
}
func main() {
app := NewApp()
tApp := tview.NewApplication()
flex := tview.NewFlex().SetDirection(tview.FlexRow)
table := tview.NewTable()
app.app = tApp
app.flex = flex
app.table = table
app.updateUI() // Initial UI setup
app.app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
switch app.state {
case StateRoot:
// Handle key events when in the root state
switch event.Key() {
case tcell.KeyRune:
switch event.Rune() {
case ' ':
app.state = StateInput
app.updateUI()
return nil // Event is handled
case '/':
app.state = StateSearch
app.updateUI()
return nil // Event is handled
}
}
case StateInput:
// Handle key events when in the input state
if event.Key() == tcell.KeyEsc {
// Exit input state without adding a task
app.state = StateRoot
app.updateUI()
return nil // Event is handled
}
case StateSearch:
// Handle key events when in the search state
if event.Key() == tcell.KeyEsc {
// Exit search state
app.state = StateRoot
app.updateUI()
return nil // Event is handled
}
}
// Return the event for further processing by tview
return event
})
if err := postTasksToExternalService(app.tasks); err != nil {
panic(err)
}
// Start the application
if err := app.app.Run(); err != nil {
panic(err)
}
}

View File

@ -44,4 +44,9 @@ type Backend interface {
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error)
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
Status(ctx context.Context) (*pb.StatusResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error)
StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error)
StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error)
StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error)
StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error)
} }

View File

@ -72,6 +72,22 @@ func (llm *Base) Status() (pb.StatusResponse, error) {
}, nil }, nil
} }
func (llm *Base) StoresSet(*pb.StoresSetOptions) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) {
return pb.StoresGetResult{}, fmt.Errorf("unimplemented")
}
func (llm *Base) StoresDelete(*pb.StoresDeleteOptions) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) {
return pb.StoresFindResult{}, fmt.Errorf("unimplemented")
}
func memoryUsage() *pb.MemoryUsageData { func memoryUsage() *pb.MemoryUsageData {
mud := pb.MemoryUsageData{ mud := pb.MemoryUsageData{
Breakdown: make(map[string]uint64), Breakdown: make(map[string]uint64),

View File

@ -291,3 +291,67 @@ func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
client := pb.NewBackendClient(conn) client := pb.NewBackendClient(conn)
return client.Status(ctx, &pb.HealthMessage{}) return client.Status(ctx, &pb.HealthMessage{})
} }
func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.StoresSet(ctx, in, opts...)
}
func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.StoresDelete(ctx, in, opts...)
}
func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.StoresGet(ctx, in, opts...)
}
func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.StoresFind(ctx, in, opts...)
}

View File

@ -85,6 +85,22 @@ func (e *embedBackend) Status(ctx context.Context) (*pb.StatusResponse, error) {
return e.s.Status(ctx, &pb.HealthMessage{}) return e.s.Status(ctx, &pb.HealthMessage{})
} }
func (e *embedBackend) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.StoresSet(ctx, in)
}
func (e *embedBackend) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.StoresDelete(ctx, in)
}
func (e *embedBackend) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) {
return e.s.StoresGet(ctx, in)
}
func (e *embedBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
return e.s.StoresFind(ctx, in)
}
type embedBackendServerStream struct { type embedBackendServerStream struct {
ctx context.Context ctx context.Context
fn func(s []byte) fn func(s []byte)

View File

@ -19,6 +19,11 @@ type LLM interface {
TTS(*pb.TTSRequest) error TTS(*pb.TTSRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error) Status() (pb.StatusResponse, error)
StoresSet(*pb.StoresSetOptions) error
StoresDelete(*pb.StoresDeleteOptions) error
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
} }
func newReply(s string) *pb.Reply { func newReply(s string) *pb.Reply {

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.2.0 // - protoc-gen-go-grpc v1.3.0
// - protoc v4.23.4 // - protoc v4.23.4
// source: backend.proto // source: backend.proto
@ -18,6 +18,23 @@ import (
// Requires gRPC-Go v1.32.0 or later. // Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7 const _ = grpc.SupportPackageIsVersion7
const (
Backend_Health_FullMethodName = "/backend.Backend/Health"
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
Backend_Status_FullMethodName = "/backend.Backend/Status"
Backend_StoresSet_FullMethodName = "/backend.Backend/StoresSet"
Backend_StoresDelete_FullMethodName = "/backend.Backend/StoresDelete"
Backend_StoresGet_FullMethodName = "/backend.Backend/StoresGet"
Backend_StoresFind_FullMethodName = "/backend.Backend/StoresFind"
)
// BackendClient is the client API for Backend service. // BackendClient is the client API for Backend service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
@ -32,6 +49,10 @@ type BackendClient interface {
TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error)
TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error)
Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error)
StoresSet(ctx context.Context, in *StoresSetOptions, opts ...grpc.CallOption) (*Result, error)
StoresDelete(ctx context.Context, in *StoresDeleteOptions, opts ...grpc.CallOption) (*Result, error)
StoresGet(ctx context.Context, in *StoresGetOptions, opts ...grpc.CallOption) (*StoresGetResult, error)
StoresFind(ctx context.Context, in *StoresFindOptions, opts ...grpc.CallOption) (*StoresFindResult, error)
} }
type backendClient struct { type backendClient struct {
@ -44,7 +65,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply) out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -53,7 +74,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply) out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,7 +83,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +91,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
} }
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -103,7 +124,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
out := new(EmbeddingResult) out := new(EmbeddingResult)
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,7 +133,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,7 +142,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
out := new(TranscriptResult) out := new(TranscriptResult)
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -130,7 +151,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result) out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -139,7 +160,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) { func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
out := new(TokenizationResponse) out := new(TokenizationResponse)
err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...) err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -148,7 +169,43 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions,
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) { func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
out := new(StatusResponse) out := new(StatusResponse)
err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...) err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) StoresSet(ctx context.Context, in *StoresSetOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, Backend_StoresSet_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) StoresDelete(ctx context.Context, in *StoresDeleteOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, Backend_StoresDelete_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) StoresGet(ctx context.Context, in *StoresGetOptions, opts ...grpc.CallOption) (*StoresGetResult, error) {
out := new(StoresGetResult)
err := c.cc.Invoke(ctx, Backend_StoresGet_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) StoresFind(ctx context.Context, in *StoresFindOptions, opts ...grpc.CallOption) (*StoresFindResult, error) {
out := new(StoresFindResult)
err := c.cc.Invoke(ctx, Backend_StoresFind_FullMethodName, in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -169,6 +226,10 @@ type BackendServer interface {
TTS(context.Context, *TTSRequest) (*Result, error) TTS(context.Context, *TTSRequest) (*Result, error)
TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error) TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error)
Status(context.Context, *HealthMessage) (*StatusResponse, error) Status(context.Context, *HealthMessage) (*StatusResponse, error)
StoresSet(context.Context, *StoresSetOptions) (*Result, error)
StoresDelete(context.Context, *StoresDeleteOptions) (*Result, error)
StoresGet(context.Context, *StoresGetOptions) (*StoresGetResult, error)
StoresFind(context.Context, *StoresFindOptions) (*StoresFindResult, error)
mustEmbedUnimplementedBackendServer() mustEmbedUnimplementedBackendServer()
} }
@ -206,6 +267,18 @@ func (UnimplementedBackendServer) TokenizeString(context.Context, *PredictOption
func (UnimplementedBackendServer) Status(context.Context, *HealthMessage) (*StatusResponse, error) { func (UnimplementedBackendServer) Status(context.Context, *HealthMessage) (*StatusResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Status not implemented") return nil, status.Errorf(codes.Unimplemented, "method Status not implemented")
} }
func (UnimplementedBackendServer) StoresSet(context.Context, *StoresSetOptions) (*Result, error) {
return nil, status.Errorf(codes.Unimplemented, "method StoresSet not implemented")
}
func (UnimplementedBackendServer) StoresDelete(context.Context, *StoresDeleteOptions) (*Result, error) {
return nil, status.Errorf(codes.Unimplemented, "method StoresDelete not implemented")
}
func (UnimplementedBackendServer) StoresGet(context.Context, *StoresGetOptions) (*StoresGetResult, error) {
return nil, status.Errorf(codes.Unimplemented, "method StoresGet not implemented")
}
func (UnimplementedBackendServer) StoresFind(context.Context, *StoresFindOptions) (*StoresFindResult, error) {
return nil, status.Errorf(codes.Unimplemented, "method StoresFind not implemented")
}
func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {} func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {}
// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service. // UnsafeBackendServer may be embedded to opt out of forward compatibility for this service.
@ -229,7 +302,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/Health", FullMethod: Backend_Health_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
@ -247,7 +320,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/Predict", FullMethod: Backend_Predict_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
@ -265,7 +338,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/LoadModel", FullMethod: Backend_LoadModel_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
@ -304,7 +377,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/Embedding", FullMethod: Backend_Embedding_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
@ -322,7 +395,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/GenerateImage", FullMethod: Backend_GenerateImage_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
@ -340,7 +413,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/AudioTranscription", FullMethod: Backend_AudioTranscription_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
@ -358,7 +431,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/TTS", FullMethod: Backend_TTS_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
@ -376,7 +449,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/TokenizeString", FullMethod: Backend_TokenizeString_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions)) return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
@ -394,7 +467,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/backend.Backend/Status", FullMethod: Backend_Status_FullMethodName,
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Status(ctx, req.(*HealthMessage)) return srv.(BackendServer).Status(ctx, req.(*HealthMessage))
@ -402,6 +475,78 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _Backend_StoresSet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StoresSetOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).StoresSet(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_StoresSet_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).StoresSet(ctx, req.(*StoresSetOptions))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_StoresDelete_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StoresDeleteOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).StoresDelete(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_StoresDelete_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).StoresDelete(ctx, req.(*StoresDeleteOptions))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_StoresGet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StoresGetOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).StoresGet(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_StoresGet_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).StoresGet(ctx, req.(*StoresGetOptions))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_StoresFind_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StoresFindOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).StoresFind(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: Backend_StoresFind_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).StoresFind(ctx, req.(*StoresFindOptions))
}
return interceptor(ctx, in, info, handler)
}
// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service. // Backend_ServiceDesc is the grpc.ServiceDesc for Backend service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@ -445,6 +590,22 @@ var Backend_ServiceDesc = grpc.ServiceDesc{
MethodName: "Status", MethodName: "Status",
Handler: _Backend_Status_Handler, Handler: _Backend_Status_Handler,
}, },
{
MethodName: "StoresSet",
Handler: _Backend_StoresSet_Handler,
},
{
MethodName: "StoresDelete",
Handler: _Backend_StoresDelete_Handler,
},
{
MethodName: "StoresGet",
Handler: _Backend_StoresGet_Handler,
},
{
MethodName: "StoresFind",
Handler: _Backend_StoresFind_Handler,
},
}, },
Streams: []grpc.StreamDesc{ Streams: []grpc.StreamDesc{
{ {

View File

@ -167,6 +167,54 @@ func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusRe
return &res, nil return &res, nil
} }
func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.StoresSet(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error setting entry: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Set key", Success: true}, nil
}
func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.StoresDelete(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Deleted key", Success: true}, nil
}
func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.StoresGet(in)
if err != nil {
return nil, err
}
return &res, nil
}
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.StoresFind(in)
if err != nil {
return nil, err
}
return &res, nil
}
func StartServer(address string, model LLM) error { func StartServer(address string, model LLM) error {
lis, err := net.Listen("tcp", address) lis, err := net.Listen("tcp", address)
if err != nil { if err != nil {

View File

@ -17,6 +17,7 @@ import (
var Aliases map[string]string = map[string]string{ var Aliases map[string]string = map[string]string{
"go-llama": LLamaCPP, "go-llama": LLamaCPP,
"llama": LLamaCPP, "llama": LLamaCPP,
"embedded-store": LocalStoreBackend,
} }
const ( const (
@ -34,6 +35,8 @@ const (
TinyDreamBackend = "tinydream" TinyDreamBackend = "tinydream"
PiperBackend = "piper" PiperBackend = "piper"
LCHuggingFaceBackend = "langchain-huggingface" LCHuggingFaceBackend = "langchain-huggingface"
LocalStoreBackend = "local-store"
) )
var AutoLoadBackends []string = []string{ var AutoLoadBackends []string = []string{

155
pkg/store/client.go Normal file
View File

@ -0,0 +1,155 @@
package store
import (
"context"
"fmt"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
)
// Wrapper for the GRPC client so that simple use cases are handled without verbosity
// SetCols sets multiple key-value pairs in the store
// It's in columnar format so that keys[i] is associated with values[i]
func SetCols(ctx context.Context, c grpc.Backend, keys [][]float32, values [][]byte) error {
protoKeys := make([]*proto.StoresKey, len(keys))
for i, k := range keys {
protoKeys[i] = &proto.StoresKey{
Floats: k,
}
}
protoValues := make([]*proto.StoresValue, len(values))
for i, v := range values {
protoValues[i] = &proto.StoresValue{
Bytes: v,
}
}
setOpts := &proto.StoresSetOptions{
Keys: protoKeys,
Values: protoValues,
}
res, err := c.StoresSet(ctx, setOpts)
if err != nil {
return err
}
if res.Success {
return nil
}
return fmt.Errorf("failed to set keys: %v", res.Message)
}
// SetSingle sets a single key-value pair in the store
// Don't call this in a tight loop, instead use SetCols
func SetSingle(ctx context.Context, c grpc.Backend, key []float32, value []byte) error {
return SetCols(ctx, c, [][]float32{key}, [][]byte{value})
}
// DeleteCols deletes multiple key-value pairs from the store
// It's in columnar format so that keys[i] is associated with values[i]
func DeleteCols(ctx context.Context, c grpc.Backend, keys [][]float32) error {
protoKeys := make([]*proto.StoresKey, len(keys))
for i, k := range keys {
protoKeys[i] = &proto.StoresKey{
Floats: k,
}
}
deleteOpts := &proto.StoresDeleteOptions{
Keys: protoKeys,
}
res, err := c.StoresDelete(ctx, deleteOpts)
if err != nil {
return err
}
if res.Success {
return nil
}
return fmt.Errorf("failed to delete keys: %v", res.Message)
}
// DeleteSingle deletes a single key-value pair from the store
// Don't call this in a tight loop, instead use DeleteCols
func DeleteSingle(ctx context.Context, c grpc.Backend, key []float32) error {
return DeleteCols(ctx, c, [][]float32{key})
}
// GetCols gets multiple key-value pairs from the store
// It's in columnar format so that keys[i] is associated with values[i]
// Be warned the keys are sorted and will be returned in a different order than they were input
// There is no guarantee as to how the keys are sorted
func GetCols(ctx context.Context, c grpc.Backend, keys [][]float32) ([][]float32, [][]byte, error) {
protoKeys := make([]*proto.StoresKey, len(keys))
for i, k := range keys {
protoKeys[i] = &proto.StoresKey{
Floats: k,
}
}
getOpts := &proto.StoresGetOptions{
Keys: protoKeys,
}
res, err := c.StoresGet(ctx, getOpts)
if err != nil {
return nil, nil, err
}
ks := make([][]float32, len(res.Keys))
for i, k := range res.Keys {
ks[i] = k.Floats
}
vs := make([][]byte, len(res.Values))
for i, v := range res.Values {
vs[i] = v.Bytes
}
return ks, vs, nil
}
// GetSingle gets a single key-value pair from the store
// Don't call this in a tight loop, instead use GetCols
func GetSingle(ctx context.Context, c grpc.Backend, key []float32) ([]byte, error) {
_, values, err := GetCols(ctx, c, [][]float32{key})
if err != nil {
return nil, err
}
if len(values) > 0 {
return values[0], nil
}
return nil, fmt.Errorf("failed to get key")
}
// Find similar keys to the given key. Returns the keys, values, and similarities
func Find(ctx context.Context, c grpc.Backend, key []float32, topk int) ([][]float32, [][]byte, []float32, error) {
findOpts := &proto.StoresFindOptions{
Key: &proto.StoresKey{
Floats: key,
},
TopK: int32(topk),
}
res, err := c.StoresFind(ctx, findOpts)
if err != nil {
return nil, nil, nil, err
}
ks := make([][]float32, len(res.Keys))
vs := make([][]byte, len(res.Values))
for i, k := range res.Keys {
ks[i] = k.Floats
}
for i, v := range res.Values {
vs[i] = v.Bytes
}
return ks, vs, res.Similarities, nil
}

View File

@ -0,0 +1,17 @@
package integration_test
import (
"os"
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func TestLocalAI(t *testing.T) {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
RegisterFailHandler(Fail)
RunSpecs(t, "LocalAI test suite")
}

View File

@ -0,0 +1,228 @@
package integration_test
import (
"context"
"embed"
"math"
"os"
"path/filepath"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/store"
)
//go:embed backend-assets/*
var backendAssets embed.FS
var _ = Describe("Integration tests for the stores backend(s) and internal APIs", Label("stores"), func() {
Context("Embedded Store get,set and delete", func() {
var sl *model.ModelLoader
var sc grpc.Backend
var tmpdir string
BeforeEach(func() {
var err error
zerolog.SetGlobalLevel(zerolog.DebugLevel)
tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
err = os.Mkdir(backendAssetsDir, 0755)
Expect(err).ToNot(HaveOccurred())
err = assets.ExtractFiles(backendAssets, backendAssetsDir)
Expect(err).ToNot(HaveOccurred())
debug := true
bc := config.BackendConfig{
Name: "store test",
Debug: &debug,
Backend: model.LocalStoreBackend,
}
storeOpts := []model.Option{
model.WithBackendString(bc.Backend),
model.WithAssetDir(backendAssetsDir),
model.WithModel("test"),
}
sl = model.NewModelLoader("")
sc, err = sl.BackendLoader(storeOpts...)
Expect(err).ToNot(HaveOccurred())
Expect(sc).ToNot(BeNil())
})
AfterEach(func() {
sl.StopAllGRPC()
err := os.RemoveAll(tmpdir)
Expect(err).ToNot(HaveOccurred())
})
It("should be able to set a key", func() {
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
Expect(err).ToNot(HaveOccurred())
})
It("should be able to set keys", func() {
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}}, [][]byte{[]byte("test1"), []byte("test2")})
Expect(err).ToNot(HaveOccurred())
err = store.SetCols(context.Background(), sc, [][]float32{{0.7, 0.8, 0.9}, {0.10, 0.11, 0.12}}, [][]byte{[]byte("test3"), []byte("test4")})
Expect(err).ToNot(HaveOccurred())
})
It("should be able to get a key", func() {
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
Expect(err).ToNot(HaveOccurred())
val, err := store.GetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
Expect(err).ToNot(HaveOccurred())
Expect(val).To(Equal([]byte("test")))
})
It("should be able to get keys", func() {
//set 3 entries
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
Expect(err).ToNot(HaveOccurred())
//get 3 entries
keys, vals, err := store.GetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}})
Expect(err).ToNot(HaveOccurred())
Expect(keys).To(HaveLen(3))
Expect(vals).To(HaveLen(3))
for i, k := range keys {
v := vals[i]
if k[0] == 0.1 && k[1] == 0.2 && k[2] == 0.3 {
Expect(v).To(Equal([]byte("test1")))
} else if k[0] == 0.4 && k[1] == 0.5 && k[2] == 0.6 {
Expect(v).To(Equal([]byte("test2")))
} else {
Expect(k).To(Equal([]float32{0.7, 0.8, 0.9}))
Expect(v).To(Equal([]byte("test3")))
}
}
//get 2 entries
keys, vals, err = store.GetCols(context.Background(), sc, [][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}})
Expect(err).ToNot(HaveOccurred())
Expect(keys).To(HaveLen(2))
Expect(vals).To(HaveLen(2))
for i, k := range keys {
v := vals[i]
if k[0] == 0.1 && k[1] == 0.2 && k[2] == 0.3 {
Expect(v).To(Equal([]byte("test1")))
} else {
Expect(k).To(Equal([]float32{0.7, 0.8, 0.9}))
Expect(v).To(Equal([]byte("test3")))
}
}
})
It("should be able to delete a key", func() {
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
Expect(err).ToNot(HaveOccurred())
err = store.DeleteSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
Expect(err).ToNot(HaveOccurred())
val, _ := store.GetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
Expect(val).To(BeNil())
})
It("should be able to delete keys", func() {
//set 3 entries
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
Expect(err).ToNot(HaveOccurred())
//delete 2 entries
err = store.DeleteCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.7, 0.8, 0.9}})
Expect(err).ToNot(HaveOccurred())
//get 1 entry
keys, vals, err := store.GetCols(context.Background(), sc, [][]float32{{0.4, 0.5, 0.6}})
Expect(err).ToNot(HaveOccurred())
Expect(keys).To(HaveLen(1))
Expect(vals).To(HaveLen(1))
Expect(keys[0]).To(Equal([]float32{0.4, 0.5, 0.6}))
Expect(vals[0]).To(Equal([]byte("test2")))
//get deleted entries
keys, vals, err = store.GetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.7, 0.8, 0.9}})
Expect(err).ToNot(HaveOccurred())
Expect(keys).To(HaveLen(0))
Expect(vals).To(HaveLen(0))
})
It("should be able to find smilar keys", func() {
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
err := store.SetCols(context.Background(), sc, [][]float32{{0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
Expect(err).ToNot(HaveOccurred())
// find similar keys
keys, vals, sims, err := store.Find(context.Background(), sc, []float32{0.1, 0.3, 0.5}, 2)
Expect(err).ToNot(HaveOccurred())
Expect(keys).To(HaveLen(2))
Expect(vals).To(HaveLen(2))
Expect(sims).To(HaveLen(2))
for i, k := range keys {
s := sims[i]
log.Debug().Float32("similarity", s).Msgf("key: %v", k)
}
Expect(keys[0]).To(Equal([]float32{0.5, 0.5, 0.5}))
Expect(vals[0]).To(Equal([]byte("test1")))
Expect(keys[1]).To(Equal([]float32{0.6, 0.6, -0.6}))
})
It("should be able to find similar normalized keys", func() {
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
keys := [][]float32{{0.1, 0.3, 0.5}, {0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}
vals := [][]byte{[]byte("test0"), []byte("test1"), []byte("test2"), []byte("test3")}
// normalize the keys
for i, k := range keys {
norm := float64(0)
for _, x := range k {
norm += float64(x * x)
}
norm = math.Sqrt(norm)
for j, x := range k {
keys[i][j] = x / float32(norm)
}
}
err := store.SetCols(context.Background(), sc, keys, vals)
Expect(err).ToNot(HaveOccurred())
// find similar keys
ks, vals, sims, err := store.Find(context.Background(), sc, keys[0], 3)
Expect(err).ToNot(HaveOccurred())
Expect(ks).To(HaveLen(3))
Expect(vals).To(HaveLen(3))
Expect(sims).To(HaveLen(3))
for i, k := range ks {
s := sims[i]
log.Debug().Float32("similarity", s).Msgf("key: %v", k)
}
Expect(ks[0]).To(Equal(keys[0]))
Expect(vals[0]).To(Equal(vals[0]))
Expect(sims[0]).To(BeNumerically("~", 1, 0.0001))
Expect(ks[1]).To(Equal(keys[1]))
Expect(vals[1]).To(Equal(vals[1]))
})
})
})