test: preliminary tests and merge fix for authv2 (#3584)

* add api key to existing app tests, add preliminary auth test

Signed-off-by: Dave Lee <dave@gray101.com>

* small fix, run test

Signed-off-by: Dave Lee <dave@gray101.com>

* status on non-opaque

Signed-off-by: Dave Lee <dave@gray101.com>

* tweak auth error

Signed-off-by: Dave Lee <dave@gray101.com>

* exp

Signed-off-by: Dave Lee <dave@gray101.com>

* quick fix on real laptop

Signed-off-by: Dave Lee <dave@gray101.com>

* add downloader version that allows providing an auth header

Signed-off-by: Dave Lee <dave@gray101.com>

* stash some devcontainer fixes during testing

Signed-off-by: Dave Lee <dave@gray101.com>

* s2

Signed-off-by: Dave Lee <dave@gray101.com>

* s

Signed-off-by: Dave Lee <dave@gray101.com>

* done with experiment

Signed-off-by: Dave Lee <dave@gray101.com>

* done with experiment

Signed-off-by: Dave Lee <dave@gray101.com>

* after merge fix

Signed-off-by: Dave Lee <dave@gray101.com>

* rename and fix

Signed-off-by: Dave Lee <dave@gray101.com>

---------

Signed-off-by: Dave Lee <dave@gray101.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
Dave 2024-09-24 03:32:48 -04:00 committed by GitHub
parent 69d2902b0a
commit 90cacb9692
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 95 additions and 41 deletions

View File

@ -9,6 +9,7 @@
# Param 2: email # Param 2: email
# #
config_user() { config_user() {
echo "Configuring git for $1 <$2>"
local gcn=$(git config --global user.name) local gcn=$(git config --global user.name)
if [ -z "${gcn}" ]; then if [ -z "${gcn}" ]; then
echo "Setting up git user / remote" echo "Setting up git user / remote"
@ -24,6 +25,7 @@ config_user() {
# Param 2: remote url # Param 2: remote url
# #
config_remote() { config_remote() {
echo "Adding git remote and fetching $2 as $1"
local gr=$(git remote -v | grep $1) local gr=$(git remote -v | grep $1)
if [ -z "${gr}" ]; then if [ -z "${gr}" ]; then
git remote add $1 $2 git remote add $1 $2

View File

@ -338,9 +338,8 @@ RUN if [ "${FFMPEG}" = "true" ]; then \
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
ssh less && \ ssh less wget
apt-get clean && \ # For the devcontainer, leave apt functional in case additional devtools are needed at runtime.
rm -rf /var/lib/apt/lists/*
RUN go install github.com/go-delve/delve/cmd/dlv@latest RUN go install github.com/go-delve/delve/cmd/dlv@latest

View File

@ -359,6 +359,9 @@ clean-tests:
rm -rf test-dir rm -rf test-dir
rm -rf core/http/backend-assets rm -rf core/http/backend-assets
clean-dc: clean
cp -r /build/backend-assets /workspace/backend-assets
## Build: ## Build:
build: prepare backend-assets grpcs ## Build the project build: prepare backend-assets grpcs ## Build the project
$(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I local-ai build info:${RESET})

View File

@ -132,7 +132,7 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*Gal
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) { func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
var refFile string var refFile string
uri := downloader.URI(url) uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
refFile = string(d) refFile = string(d)
if len(refFile) == 0 { if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d) return fmt.Errorf("invalid reference file at url %s: %s", url, d)
@ -156,7 +156,7 @@ func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel,
} }
uri := downloader.URI(gallery.URL) uri := downloader.URI(gallery.URL)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models) return yaml.Unmarshal(d, &models)
}) })
if err != nil { if err != nil {

View File

@ -69,7 +69,7 @@ type PromptTemplate struct {
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) { func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
var config Config var config Config
uri := downloader.URI(url) uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config) return yaml.Unmarshal(d, &config)
}) })
if err != nil { if err != nil {

View File

@ -31,24 +31,6 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")
// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
return authHeader
}
// Embed a directory // Embed a directory
// //
//go:embed static/* //go:embed static/*

View File

@ -31,6 +31,9 @@ import (
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
const apiKey = "joshua"
const bearerKey = "Bearer " + apiKey
const testPrompt = `### System: const testPrompt = `### System:
You are an AI assistant that follows instruction extremely well. Help as much as you can. You are an AI assistant that follows instruction extremely well. Help as much as you can.
@ -50,11 +53,19 @@ type modelApplyRequest struct {
func getModelStatus(url string) (response map[string]interface{}) { func getModelStatus(url string) (response map[string]interface{}) {
// Create the HTTP request // Create the HTTP request
resp, err := http.Get(url) req, err := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
if err != nil { if err != nil {
fmt.Println("Error creating request:", err) fmt.Println("Error creating request:", err)
return return
} }
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Println("Error sending request:", err)
return
}
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
@ -72,14 +83,15 @@ func getModelStatus(url string) (response map[string]interface{}) {
return return
} }
func getModels(url string) (response []gallery.GalleryModel) { func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url) uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix? // TODO: No tests currently seem to exercise file:// urls. Fix?
uri.DownloadAndUnmarshal("", func(url string, i []byte) error { err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct // Unmarshal YAML data into a struct
return json.Unmarshal(i, &response) return json.Unmarshal(i, &response)
}) })
return return response, err
} }
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
@ -101,6 +113,7 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return return
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
@ -140,6 +153,7 @@ func postRequestJSON[B any](url string, bodyJson *B) error {
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
@ -175,6 +189,7 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
@ -195,6 +210,35 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
return json.Unmarshal(body, respJson) return json.Unmarshal(body, respJson)
} }
func postInvalidRequest(url string) (error, int) {
req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request"))
if err != nil {
return err, -1
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err, -1
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err, -1
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode
}
return nil, resp.StatusCode
}
//go:embed backend-assets/* //go:embed backend-assets/*
var backendAssets embed.FS var backendAssets embed.FS
@ -260,6 +304,7 @@ var _ = Describe("API test", func() {
config.WithContext(c), config.WithContext(c),
config.WithGalleries(galleries), config.WithGalleries(galleries),
config.WithModelPath(modelDir), config.WithModelPath(modelDir),
config.WithApiKeys([]string{apiKey}),
config.WithBackendAssets(backendAssets), config.WithBackendAssets(backendAssets),
config.WithBackendAssetsOutput(backendAssetsDir))...) config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -269,7 +314,7 @@ var _ = Describe("API test", func() {
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("") client2 = openaigo.NewClient("")
@ -295,10 +340,19 @@ var _ = Describe("API test", func() {
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
Context("Auth Tests", func() {
It("Should fail if the api key is missing", func() {
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
Expect(err).ToNot(BeNil())
Expect(sc).To(Equal(403))
})
})
Context("Applying models", func() { Context("Applying models", func() {
It("applies models from a gallery", func() { It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/available") models, err := getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
@ -331,7 +385,8 @@ var _ = Describe("API test", func() {
Expect(content["backend"]).To(Equal("bert-embeddings")) Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar")) Expect(content["foo"]).To(Equal("bar"))
models = getModels("http://127.0.0.1:9090/models/available") models, err = getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2"))) Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))

View File

@ -38,6 +38,7 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
if applicationConfig.OpaqueErrors { if applicationConfig.OpaqueErrors {
return ctx.SendStatus(403) return ctx.SendStatus(403)
} }
return ctx.Status(403).SendString(err.Error())
} }
if applicationConfig.OpaqueErrors { if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500) return ctx.SendStatus(500)

View File

@ -39,7 +39,7 @@ func init() {
func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) { func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) {
remoteLibrary := map[string]string{} remoteLibrary := map[string]string{}
uri := downloader.URI(url) uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(_ string, i []byte) error { err := uri.DownloadWithCallback(basePath, func(_ string, i []byte) error {
return yaml.Unmarshal(i, &remoteLibrary) return yaml.Unmarshal(i, &remoteLibrary)
}) })
if err != nil { if err != nil {

4
go.mod
View File

@ -1,8 +1,8 @@
module github.com/mudler/LocalAI module github.com/mudler/LocalAI
go 1.22.0 go 1.23
toolchain go1.22.4 toolchain go1.23.1
require ( require (
dario.cat/mergo v1.0.0 dario.cat/mergo v1.0.0

View File

@ -31,7 +31,11 @@ const (
type URI string type URI string
func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error { func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error {
return uri.DownloadWithAuthorizationAndCallback(basePath, "", f)
}
func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error {
url := uri.ResolveURL() url := uri.ResolveURL()
if strings.HasPrefix(url, LocalPrefix) { if strings.HasPrefix(url, LocalPrefix) {
@ -41,7 +45,6 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte
if err != nil { if err != nil {
return err return err
} }
// ???
resolvedBasePath, err := filepath.EvalSymlinks(basePath) resolvedBasePath, err := filepath.EvalSymlinks(basePath)
if err != nil { if err != nil {
return err return err
@ -63,7 +66,16 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte
} }
// Send a GET request to the URL // Send a GET request to the URL
response, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
if authorization != "" {
req.Header.Add("Authorization", authorization)
}
response, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,7 +11,7 @@ var _ = Describe("Gallery API tests", func() {
It("parses github with a branch", func() { It("parses github with a branch", func() {
uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml") uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml")
Expect( Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error { uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil return nil
}), }),
@ -21,7 +21,7 @@ var _ = Describe("Gallery API tests", func() {
uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml@main") uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml@main")
Expect( Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error { uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil return nil
}), }),
@ -30,7 +30,7 @@ var _ = Describe("Gallery API tests", func() {
It("parses github with urls", func() { It("parses github with urls", func() {
uri := URI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml") uri := URI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")
Expect( Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error { uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil return nil
}), }),