From 90cacb9692f3dc374766b0e32f75be8229a47db3 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 24 Sep 2024 03:32:48 -0400 Subject: [PATCH] test: preliminary tests and merge fix for authv2 (#3584) * add api key to existing app tests, add preliminary auth test Signed-off-by: Dave Lee * small fix, run test Signed-off-by: Dave Lee * status on non-opaque Signed-off-by: Dave Lee * tweak auth error Signed-off-by: Dave Lee * exp Signed-off-by: Dave Lee * quick fix on real laptop Signed-off-by: Dave Lee * add downloader version that allows providing an auth header Signed-off-by: Dave Lee * stash some devcontainer fixes during testing Signed-off-by: Dave Lee * s2 Signed-off-by: Dave Lee * s Signed-off-by: Dave Lee * done with experiment Signed-off-by: Dave Lee * done with experiment Signed-off-by: Dave Lee * after merge fix Signed-off-by: Dave Lee * rename and fix Signed-off-by: Dave Lee --------- Signed-off-by: Dave Lee Co-authored-by: Ettore Di Giacinto --- .devcontainer-scripts/utils.sh | 2 + Dockerfile | 5 +-- Makefile | 3 ++ core/gallery/gallery.go | 4 +- core/gallery/models.go | 2 +- core/http/app.go | 18 --------- core/http/app_test.go | 69 ++++++++++++++++++++++++++++++---- core/http/middleware/auth.go | 3 +- embedded/embedded.go | 2 +- go.mod | 4 +- pkg/downloader/uri.go | 18 +++++++-- pkg/downloader/uri_test.go | 6 +-- 12 files changed, 95 insertions(+), 41 deletions(-) diff --git a/.devcontainer-scripts/utils.sh b/.devcontainer-scripts/utils.sh index 98ac063c..8416d43d 100644 --- a/.devcontainer-scripts/utils.sh +++ b/.devcontainer-scripts/utils.sh @@ -9,6 +9,7 @@ # Param 2: email # config_user() { + echo "Configuring git for $1 <$2>" local gcn=$(git config --global user.name) if [ -z "${gcn}" ]; then echo "Setting up git user / remote" @@ -24,6 +25,7 @@ config_user() { # Param 2: remote url # config_remote() { + echo "Adding git remote and fetching $2 as $1" local gr=$(git remote -v | grep $1) if [ -z "${gr}" ]; then git remote add $1 $2 diff --git a/Dockerfile b/Dockerfile index 323c3d9a..8c657469 100644 --- a/Dockerfile +++ b/Dockerfile @@ -338,9 +338,8 @@ RUN if [ "${FFMPEG}" = "true" ]; then \ RUN apt-get update && \ apt-get install -y --no-install-recommends \ - ssh less && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* + ssh less wget +# For the devcontainer, leave apt functional in case additional devtools are needed at runtime. RUN go install github.com/go-delve/delve/cmd/dlv@latest diff --git a/Makefile b/Makefile index 578656e5..7523d5ff 100644 --- a/Makefile +++ b/Makefile @@ -359,6 +359,9 @@ clean-tests: rm -rf test-dir rm -rf core/http/backend-assets +clean-dc: clean + cp -r /build/backend-assets /workspace/backend-assets + ## Build: build: prepare backend-assets grpcs ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index 6ced6244..3a60e618 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -132,7 +132,7 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*Gal func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) { var refFile string uri := downloader.URI(url) - err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { + err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { refFile = string(d) if len(refFile) == 0 { return fmt.Errorf("invalid reference file at url %s: %s", url, d) @@ -156,7 +156,7 @@ func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel, } uri := downloader.URI(gallery.URL) - err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { + err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &models) }) if err != nil { diff --git a/core/gallery/models.go b/core/gallery/models.go index dec6312e..58f1963a 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -69,7 +69,7 @@ type PromptTemplate struct { func GetGalleryConfigFromURL(url string, basePath string) (Config, error) { var config Config uri := downloader.URI(url) - err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { + err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { diff --git a/core/http/app.go b/core/http/app.go index fa9cd866..23e97f18 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -31,24 +31,6 @@ import ( "github.com/rs/zerolog/log" ) -func readAuthHeader(c *fiber.Ctx) string { - authHeader := c.Get("Authorization") - - // elevenlabs - xApiKey := c.Get("xi-api-key") - if xApiKey != "" { - authHeader = "Bearer " + xApiKey - } - - // anthropic - xApiKey = c.Get("x-api-key") - if xApiKey != "" { - authHeader = "Bearer " + xApiKey - } - - return authHeader -} - // Embed a directory // //go:embed static/* diff --git a/core/http/app_test.go b/core/http/app_test.go index 86fe7fdd..bbe52c34 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -31,6 +31,9 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) +const apiKey = "joshua" +const bearerKey = "Bearer " + apiKey + const testPrompt = `### System: You are an AI assistant that follows instruction extremely well. Help as much as you can. @@ -50,11 +53,19 @@ type modelApplyRequest struct { func getModelStatus(url string) (response map[string]interface{}) { // Create the HTTP request - resp, err := http.Get(url) + req, err := http.NewRequest("GET", url, nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) if err != nil { fmt.Println("Error creating request:", err) return } + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + fmt.Println("Error sending request:", err) + return + } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -72,14 +83,15 @@ func getModelStatus(url string) (response map[string]interface{}) { return } -func getModels(url string) (response []gallery.GalleryModel) { +func getModels(url string) ([]gallery.GalleryModel, error) { + response := []gallery.GalleryModel{} uri := downloader.URI(url) // TODO: No tests currently seem to exercise file:// urls. Fix? - uri.DownloadAndUnmarshal("", func(url string, i []byte) error { + err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error { // Unmarshal YAML data into a struct return json.Unmarshal(i, &response) }) - return + return response, err } func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { @@ -101,6 +113,7 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[ return } req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) // Make the request client := &http.Client{} @@ -140,6 +153,7 @@ func postRequestJSON[B any](url string, bodyJson *B) error { } req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) @@ -175,6 +189,7 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson * } req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) @@ -195,6 +210,35 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson * return json.Unmarshal(body, respJson) } +func postInvalidRequest(url string) (error, int) { + + req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request")) + if err != nil { + return err, -1 + } + + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err, -1 + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return err, -1 + } + + if resp.StatusCode < 200 || resp.StatusCode >= 400 { + return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode + } + + return nil, resp.StatusCode +} + //go:embed backend-assets/* var backendAssets embed.FS @@ -260,6 +304,7 @@ var _ = Describe("API test", func() { config.WithContext(c), config.WithGalleries(galleries), config.WithModelPath(modelDir), + config.WithApiKeys([]string{apiKey}), config.WithBackendAssets(backendAssets), config.WithBackendAssetsOutput(backendAssetsDir))...) Expect(err).ToNot(HaveOccurred()) @@ -269,7 +314,7 @@ var _ = Describe("API test", func() { go app.Listen("127.0.0.1:9090") - defaultConfig := openai.DefaultConfig("") + defaultConfig := openai.DefaultConfig(apiKey) defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" client2 = openaigo.NewClient("") @@ -295,10 +340,19 @@ var _ = Describe("API test", func() { Expect(err).To(HaveOccurred()) }) + Context("Auth Tests", func() { + It("Should fail if the api key is missing", func() { + err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available") + Expect(err).ToNot(BeNil()) + Expect(sc).To(Equal(403)) + }) + }) + Context("Applying models", func() { It("applies models from a gallery", func() { - models := getModels("http://127.0.0.1:9090/models/available") + models, err := getModels("http://127.0.0.1:9090/models/available") + Expect(err).To(BeNil()) Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) @@ -331,7 +385,8 @@ var _ = Describe("API test", func() { Expect(content["backend"]).To(Equal("bert-embeddings")) Expect(content["foo"]).To(Equal("bar")) - models = getModels("http://127.0.0.1:9090/models/available") + models, err = getModels("http://127.0.0.1:9090/models/available") + Expect(err).To(BeNil()) Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2"))) diff --git a/core/http/middleware/auth.go b/core/http/middleware/auth.go index bc8bcf80..d2152e9b 100644 --- a/core/http/middleware/auth.go +++ b/core/http/middleware/auth.go @@ -38,6 +38,7 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er if applicationConfig.OpaqueErrors { return ctx.SendStatus(403) } + return ctx.Status(403).SendString(err.Error()) } if applicationConfig.OpaqueErrors { return ctx.SendStatus(500) @@ -90,4 +91,4 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig } } return func(c *fiber.Ctx) bool { return false } -} \ No newline at end of file +} diff --git a/embedded/embedded.go b/embedded/embedded.go index 672c32ed..3a4ea262 100644 --- a/embedded/embedded.go +++ b/embedded/embedded.go @@ -39,7 +39,7 @@ func init() { func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) { remoteLibrary := map[string]string{} uri := downloader.URI(url) - err := uri.DownloadAndUnmarshal(basePath, func(_ string, i []byte) error { + err := uri.DownloadWithCallback(basePath, func(_ string, i []byte) error { return yaml.Unmarshal(i, &remoteLibrary) }) if err != nil { diff --git a/go.mod b/go.mod index a3359abf..dd8fce9f 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/mudler/LocalAI -go 1.22.0 +go 1.23 -toolchain go1.22.4 +toolchain go1.23.1 require ( dario.cat/mergo v1.0.0 diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 7fedd646..9acbb621 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -31,7 +31,11 @@ const ( type URI string -func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error { +func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error { + return uri.DownloadWithAuthorizationAndCallback(basePath, "", f) +} + +func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error { url := uri.ResolveURL() if strings.HasPrefix(url, LocalPrefix) { @@ -41,7 +45,6 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte if err != nil { return err } - // ??? resolvedBasePath, err := filepath.EvalSymlinks(basePath) if err != nil { return err @@ -63,7 +66,16 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte } // Send a GET request to the URL - response, err := http.Get(url) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + if authorization != "" { + req.Header.Add("Authorization", authorization) + } + + response, err := http.DefaultClient.Do(req) if err != nil { return err } diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go index 21a093a9..3b7a80b3 100644 --- a/pkg/downloader/uri_test.go +++ b/pkg/downloader/uri_test.go @@ -11,7 +11,7 @@ var _ = Describe("Gallery API tests", func() { It("parses github with a branch", func() { uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml") Expect( - uri.DownloadAndUnmarshal("", func(url string, i []byte) error { + uri.DownloadWithCallback("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), @@ -21,7 +21,7 @@ var _ = Describe("Gallery API tests", func() { uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml@main") Expect( - uri.DownloadAndUnmarshal("", func(url string, i []byte) error { + uri.DownloadWithCallback("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), @@ -30,7 +30,7 @@ var _ = Describe("Gallery API tests", func() { It("parses github with urls", func() { uri := URI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml") Expect( - uri.DownloadAndUnmarshal("", func(url string, i []byte) error { + uri.DownloadWithCallback("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }),