mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-17 18:30:07 +00:00
229 lines
7.2 KiB
Go
229 lines
7.2 KiB
Go
|
package integration_test
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"embed"
|
||
|
"math"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
|
||
|
. "github.com/onsi/ginkgo/v2"
|
||
|
. "github.com/onsi/gomega"
|
||
|
"github.com/rs/zerolog"
|
||
|
"github.com/rs/zerolog/log"
|
||
|
|
||
|
"github.com/go-skynet/LocalAI/core/config"
|
||
|
"github.com/go-skynet/LocalAI/pkg/assets"
|
||
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||
|
"github.com/go-skynet/LocalAI/pkg/store"
|
||
|
)
|
||
|
|
||
|
//go:embed backend-assets/*
|
||
|
var backendAssets embed.FS
|
||
|
|
||
|
var _ = Describe("Integration tests for the stores backend(s) and internal APIs", Label("stores"), func() {
|
||
|
Context("Embedded Store get,set and delete", func() {
|
||
|
var sl *model.ModelLoader
|
||
|
var sc grpc.Backend
|
||
|
var tmpdir string
|
||
|
|
||
|
BeforeEach(func() {
|
||
|
var err error
|
||
|
|
||
|
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||
|
|
||
|
tmpdir, err = os.MkdirTemp("", "")
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
|
||
|
err = os.Mkdir(backendAssetsDir, 0755)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
err = assets.ExtractFiles(backendAssets, backendAssetsDir)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
debug := true
|
||
|
|
||
|
bc := config.BackendConfig{
|
||
|
Name: "store test",
|
||
|
Debug: &debug,
|
||
|
Backend: model.LocalStoreBackend,
|
||
|
}
|
||
|
|
||
|
storeOpts := []model.Option{
|
||
|
model.WithBackendString(bc.Backend),
|
||
|
model.WithAssetDir(backendAssetsDir),
|
||
|
model.WithModel("test"),
|
||
|
}
|
||
|
|
||
|
sl = model.NewModelLoader("")
|
||
|
sc, err = sl.BackendLoader(storeOpts...)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(sc).ToNot(BeNil())
|
||
|
})
|
||
|
|
||
|
AfterEach(func() {
|
||
|
sl.StopAllGRPC()
|
||
|
err := os.RemoveAll(tmpdir)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
})
|
||
|
|
||
|
It("should be able to set a key", func() {
|
||
|
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
})
|
||
|
|
||
|
It("should be able to set keys", func() {
|
||
|
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}}, [][]byte{[]byte("test1"), []byte("test2")})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
err = store.SetCols(context.Background(), sc, [][]float32{{0.7, 0.8, 0.9}, {0.10, 0.11, 0.12}}, [][]byte{[]byte("test3"), []byte("test4")})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
})
|
||
|
|
||
|
It("should be able to get a key", func() {
|
||
|
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
val, err := store.GetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(val).To(Equal([]byte("test")))
|
||
|
})
|
||
|
|
||
|
It("should be able to get keys", func() {
|
||
|
//set 3 entries
|
||
|
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
//get 3 entries
|
||
|
keys, vals, err := store.GetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(keys).To(HaveLen(3))
|
||
|
Expect(vals).To(HaveLen(3))
|
||
|
for i, k := range keys {
|
||
|
v := vals[i]
|
||
|
|
||
|
if k[0] == 0.1 && k[1] == 0.2 && k[2] == 0.3 {
|
||
|
Expect(v).To(Equal([]byte("test1")))
|
||
|
} else if k[0] == 0.4 && k[1] == 0.5 && k[2] == 0.6 {
|
||
|
Expect(v).To(Equal([]byte("test2")))
|
||
|
} else {
|
||
|
Expect(k).To(Equal([]float32{0.7, 0.8, 0.9}))
|
||
|
Expect(v).To(Equal([]byte("test3")))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//get 2 entries
|
||
|
keys, vals, err = store.GetCols(context.Background(), sc, [][]float32{{0.7, 0.8, 0.9}, {0.1, 0.2, 0.3}})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(keys).To(HaveLen(2))
|
||
|
Expect(vals).To(HaveLen(2))
|
||
|
for i, k := range keys {
|
||
|
v := vals[i]
|
||
|
|
||
|
if k[0] == 0.1 && k[1] == 0.2 && k[2] == 0.3 {
|
||
|
Expect(v).To(Equal([]byte("test1")))
|
||
|
} else {
|
||
|
Expect(k).To(Equal([]float32{0.7, 0.8, 0.9}))
|
||
|
Expect(v).To(Equal([]byte("test3")))
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
|
||
|
It("should be able to delete a key", func() {
|
||
|
err := store.SetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3}, []byte("test"))
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
err = store.DeleteSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
val, _ := store.GetSingle(context.Background(), sc, []float32{0.1, 0.2, 0.3})
|
||
|
Expect(val).To(BeNil())
|
||
|
})
|
||
|
|
||
|
It("should be able to delete keys", func() {
|
||
|
//set 3 entries
|
||
|
err := store.SetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
//delete 2 entries
|
||
|
err = store.DeleteCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.7, 0.8, 0.9}})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
//get 1 entry
|
||
|
keys, vals, err := store.GetCols(context.Background(), sc, [][]float32{{0.4, 0.5, 0.6}})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(keys).To(HaveLen(1))
|
||
|
Expect(vals).To(HaveLen(1))
|
||
|
Expect(keys[0]).To(Equal([]float32{0.4, 0.5, 0.6}))
|
||
|
Expect(vals[0]).To(Equal([]byte("test2")))
|
||
|
|
||
|
//get deleted entries
|
||
|
keys, vals, err = store.GetCols(context.Background(), sc, [][]float32{{0.1, 0.2, 0.3}, {0.7, 0.8, 0.9}})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(keys).To(HaveLen(0))
|
||
|
Expect(vals).To(HaveLen(0))
|
||
|
})
|
||
|
|
||
|
It("should be able to find smilar keys", func() {
|
||
|
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
|
||
|
err := store.SetCols(context.Background(), sc, [][]float32{{0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}, [][]byte{[]byte("test1"), []byte("test2"), []byte("test3")})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
// find similar keys
|
||
|
keys, vals, sims, err := store.Find(context.Background(), sc, []float32{0.1, 0.3, 0.5}, 2)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(keys).To(HaveLen(2))
|
||
|
Expect(vals).To(HaveLen(2))
|
||
|
Expect(sims).To(HaveLen(2))
|
||
|
|
||
|
for i, k := range keys {
|
||
|
s := sims[i]
|
||
|
log.Debug().Float32("similarity", s).Msgf("key: %v", k)
|
||
|
}
|
||
|
|
||
|
Expect(keys[0]).To(Equal([]float32{0.5, 0.5, 0.5}))
|
||
|
Expect(vals[0]).To(Equal([]byte("test1")))
|
||
|
Expect(keys[1]).To(Equal([]float32{0.6, 0.6, -0.6}))
|
||
|
})
|
||
|
|
||
|
It("should be able to find similar normalized keys", func() {
|
||
|
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
|
||
|
keys := [][]float32{{0.1, 0.3, 0.5}, {0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}
|
||
|
vals := [][]byte{[]byte("test0"), []byte("test1"), []byte("test2"), []byte("test3")}
|
||
|
// normalize the keys
|
||
|
for i, k := range keys {
|
||
|
norm := float64(0)
|
||
|
for _, x := range k {
|
||
|
norm += float64(x * x)
|
||
|
}
|
||
|
norm = math.Sqrt(norm)
|
||
|
for j, x := range k {
|
||
|
keys[i][j] = x / float32(norm)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
err := store.SetCols(context.Background(), sc, keys, vals)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
// find similar keys
|
||
|
ks, vals, sims, err := store.Find(context.Background(), sc, keys[0], 3)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(ks).To(HaveLen(3))
|
||
|
Expect(vals).To(HaveLen(3))
|
||
|
Expect(sims).To(HaveLen(3))
|
||
|
|
||
|
for i, k := range ks {
|
||
|
s := sims[i]
|
||
|
log.Debug().Float32("similarity", s).Msgf("key: %v", k)
|
||
|
}
|
||
|
|
||
|
Expect(ks[0]).To(Equal(keys[0]))
|
||
|
Expect(vals[0]).To(Equal(vals[0]))
|
||
|
Expect(sims[0]).To(BeNumerically("~", 1, 0.0001))
|
||
|
Expect(ks[1]).To(Equal(keys[1]))
|
||
|
Expect(vals[1]).To(Equal(vals[1]))
|
||
|
})
|
||
|
})
|
||
|
})
|