mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 12:26:26 +00:00
feat(stores): Vector store backend (#1795)
Add simple vector store backend Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
parent
4b1ee0c170
commit
643d85d2cc
31
.editorconfig
Normal file
31
.editorconfig
Normal file
@ -0,0 +1,31 @@
|
||||
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.go]
|
||||
indent_style = tab
|
||||
|
||||
[Makefile]
|
||||
indent_style = tab
|
||||
|
||||
[*.proto]
|
||||
indent_size = 2
|
||||
|
||||
[*.py]
|
||||
indent_size = 4
|
||||
|
||||
[*.js]
|
||||
indent_size = 2
|
||||
|
||||
[*.yaml]
|
||||
indent_size = 2
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
11
Makefile
11
Makefile
@ -159,6 +159,7 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-ggml
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/gpt4all
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
|
||||
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
||||
|
||||
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
|
||||
@ -333,7 +334,7 @@ prepare-test: grpcs
|
||||
|
||||
test: prepare test-models/testmodel.ggml grpcs
|
||||
@echo 'Running tests'
|
||||
export GO_TAGS="tts stablediffusion"
|
||||
export GO_TAGS="tts stablediffusion debug"
|
||||
$(MAKE) prepare-test
|
||||
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
|
||||
@ -387,6 +388,11 @@ test-stablediffusion: prepare-test
|
||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r $(TEST_PATHS)
|
||||
|
||||
test-stores: backend-assets/grpc/local-store
|
||||
mkdir -p tests/integration/backend-assets/grpc
|
||||
cp -f backend-assets/grpc/local-store tests/integration/backend-assets/grpc/
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts 1 -v -r tests/integration
|
||||
|
||||
test-container:
|
||||
docker build --target requirements -t local-ai-test-container .
|
||||
docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container
|
||||
@ -536,6 +542,9 @@ backend-assets/grpc/whisper: sources/whisper.cpp sources/whisper.cpp/libwhisper.
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH=$(CURDIR)/sources/whisper.cpp LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/
|
||||
|
||||
backend-assets/grpc/local-store: backend-assets/grpc
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/local-store ./backend/go/stores/
|
||||
|
||||
grpcs: prepare $(GRPC_BACKENDS)
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
|
@ -18,6 +18,48 @@ service Backend {
|
||||
rpc TTS(TTSRequest) returns (Result) {}
|
||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||
|
||||
rpc StoresSet(StoresSetOptions) returns (Result) {}
|
||||
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
||||
rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
|
||||
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
|
||||
}
|
||||
|
||||
message StoresKey {
|
||||
repeated float Floats = 1;
|
||||
}
|
||||
|
||||
message StoresValue {
|
||||
bytes Bytes = 1;
|
||||
}
|
||||
|
||||
message StoresSetOptions {
|
||||
repeated StoresKey Keys = 1;
|
||||
repeated StoresValue Values = 2;
|
||||
}
|
||||
|
||||
message StoresDeleteOptions {
|
||||
repeated StoresKey Keys = 1;
|
||||
}
|
||||
|
||||
message StoresGetOptions {
|
||||
repeated StoresKey Keys = 1;
|
||||
}
|
||||
|
||||
message StoresGetResult {
|
||||
repeated StoresKey Keys = 1;
|
||||
repeated StoresValue Values = 2;
|
||||
}
|
||||
|
||||
message StoresFindOptions {
|
||||
StoresKey Key = 1;
|
||||
int32 TopK = 2;
|
||||
}
|
||||
|
||||
message StoresFindResult {
|
||||
repeated StoresKey Keys = 1;
|
||||
repeated StoresValue Values = 2;
|
||||
repeated float Similarities = 3;
|
||||
}
|
||||
|
||||
message HealthMessage {}
|
||||
|
14
backend/go/stores/debug.go
Normal file
14
backend/go/stores/debug.go
Normal file
@ -0,0 +1,14 @@
|
||||
//go:build debug
|
||||
// +build debug
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func assert(cond bool, msg string) {
|
||||
if !cond {
|
||||
log.Fatal().Stack().Msg(msg)
|
||||
}
|
||||
}
|
26
backend/go/stores/main.go
Normal file
26
backend/go/stores/main.go
Normal file
@ -0,0 +1,26 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each store
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, NewStore()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
7
backend/go/stores/production.go
Normal file
7
backend/go/stores/production.go
Normal file
@ -0,0 +1,7 @@
|
||||
//go:build !debug
|
||||
// +build !debug
|
||||
|
||||
package main
|
||||
|
||||
func assert(cond bool, msg string) {
|
||||
}
|
507
backend/go/stores/store.go
Normal file
507
backend/go/stores/store.go
Normal file
@ -0,0 +1,507 @@
|
||||
package main
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
base.SingleThread
|
||||
|
||||
// The sorted keys
|
||||
keys [][]float32
|
||||
// The sorted values
|
||||
values [][]byte
|
||||
|
||||
// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
|
||||
// TODO: Should we normalize incoming keys if they are not instead?
|
||||
keysAreNormalized bool
|
||||
// The first key decides the length of the keys
|
||||
keyLen int
|
||||
}
|
||||
|
||||
// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
|
||||
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
|
||||
type Pair struct {
|
||||
Key []float32
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
keys: make([][]float32, 0),
|
||||
values: make([][]byte, 0),
|
||||
keysAreNormalized: true,
|
||||
keyLen: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func compareSlices(k1, k2 []float32) int {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
return slices.Compare(k1, k2)
|
||||
}
|
||||
|
||||
func hasKey(unsortedSlice [][]float32, target []float32) bool {
|
||||
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
|
||||
return compareSlices(k, target) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
|
||||
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
|
||||
return compareSlices(k, t)
|
||||
})
|
||||
}
|
||||
|
||||
func isSortedPairs(kvs []Pair) bool {
|
||||
for i := 1; i < len(kvs); i++ {
|
||||
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isSortedKeys(keys [][]float32) bool {
|
||||
for i := 1; i < len(keys); i++ {
|
||||
if compareSlices(keys[i-1], keys[i]) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
|
||||
ks := make([][]float32, len(keys))
|
||||
|
||||
for i, k := range keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
slices.SortFunc(ks, compareSlices)
|
||||
|
||||
assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
|
||||
assert(isSortedKeys(ks), "keys are not sorted")
|
||||
|
||||
return ks
|
||||
}
|
||||
|
||||
func (s *Store) Load(opts *pb.ModelOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort the incoming kvs and merge them with the existing sorted kvs
|
||||
func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
|
||||
if len(opts.Keys) != len(opts.Values) {
|
||||
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
kvs := make([]Pair, len(opts.Keys))
|
||||
|
||||
for i, k := range opts.Keys {
|
||||
if s.keysAreNormalized && !isNormalized(k.Floats) {
|
||||
s.keysAreNormalized = false
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = k.Floats[:5]
|
||||
} else {
|
||||
sample = k.Floats
|
||||
}
|
||||
log.Debug().Msgf("Key is not normalized: %v", sample)
|
||||
}
|
||||
|
||||
kvs[i] = Pair{
|
||||
Key: k.Floats,
|
||||
Value: opts.Values[i].Bytes,
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortFunc(kvs, func(a, b Pair) int {
|
||||
return compareSlices(a.Key, b.Key)
|
||||
})
|
||||
|
||||
assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
|
||||
assert(isSortedPairs(kvs), "keys are not sorted")
|
||||
|
||||
l := len(kvs) + len(s.keys)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
i, j := 0, 0
|
||||
for {
|
||||
if i+j >= l {
|
||||
break
|
||||
}
|
||||
|
||||
if i >= len(kvs) {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
continue
|
||||
}
|
||||
|
||||
if j >= len(s.keys) {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
c := compareSlices(kvs[i].Key, s.keys[j])
|
||||
if c < 0 {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
} else if c > 0 {
|
||||
merge_ks = append(merge_ks, s.keys[j])
|
||||
merge_vs = append(merge_vs, s.values[j])
|
||||
j++
|
||||
} else {
|
||||
merge_ks = append(merge_ks, kvs[i].Key)
|
||||
merge_vs = append(merge_vs, kvs[i].Value)
|
||||
i++
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
|
||||
assert(isSortedKeys(merge_ks), "merge keys are not sorted")
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to delete")
|
||||
}
|
||||
|
||||
if len(opts.Keys) == 0 {
|
||||
return fmt.Errorf("no keys to add")
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
l := len(s.keys) - len(ks)
|
||||
merge_ks := make([][]float32, 0, l)
|
||||
merge_vs := make([][]byte, 0, l)
|
||||
|
||||
tail_ks := s.keys
|
||||
tail_vs := s.values
|
||||
for _, k := range ks {
|
||||
j, found := findInSortedSlice(tail_ks, k)
|
||||
|
||||
if found {
|
||||
merge_ks = append(merge_ks, tail_ks[:j]...)
|
||||
merge_vs = append(merge_vs, tail_vs[:j]...)
|
||||
tail_ks = tail_ks[j+1:]
|
||||
tail_vs = tail_vs[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Delete: found = %v, t = %d, j = %d, len(merge_ks) = %d, len(merge_vs) = %d", found, len(tail_ks), j, len(merge_ks), len(merge_vs))
|
||||
}
|
||||
|
||||
merge_ks = append(merge_ks, tail_ks...)
|
||||
merge_vs = append(merge_vs, tail_vs...)
|
||||
|
||||
assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))
|
||||
|
||||
s.keys = merge_ks
|
||||
s.values = merge_vs
|
||||
|
||||
assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
|
||||
assert(isSortedKeys(s.keys), "keys are not sorted")
|
||||
assert(func() bool {
|
||||
for _, k := range ks {
|
||||
if _, found := findInSortedSlice(s.keys, k); found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}(), "Keys to delete still present")
|
||||
|
||||
if len(s.keys) != l {
|
||||
log.Debug().Msgf("Delete: Some keys not found: len(s.keys) = %d, l = %d", len(s.keys), l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys))
|
||||
pbValues := make([]*pb.StoresValue, 0, len(opts.Keys))
|
||||
ks := sortIntoKeySlicese(opts.Keys)
|
||||
|
||||
if len(s.keys) == 0 {
|
||||
log.Debug().Msgf("Get: No keys in store")
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Keys[0].Floats)
|
||||
} else {
|
||||
if len(opts.Keys[0].Floats) != s.keyLen {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
tail_k := s.keys
|
||||
tail_v := s.values
|
||||
for i, k := range ks {
|
||||
j, found := findInSortedSlice(tail_k, k)
|
||||
|
||||
if found {
|
||||
pbKeys = append(pbKeys, &pb.StoresKey{
|
||||
Floats: k,
|
||||
})
|
||||
pbValues = append(pbValues, &pb.StoresValue{
|
||||
Bytes: tail_v[j],
|
||||
})
|
||||
|
||||
tail_k = tail_k[j+1:]
|
||||
tail_v = tail_v[j+1:]
|
||||
} else {
|
||||
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k))
|
||||
}
|
||||
}
|
||||
|
||||
if len(pbKeys) != len(opts.Keys) {
|
||||
log.Debug().Msgf("Get: Some keys not found: len(pbKeys) = %d, len(opts.Keys) = %d, len(s.Keys) = %d", len(pbKeys), len(opts.Keys), len(s.keys))
|
||||
}
|
||||
|
||||
return pb.StoresGetResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isNormalized(k []float32) bool {
|
||||
var sum float32
|
||||
for _, v := range k {
|
||||
sum += v
|
||||
}
|
||||
|
||||
return sum == 1.0
|
||||
}
|
||||
|
||||
// TODO: This we could replace with handwritten SIMD code
|
||||
func normalizedCosineSimilarity(k1, k2 []float32) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot float32
|
||||
for i := 0; i < len(k1); i++ {
|
||||
dot += k1[i] * k2[i]
|
||||
}
|
||||
|
||||
assert(dot >= -1 && dot <= 1, fmt.Sprintf("dot = %f", dot))
|
||||
|
||||
// 2.0 * (1.0 - dot) would be the Euclidean distance
|
||||
return dot
|
||||
}
|
||||
|
||||
type PriorityItem struct {
|
||||
Similarity float32
|
||||
Key []float32
|
||||
Value []byte
|
||||
}
|
||||
|
||||
type PriorityQueue []*PriorityItem
|
||||
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
// Inverted because the most similar should be at the top
|
||||
return pq[i].Similarity < pq[j].Similarity
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Push(x any) {
|
||||
item := x.(*PriorityItem)
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() any {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*pq = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
for i, k := range s.keys {
|
||||
sim := normalizedCosineSimilarity(tk, k)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: sim,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 {
|
||||
assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))
|
||||
|
||||
var dot, mag2 float64
|
||||
for i := 0; i < len(k1); i++ {
|
||||
dot += float64(k1[i] * k2[i])
|
||||
mag2 += float64(k2[i] * k2[i])
|
||||
}
|
||||
|
||||
sim := float32(dot / (mag1 * math.Sqrt(mag2)))
|
||||
|
||||
assert(sim >= -1 && sim <= 1, fmt.Sprintf("sim = %f", sim))
|
||||
|
||||
return sim
|
||||
}
|
||||
|
||||
func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
top_ks := make(PriorityQueue, 0, int(opts.TopK))
|
||||
heap.Init(&top_ks)
|
||||
|
||||
var mag1 float64
|
||||
for _, v := range tk {
|
||||
mag1 += float64(v * v)
|
||||
}
|
||||
mag1 = math.Sqrt(mag1)
|
||||
|
||||
for i, k := range s.keys {
|
||||
dist := cosineSimilarity(tk, k, mag1)
|
||||
heap.Push(&top_ks, &PriorityItem{
|
||||
Similarity: dist,
|
||||
Key: k,
|
||||
Value: s.values[i],
|
||||
})
|
||||
|
||||
if top_ks.Len() > int(opts.TopK) {
|
||||
heap.Pop(&top_ks)
|
||||
}
|
||||
}
|
||||
|
||||
similarities := make([]float32, top_ks.Len())
|
||||
pbKeys := make([]*pb.StoresKey, top_ks.Len())
|
||||
pbValues := make([]*pb.StoresValue, top_ks.Len())
|
||||
|
||||
for i := top_ks.Len() - 1; i >= 0; i-- {
|
||||
item := heap.Pop(&top_ks).(*PriorityItem)
|
||||
|
||||
similarities[i] = item.Similarity
|
||||
pbKeys[i] = &pb.StoresKey{
|
||||
Floats: item.Key,
|
||||
}
|
||||
pbValues[i] = &pb.StoresValue{
|
||||
Bytes: item.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return pb.StoresFindResult{
|
||||
Keys: pbKeys,
|
||||
Values: pbValues,
|
||||
Similarities: similarities,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
tk := opts.Key.Floats
|
||||
|
||||
if len(tk) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen)
|
||||
}
|
||||
|
||||
if opts.TopK < 1 {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK)
|
||||
}
|
||||
|
||||
if s.keyLen == -1 {
|
||||
s.keyLen = len(opts.Key.Floats)
|
||||
} else {
|
||||
if len(opts.Key.Floats) != s.keyLen {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
if s.keysAreNormalized && isNormalized(tk) {
|
||||
return s.StoresFindNormalized(opts)
|
||||
} else {
|
||||
if s.keysAreNormalized {
|
||||
var sample []float32
|
||||
if len(s.keys) > 5 {
|
||||
sample = tk[:5]
|
||||
} else {
|
||||
sample = tk
|
||||
}
|
||||
log.Debug().Msgf("Trying to compare non-normalized key with normalized keys: %v", sample)
|
||||
}
|
||||
|
||||
return s.StoresFindFallback(opts)
|
||||
}
|
||||
}
|
23
core/backend/stores.go
Normal file
23
core/backend/stores.go
Normal file
@ -0,0 +1,23 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) (grpc.Backend, error) {
|
||||
if storeName == "" {
|
||||
storeName = "default"
|
||||
}
|
||||
|
||||
sc := []model.Option{
|
||||
model.WithBackendString(model.LocalStoreBackend),
|
||||
model.WithAssetDir(appConfig.AssetsDestination),
|
||||
model.WithModel(storeName),
|
||||
}
|
||||
|
||||
return sl.BackendLoader(sc...)
|
||||
}
|
||||
|
@ -172,6 +172,13 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
||||
// Elevenlabs
|
||||
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||
|
||||
// Stores
|
||||
sl := model.NewModelLoader("")
|
||||
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
|
||||
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
|
||||
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
|
||||
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
|
||||
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
. "github.com/go-skynet/LocalAI/core/http"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/core/startup"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
@ -122,6 +123,75 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
|
||||
return
|
||||
}
|
||||
|
||||
func postRequestJSON[B any](url string, bodyJson *B) error {
|
||||
payload, err := json.Marshal(bodyJson)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error {
|
||||
payload, err := json.Marshal(reqJson)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
GinkgoWriter.Printf("POST %s: %s\n", url, string(payload))
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return json.Unmarshal(body, respJson)
|
||||
}
|
||||
|
||||
//go:embed backend-assets/*
|
||||
var backendAssets embed.FS
|
||||
|
||||
@ -836,6 +906,78 @@ var _ = Describe("API test", func() {
|
||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
||||
})
|
||||
})
|
||||
|
||||
// See tests/integration/stores_test
|
||||
Context("Stores", Label("stores"), func() {
|
||||
|
||||
It("sets, gets, finds and deletes entries", func() {
|
||||
ks := [][]float32{
|
||||
{0.1, 0.2, 0.3},
|
||||
{0.4, 0.5, 0.6},
|
||||
{0.7, 0.8, 0.9},
|
||||
}
|
||||
vs := []string{
|
||||
"test1",
|
||||
"test2",
|
||||
"test3",
|
||||
}
|
||||
setBody := schema.StoresSet{
|
||||
Keys: ks,
|
||||
Values: vs,
|
||||
}
|
||||
|
||||
url := "http://127.0.0.1:9090/stores/"
|
||||
err := postRequestJSON(url+"set", &setBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
getBody := schema.StoresGet{
|
||||
Keys: ks,
|
||||
}
|
||||
var getRespBody schema.StoresGetResponse
|
||||
err = postRequestResponseJSON(url+"get", &getBody, &getRespBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(getRespBody.Keys)).To(Equal(len(ks)))
|
||||
|
||||
for i, v := range getRespBody.Keys {
|
||||
if v[0] == 0.1 {
|
||||
Expect(getRespBody.Values[i]).To(Equal("test1"))
|
||||
} else if v[0] == 0.4 {
|
||||
Expect(getRespBody.Values[i]).To(Equal("test2"))
|
||||
} else {
|
||||
Expect(getRespBody.Values[i]).To(Equal("test3"))
|
||||
}
|
||||
}
|
||||
|
||||
deleteBody := schema.StoresDelete{
|
||||
Keys: [][]float32{
|
||||
{0.1, 0.2, 0.3},
|
||||
},
|
||||
}
|
||||
err = postRequestJSON(url+"delete", &deleteBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
findBody := schema.StoresFind{
|
||||
Key: []float32{0.1, 0.3, 0.7},
|
||||
Topk: 10,
|
||||
}
|
||||
|
||||
var findRespBody schema.StoresFindResponse
|
||||
err = postRequestResponseJSON(url+"find", &findBody, &findRespBody)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(findRespBody.Keys)).To(Equal(2))
|
||||
|
||||
for i, v := range findRespBody.Keys {
|
||||
if v[0] == 0.4 {
|
||||
Expect(findRespBody.Values[i]).To(Equal("test2"))
|
||||
} else {
|
||||
Expect(findRespBody.Values[i]).To(Equal("test3"))
|
||||
}
|
||||
|
||||
Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1))
|
||||
Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Config file", func() {
|
||||
|
121
core/http/endpoints/localai/stores.go
Normal file
121
core/http/endpoints/localai/stores.go
Normal file
@ -0,0 +1,121 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/store"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.StoresSet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vals := make([][]byte, len(input.Values))
|
||||
for i, v := range input.Values {
|
||||
vals[i] = []byte(v)
|
||||
}
|
||||
|
||||
err = store.SetCols(c.Context(), sb, input.Keys, vals)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Send(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.StoresDelete)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Send(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.StoresGet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res := schema.StoresGetResponse{
|
||||
Keys: keys,
|
||||
Values: make([]string, len(vals)),
|
||||
}
|
||||
|
||||
for i, v := range vals {
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.StoresFind)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sb, err := backend.StoreBackend(sl, appConfig, input.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res := schema.StoresFindResponse{
|
||||
Keys: keys,
|
||||
Values: make([]string, len(vals)),
|
||||
Similarities: similarities,
|
||||
}
|
||||
|
||||
for i, v := range vals {
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
}
|
||||
}
|
@ -20,3 +20,40 @@ type TTSRequest struct {
|
||||
Voice string `json:"voice" yaml:"voice"`
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
}
|
||||
|
||||
type StoresSet struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
||||
Values []string `json:"values" yaml:"values"`
|
||||
}
|
||||
|
||||
type StoresDelete struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
Keys [][]float32 `json:"keys"`
|
||||
}
|
||||
|
||||
type StoresGet struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
||||
}
|
||||
|
||||
type StoresGetResponse struct {
|
||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
||||
Values []string `json:"values" yaml:"values"`
|
||||
}
|
||||
|
||||
type StoresFind struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
Key []float32 `json:"key" yaml:"key"`
|
||||
Topk int `json:"topk" yaml:"topk"`
|
||||
}
|
||||
|
||||
type StoresFindResponse struct {
|
||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
||||
Values []string `json:"values" yaml:"values"`
|
||||
Similarities []float32 `json:"similarities" yaml:"similarities"`
|
||||
}
|
||||
|
97
docs/content/docs/features/stores.md
Normal file
97
docs/content/docs/features/stores.md
Normal file
@ -0,0 +1,97 @@
|
||||
|
||||
+++
|
||||
disableToc = false
|
||||
title = "💾 Stores"
|
||||
|
||||
weight = 18
|
||||
url = '/stores'
|
||||
+++
|
||||
|
||||
Stores are an experimental feature to help with querying data using similarity search. It is
|
||||
a low level API that consists of only `get`, `set`, `delete` and `find`.
|
||||
|
||||
For example if you have an embedding of some text and want to find text with similar embeddings.
|
||||
You can create embeddings for chunks of all your text then compare them against the embedding of the text you
|
||||
are searching on.
|
||||
|
||||
An embedding here meaning a vector of numbers that represent some information about the text. The
|
||||
embeddings are created from an A.I. model such as BERT or a more traditional method such as word
|
||||
frequency.
|
||||
|
||||
Previously you would have to integrate with an external vector database or library directly.
|
||||
With the stores feature you can now do it through the LocalAI API.
|
||||
|
||||
Note however that doing a similarity search on embeddings is just one way to do retrieval. A higher level
|
||||
API can take this into account, so this may not be the best place to start.
|
||||
|
||||
## API overview
|
||||
|
||||
There is an internal gRPC API and an external facing HTTP JSON API. We'll just discuss the external HTTP API,
|
||||
however the HTTP API mirrors the gRPC API. Consult `pkg/store/client` for internal usage.
|
||||
|
||||
Everything is in columnar format meaning that instead of getting an array of objects with a key and a value each.
|
||||
You instead get two separate arrays of keys and values.
|
||||
|
||||
Keys are arrays of floating point numbers with a maximum width of 32bits. Values are strings (in gRPC they are bytes).
|
||||
|
||||
The key vectors must all be the same length and it's best for search performance if they are normalized. When
|
||||
addings keys it will be detected if they are not normalized and what length they are.
|
||||
|
||||
All endpoints accept a `store` field which specifies which store to operate on. Presently they are created
|
||||
on the fly and there is only one store backend so no configuration is required.
|
||||
|
||||
## Set
|
||||
|
||||
To set some keys you can do
|
||||
|
||||
```
|
||||
curl -X POST http://localhost:8080/stores/set \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"keys": [[0.1, 0.2], [0.3, 0.4]], "values": ["foo", "bar"]}'
|
||||
```
|
||||
|
||||
Setting the same keys again will update their values.
|
||||
|
||||
On success 200 OK is returned with no body.
|
||||
|
||||
## Get
|
||||
|
||||
To get some keys you can do
|
||||
|
||||
```
|
||||
curl -X POST http://localhost:8080/stores/get \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"keys": [[0.1, 0.2]]}'
|
||||
```
|
||||
|
||||
Both the keys and values are returned, e.g: `{"keys":[[0.1,0.2]],"values":["foo"]}`
|
||||
|
||||
The order of the keys is not preserved! If a key does not exist then nothing is returned.
|
||||
|
||||
## Delete
|
||||
|
||||
To delete keys and values you can do
|
||||
|
||||
```
|
||||
curl -X POST http://localhost:8080/stores/delete \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"keys": [[0.1, 0.2]]}'
|
||||
```
|
||||
|
||||
If a key doesn't exist then it is ignored.
|
||||
|
||||
On success 200 OK is returned with no body.
|
||||
|
||||
## Find
|
||||
|
||||
To do a similarity search you can do
|
||||
|
||||
```
|
||||
curl -X POST http://localhost:8080/stores/find
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"topk": 2, "key": [0.2, 0.1]}'
|
||||
```
|
||||
|
||||
`topk` limits the number of results returned. The result value is the same as `get`,
|
||||
except that it also includes an array of `similarities`. Where `1.0` is the maximum similarity.
|
||||
They are returned in the order of most similar to least.
|
@ -73,6 +73,7 @@ Note that this started just as a fun weekend project by [mudler](https://github.
|
||||
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
||||
- 🆕 [Vision API](https://localai.io/features/gpt-vision/)
|
||||
- 💾 [Stores](https://localai.io/features/stores)
|
||||
|
||||
## Contribute and help
|
||||
|
||||
|
15
examples/semantic-todo/README.md
Normal file
15
examples/semantic-todo/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
This demonstrates the vector store backend in its simplest form.
|
||||
You can add tasks and then search/sort them using the TUI.
|
||||
|
||||
To build and run do
|
||||
|
||||
```bash
|
||||
$ go get .
|
||||
$ go run .
|
||||
```
|
||||
|
||||
A seperate LocaAI instance is required of course. For e.g.
|
||||
|
||||
```bash
|
||||
$ docker run -e DEBUG=true --rm -it -p 8080:8080 <LocalAI-image> bert-cpp
|
||||
```
|
18
examples/semantic-todo/go.mod
Normal file
18
examples/semantic-todo/go.mod
Normal file
@ -0,0 +1,18 @@
|
||||
module semantic-todo
|
||||
|
||||
go 1.21.6
|
||||
|
||||
require (
|
||||
github.com/gdamore/tcell/v2 v2.7.1
|
||||
github.com/rivo/tview v0.0.0-20240307173318-e804876934a1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gdamore/encoding v1.0.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
golang.org/x/sys v0.17.0 // indirect
|
||||
golang.org/x/term v0.17.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
)
|
50
examples/semantic-todo/go.sum
Normal file
50
examples/semantic-todo/go.sum
Normal file
@ -0,0 +1,50 @@
|
||||
github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko=
|
||||
github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg=
|
||||
github.com/gdamore/tcell/v2 v2.7.1 h1:TiCcmpWHiAU7F0rA2I3S2Y4mmLmO9KHxJ7E1QhYzQbc=
|
||||
github.com/gdamore/tcell/v2 v2.7.1/go.mod h1:dSXtXTSK0VsW1biw65DZLZ2NKr7j0qP/0J7ONmsraWg=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/rivo/tview v0.0.0-20240307173318-e804876934a1 h1:bWLHTRekAy497pE7+nXSuzXwwFHI0XauRzz6roUvY+s=
|
||||
github.com/rivo/tview v0.0.0-20240307173318-e804876934a1/go.mod h1:02iFIz7K/A9jGCvrizLPvoqr4cEIx7q54RH5Qudkrss=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.3/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
352
examples/semantic-todo/main.go
Normal file
352
examples/semantic-todo/main.go
Normal file
@ -0,0 +1,352 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
)
|
||||
|
||||
const (
|
||||
localAI string = "http://localhost:8080"
|
||||
rootStatus string = "[::b]<space>[::-]: Add Task [::b]/[::-]: Search Task [::b]<C-c>[::-]: Exit"
|
||||
inputStatus string = "Press [::b]<enter>[::-] to submit the task, [::b]<esc>[::-] to cancel"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
Description string
|
||||
Similarity float32
|
||||
}
|
||||
|
||||
type AppState int
|
||||
|
||||
const (
|
||||
StateRoot AppState = iota
|
||||
StateInput
|
||||
StateSearch
|
||||
)
|
||||
|
||||
type App struct {
|
||||
state AppState
|
||||
tasks []Task
|
||||
app *tview.Application
|
||||
flex *tview.Flex
|
||||
table *tview.Table
|
||||
}
|
||||
|
||||
func NewApp() *App {
|
||||
return &App{
|
||||
state: StateRoot,
|
||||
tasks: []Task{
|
||||
{Description: "Take the dog for a walk (after I get a dog)"},
|
||||
{Description: "Go to the toilet"},
|
||||
{Description: "Allow TODOs to be marked completed or removed"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getEmbeddings(description string) ([]float32, error) {
|
||||
// Define the request payload
|
||||
payload := map[string]interface{}{
|
||||
"model": "bert-cpp-minilm-v6",
|
||||
"input": description,
|
||||
}
|
||||
|
||||
// Marshal the payload into JSON
|
||||
jsonPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make the HTTP request to the local OpenAI embeddings API
|
||||
resp, err := http.Post(localAI+"/embeddings", "application/json", bytes.NewBuffer(jsonPayload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if the request was successful
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("request to embeddings API failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse the response body
|
||||
var result struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the embedding
|
||||
if len(result.Data) > 0 {
|
||||
return result.Data[0].Embedding, nil
|
||||
}
|
||||
return nil, errors.New("no embedding received from API")
|
||||
}
|
||||
|
||||
type StoresSet struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
||||
Values []string `json:"values" yaml:"values"`
|
||||
}
|
||||
|
||||
func postTasksToExternalService(tasks []Task) error {
|
||||
keys := make([][]float32, 0, len(tasks))
|
||||
// Get the embeddings for the task description
|
||||
for _, task := range tasks {
|
||||
embedding, err := getEmbeddings(task.Description)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keys = append(keys, embedding)
|
||||
}
|
||||
|
||||
values := make([]string, 0, len(tasks))
|
||||
for _, task := range tasks {
|
||||
values = append(values, task.Description)
|
||||
}
|
||||
|
||||
// Construct the StoresSet object
|
||||
storesSet := StoresSet{
|
||||
Store: "tasks_store", // Assuming you have a specific store name
|
||||
Keys: keys,
|
||||
Values: values,
|
||||
}
|
||||
|
||||
// Marshal the StoresSet object into JSON
|
||||
jsonData, err := json.Marshal(storesSet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make the HTTP POST request to the external service
|
||||
resp, err := http.Post(localAI+"/stores/set", "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if the request was successful
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// read resp body into string
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("store request failed with status code: %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type StoresFind struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
Key []float32 `json:"key" yaml:"key"`
|
||||
Topk int `json:"topk" yaml:"topk"`
|
||||
}
|
||||
|
||||
type StoresFindResponse struct {
|
||||
Keys [][]float32 `json:"keys" yaml:"keys"`
|
||||
Values []string `json:"values" yaml:"values"`
|
||||
Similarities []float32 `json:"similarities" yaml:"similarities"`
|
||||
}
|
||||
|
||||
func findSimilarTexts(inputText string, topk int) (StoresFindResponse, error) {
|
||||
// Initialize an empty response object
|
||||
response := StoresFindResponse{}
|
||||
|
||||
// Get the embedding for the input text
|
||||
embedding, err := getEmbeddings(inputText)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
// Construct the StoresFind object
|
||||
storesFind := StoresFind{
|
||||
Store: "tasks_store", // Assuming you have a specific store name
|
||||
Key: embedding,
|
||||
Topk: topk,
|
||||
}
|
||||
|
||||
// Marshal the StoresFind object into JSON
|
||||
jsonData, err := json.Marshal(storesFind)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
// Make the HTTP POST request to the external service's /stores/find endpoint
|
||||
resp, err := http.Post(localAI+"/stores/find", "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if the request was successful
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return response, fmt.Errorf("request to /stores/find failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse the response body to retrieve similar texts and similarities
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (app *App) updateUI() {
|
||||
// Clear the flex layout
|
||||
app.flex.Clear()
|
||||
app.flex.SetDirection(tview.FlexColumn)
|
||||
app.flex.AddItem(nil, 0, 1, false)
|
||||
|
||||
midCol := tview.NewFlex()
|
||||
midCol.SetDirection(tview.FlexRow)
|
||||
midCol.AddItem(nil, 0, 1, false)
|
||||
|
||||
// Create a new table.
|
||||
app.table.Clear()
|
||||
app.table.SetBorders(true)
|
||||
|
||||
// Set table headers
|
||||
app.table.SetCell(0, 0, tview.NewTableCell("Description").SetAlign(tview.AlignLeft).SetExpansion(1).SetAttributes(tcell.AttrBold))
|
||||
app.table.SetCell(0, 1, tview.NewTableCell("Similarity").SetAlign(tview.AlignCenter).SetExpansion(0).SetAttributes(tcell.AttrBold))
|
||||
|
||||
// Add the tasks to the table.
|
||||
for i, task := range app.tasks {
|
||||
row := i + 1
|
||||
app.table.SetCell(row, 0, tview.NewTableCell(task.Description))
|
||||
app.table.SetCell(row, 1, tview.NewTableCell(fmt.Sprintf("%.2f", task.Similarity)))
|
||||
}
|
||||
|
||||
if app.state == StateInput {
|
||||
inputField := tview.NewInputField()
|
||||
inputField.
|
||||
SetLabel("New Task: ").
|
||||
SetFieldWidth(0).
|
||||
SetDoneFunc(func(key tcell.Key) {
|
||||
if key == tcell.KeyEnter {
|
||||
task := Task{Description: inputField.GetText()}
|
||||
app.tasks = append(app.tasks, task)
|
||||
app.state = StateRoot
|
||||
postTasksToExternalService([]Task{task})
|
||||
}
|
||||
app.updateUI()
|
||||
})
|
||||
midCol.AddItem(inputField, 3, 2, true)
|
||||
app.app.SetFocus(inputField)
|
||||
} else if app.state == StateSearch {
|
||||
searchField := tview.NewInputField()
|
||||
searchField.SetLabel("Search: ").
|
||||
SetFieldWidth(0).
|
||||
SetDoneFunc(func(key tcell.Key) {
|
||||
if key == tcell.KeyEnter {
|
||||
similar, err := findSimilarTexts(searchField.GetText(), 100)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
app.tasks = make([]Task, len(similar.Keys))
|
||||
for i, v := range similar.Values {
|
||||
app.tasks[i] = Task{Description: v, Similarity: similar.Similarities[i]}
|
||||
}
|
||||
}
|
||||
app.updateUI()
|
||||
})
|
||||
midCol.AddItem(searchField, 3, 2, true)
|
||||
app.app.SetFocus(searchField)
|
||||
} else {
|
||||
midCol.AddItem(nil, 3, 1, false)
|
||||
}
|
||||
|
||||
midCol.AddItem(app.table, 0, 2, true)
|
||||
|
||||
// Add the status bar to the flex layout
|
||||
statusBar := tview.NewTextView().
|
||||
SetText(rootStatus).
|
||||
SetDynamicColors(true).
|
||||
SetTextAlign(tview.AlignCenter)
|
||||
if app.state == StateInput {
|
||||
statusBar.SetText(inputStatus)
|
||||
}
|
||||
midCol.AddItem(statusBar, 1, 1, false)
|
||||
midCol.AddItem(nil, 0, 1, false)
|
||||
|
||||
app.flex.AddItem(midCol, 0, 10, true)
|
||||
app.flex.AddItem(nil, 0, 1, false)
|
||||
|
||||
// Set the flex as the root element
|
||||
app.app.SetRoot(app.flex, true)
|
||||
}
|
||||
|
||||
func main() {
|
||||
app := NewApp()
|
||||
tApp := tview.NewApplication()
|
||||
flex := tview.NewFlex().SetDirection(tview.FlexRow)
|
||||
table := tview.NewTable()
|
||||
|
||||
app.app = tApp
|
||||
app.flex = flex
|
||||
app.table = table
|
||||
|
||||
app.updateUI() // Initial UI setup
|
||||
|
||||
app.app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
|
||||
switch app.state {
|
||||
case StateRoot:
|
||||
// Handle key events when in the root state
|
||||
switch event.Key() {
|
||||
case tcell.KeyRune:
|
||||
switch event.Rune() {
|
||||
case ' ':
|
||||
app.state = StateInput
|
||||
app.updateUI()
|
||||
return nil // Event is handled
|
||||
case '/':
|
||||
app.state = StateSearch
|
||||
app.updateUI()
|
||||
return nil // Event is handled
|
||||
}
|
||||
}
|
||||
|
||||
case StateInput:
|
||||
// Handle key events when in the input state
|
||||
if event.Key() == tcell.KeyEsc {
|
||||
// Exit input state without adding a task
|
||||
app.state = StateRoot
|
||||
app.updateUI()
|
||||
return nil // Event is handled
|
||||
}
|
||||
|
||||
case StateSearch:
|
||||
// Handle key events when in the search state
|
||||
if event.Key() == tcell.KeyEsc {
|
||||
// Exit search state
|
||||
app.state = StateRoot
|
||||
app.updateUI()
|
||||
return nil // Event is handled
|
||||
}
|
||||
}
|
||||
|
||||
// Return the event for further processing by tview
|
||||
return event
|
||||
})
|
||||
|
||||
if err := postTasksToExternalService(app.tasks); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start the application
|
||||
if err := app.app.Run(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
@ -44,4 +44,9 @@ type Backend interface {
|
||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error)
|
||||
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
|
||||
Status(ctx context.Context) (*pb.StatusResponse, error)
|
||||
|
||||
StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error)
|
||||
StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error)
|
||||
}
|
||||
|
@ -72,6 +72,22 @@ func (llm *Base) Status() (pb.StatusResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (llm *Base) StoresSet(*pb.StoresSetOptions) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
return pb.StoresGetResult{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StoresDelete(*pb.StoresDeleteOptions) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
return pb.StoresFindResult{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func memoryUsage() *pb.MemoryUsageData {
|
||||
mud := pb.MemoryUsageData{
|
||||
Breakdown: make(map[string]uint64),
|
||||
|
@ -291,3 +291,67 @@ func (c *Client) Status(ctx context.Context) (*pb.StatusResponse, error) {
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.Status(ctx, &pb.HealthMessage{})
|
||||
}
|
||||
|
||||
func (c *Client) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StoresSet(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StoresDelete(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StoresGet(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StoresFind(ctx, in, opts...)
|
||||
}
|
||||
|
@ -85,6 +85,22 @@ func (e *embedBackend) Status(ctx context.Context) (*pb.StatusResponse, error) {
|
||||
return e.s.Status(ctx, &pb.HealthMessage{})
|
||||
}
|
||||
|
||||
func (e *embedBackend) StoresSet(ctx context.Context, in *pb.StoresSetOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.StoresSet(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.StoresDelete(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) {
|
||||
return e.s.StoresGet(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) {
|
||||
return e.s.StoresFind(ctx, in)
|
||||
}
|
||||
|
||||
type embedBackendServerStream struct {
|
||||
ctx context.Context
|
||||
fn func(s []byte)
|
||||
|
@ -19,6 +19,11 @@ type LLM interface {
|
||||
TTS(*pb.TTSRequest) error
|
||||
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
||||
Status() (pb.StatusResponse, error)
|
||||
|
||||
StoresSet(*pb.StoresSetOptions) error
|
||||
StoresDelete(*pb.StoresDeleteOptions) error
|
||||
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
|
||||
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
|
||||
}
|
||||
|
||||
func newReply(s string) *pb.Reply {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc-gen-go-grpc v1.3.0
|
||||
// - protoc v4.23.4
|
||||
// source: backend.proto
|
||||
|
||||
@ -18,6 +18,23 @@ import (
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
const (
|
||||
Backend_Health_FullMethodName = "/backend.Backend/Health"
|
||||
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
|
||||
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
|
||||
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
|
||||
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
|
||||
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
|
||||
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
|
||||
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
|
||||
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
|
||||
Backend_Status_FullMethodName = "/backend.Backend/Status"
|
||||
Backend_StoresSet_FullMethodName = "/backend.Backend/StoresSet"
|
||||
Backend_StoresDelete_FullMethodName = "/backend.Backend/StoresDelete"
|
||||
Backend_StoresGet_FullMethodName = "/backend.Backend/StoresGet"
|
||||
Backend_StoresFind_FullMethodName = "/backend.Backend/StoresFind"
|
||||
)
|
||||
|
||||
// BackendClient is the client API for Backend service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
@ -32,6 +49,10 @@ type BackendClient interface {
|
||||
TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error)
|
||||
TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error)
|
||||
Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error)
|
||||
StoresSet(ctx context.Context, in *StoresSetOptions, opts ...grpc.CallOption) (*Result, error)
|
||||
StoresDelete(ctx context.Context, in *StoresDeleteOptions, opts ...grpc.CallOption) (*Result, error)
|
||||
StoresGet(ctx context.Context, in *StoresGetOptions, opts ...grpc.CallOption) (*StoresGetResult, error)
|
||||
StoresFind(ctx context.Context, in *StoresFindOptions, opts ...grpc.CallOption) (*StoresFindResult, error)
|
||||
}
|
||||
|
||||
type backendClient struct {
|
||||
@ -44,7 +65,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
|
||||
|
||||
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
|
||||
out := new(Reply)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -53,7 +74,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
|
||||
|
||||
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
|
||||
out := new(Reply)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -62,7 +83,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
|
||||
|
||||
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -70,7 +91,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
|
||||
}
|
||||
|
||||
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -103,7 +124,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
|
||||
|
||||
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
|
||||
out := new(EmbeddingResult)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -112,7 +133,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
|
||||
|
||||
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -121,7 +142,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
|
||||
|
||||
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
|
||||
out := new(TranscriptResult)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -130,7 +151,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
|
||||
|
||||
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -139,7 +160,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca
|
||||
|
||||
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
|
||||
out := new(TokenizationResponse)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -148,7 +169,43 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions,
|
||||
|
||||
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
|
||||
out := new(StatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *backendClient) StoresSet(ctx context.Context, in *StoresSetOptions, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, Backend_StoresSet_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *backendClient) StoresDelete(ctx context.Context, in *StoresDeleteOptions, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, Backend_StoresDelete_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *backendClient) StoresGet(ctx context.Context, in *StoresGetOptions, opts ...grpc.CallOption) (*StoresGetResult, error) {
|
||||
out := new(StoresGetResult)
|
||||
err := c.cc.Invoke(ctx, Backend_StoresGet_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *backendClient) StoresFind(ctx context.Context, in *StoresFindOptions, opts ...grpc.CallOption) (*StoresFindResult, error) {
|
||||
out := new(StoresFindResult)
|
||||
err := c.cc.Invoke(ctx, Backend_StoresFind_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -169,6 +226,10 @@ type BackendServer interface {
|
||||
TTS(context.Context, *TTSRequest) (*Result, error)
|
||||
TokenizeString(context.Context, *PredictOptions) (*TokenizationResponse, error)
|
||||
Status(context.Context, *HealthMessage) (*StatusResponse, error)
|
||||
StoresSet(context.Context, *StoresSetOptions) (*Result, error)
|
||||
StoresDelete(context.Context, *StoresDeleteOptions) (*Result, error)
|
||||
StoresGet(context.Context, *StoresGetOptions) (*StoresGetResult, error)
|
||||
StoresFind(context.Context, *StoresFindOptions) (*StoresFindResult, error)
|
||||
mustEmbedUnimplementedBackendServer()
|
||||
}
|
||||
|
||||
@ -206,6 +267,18 @@ func (UnimplementedBackendServer) TokenizeString(context.Context, *PredictOption
|
||||
func (UnimplementedBackendServer) Status(context.Context, *HealthMessage) (*StatusResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Status not implemented")
|
||||
}
|
||||
func (UnimplementedBackendServer) StoresSet(context.Context, *StoresSetOptions) (*Result, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method StoresSet not implemented")
|
||||
}
|
||||
func (UnimplementedBackendServer) StoresDelete(context.Context, *StoresDeleteOptions) (*Result, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method StoresDelete not implemented")
|
||||
}
|
||||
func (UnimplementedBackendServer) StoresGet(context.Context, *StoresGetOptions) (*StoresGetResult, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method StoresGet not implemented")
|
||||
}
|
||||
func (UnimplementedBackendServer) StoresFind(context.Context, *StoresFindOptions) (*StoresFindResult, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method StoresFind not implemented")
|
||||
}
|
||||
func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {}
|
||||
|
||||
// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service.
|
||||
@ -229,7 +302,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/Health",
|
||||
FullMethod: Backend_Health_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
|
||||
@ -247,7 +320,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/Predict",
|
||||
FullMethod: Backend_Predict_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
|
||||
@ -265,7 +338,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/LoadModel",
|
||||
FullMethod: Backend_LoadModel_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
|
||||
@ -304,7 +377,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/Embedding",
|
||||
FullMethod: Backend_Embedding_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
|
||||
@ -322,7 +395,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/GenerateImage",
|
||||
FullMethod: Backend_GenerateImage_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
|
||||
@ -340,7 +413,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/AudioTranscription",
|
||||
FullMethod: Backend_AudioTranscription_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
|
||||
@ -358,7 +431,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/TTS",
|
||||
FullMethod: Backend_TTS_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
|
||||
@ -376,7 +449,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/TokenizeString",
|
||||
FullMethod: Backend_TokenizeString_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
|
||||
@ -394,7 +467,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/Status",
|
||||
FullMethod: Backend_Status_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Status(ctx, req.(*HealthMessage))
|
||||
@ -402,6 +475,78 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Backend_StoresSet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StoresSetOptions)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(BackendServer).StoresSet(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_StoresSet_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).StoresSet(ctx, req.(*StoresSetOptions))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Backend_StoresDelete_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StoresDeleteOptions)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(BackendServer).StoresDelete(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_StoresDelete_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).StoresDelete(ctx, req.(*StoresDeleteOptions))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Backend_StoresGet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StoresGetOptions)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(BackendServer).StoresGet(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_StoresGet_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).StoresGet(ctx, req.(*StoresGetOptions))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Backend_StoresFind_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StoresFindOptions)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(BackendServer).StoresFind(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: Backend_StoresFind_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).StoresFind(ctx, req.(*StoresFindOptions))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@ -445,6 +590,22 @@ var Backend_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "Status",
|
||||
Handler: _Backend_Status_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StoresSet",
|
||||
Handler: _Backend_StoresSet_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StoresDelete",
|
||||
Handler: _Backend_StoresDelete_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StoresGet",
|
||||
Handler: _Backend_StoresGet_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "StoresFind",
|
||||
Handler: _Backend_StoresFind_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
@ -167,6 +167,54 @@ func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusRe
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.StoresSet(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error setting entry: %s", err.Error()), Success: false}, err
|
||||
}
|
||||
return &pb.Result{Message: "Set key", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.StoresDelete(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err
|
||||
}
|
||||
return &pb.Result{Message: "Deleted key", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.StoresGet(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.StoresFind(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func StartServer(address string, model LLM) error {
|
||||
lis, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
var Aliases map[string]string = map[string]string{
|
||||
"go-llama": LLamaCPP,
|
||||
"llama": LLamaCPP,
|
||||
"embedded-store": LocalStoreBackend,
|
||||
}
|
||||
|
||||
const (
|
||||
@ -34,6 +35,8 @@ const (
|
||||
TinyDreamBackend = "tinydream"
|
||||
PiperBackend = "piper"
|
||||
LCHuggingFaceBackend = "langchain-huggingface"
|
||||
|
||||
LocalStoreBackend = "local-store"
|
||||
)
|
||||
|
||||
var AutoLoadBackends []string = []string{
|
||||
|
155
pkg/store/client.go
Normal file
155
pkg/store/client.go
Normal file
@ -0,0 +1,155 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
// Wrapper for the GRPC client so that simple use cases are handled without verbosity
|
||||
|
||||
// SetCols sets multiple key-value pairs in the store
|
||||
// It's in columnar format so that keys[i] is associated with values[i]
|
||||
func SetCols(ctx context.Context, c grpc.Backend, keys [][]float32, values [][]byte) error {
|
||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
||||
for i, k := range keys {
|
||||
protoKeys[i] = &proto.StoresKey{
|
||||
Floats: k,
|
||||
}
|
||||
}
|
||||
protoValues := make([]*proto.StoresValue, len(values))
|
||||
for i, v := range values {
|
||||
protoValues[i] = &proto.StoresValue{
|
||||
Bytes: v,
|
||||
}
|
||||
}
|
||||
setOpts := &proto.StoresSetOptions{
|
||||
Keys: protoKeys,
|
||||
Values: protoValues,
|
||||
}
|
||||
|
||||
res, err := c.StoresSet(ctx, setOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Success {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to set keys: %v", res.Message)
|
||||
}
|
||||
|
||||
// SetSingle sets a single key-value pair in the store
|
||||
// Don't call this in a tight loop, instead use SetCols
|
||||
func SetSingle(ctx context.Context, c grpc.Backend, key []float32, value []byte) error {
|
||||
return SetCols(ctx, c, [][]float32{key}, [][]byte{value})
|
||||
}
|
||||
|
||||
// DeleteCols deletes multiple key-value pairs from the store
|
||||
// It's in columnar format so that keys[i] is associated with values[i]
|
||||
func DeleteCols(ctx context.Context, c grpc.Backend, keys [][]float32) error {
|
||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
||||
for i, k := range keys {
|
||||
protoKeys[i] = &proto.StoresKey{
|
||||
Floats: k,
|
||||
}
|
||||
}
|
||||
deleteOpts := &proto.StoresDeleteOptions{
|
||||
Keys: protoKeys,
|
||||
}
|
||||
|
||||
res, err := c.StoresDelete(ctx, deleteOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Success {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to delete keys: %v", res.Message)
|
||||
}
|
||||
|
||||
// DeleteSingle deletes a single key-value pair from the store
|
||||
// Don't call this in a tight loop, instead use DeleteCols
|
||||
func DeleteSingle(ctx context.Context, c grpc.Backend, key []float32) error {
|
||||
return DeleteCols(ctx, c, [][]float32{key})
|
||||
}
|
||||
|
||||
// GetCols gets multiple key-value pairs from the store
|
||||
// It's in columnar format so that keys[i] is associated with values[i]
|
||||
// Be warned the keys are sorted and will be returned in a different order than they were input
|
||||
// There is no guarantee as to how the keys are sorted
|
||||
func GetCols(ctx context.Context, c grpc.Backend, keys [][]float32) ([][]float32, [][]byte, error) {
|
||||
protoKeys := make([]*proto.StoresKey, len(keys))
|
||||
for i, k := range keys {
|
||||
protoKeys[i] = &proto.StoresKey{
|
||||
Floats: k,
|
||||
}
|
||||
}
|
||||
getOpts := &proto.StoresGetOptions{
|
||||
Keys: protoKeys,
|
||||
}
|
||||
|
||||
res, err := c.StoresGet(ctx, getOpts)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ks := make([][]float32, len(res.Keys))
|
||||
for i, k := range res.Keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
vs := make([][]byte, len(res.Values))
|
||||
for i, v := range res.Values {
|
||||
vs[i] = v.Bytes
|
||||
}
|
||||
|
||||
return ks, vs, nil
|
||||
}
|
||||
|
||||
// GetSingle gets a single key-value pair from the store
|
||||
// Don't call this in a tight loop, instead use GetCols
|
||||
func GetSingle(ctx context.Context, c grpc.Backend, key []float32) ([]byte, error) {
|
||||
_, values, err := GetCols(ctx, c, [][]float32{key})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
return values[0], nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to get key")
|
||||
}
|
||||
|
||||
// Find similar keys to the given key. Returns the keys, values, and similarities
|
||||
func Find(ctx context.Context, c grpc.Backend, key []float32, topk int) ([][]float32, [][]byte, []float32, error) {
|
||||
findOpts := &proto.StoresFindOptions{
|
||||
Key: &proto.StoresKey{
|
||||
Floats: key,
|
||||
},
|
||||
TopK: int32(topk),
|
||||
}
|
||||
|
||||
res, err := c.StoresFind(ctx, findOpts)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
ks := make([][]float32, len(res.Keys))
|
||||
vs := make([][]byte, len(res.Values))
|
||||
|
||||
for i, k := range res.Keys {
|
||||
ks[i] = k.Floats
|
||||
}
|
||||
|
||||
for i, v := range res.Values {
|
||||
vs[i] = v.Bytes
|
||||
}
|
||||
|
||||
return ks, vs, res.Similarities, nil
|
||||
}
|
17
tests/integration/integration_suite_test.go
Normal file
17
tests/integration/integration_suite_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func TestLocalAI(t *testing.T) {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "LocalAI test suite")
|
||||
}
|
228
tests/integration/stores_test.go
Normal file
228
tests/integration/stores_test.go
Normal file
@ -0,0 +1,228 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
//go:embed backend-assets/*
|
||||
var backendAssets embed.FS
|
||||
|
||||
var _ = Describe("Integration tests for the stores backend(s) and internal APIs", Label("stores"), func() {
|
||||
Context("Embedded Store get,set and delete", func() {
|
||||
var sl *model.ModelLoader
|
||||
var sc grpc.Backend
|
||||
var tmpdir string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
|
||||
err = os.Mkdir(backendAssetsDir, 0755)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = assets.ExtractFiles(backendAssets, backendAssetsDir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
debug := true
|
||||
|
||||
bc := config.BackendConfig{
|
||||
Name: "store test",
|
||||
Debug: &debug,
|
||||
Backend: model.LocalStoreBackend,
|
||||
}
|
||||
|
||||
storeOpts := []model.Option{
|
||||
model.WithBackendString(bc.Backend),
|
||||
model.WithAssetDir(backendAssetsDir),
|
||||
model.WithModel("test"),
|
||||
}
|
||||
|
||||
sl = model.NewModelLoader("")
|
||||
sc, err = sl.BackendLoader(storeOpts...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sc).ToNot(BeNil())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
sl.StopAllGRPC()
|
||||
err := os.RemoveAll(tmpdir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should be able to set a key", func() {
|
||||
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should be able to set keys", func() {
|
||||
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}}, [][]byte{[]byte("test1"), []byte("test2")})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = store.SetCols(context.Background(), sc, [][]float32{{0.7, 0.8, 0.9}, {0.10, 0.11, 0.12}}, [][]byte{[]byte("test3"), []byte("test4")})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should be able to get a key", func() {
|
||||
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
val, err := store.GetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(val).To(Equal([]byte("test")))
|
||||
})
|
||||
|
||||
It("should be able to get keys", func() {
|
||||
//set 3 entries
|
||||
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
//get 3 entries
|
||||
keys, vals, err := store.GetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(keys).To(HaveLen(3))
|
||||
Expect(vals).To(HaveLen(3))
|
||||
for i, k := range keys {
|
||||
v := vals[i]
|
||||
|
||||
if k[0] == 0.1 && k[1] == 0.2 && k[2] == 0.3 {
|
||||
Expect(v).To(Equal([]byte("test1")))
|
||||
} else if k[0] == 0.4 && k[1] == 0.5 && k[2] == 0.6 {
|
||||
Expect(v).To(Equal([]byte("test2")))
|
||||
} else {
|
||||
Expect(k).To(Equal([]float32{0.7, 0.8, 0.9}))
|
||||
Expect(v).To(Equal([]byte("test3")))
|
||||
}
|
||||
}
|
||||
|
||||
//get 2 entries
|
||||
keys, vals, err = store.GetCols(context.Background(), sc, [][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(keys).To(HaveLen(2))
|
||||
Expect(vals).To(HaveLen(2))
|
||||
for i, k := range keys {
|
||||
v := vals[i]
|
||||
|
||||
if k[0] == 0.1 && k[1] == 0.2 && k[2] == 0.3 {
|
||||
Expect(v).To(Equal([]byte("test1")))
|
||||
} else {
|
||||
Expect(k).To(Equal([]float32{0.7, 0.8, 0.9}))
|
||||
Expect(v).To(Equal([]byte("test3")))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("should be able to delete a key", func() {
|
||||
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = store.DeleteSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
val, _ := store.GetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
|
||||
Expect(val).To(BeNil())
|
||||
})
|
||||
|
||||
It("should be able to delete keys", func() {
|
||||
//set 3 entries
|
||||
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
//delete 2 entries
|
||||
err = store.DeleteCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.7, 0.8, 0.9}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
//get 1 entry
|
||||
keys, vals, err := store.GetCols(context.Background(), sc, [][]float32{{0.4, 0.5, 0.6}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(keys).To(HaveLen(1))
|
||||
Expect(vals).To(HaveLen(1))
|
||||
Expect(keys[0]).To(Equal([]float32{0.4, 0.5, 0.6}))
|
||||
Expect(vals[0]).To(Equal([]byte("test2")))
|
||||
|
||||
//get deleted entries
|
||||
keys, vals, err = store.GetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.7, 0.8, 0.9}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(keys).To(HaveLen(0))
|
||||
Expect(vals).To(HaveLen(0))
|
||||
})
|
||||
|
||||
It("should be able to find smilar keys", func() {
|
||||
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
|
||||
err := store.SetCols(context.Background(), sc, [][]float32{{0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// find similar keys
|
||||
keys, vals, sims, err := store.Find(context.Background(), sc, []float32{0.1, 0.3, 0.5}, 2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(keys).To(HaveLen(2))
|
||||
Expect(vals).To(HaveLen(2))
|
||||
Expect(sims).To(HaveLen(2))
|
||||
|
||||
for i, k := range keys {
|
||||
s := sims[i]
|
||||
log.Debug().Float32("similarity", s).Msgf("key: %v", k)
|
||||
}
|
||||
|
||||
Expect(keys[0]).To(Equal([]float32{0.5, 0.5, 0.5}))
|
||||
Expect(vals[0]).To(Equal([]byte("test1")))
|
||||
Expect(keys[1]).To(Equal([]float32{0.6, 0.6, -0.6}))
|
||||
})
|
||||
|
||||
It("should be able to find similar normalized keys", func() {
|
||||
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
|
||||
keys := [][]float32{{0.1, 0.3, 0.5}, {0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}
|
||||
vals := [][]byte{[]byte("test0"), []byte("test1"), []byte("test2"), []byte("test3")}
|
||||
// normalize the keys
|
||||
for i, k := range keys {
|
||||
norm := float64(0)
|
||||
for _, x := range k {
|
||||
norm += float64(x * x)
|
||||
}
|
||||
norm = math.Sqrt(norm)
|
||||
for j, x := range k {
|
||||
keys[i][j] = x / float32(norm)
|
||||
}
|
||||
}
|
||||
|
||||
err := store.SetCols(context.Background(), sc, keys, vals)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// find similar keys
|
||||
ks, vals, sims, err := store.Find(context.Background(), sc, keys[0], 3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ks).To(HaveLen(3))
|
||||
Expect(vals).To(HaveLen(3))
|
||||
Expect(sims).To(HaveLen(3))
|
||||
|
||||
for i, k := range ks {
|
||||
s := sims[i]
|
||||
log.Debug().Float32("similarity", s).Msgf("key: %v", k)
|
||||
}
|
||||
|
||||
Expect(ks[0]).To(Equal(keys[0]))
|
||||
Expect(vals[0]).To(Equal(vals[0]))
|
||||
Expect(sims[0]).To(BeNumerically("~", 1, 0.0001))
|
||||
Expect(ks[1]).To(Equal(keys[1]))
|
||||
Expect(vals[1]).To(Equal(vals[1]))
|
||||
})
|
||||
})
|
||||
})
|
Loading…
Reference in New Issue
Block a user