mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-02 02:46:41 +00:00
5866fc8ded
Signed-off-by: Sertac Ozercan <sozercan@gmail.com>
300 lines
9.0 KiB
Go
300 lines
9.0 KiB
Go
package e2e_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
"github.com/sashabaranov/go-openai"
|
|
"github.com/sashabaranov/go-openai/jsonschema"
|
|
)
|
|
|
|
var _ = Describe("E2E test", func() {
|
|
Context("Generating", func() {
|
|
BeforeEach(func() {
|
|
//
|
|
})
|
|
|
|
// Check that the GPU was used
|
|
AfterEach(func() {
|
|
//
|
|
})
|
|
|
|
Context("text", func() {
|
|
It("correctly", func() {
|
|
model := "gpt-4"
|
|
resp, err := client.CreateChatCompletion(context.TODO(),
|
|
openai.ChatCompletionRequest{
|
|
Model: model, Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: "user",
|
|
Content: "How much is 2+2?",
|
|
},
|
|
}})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
|
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")), fmt.Sprint(resp.Choices[0].Message.Content))
|
|
})
|
|
})
|
|
|
|
Context("function calls", func() {
|
|
It("correctly invoke", func() {
|
|
params := jsonschema.Definition{
|
|
Type: jsonschema.Object,
|
|
Properties: map[string]jsonschema.Definition{
|
|
"location": {
|
|
Type: jsonschema.String,
|
|
Description: "The city and state, e.g. San Francisco, CA",
|
|
},
|
|
"unit": {
|
|
Type: jsonschema.String,
|
|
Enum: []string{"celsius", "fahrenheit"},
|
|
},
|
|
},
|
|
Required: []string{"location"},
|
|
}
|
|
|
|
f := openai.FunctionDefinition{
|
|
Name: "get_current_weather",
|
|
Description: "Get the current weather in a given location",
|
|
Parameters: params,
|
|
}
|
|
t := openai.Tool{
|
|
Type: openai.ToolTypeFunction,
|
|
Function: &f,
|
|
}
|
|
|
|
dialogue := []openai.ChatCompletionMessage{
|
|
{Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"},
|
|
}
|
|
resp, err := client.CreateChatCompletion(context.TODO(),
|
|
openai.ChatCompletionRequest{
|
|
Model: openai.GPT4,
|
|
Messages: dialogue,
|
|
Tools: []openai.Tool{t},
|
|
},
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
|
|
|
msg := resp.Choices[0].Message
|
|
Expect(len(msg.ToolCalls)).To(Equal(1), fmt.Sprint(msg.ToolCalls))
|
|
Expect(msg.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), fmt.Sprint(msg.ToolCalls[0].Function.Name))
|
|
Expect(msg.ToolCalls[0].Function.Arguments).To(ContainSubstring("Boston"), fmt.Sprint(msg.ToolCalls[0].Function.Arguments))
|
|
})
|
|
})
|
|
Context("json", func() {
|
|
It("correctly", func() {
|
|
model := "gpt-4"
|
|
|
|
req := openai.ChatCompletionRequest{
|
|
ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject},
|
|
Model: model,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
|
|
Role: "user",
|
|
Content: "An animal with 'name', 'gender' and 'legs' fields",
|
|
},
|
|
},
|
|
}
|
|
|
|
resp, err := client.CreateChatCompletion(context.TODO(), req)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
|
|
|
var i map[string]interface{}
|
|
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &i)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(i).To(HaveKey("name"))
|
|
Expect(i).To(HaveKey("gender"))
|
|
Expect(i).To(HaveKey("legs"))
|
|
})
|
|
})
|
|
|
|
Context("images", func() {
|
|
It("correctly", func() {
|
|
resp, err := client.CreateImage(context.TODO(),
|
|
openai.ImageRequest{
|
|
Prompt: "test",
|
|
Size: openai.CreateImageSize512x512,
|
|
},
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
|
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
|
})
|
|
It("correctly changes the response format to url", func() {
|
|
resp, err := client.CreateImage(context.TODO(),
|
|
openai.ImageRequest{
|
|
Prompt: "test",
|
|
Size: openai.CreateImageSize512x512,
|
|
ResponseFormat: openai.CreateImageResponseFormatURL,
|
|
},
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
|
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
|
})
|
|
It("correctly changes the response format to base64", func() {
|
|
resp, err := client.CreateImage(context.TODO(),
|
|
openai.ImageRequest{
|
|
Prompt: "test",
|
|
Size: openai.CreateImageSize512x512,
|
|
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
|
|
},
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
|
Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON))
|
|
})
|
|
})
|
|
Context("embeddings", func() {
|
|
It("correctly", func() {
|
|
resp, err := client.CreateEmbeddings(context.TODO(),
|
|
openai.EmbeddingRequestStrings{
|
|
Input: []string{"doc"},
|
|
Model: openai.AdaEmbeddingV2,
|
|
},
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
|
Expect(resp.Data[0].Embedding).ToNot(BeEmpty())
|
|
})
|
|
})
|
|
Context("vision", func() {
|
|
It("correctly", func() {
|
|
model := "gpt-4-vision-preview"
|
|
resp, err := client.CreateChatCompletion(context.TODO(),
|
|
openai.ChatCompletionRequest{
|
|
Model: model, Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
|
|
Role: "user",
|
|
MultiContent: []openai.ChatMessagePart{
|
|
{
|
|
Type: openai.ChatMessagePartTypeText,
|
|
Text: "What is in the image?",
|
|
},
|
|
{
|
|
Type: openai.ChatMessagePartTypeImageURL,
|
|
ImageURL: &openai.ChatMessageImageURL{
|
|
URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
|
Detail: openai.ImageURLDetailLow,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
|
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("wooden"), ContainSubstring("grass")), fmt.Sprint(resp.Choices[0].Message.Content))
|
|
})
|
|
})
|
|
Context("text to audio", func() {
|
|
It("correctly", func() {
|
|
res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{
|
|
Model: openai.TTSModel1,
|
|
Input: "Hello!",
|
|
Voice: openai.VoiceAlloy,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer res.Close()
|
|
|
|
_, err = io.ReadAll(res)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
})
|
|
})
|
|
Context("audio to text", func() {
|
|
It("correctly", func() {
|
|
|
|
downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
|
|
file, err := downloadHttpFile(downloadURL)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
req := openai.AudioRequest{
|
|
Model: openai.Whisper1,
|
|
FilePath: file,
|
|
}
|
|
resp, err := client.CreateTranscription(context.Background(), req)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
|
|
})
|
|
})
|
|
|
|
Context("reranker", func() {
|
|
It("correctly", func() {
|
|
modelName := "jina-reranker-v1-base-en"
|
|
|
|
req := schema.JINARerankRequest{
|
|
Model: modelName,
|
|
Query: "Organic skincare products for sensitive skin",
|
|
Documents: []string{
|
|
"Eco-friendly kitchenware for modern homes",
|
|
"Biodegradable cleaning supplies for eco-conscious consumers",
|
|
"Organic cotton baby clothes for sensitive skin",
|
|
"Natural organic skincare range for sensitive skin",
|
|
"Tech gadgets for smart homes: 2024 edition",
|
|
"Sustainable gardening tools and compost solutions",
|
|
"Sensitive skin-friendly facial cleansers and toners",
|
|
"Organic food wraps and storage solutions",
|
|
"All-natural pet food for dogs with allergies",
|
|
"Yoga mats made from recycled materials",
|
|
},
|
|
TopN: 3,
|
|
}
|
|
|
|
serialized, err := json.Marshal(req)
|
|
Expect(err).To(BeNil())
|
|
Expect(serialized).ToNot(BeNil())
|
|
|
|
rerankerEndpoint := apiEndpoint + "/rerank"
|
|
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
|
|
Expect(err).To(BeNil())
|
|
Expect(resp).ToNot(BeNil())
|
|
Expect(resp.StatusCode).To(Equal(200))
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
Expect(err).To(BeNil())
|
|
Expect(body).ToNot(BeNil())
|
|
|
|
deserializedResponse := schema.JINARerankResponse{}
|
|
err = json.Unmarshal(body, &deserializedResponse)
|
|
Expect(err).To(BeNil())
|
|
Expect(deserializedResponse).ToNot(BeZero())
|
|
Expect(deserializedResponse.Model).To(Equal(modelName))
|
|
Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
|
|
})
|
|
})
|
|
})
|
|
})
|
|
|
|
func downloadHttpFile(url string) (string, error) {
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
tmpfile, err := os.CreateTemp("", "example")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer tmpfile.Close()
|
|
|
|
_, err = io.Copy(tmpfile, resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return tmpfile.Name(), nil
|
|
}
|