test: e2e /reranker endpoint (#2211)

Create a simple e2e test for the /reranker api \\ go mod tidy

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-06-07 14:45:52 -04:00 committed by GitHub
parent 3b7a78adda
commit 219078a5e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 2 deletions

View File

@ -178,7 +178,7 @@ jobs:
submodules: true submodules: true
- name: Build images - name: Build images
run: | run: |
docker build --build-arg FFMPEG=true --build-arg IMAGE_TYPE=core --build-arg MAKEFLAGS="--jobs=5 --output-sync=target" -t local-ai:tests -f Dockerfile . docker build --build-arg FFMPEG=true --build-arg IMAGE_TYPE=extras --build-arg EXTRA_BACKENDS=rerankers --build-arg MAKEFLAGS="--jobs=5 --output-sync=target" -t local-ai:tests -f Dockerfile .
BASE_IMAGE=local-ai:tests DOCKER_AIO_IMAGE=local-ai-aio:test make docker-aio BASE_IMAGE=local-ai:tests DOCKER_AIO_IMAGE=local-ai-aio:test make docker-aio
- name: Test - name: Test
run: | run: |

View File

@ -40,7 +40,8 @@ var _ = BeforeSuite(func() {
if apiEndpoint == "" { if apiEndpoint == "" {
startDockerImage() startDockerImage()
defaultConfig = openai.DefaultConfig(apiKey) defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://localhost:" + apiPort + "/v1" apiEndpoint = "http://localhost:" + apiPort + "/v1" // So that other tests can reference this value safely.
defaultConfig.BaseURL = apiEndpoint
} else { } else {
fmt.Println("Default ", apiEndpoint) fmt.Println("Default ", apiEndpoint)
defaultConfig = openai.DefaultConfig(apiKey) defaultConfig = openai.DefaultConfig(apiKey)

View File

@ -1,6 +1,7 @@
package e2e_test package e2e_test
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -8,6 +9,7 @@ import (
"net/http" "net/http"
"os" "os"
"github.com/go-skynet/LocalAI/core/schema"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
@ -227,6 +229,51 @@ var _ = Describe("E2E test", func() {
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text)) Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
}) })
}) })
Context("reranker", func() {
It("correctly", func() {
modelName := "jina-reranker-v1-base-en"
req := schema.JINARerankRequest{
Model: modelName,
Query: "Organic skincare products for sensitive skin",
Documents: []string{
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"All-natural pet food for dogs with allergies",
"Yoga mats made from recycled materials",
},
TopN: 3,
}
serialized, err := json.Marshal(req)
Expect(err).To(BeNil())
Expect(serialized).ToNot(BeNil())
rerankerEndpoint := apiEndpoint + "/rerank"
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
Expect(err).To(BeNil())
Expect(resp).ToNot(BeNil())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(resp.Body)
Expect(err).To(BeNil())
Expect(body).ToNot(BeNil())
deserializedResponse := schema.JINARerankResponse{}
err = json.Unmarshal(body, &deserializedResponse)
Expect(err).To(BeNil())
Expect(deserializedResponse).ToNot(BeZero())
Expect(deserializedResponse.Model).To(Equal(modelName))
Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
})
})
}) })
}) })