mirror of
https://github.com/mudler/LocalAI.git
synced 2025-02-13 22:21:57 +00:00
Stores to chromem (WIP)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
2f09aa1b85
commit
a1d5462ad0
@ -21,8 +21,7 @@ service Backend {
|
||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||
|
||||
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
||||
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
||||
rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
|
||||
rpc StoresReset(StoresResetOptions) returns (Result) {}
|
||||
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
|
||||
|
||||
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||
@ -78,19 +77,10 @@ message StoresSetOptions {
|
||||
repeated StoresValue Values = 2;
|
||||
}
|
||||
|
||||
message StoresDeleteOptions {
|
||||
message StoresResetOptions {
|
||||
repeated StoresKey Keys = 1;
|
||||
}
|
||||
|
||||
message StoresGetOptions {
|
||||
repeated StoresKey Keys = 1;
|
||||
}
|
||||
|
||||
message StoresGetResult {
|
||||
repeated StoresKey Keys = 1;
|
||||
repeated StoresValue Values = 2;
|
||||
}
|
||||
|
||||
message StoresFindOptions {
|
||||
StoresKey Key = 1;
|
||||
int32 TopK = 2;
|
||||
|
@ -4,101 +4,36 @@ package main
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"runtime"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
chromem "github.com/philippgille/chromem-go"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
base.SingleThread
|
||||
|
||||
// The sorted keys
|
||||
keys [][]float32
|
||||
// The sorted values
|
||||
values [][]byte
|
||||
|
||||
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
|
||||
// TODO: Should we normalize incoming keys if they are not instead?
|
||||
keysAreNormalized bool
|
||||
// The first key decides the length of the keys
|
||||
keyLen int
|
||||
}
|
||||
|
||||
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
|
||||
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
|
||||
type Pair struct {
|
||||
Key []float32
|
||||
Value []byte
|
||||
*chromem.DB
|
||||
*chromem.Collection
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
keys: make([][]float32, 0),
|
||||
values: make([][]byte, 0),
|
||||
keysAreNormalized: true,
|
||||
keyLen: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func compareSlices(k1, k2 []float32) int {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
return slices.Compare(k1, k2)
|
||||
}
|
||||
|
||||
func hasKey(unsortedSlice [][]float32, target []float32) bool {
|
||||
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
|
||||
return compareSlices(k, target) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
|
||||
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
|
||||
return compareSlices(k, t)
|
||||
})
|
||||
}
|
||||
|
||||
func isSortedPairs(kvs []Pair) bool {
|
||||
for i := 1; i < len(kvs); i++ {
|
||||
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isSortedKeys(keys [][]float32) bool {
|
||||
for i := 1; i < len(keys); i++ {
|
||||
if compareSlices(keys[i-1], keys[i]) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
||||
ks := make([][]float32, len(keys))
|
||||
|
||||
for i, k := range keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
slices.SortFunc(ks, compareSlices)
|
||||
|
||||
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
|
||||
assert(isSortedKeys(ks), "keys are not sorted")
|
||||
|
||||
return ks
|
||||
return &Store{}
|
||||
}
|
||||
|
||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||
db := chromem.NewDB()
|
||||
collection, err := db.CreateCollection("all-documents", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.DB = db
|
||||
s.Collection = collection
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -111,156 +46,25 @@ func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
|
||||
if len(opts.Keys) != len(opts.Values) {
|
||||
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
kvs := make([]Pair, len(opts.Keys))
|
||||
docs := []chromem.Document{}
|
||||
|
||||
for i, k := range opts.Keys {
|
||||
if s.keysAreNormalized && !isNormalized(k.Floats) {
|
||||
s.keysAreNormalized = false
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = k.Floats[:5]
|
||||
} else {
|
||||
sample = k.Floats
|
||||
}
|
||||
log.Debug().Msgf("Key is not normalized: %v", sample)
|
||||
}
|
||||
|
||||
kvs[i] = Pair{
|
||||
Key: k.Floats,
|
||||
Value: opts.Values[i].Bytes,
|
||||
}
|
||||
docs = append(docs, chromem.Document{
|
||||
ID: k.String(),
|
||||
Content: opts.Values[i].String(),
|
||||
})
|
||||
}
|
||||
|
||||
slices.SortFunc(kvs, func(a, b Pair) int {
|
||||
return compareSlices(a.Key, b.Key)
|
||||
})
|
||||
|
||||
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
|
||||
assert(isSortedPairs(kvs), "keys are not sorted")
|
||||
|
||||
l := len(kvs) + len(s.keys)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
i, j := 0, 0
|
||||
for {
|
||||
if i+j >= l {
|
||||
break
|
||||
}
|
||||
|
||||
if i >= len(kvs) {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
continue
|
||||
}
|
||||
|
||||
if j >= len(s.keys) {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
c := compareSlices(kvs[i].Key, s.keys[j])
|
||||
if c < 0 {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
} else if c > 0 {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
} else {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
|
||||
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
return nil
|
||||
return s.Collection.AddDocuments(context.Background(), docs, runtime.NumCPU())
|
||||
}
|
||||
|
||||
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to delete")
|
||||
func (s *Store) StoresReset(opts *pb.StoresResetOptions) error {
|
||||
err := s.DB.DeleteCollection("all-documents")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
l := len(s.keys) - len(ks)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
tail_ks := s.keys
|
||||
tail_vs := s.values
|
||||
for _, k := range ks {
|
||||
j, found := findInSortedSlice(tail_ks, k)
|
||||
|
||||
if found {
|
||||
merge_ks = append(merge_ks, tail_ks[:j]...)
|
||||
merge_vs = append(merge_vs, tail_vs[:j]...)
|
||||
tail_ks = tail_ks[j+1:]
|
||||
tail_vs = tail_vs[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Delete: found = %v, t = %d, j = %d, len(merge_ks) = %d, len(merge_vs) = %d", found, len(tail_ks), j, len(merge_ks), len(merge_vs))
|
||||
}
|
||||
|
||||
merge_ks = append(merge_ks, tail_ks...)
|
||||
merge_vs = append(merge_vs, tail_vs...)
|
||||
|
||||
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
|
||||
assert(isSortedKeys(s.keys), "keys are not sorted")
|
||||
assert(func() bool {
|
||||
for _, k := range ks {
|
||||
if _, found := findInSortedSlice(s.keys, k); found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}(), "Keys to delete still present")
|
||||
|
||||
if len(s.keys) != l {
|
||||
log.Debug().Msgf("Delete: Some keys not found: len(s.keys) = %d, l = %d", len(s.keys), l)
|
||||
}
|
||||
|
||||
return nil
|
||||
s.Collection, err = s.CreateCollection("all-documents", nil, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
|
@ -1000,7 +1000,7 @@ var _ = Describe("API test", func() {
|
||||
}
|
||||
}
|
||||
|
||||
deleteBody := schema.StoresDelete{
|
||||
deleteBody := schema.StoresReset{
|
||||
Keys: [][]float32{
|
||||
{0.1, 0.2, 0.3},
|
||||
},
|
||||
|
@ -36,9 +36,9 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
}
|
||||
}
|
||||
|
||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func StoresResetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.StoresDelete)
|
||||
input := new(schema.StoresReset)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
@ -49,7 +49,7 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
|
||||
return err
|
||||
}
|
||||
|
||||
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
|
||||
if _, err := sb.StoresReset(c.Context(), nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -57,37 +57,6 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
|
||||
}
|
||||
}
|
||||
|
||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.StoresGet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res := schema.StoresGetResponse{
|
||||
Keys: keys,
|
||||
Values: make([]string, len(vals)),
|
||||
}
|
||||
|
||||
for i, v := range vals {
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.StoresFind)
|
||||
|
@ -39,8 +39,7 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||
// Stores
|
||||
sl := model.NewModelLoader("")
|
||||
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
||||
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
||||
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
||||
router.Post("/stores/reset", localai.StoresDeleteEndpoint(sl, appConfig))
|
||||
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
||||
|
||||
if !appConfig.DisableMetrics {
|
||||
|
@ -47,21 +47,8 @@ type StoresSet struct {
|
||||
Values []string `json:"values" yaml:"values"`
|
||||
}
|
||||
|
||||
type StoresDelete struct {
|
||||
type StoresReset struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
Keys [][]float32 `json:"keys"`
|
||||
}
|
||||
|
||||
type StoresGet struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
||||
}
|
||||
|
||||
type StoresGetResponse struct {
|
||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
||||
Values []string `json:"values" yaml:"values"`
|
||||
}
|
||||
|
||||
type StoresFind struct {
|
||||
|
1
go.mod
1
go.mod
@ -93,6 +93,7 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
|
||||
github.com/philippgille/chromem-go v0.7.0 // indirect
|
||||
github.com/pion/datachannel v1.5.10 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||
github.com/pion/ice/v2 v2.3.37 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -611,6 +611,8 @@ github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1H
|
||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
|
||||
github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw=
|
||||
github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0=
|
||||
github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY=
|
||||
github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
|
||||
github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM=
|
||||
github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pion/datachannel v1.5.8 h1:ph1P1NsGkazkjrvyMfhRBUAWMxugJjq2HfQifaOoSNo=
|
||||
|
@ -46,8 +46,7 @@ type Backend interface {
|
||||
Status(ctx context.Context) (*pb.StatusResponse, error)
|
||||
|
||||
StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error)
|
||||
StoresReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error)
|
||||
|
||||
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
|
||||
|
@ -80,11 +80,7 @@ func (llm *Base) StoresSet(*pb.StoresSetOptions) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StoresDelete(*pb.StoresDeleteOptions) error {
|
||||
func (llm *Base) StoresReset(*pb.StoresResetOptions) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
|
@ -303,7 +303,7 @@ func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ..
|
||||
return client.StoresSet(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
func (c *Client) StoreReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
@ -318,25 +318,7 @@ func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, o
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StoresDelete(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StoresGet(ctx, in, opts...)
|
||||
return client.StoresReset(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
|
||||
|
@ -71,12 +71,8 @@ func (e *embedBackend) StoresSet(ctx context.Context, in *pb.StoresSetOptions, o
|
||||
return e.s.StoresSet(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.StoresDelete(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) {
|
||||
return e.s.StoresGet(ctx, in)
|
||||
func (e *embedBackend) StoresReset(ctx context.Context, in *pb.StoresResetOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.StoresReset(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
|
||||
|
@ -21,8 +21,7 @@ type LLM interface {
|
||||
Status() (pb.StatusResponse, error)
|
||||
|
||||
StoresSet(*pb.StoresSetOptions) error
|
||||
StoresDelete(*pb.StoresDeleteOptions) error
|
||||
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
|
||||
StoresReset(*pb.StoresResetOptions) error
|
||||
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
|
||||
|
||||
VAD(*pb.VADRequest) (pb.VADResponse, error)
|
||||
|
@ -191,28 +191,16 @@ func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Re
|
||||
return &pb.Result{Message: "Set key", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) {
|
||||
func (s *server) StoresReset(ctx context.Context, in *pb.StoresResetOptions) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.StoresDelete(in)
|
||||
err := s.llm.StoresReset(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err
|
||||
}
|
||||
return &pb.Result{Message: "Deleted key", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.StoresGet(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
return &pb.Result{Message: "Deleted mem db", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) {
|
||||
|
@ -1,155 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// Wrapper for the GRPC client so that simple use cases are handled without verbosity
|
||||
|
||||
// SetCols sets multiple key-value pairs in the store
|
||||
// It's in columnar format so that keys[i] is associated with values[i]
|
||||
func SetCols(ctx context.Context, c grpc.Backend, keys [][]float32, values [][]byte) error {
|
||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
||||
for i, k := range keys {
|
||||
protoKeys[i] = &proto.StoresKey{
|
||||
Floats: k,
|
||||
}
|
||||
}
|
||||
protoValues := make([]*proto.StoresValue, len(values))
|
||||
for i, v := range values {
|
||||
protoValues[i] = &proto.StoresValue{
|
||||
Bytes: v,
|
||||
}
|
||||
}
|
||||
setOpts := &proto.StoresSetOptions{
|
||||
Keys: protoKeys,
|
||||
Values: protoValues,
|
||||
}
|
||||
|
||||
res, err := c.StoresSet(ctx, setOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Success {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to set keys: %v", res.Message)
|
||||
}
|
||||
|
||||
// SetSingle sets a single key-value pair in the store
|
||||
// Don't call this in a tight loop, instead use SetCols
|
||||
func SetSingle(ctx context.Context, c grpc.Backend, key []float32, value []byte) error {
|
||||
return SetCols(ctx, c, [][]float32{key}, [][]byte{value})
|
||||
}
|
||||
|
||||
// DeleteCols deletes multiple key-value pairs from the store
|
||||
// It's in columnar format so that keys[i] is associated with values[i]
|
||||
func DeleteCols(ctx context.Context, c grpc.Backend, keys [][]float32) error {
|
||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
||||
for i, k := range keys {
|
||||
protoKeys[i] = &proto.StoresKey{
|
||||
Floats: k,
|
||||
}
|
||||
}
|
||||
deleteOpts := &proto.StoresDeleteOptions{
|
||||
Keys: protoKeys,
|
||||
}
|
||||
|
||||
res, err := c.StoresDelete(ctx, deleteOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Success {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to delete keys: %v", res.Message)
|
||||
}
|
||||
|
||||
// DeleteSingle deletes a single key-value pair from the store
|
||||
// Don't call this in a tight loop, instead use DeleteCols
|
||||
func DeleteSingle(ctx context.Context, c grpc.Backend, key []float32) error {
|
||||
return DeleteCols(ctx, c, [][]float32{key})
|
||||
}
|
||||
|
||||
// GetCols gets multiple key-value pairs from the store
|
||||
// It's in columnar format so that keys[i] is associated with values[i]
|
||||
// Be warned the keys are sorted and will be returned in a different order than they were input
|
||||
// There is no guarantee as to how the keys are sorted
|
||||
func GetCols(ctx context.Context, c grpc.Backend, keys [][]float32) ([][]float32, [][]byte, error) {
|
||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
||||
for i, k := range keys {
|
||||
protoKeys[i] = &proto.StoresKey{
|
||||
Floats: k,
|
||||
}
|
||||
}
|
||||
getOpts := &proto.StoresGetOptions{
|
||||
Keys: protoKeys,
|
||||
}
|
||||
|
||||
res, err := c.StoresGet(ctx, getOpts)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ks := make([][]float32, len(res.Keys))
|
||||
for i, k := range res.Keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
vs := make([][]byte, len(res.Values))
|
||||
for i, v := range res.Values {
|
||||
vs[i] = v.Bytes
|
||||
}
|
||||
|
||||
return ks, vs, nil
|
||||
}
|
||||
|
||||
// GetSingle gets a single key-value pair from the store
|
||||
// Don't call this in a tight loop, instead use GetCols
|
||||
func GetSingle(ctx context.Context, c grpc.Backend, key []float32) ([]byte, error) {
|
||||
_, values, err := GetCols(ctx, c, [][]float32{key})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
return values[0], nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to get key")
|
||||
}
|
||||
|
||||
// Find similar keys to the given key. Returns the keys, values, and similarities
|
||||
func Find(ctx context.Context, c grpc.Backend, key []float32, topk int) ([][]float32, [][]byte, []float32, error) {
|
||||
findOpts := &proto.StoresFindOptions{
|
||||
Key: &proto.StoresKey{
|
||||
Floats: key,
|
||||
},
|
||||
TopK: int32(topk),
|
||||
}
|
||||
|
||||
res, err := c.StoresFind(ctx, findOpts)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
ks := make([][]float32, len(res.Keys))
|
||||
vs := make([][]byte, len(res.Values))
|
||||
|
||||
for i, k := range res.Keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
for i, v := range res.Values {
|
||||
vs[i] = v.Bytes
|
||||
}
|
||||
|
||||
return ks, vs, res.Similarities, nil
|
||||
}
|
@ -70,6 +70,10 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
|
||||
})
|
||||
|
||||
It("should be able to set a key", func() {
|
||||
sc.StoresSet(context.Background(), &store.StoresSetOptions{
|
||||
Keys: [][]float32{{0.1, 0.2, 0.3}},
|
||||
Values: [][]byte{[]byte("test")},
|
||||
})
|
||||
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user