diff --git a/backend/backend.proto b/backend/backend.proto index fea4214f..e31aa34d 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -21,8 +21,7 @@ service Backend { rpc Status(HealthMessage) returns (StatusResponse) {} rpc StoresSet(StoresSetOptions) returns (Result) {} - rpc StoresDelete(StoresDeleteOptions) returns (Result) {} - rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {} + rpc StoresReset(StoresResetOptions) returns (Result) {} rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {} rpc Rerank(RerankRequest) returns (RerankResult) {} @@ -78,19 +77,10 @@ message StoresSetOptions { repeated StoresValue Values = 2; } -message StoresDeleteOptions { +message StoresResetOptions { repeated StoresKey Keys = 1; } -message StoresGetOptions { - repeated StoresKey Keys = 1; -} - -message StoresGetResult { - repeated StoresKey Keys = 1; - repeated StoresValue Values = 2; -} - message StoresFindOptions { StoresKey Key = 1; int32 TopK = 2; diff --git a/backend/go/stores/store.go b/backend/go/stores/store.go index a4849b57..d7900e13 100644 --- a/backend/go/stores/store.go +++ b/backend/go/stores/store.go @@ -4,101 +4,36 @@ package main // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "container/heap" + "context" "fmt" "math" - "slices" + "runtime" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" + chromem "github.com/philippgille/chromem-go" "github.com/rs/zerolog/log" ) type Store struct { base.SingleThread - - // The sorted keys - keys [][]float32 - // The sorted values - values [][]byte - - // If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions - // TODO: Should we normalize incoming keys if they are not instead? - keysAreNormalized bool - // The first key decides the length of the keys - keyLen int -} - -// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because -// that's theoretically best for memory layout and cache locality, but this isn't optimized yet. -type Pair struct { - Key []float32 - Value []byte + *chromem.DB + *chromem.Collection } func NewStore() *Store { - return &Store{ - keys: make([][]float32, 0), - values: make([][]byte, 0), - keysAreNormalized: true, - keyLen: -1, - } -} - -func compareSlices(k1, k2 []float32) int { - assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - return slices.Compare(k1, k2) -} - -func hasKey(unsortedSlice [][]float32, target []float32) bool { - return slices.ContainsFunc(unsortedSlice, func(k []float32) bool { - return compareSlices(k, target) == 0 - }) -} - -func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) { - return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int { - return compareSlices(k, t) - }) -} - -func isSortedPairs(kvs []Pair) bool { - for i := 1; i < len(kvs); i++ { - if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 { - return false - } - } - - return true -} - -func isSortedKeys(keys [][]float32) bool { - for i := 1; i < len(keys); i++ { - if compareSlices(keys[i-1], keys[i]) > 0 { - return false - } - } - - return true -} - -func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 { - ks := make([][]float32, len(keys)) - - for i, k := range keys { - ks[i] = k.Floats - } - - slices.SortFunc(ks, compareSlices) - - assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys))) - assert(isSortedKeys(ks), "keys are not sorted") - - return ks + return &Store{} } func (s *Store) Load(opts *pb.ModelOptions) error { + db := chromem.NewDB() + collection, err := db.CreateCollection("all-documents", nil, nil) + if err != nil { + return err + } + s.DB = db + s.Collection = collection return nil } @@ -111,156 +46,25 @@ func (s *Store) StoresSet(opts *pb.StoresSetOptions) error { if len(opts.Keys) != len(opts.Values) { return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values)) } - - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } - } - - kvs := make([]Pair, len(opts.Keys)) + docs := []chromem.Document{} for i, k := range opts.Keys { - if s.keysAreNormalized && !isNormalized(k.Floats) { - s.keysAreNormalized = false - var sample []float32 - if len(s.keys) > 5 { - sample = k.Floats[:5] - } else { - sample = k.Floats - } - log.Debug().Msgf("Key is not normalized: %v", sample) - } - - kvs[i] = Pair{ - Key: k.Floats, - Value: opts.Values[i].Bytes, - } + docs = append(docs, chromem.Document{ + ID: k.String(), + Content: opts.Values[i].String(), + }) } - slices.SortFunc(kvs, func(a, b Pair) int { - return compareSlices(a.Key, b.Key) - }) - - assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys))) - assert(isSortedPairs(kvs), "keys are not sorted") - - l := len(kvs) + len(s.keys) - merge_ks := make([][]float32, 0, l) - merge_vs := make([][]byte, 0, l) - - i, j := 0, 0 - for { - if i+j >= l { - break - } - - if i >= len(kvs) { - merge_ks = append(merge_ks, s.keys[j]) - merge_vs = append(merge_vs, s.values[j]) - j++ - continue - } - - if j >= len(s.keys) { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - continue - } - - c := compareSlices(kvs[i].Key, s.keys[j]) - if c < 0 { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - } else if c > 0 { - merge_ks = append(merge_ks, s.keys[j]) - merge_vs = append(merge_vs, s.values[j]) - j++ - } else { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - j++ - } - } - - assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l)) - assert(isSortedKeys(merge_ks), "merge keys are not sorted") - - s.keys = merge_ks - s.values = merge_vs - - return nil + return s.Collection.AddDocuments(context.Background(), docs, runtime.NumCPU()) } -func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error { - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to delete") +func (s *Store) StoresReset(opts *pb.StoresResetOptions) error { + err := s.DB.DeleteCollection("all-documents") + if err != nil { + return err } - - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to add") - } - - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } - } - - ks := sortIntoKeySlicese(opts.Keys) - - l := len(s.keys) - len(ks) - merge_ks := make([][]float32, 0, l) - merge_vs := make([][]byte, 0, l) - - tail_ks := s.keys - tail_vs := s.values - for _, k := range ks { - j, found := findInSortedSlice(tail_ks, k) - - if found { - merge_ks = append(merge_ks, tail_ks[:j]...) - merge_vs = append(merge_vs, tail_vs[:j]...) - tail_ks = tail_ks[j+1:] - tail_vs = tail_vs[j+1:] - } else { - assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k)) - } - - log.Debug().Msgf("Delete: found = %v, t = %d, j = %d, len(merge_ks) = %d, len(merge_vs) = %d", found, len(tail_ks), j, len(merge_ks), len(merge_vs)) - } - - merge_ks = append(merge_ks, tail_ks...) - merge_vs = append(merge_vs, tail_vs...) - - assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys))) - - s.keys = merge_ks - s.values = merge_vs - - assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l)) - assert(isSortedKeys(s.keys), "keys are not sorted") - assert(func() bool { - for _, k := range ks { - if _, found := findInSortedSlice(s.keys, k); found { - return false - } - } - return true - }(), "Keys to delete still present") - - if len(s.keys) != l { - log.Debug().Msgf("Delete: Some keys not found: len(s.keys) = %d, l = %d", len(s.keys), l) - } - - return nil + s.Collection, err = s.CreateCollection("all-documents", nil, nil) + return err } func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) { diff --git a/core/http/app_test.go b/core/http/app_test.go index a2e2f758..77b896b5 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -1000,7 +1000,7 @@ var _ = Describe("API test", func() { } } - deleteBody := schema.StoresDelete{ + deleteBody := schema.StoresReset{ Keys: [][]float32{ {0.1, 0.2, 0.3}, }, diff --git a/core/http/endpoints/localai/stores.go b/core/http/endpoints/localai/stores.go index f417c580..f773840f 100644 --- a/core/http/endpoints/localai/stores.go +++ b/core/http/endpoints/localai/stores.go @@ -36,9 +36,9 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi } } -func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { +func StoresResetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(schema.StoresDelete) + input := new(schema.StoresReset) if err := c.BodyParser(input); err != nil { return err @@ -49,7 +49,7 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo return err } - if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil { + if _, err := sb.StoresReset(c.Context(), nil); err != nil { return err } @@ -57,37 +57,6 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo } } -func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(schema.StoresGet) - - if err := c.BodyParser(input); err != nil { - return err - } - - sb, err := backend.StoreBackend(sl, appConfig, input.Store) - if err != nil { - return err - } - - keys, vals, err := store.GetCols(c.Context(), sb, input.Keys) - if err != nil { - return err - } - - res := schema.StoresGetResponse{ - Keys: keys, - Values: make([]string, len(vals)), - } - - for i, v := range vals { - res.Values[i] = string(v) - } - - return c.JSON(res) - } -} - func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.StoresFind) diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 2ea9896a..669209bc 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -39,8 +39,7 @@ func RegisterLocalAIRoutes(router *fiber.App, // Stores sl := model.NewModelLoader("") router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig)) - router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig)) - router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig)) + router.Post("/stores/reset", localai.StoresDeleteEndpoint(sl, appConfig)) router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig)) if !appConfig.DisableMetrics { diff --git a/core/schema/localai.go b/core/schema/localai.go index 08afc6df..6141ba26 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -47,21 +47,8 @@ type StoresSet struct { Values []string `json:"values" yaml:"values"` } -type StoresDelete struct { +type StoresReset struct { Store string `json:"store,omitempty" yaml:"store,omitempty"` - - Keys [][]float32 `json:"keys"` -} - -type StoresGet struct { - Store string `json:"store,omitempty" yaml:"store,omitempty"` - - Keys [][]float32 `json:"keys" yaml:"keys"` -} - -type StoresGetResponse struct { - Keys [][]float32 `json:"keys" yaml:"keys"` - Values []string `json:"values" yaml:"values"` } type StoresFind struct { diff --git a/go.mod b/go.mod index adfa7357..9564d73f 100644 --- a/go.mod +++ b/go.mod @@ -93,6 +93,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect + github.com/philippgille/chromem-go v0.7.0 // indirect github.com/pion/datachannel v1.5.10 // indirect github.com/pion/dtls/v2 v2.2.12 // indirect github.com/pion/ice/v2 v2.3.37 // indirect diff --git a/go.sum b/go.sum index 4a744ed8..7a6eef40 100644 --- a/go.sum +++ b/go.sum @@ -611,6 +611,8 @@ github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1H github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw= github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0= +github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= +github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM= github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pion/datachannel v1.5.8 h1:ph1P1NsGkazkjrvyMfhRBUAWMxugJjq2HfQifaOoSNo= diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index fabc0268..354a9de3 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -46,8 +46,7 @@ type Backend interface { Status(ctx context.Context) (*pb.StatusResponse, error) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error) - StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) - StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) + StoresReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 2e1fb209..457c24e3 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -80,11 +80,7 @@ func (llm *Base) StoresSet(*pb.StoresSetOptions) error { return fmt.Errorf("unimplemented") } -func (llm *Base) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) { - return pb.StoresGetResult{}, fmt.Errorf("unimplemented") -} - -func (llm *Base) StoresDelete(*pb.StoresDeleteOptions) error { +func (llm *Base) StoresReset(*pb.StoresResetOptions) error { return fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index ca207c3f..cb02b71c 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -303,7 +303,7 @@ func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts .. return client.StoresSet(ctx, in, opts...) } -func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) { +func (c *Client) StoreReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -318,25 +318,7 @@ func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, o } defer conn.Close() client := pb.NewBackendClient(conn) - return client.StoresDelete(ctx, in, opts...) -} - -func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) { - if !c.parallel { - c.opMutex.Lock() - defer c.opMutex.Unlock() - } - c.setBusy(true) - defer c.setBusy(false) - c.wdMark() - defer c.wdUnMark() - conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - return nil, err - } - defer conn.Close() - client := pb.NewBackendClient(conn) - return client.StoresGet(ctx, in, opts...) + return client.StoresReset(ctx, in, opts...) } func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) { diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 79648c5a..db43676a 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -71,12 +71,8 @@ func (e *embedBackend) StoresSet(ctx context.Context, in *pb.StoresSetOptions, o return e.s.StoresSet(ctx, in) } -func (e *embedBackend) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) { - return e.s.StoresDelete(ctx, in) -} - -func (e *embedBackend) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) { - return e.s.StoresGet(ctx, in) +func (e *embedBackend) StoresReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error) { + return e.s.StoresReset(ctx, in) } func (e *embedBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 9214e3cf..e426a4ca 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -21,8 +21,7 @@ type LLM interface { Status() (pb.StatusResponse, error) StoresSet(*pb.StoresSetOptions) error - StoresDelete(*pb.StoresDeleteOptions) error - StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) + StoresReset(*pb.StoresResetOptions) error StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) VAD(*pb.VADRequest) (pb.VADResponse, error) diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 0b2a167f..8dd203f8 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -191,28 +191,16 @@ func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Re return &pb.Result{Message: "Set key", Success: true}, nil } -func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) { +func (s *server) StoresReset(ctx context.Context, in *pb.StoresResetOptions) (*pb.Result, error) { if s.llm.Locking() { s.llm.Lock() defer s.llm.Unlock() } - err := s.llm.StoresDelete(in) + err := s.llm.StoresReset(in) if err != nil { return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err } - return &pb.Result{Message: "Deleted key", Success: true}, nil -} - -func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) { - if s.llm.Locking() { - s.llm.Lock() - defer s.llm.Unlock() - } - res, err := s.llm.StoresGet(in) - if err != nil { - return nil, err - } - return &res, nil + return &pb.Result{Message: "Deleted mem db", Success: true}, nil } func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) { diff --git a/pkg/store/client.go b/pkg/store/client.go deleted file mode 100644 index 1a1f46cc..00000000 --- a/pkg/store/client.go +++ /dev/null @@ -1,155 +0,0 @@ -package store - -import ( - "context" - "fmt" - - grpc "github.com/mudler/LocalAI/pkg/grpc" - "github.com/mudler/LocalAI/pkg/grpc/proto" -) - -// Wrapper for the GRPC client so that simple use cases are handled without verbosity - -// SetCols sets multiple key-value pairs in the store -// It's in columnar format so that keys[i] is associated with values[i] -func SetCols(ctx context.Context, c grpc.Backend, keys [][]float32, values [][]byte) error { - protoKeys := make([]*proto.StoresKey, len(keys)) - for i, k := range keys { - protoKeys[i] = &proto.StoresKey{ - Floats: k, - } - } - protoValues := make([]*proto.StoresValue, len(values)) - for i, v := range values { - protoValues[i] = &proto.StoresValue{ - Bytes: v, - } - } - setOpts := &proto.StoresSetOptions{ - Keys: protoKeys, - Values: protoValues, - } - - res, err := c.StoresSet(ctx, setOpts) - if err != nil { - return err - } - - if res.Success { - return nil - } - - return fmt.Errorf("failed to set keys: %v", res.Message) -} - -// SetSingle sets a single key-value pair in the store -// Don't call this in a tight loop, instead use SetCols -func SetSingle(ctx context.Context, c grpc.Backend, key []float32, value []byte) error { - return SetCols(ctx, c, [][]float32{key}, [][]byte{value}) -} - -// DeleteCols deletes multiple key-value pairs from the store -// It's in columnar format so that keys[i] is associated with values[i] -func DeleteCols(ctx context.Context, c grpc.Backend, keys [][]float32) error { - protoKeys := make([]*proto.StoresKey, len(keys)) - for i, k := range keys { - protoKeys[i] = &proto.StoresKey{ - Floats: k, - } - } - deleteOpts := &proto.StoresDeleteOptions{ - Keys: protoKeys, - } - - res, err := c.StoresDelete(ctx, deleteOpts) - if err != nil { - return err - } - - if res.Success { - return nil - } - - return fmt.Errorf("failed to delete keys: %v", res.Message) -} - -// DeleteSingle deletes a single key-value pair from the store -// Don't call this in a tight loop, instead use DeleteCols -func DeleteSingle(ctx context.Context, c grpc.Backend, key []float32) error { - return DeleteCols(ctx, c, [][]float32{key}) -} - -// GetCols gets multiple key-value pairs from the store -// It's in columnar format so that keys[i] is associated with values[i] -// Be warned the keys are sorted and will be returned in a different order than they were input -// There is no guarantee as to how the keys are sorted -func GetCols(ctx context.Context, c grpc.Backend, keys [][]float32) ([][]float32, [][]byte, error) { - protoKeys := make([]*proto.StoresKey, len(keys)) - for i, k := range keys { - protoKeys[i] = &proto.StoresKey{ - Floats: k, - } - } - getOpts := &proto.StoresGetOptions{ - Keys: protoKeys, - } - - res, err := c.StoresGet(ctx, getOpts) - if err != nil { - return nil, nil, err - } - - ks := make([][]float32, len(res.Keys)) - for i, k := range res.Keys { - ks[i] = k.Floats - } - vs := make([][]byte, len(res.Values)) - for i, v := range res.Values { - vs[i] = v.Bytes - } - - return ks, vs, nil -} - -// GetSingle gets a single key-value pair from the store -// Don't call this in a tight loop, instead use GetCols -func GetSingle(ctx context.Context, c grpc.Backend, key []float32) ([]byte, error) { - _, values, err := GetCols(ctx, c, [][]float32{key}) - if err != nil { - return nil, err - } - - if len(values) > 0 { - return values[0], nil - } - - return nil, fmt.Errorf("failed to get key") -} - -// Find similar keys to the given key. Returns the keys, values, and similarities -func Find(ctx context.Context, c grpc.Backend, key []float32, topk int) ([][]float32, [][]byte, []float32, error) { - findOpts := &proto.StoresFindOptions{ - Key: &proto.StoresKey{ - Floats: key, - }, - TopK: int32(topk), - } - - res, err := c.StoresFind(ctx, findOpts) - if err != nil { - return nil, nil, nil, err - } - - ks := make([][]float32, len(res.Keys)) - vs := make([][]byte, len(res.Values)) - - for i, k := range res.Keys { - ks[i] = k.Floats - } - - for i, v := range res.Values { - vs[i] = v.Bytes - } - - return ks, vs, res.Similarities, nil -} diff --git a/tests/integration/stores_test.go b/tests/integration/stores_test.go index 5ed46b19..2d932694 100644 --- a/tests/integration/stores_test.go +++ b/tests/integration/stores_test.go @@ -70,6 +70,10 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" }) It("should be able to set a key", func() { + sc.StoresSet(context.Background(), &store.StoresSetOptions{ + Keys: [][]float32{{0.1, 0.2, 0.3}}, + Values: [][]byte{[]byte("test")}, + }) err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test")) Expect(err).ToNot(HaveOccurred()) })