From 12513ebae0c333262e45077438d745de82b87590 Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 24 Jun 2024 02:34:36 -0400 Subject: [PATCH] rf: centralize base64 image handling (#2595) contains simple fixes to warnings and errors, removes a broken / outdated test, runs go mod tidy, and as the actual change, centralizes base64 image handling Signed-off-by: Dave Lee --- backend/go/llm/rwkv/rwkv.go | 2 +- core/http/endpoints/openai/assistant.go | 2 +- core/http/endpoints/openai/assistant_test.go | 12 +++--- core/http/endpoints/openai/request.go | 44 ++------------------ core/http/render.go | 5 +-- core/startup/startup.go | 5 ++- pkg/library/dynaload.go | 19 ++++++--- pkg/model/loader_test.go | 21 +++++----- pkg/model/process.go | 5 ++- pkg/utils/base64.go | 11 +++-- pkg/utils/base64_test.go | 9 +++- tests/integration/reflect_test.go | 23 ---------- 12 files changed, 60 insertions(+), 98 deletions(-) delete mode 100644 tests/integration/reflect_test.go diff --git a/backend/go/llm/rwkv/rwkv.go b/backend/go/llm/rwkv/rwkv.go index 54047521..fe9cd815 100644 --- a/backend/go/llm/rwkv/rwkv.go +++ b/backend/go/llm/rwkv/rwkv.go @@ -31,7 +31,7 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads())) if model == nil { - return fmt.Errorf("could not load model") + return fmt.Errorf("rwkv could not load model") } llm.rwkv = model return nil diff --git a/core/http/endpoints/openai/assistant.go b/core/http/endpoints/openai/assistant.go index 0f5ab08d..4882eeaf 100644 --- a/core/http/endpoints/openai/assistant.go +++ b/core/http/endpoints/openai/assistant.go @@ -339,7 +339,7 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model } } - return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find ")) + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find %q", assistantID)) } } diff --git a/core/http/endpoints/openai/assistant_test.go b/core/http/endpoints/openai/assistant_test.go index 76bc5712..7d6c0c06 100644 --- a/core/http/endpoints/openai/assistant_test.go +++ b/core/http/endpoints/openai/assistant_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "os" @@ -183,7 +182,7 @@ func TestAssistantEndpoints(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tt.expectedStatus, response.StatusCode) if tt.expectedStatus != fiber.StatusOK { - all, _ := ioutil.ReadAll(response.Body) + all, _ := io.ReadAll(response.Body) assert.Equal(t, tt.expectedStringResult, string(all)) } else { var result []Assistant @@ -279,6 +278,7 @@ func TestAssistantEndpoints(t *testing.T) { assert.NoError(t, err) var getAssistant Assistant err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant) + assert.NoError(t, err) t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID})) @@ -391,7 +391,10 @@ func createAssistantFile(app *fiber.App, afr AssistantFileRequest, assistantId s } var assistantFile AssistantFile - all, err := ioutil.ReadAll(resp.Body) + all, err := io.ReadAll(resp.Body) + if err != nil { + return AssistantFile{}, resp, err + } err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile) if err != nil { return AssistantFile{}, resp, err @@ -422,8 +425,7 @@ func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Resp var resultAssistant Assistant err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant) - - return resultAssistant, resp, nil + return resultAssistant, resp, err } func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() { diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index 95a02d0c..7f7cc3a2 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -2,19 +2,16 @@ package openai import ( "context" - "encoding/base64" "encoding/json" "fmt" - "io" - "net/http" - "strings" "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/config" fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" - model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" ) @@ -39,41 +36,6 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi return modelFile, input, err } -// this function check if the string is an URL, if it's an URL downloads the image in memory -// encodes it in base64 and returns the base64 string -func getBase64Image(s string) (string, error) { - if strings.HasPrefix(s, "http") { - // download the image - resp, err := http.Get(s) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // read the image data into memory - data, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - // encode the image data in base64 - encoded := base64.StdEncoding.EncodeToString(data) - - // return the base64 string - return encoded, nil - } - - // if the string instead is prefixed with "data:image/...;base64,", drop it - dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"} - for _, prefix := range dropPrefix { - if strings.HasPrefix(s, prefix) { - return strings.ReplaceAll(s, prefix, ""), nil - } - } - - return "", fmt.Errorf("not valid string") -} - func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { if input.Echo { config.Echo = input.Echo @@ -187,7 +149,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque input.Messages[i].StringContent = pp.Text } else if pp.Type == "image_url" { // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: - base64, err := getBase64Image(pp.ImageURL.URL) + base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL) if err == nil { input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff // set a placeholder for each image diff --git a/core/http/render.go b/core/http/render.go index d9e42c4d..205f7ca3 100644 --- a/core/http/render.go +++ b/core/http/render.go @@ -21,14 +21,13 @@ func notFoundHandler(c *fiber.Ctx) error { // Check if the request accepts JSON if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 { // The client expects a JSON response - c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{ + return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{ Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound}, }) } else { // The client expects an HTML response - c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{}) + return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{}) } - return nil } func renderEngine() *fiberhtml.Engine { diff --git a/core/startup/startup.go b/core/startup/startup.go index 4a90f6f8..278c8e1c 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -112,7 +112,10 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode if options.LibPath != "" { // If there is a lib directory, set LD_LIBRARY_PATH to include it - library.LoadExternal(options.LibPath) + err := library.LoadExternal(options.LibPath) + if err != nil { + log.Error().Err(err).Str("LibPath", options.LibPath).Msg("Error while loading external libraries") + } } // turn off any process that was started by GRPC if the context is canceled diff --git a/pkg/library/dynaload.go b/pkg/library/dynaload.go index 4e25ed91..c1f79f65 100644 --- a/pkg/library/dynaload.go +++ b/pkg/library/dynaload.go @@ -1,6 +1,7 @@ package library import ( + "errors" "fmt" "os" "path/filepath" @@ -17,14 +18,17 @@ import ( var skipLibraryPath = os.Getenv("LOCALAI_SKIP_LIBRARY_PATH") != "" // LoadExtractedLibs loads the extracted libraries from the asset dir -func LoadExtractedLibs(dir string) { +func LoadExtractedLibs(dir string) error { + // Skip this if LOCALAI_SKIP_LIBRARY_PATH is set if skipLibraryPath { - return + return nil } + var err error = nil for _, libDir := range []string{filepath.Join(dir, "backend-assets", "lib"), filepath.Join(dir, "lib")} { - LoadExternal(libDir) + err = errors.Join(err, LoadExternal(libDir)) } + return err } // LoadLDSO checks if there is a ld.so in the asset dir and if so, prefixes the grpc process with it. @@ -57,9 +61,10 @@ func LoadLDSO(assetDir string, args []string, grpcProcess string) ([]string, str } // LoadExternal sets the LD_LIBRARY_PATH to include the given directory -func LoadExternal(dir string) { +func LoadExternal(dir string) error { + // Skip this if LOCALAI_SKIP_LIBRARY_PATH is set if skipLibraryPath { - return + return nil } lpathVar := "LD_LIBRARY_PATH" @@ -67,6 +72,7 @@ func LoadExternal(dir string) { lpathVar = "DYLD_FALLBACK_LIBRARY_PATH" // should it be DYLD_LIBRARY_PATH ? } + var setErr error = nil if _, err := os.Stat(dir); err == nil { ldLibraryPath := os.Getenv(lpathVar) if ldLibraryPath == "" { @@ -74,6 +80,7 @@ func LoadExternal(dir string) { } else { ldLibraryPath = fmt.Sprintf("%s:%s", ldLibraryPath, dir) } - os.Setenv(lpathVar, ldLibraryPath) + setErr = errors.Join(setErr, os.Setenv(lpathVar, ldLibraryPath)) } + return setErr } diff --git a/pkg/model/loader_test.go b/pkg/model/loader_test.go index 17e318b6..1142ed0c 100644 --- a/pkg/model/loader_test.go +++ b/pkg/model/loader_test.go @@ -1,7 +1,6 @@ package model_test import ( - "github.com/mudler/LocalAI/pkg/model" . "github.com/mudler/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" @@ -44,7 +43,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in "user": { "template": llama3, "expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", - "data": model.ChatMessageTemplateData{ + "data": ChatMessageTemplateData{ SystemPrompt: "", Role: "user", RoleName: "user", @@ -59,7 +58,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in "assistant": { "template": llama3, "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>", - "data": model.ChatMessageTemplateData{ + "data": ChatMessageTemplateData{ SystemPrompt: "", Role: "assistant", RoleName: "assistant", @@ -74,7 +73,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in "function_call": { "template": llama3, "expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>", - "data": model.ChatMessageTemplateData{ + "data": ChatMessageTemplateData{ SystemPrompt: "", Role: "assistant", RoleName: "assistant", @@ -89,7 +88,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in "function_response": { "template": llama3, "expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>", - "data": model.ChatMessageTemplateData{ + "data": ChatMessageTemplateData{ SystemPrompt: "", Role: "tool", RoleName: "tool", @@ -107,7 +106,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in "user": { "template": chatML, "expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>", - "data": model.ChatMessageTemplateData{ + "data": ChatMessageTemplateData{ SystemPrompt: "", Role: "user", RoleName: "user", @@ -122,7 +121,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in "assistant": { "template": chatML, "expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>", - "data": model.ChatMessageTemplateData{ + "data": ChatMessageTemplateData{ SystemPrompt: "", Role: "assistant", RoleName: "assistant", @@ -137,7 +136,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in "function_call": { "template": chatML, "expected": "<|im_start|>assistant\n\n{\"function\":\"test\"}\n<|im_end|>", - "data": model.ChatMessageTemplateData{ + "data": ChatMessageTemplateData{ SystemPrompt: "", Role: "assistant", RoleName: "assistant", @@ -152,7 +151,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in "function_response": { "template": chatML, "expected": "<|im_start|>tool\n\nResponse from tool\n<|im_end|>", - "data": model.ChatMessageTemplateData{ + "data": ChatMessageTemplateData{ SystemPrompt: "", Role: "tool", RoleName: "tool", @@ -175,7 +174,7 @@ var _ = Describe("Templates", func() { for key := range chatMLTestMatch { foo := chatMLTestMatch[key] It("renders correctly `"+key+"`", func() { - templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData)) + templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) Expect(err).ToNot(HaveOccurred()) Expect(templated).To(Equal(foo["expected"]), templated) }) @@ -189,7 +188,7 @@ var _ = Describe("Templates", func() { for key := range llama3TestMatch { foo := llama3TestMatch[key] It("renders correctly `"+key+"`", func() { - templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData)) + templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData)) Expect(err).ToNot(HaveOccurred()) Expect(templated).To(Equal(foo["expected"]), templated) }) diff --git a/pkg/model/process.go b/pkg/model/process.go index 58def58c..7b7ecb97 100644 --- a/pkg/model/process.go +++ b/pkg/model/process.go @@ -103,7 +103,10 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) <-c - grpcControlProcess.Stop() + err := grpcControlProcess.Stop() + if err != nil { + log.Error().Err(err).Msg("error while shutting down grpc process") + } }() go func() { diff --git a/pkg/utils/base64.go b/pkg/utils/base64.go index 977156e9..3fbb405b 100644 --- a/pkg/utils/base64.go +++ b/pkg/utils/base64.go @@ -42,9 +42,12 @@ func GetImageURLAsBase64(s string) (string, error) { return encoded, nil } - // if the string instead is prefixed with "data:image/jpeg;base64,", drop it - if strings.HasPrefix(s, "data:image/jpeg;base64,") { - return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil + // if the string instead is prefixed with "data:image/...;base64,", drop it + dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"} + for _, prefix := range dropPrefix { + if strings.HasPrefix(s, prefix) { + return strings.ReplaceAll(s, prefix, ""), nil + } } return "", fmt.Errorf("not valid string") -} \ No newline at end of file +} diff --git a/pkg/utils/base64_test.go b/pkg/utils/base64_test.go index f1fcda95..3b3dc9fb 100644 --- a/pkg/utils/base64_test.go +++ b/pkg/utils/base64_test.go @@ -7,13 +7,20 @@ import ( ) var _ = Describe("utils/base64 tests", func() { - It("GetImageURLAsBase64 can strip data url prefixes", func() { + It("GetImageURLAsBase64 can strip jpeg data url prefixes", func() { // This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes. input := "" b64, err := GetImageURLAsBase64(input) Expect(err).To(BeNil()) Expect(b64).To(Equal("FOO")) }) + It("GetImageURLAsBase64 can strip png data url prefixes", func() { + // This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes. + input := "" + b64, err := GetImageURLAsBase64(input) + Expect(err).To(BeNil()) + Expect(b64).To(Equal("BAR")) + }) It("GetImageURLAsBase64 returns an error for bogus data", func() { input := "FOO" b64, err := GetImageURLAsBase64(input) diff --git a/tests/integration/reflect_test.go b/tests/integration/reflect_test.go deleted file mode 100644 index 3a99cdf2..00000000 --- a/tests/integration/reflect_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package integration_test - -import ( - "reflect" - - "github.com/mudler/LocalAI/core/config" - model "github.com/mudler/LocalAI/pkg/model" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("Integration Tests involving reflection in liue of code generation", func() { - Context("config.TemplateConfig and model.TemplateType must stay in sync", func() { - - ttc := reflect.TypeOf(config.TemplateConfig{}) - - It("TemplateConfig and TemplateType should have the same number of valid values", func() { - const lastValidTemplateType = model.IntegrationTestTemplate - 1 - Expect(lastValidTemplateType).To(Equal(ttc.NumField())) - }) - - }) -})