From e8eb0b2c50a7653c9d8dc3e2388eb4074705b4b7 Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Wed, 22 Jan 2025 18:35:05 +0000 Subject: [PATCH] fix(stores): Stores fixes and testing (#4663) * fix(stores): Actually check a vector is a unit vector/normalized Instead of just summing the components to see if they equal 1.0, take the actual magnitude/p-norm of the vector and check that is approximately 1.0. Note that this shouldn't change the order of results except in edge cases if I am too lax with the precision of the equality comparison. However it should improve performance for normalized vectors which were being misclassified. Signed-off-by: Richard Palethorpe * fix(stores): Add tests for known results and triangle inequality This adds some more tests to check the cosine similarity function has some expected mathematical properties. Signed-off-by: Richard Palethorpe --------- Signed-off-by: Richard Palethorpe --- backend/go/stores/store.go | 14 +-- tests/integration/stores_test.go | 143 ++++++++++++++++++++++++++++--- 2 files changed, 141 insertions(+), 16 deletions(-) diff --git a/backend/go/stores/store.go b/backend/go/stores/store.go index a4849b57..c8788a9c 100644 --- a/backend/go/stores/store.go +++ b/backend/go/stores/store.go @@ -311,12 +311,16 @@ func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) } func isNormalized(k []float32) bool { - var sum float32 + var sum float64 + for _, v := range k { - sum += v + v64 := float64(v) + sum += v64*v64 } - return sum == 1.0 + s := math.Sqrt(sum) + + return s >= 0.99 && s <= 1.01 } // TODO: This we could replace with handwritten SIMD code @@ -328,7 +332,7 @@ func normalizedCosineSimilarity(k1, k2 []float32) float32 { dot += k1[i] * k2[i] } - assert(dot >= -1 && dot <= 1, fmt.Sprintf("dot = %f", dot)) + assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot)) // 2.0 * (1.0 - dot) would be the Euclidean distance return dot @@ -418,7 +422,7 @@ func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 { sim := float32(dot / (mag1 * math.Sqrt(mag2))) - assert(sim >= -1 && sim <= 1, fmt.Sprintf("sim = %f", sim)) + assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim)) return sim } diff --git a/tests/integration/stores_test.go b/tests/integration/stores_test.go index 5ed46b19..9612bec0 100644 --- a/tests/integration/stores_test.go +++ b/tests/integration/stores_test.go @@ -4,6 +4,7 @@ import ( "context" "embed" "math" + "math/rand" "os" "path/filepath" @@ -22,6 +23,19 @@ import ( //go:embed backend-assets/* var backendAssets embed.FS +func normalize(vecs [][]float32) { + for i, k := range vecs { + norm := float64(0) + for _, x := range k { + norm += float64(x * x) + } + norm = math.Sqrt(norm) + for j, x := range k { + vecs[i][j] = x / float32(norm) + } + } +} + var _ = Describe("Integration tests for the stores backend(s) and internal APIs", Label("stores"), func() { Context("Embedded Store get,set and delete", func() { var sl *model.ModelLoader @@ -192,17 +206,8 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" // set 3 vectors that are at varying angles to {0.5, 0.5, 0.5} keys := [][]float32{{0.1, 0.3, 0.5}, {0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}} vals := [][]byte{[]byte("test0"), []byte("test1"), []byte("test2"), []byte("test3")} - // normalize the keys - for i, k := range keys { - norm := float64(0) - for _, x := range k { - norm += float64(x * x) - } - norm = math.Sqrt(norm) - for j, x := range k { - keys[i][j] = x / float32(norm) - } - } + + normalize(keys) err := store.SetCols(context.Background(), sc, keys, vals) Expect(err).ToNot(HaveOccurred()) @@ -225,5 +230,121 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" Expect(ks[1]).To(Equal(keys[1])) Expect(vals[1]).To(Equal(vals[1])) }) + + It("It produces the correct cosine similarities for orthogonal and opposite unit vectors", func() { + keys := [][]float32{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}, {-1.0, 0.0, 0.0}} + vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")} + + err := store.SetCols(context.Background(), sc, keys, vals); + Expect(err).ToNot(HaveOccurred()) + + _, _, sims, err := store.Find(context.Background(), sc, keys[0], 4) + Expect(err).ToNot(HaveOccurred()) + Expect(sims).To(Equal([]float32{1.0, 0.0, 0.0, -1.0})) + }) + + It("It produces the correct cosine similarities for orthogonal and opposite vectors", func() { + keys := [][]float32{{1.0, 0.0, 1.0}, {0.0, 2.0, 0.0}, {0.0, 0.0, -1.0}, {-1.0, 0.0, -1.0}} + vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")} + + err := store.SetCols(context.Background(), sc, keys, vals); + Expect(err).ToNot(HaveOccurred()) + + _, _, sims, err := store.Find(context.Background(), sc, keys[0], 4) + Expect(err).ToNot(HaveOccurred()) + Expect(sims[0]).To(BeNumerically("~", 1, 0.1)) + Expect(sims[1]).To(BeNumerically("~", 0, 0.1)) + Expect(sims[2]).To(BeNumerically("~", -0.7, 0.1)) + Expect(sims[3]).To(BeNumerically("~", -1, 0.1)) + }) + + expectTriangleEq := func(keys [][]float32, vals [][]byte) { + sims := map[string]map[string]float32{} + + // compare every key vector pair and store the similarities in a lookup table + // that uses the values as keys + for i, k := range keys { + _, valsk, simsk, err := store.Find(context.Background(), sc, k, 9) + Expect(err).ToNot(HaveOccurred()) + + for j, v := range valsk { + p := string(vals[i]) + q := string(v) + + if sims[p] == nil { + sims[p] = map[string]float32{} + } + + //log.Debug().Strs("vals", []string{p, q}).Float32("similarity", simsk[j]).Send() + + sims[p][q] = simsk[j] + } + } + + // Check that the triangle inequality holds for every combination of the triplet + // u, v and w + for _, simsu := range sims { + for w, simw := range simsu { + // acos(u,w) <= ... + uws := math.Acos(float64(simw)) + + // ... acos(u,v) + acos(v,w) + for v, _ := range simsu { + uvws := math.Acos(float64(simsu[v])) + math.Acos(float64(sims[v][w])) + + //log.Debug().Str("u", u).Str("v", v).Str("w", w).Send() + //log.Debug().Float32("uw", simw).Float32("uv", simsu[v]).Float32("vw", sims[v][w]).Send() + Expect(uws).To(BeNumerically("<=", uvws)) + } + } + } + } + + It("It obeys the triangle inequality for normalized values", func() { + keys := [][]float32{ + {1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}, + {-1.0, 0.0, 0.0}, {0.0, -1.0, 0.0}, {0.0, 0.0, -1.0}, + {2.0, 3.0, 4.0}, {9.0, 7.0, 1.0}, {0.0, -1.2, 2.3}, + } + vals := [][]byte{ + []byte("x"), []byte("y"), []byte("z"), + []byte("-x"), []byte("-y"), []byte("-z"), + []byte("u"), []byte("v"), []byte("w"), + } + + normalize(keys[6:]) + + err := store.SetCols(context.Background(), sc, keys, vals); + Expect(err).ToNot(HaveOccurred()) + + expectTriangleEq(keys, vals) + }) + + It("It obeys the triangle inequality", func() { + rnd := rand.New(rand.NewSource(151)) + keys := make([][]float32, 20) + vals := make([][]byte, 20) + + for i := range keys { + k := make([]float32, 768) + + for j := range k { + k[j] = rnd.Float32() + } + + keys[i] = k + } + + c := byte('a') + for i := range vals { + vals[i] = []byte{c} + c += 1 + } + + err := store.SetCols(context.Background(), sc, keys, vals); + Expect(err).ToNot(HaveOccurred()) + + expectTriangleEq(keys, vals) + }) }) })