mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
rf: centralize base64 image handling (#2595)
contains simple fixes to warnings and errors, removes a broken / outdated test, runs go mod tidy, and as the actual change, centralizes base64 image handling Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
parent
4156a4f15f
commit
12513ebae0
@ -31,7 +31,7 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
|||||||
model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))
|
model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))
|
||||||
|
|
||||||
if model == nil {
|
if model == nil {
|
||||||
return fmt.Errorf("could not load model")
|
return fmt.Errorf("rwkv could not load model")
|
||||||
}
|
}
|
||||||
llm.rwkv = model
|
llm.rwkv = model
|
||||||
return nil
|
return nil
|
||||||
|
@ -339,7 +339,7 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find "))
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find %q", assistantID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@ -183,7 +182,7 @@ func TestAssistantEndpoints(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, tt.expectedStatus, response.StatusCode)
|
assert.Equal(t, tt.expectedStatus, response.StatusCode)
|
||||||
if tt.expectedStatus != fiber.StatusOK {
|
if tt.expectedStatus != fiber.StatusOK {
|
||||||
all, _ := ioutil.ReadAll(response.Body)
|
all, _ := io.ReadAll(response.Body)
|
||||||
assert.Equal(t, tt.expectedStringResult, string(all))
|
assert.Equal(t, tt.expectedStringResult, string(all))
|
||||||
} else {
|
} else {
|
||||||
var result []Assistant
|
var result []Assistant
|
||||||
@ -279,6 +278,7 @@ func TestAssistantEndpoints(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
var getAssistant Assistant
|
var getAssistant Assistant
|
||||||
err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant)
|
err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID}))
|
t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID}))
|
||||||
|
|
||||||
@ -391,7 +391,10 @@ func createAssistantFile(app *fiber.App, afr AssistantFileRequest, assistantId s
|
|||||||
}
|
}
|
||||||
|
|
||||||
var assistantFile AssistantFile
|
var assistantFile AssistantFile
|
||||||
all, err := ioutil.ReadAll(resp.Body)
|
all, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantFile{}, resp, err
|
||||||
|
}
|
||||||
err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile)
|
err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AssistantFile{}, resp, err
|
return AssistantFile{}, resp, err
|
||||||
@ -422,8 +425,7 @@ func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Resp
|
|||||||
|
|
||||||
var resultAssistant Assistant
|
var resultAssistant Assistant
|
||||||
err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant)
|
err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant)
|
||||||
|
return resultAssistant, resp, err
|
||||||
return resultAssistant, resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() {
|
func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() {
|
||||||
|
@ -2,19 +2,16 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
"github.com/mudler/LocalAI/pkg/functions"
|
"github.com/mudler/LocalAI/pkg/functions"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,41 +36,6 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi
|
|||||||
return modelFile, input, err
|
return modelFile, input, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
|
||||||
// encodes it in base64 and returns the base64 string
|
|
||||||
func getBase64Image(s string) (string, error) {
|
|
||||||
if strings.HasPrefix(s, "http") {
|
|
||||||
// download the image
|
|
||||||
resp, err := http.Get(s)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// read the image data into memory
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// encode the image data in base64
|
|
||||||
encoded := base64.StdEncoding.EncodeToString(data)
|
|
||||||
|
|
||||||
// return the base64 string
|
|
||||||
return encoded, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the string instead is prefixed with "data:image/...;base64,", drop it
|
|
||||||
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
|
|
||||||
for _, prefix := range dropPrefix {
|
|
||||||
if strings.HasPrefix(s, prefix) {
|
|
||||||
return strings.ReplaceAll(s, prefix, ""), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("not valid string")
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
|
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
|
||||||
if input.Echo {
|
if input.Echo {
|
||||||
config.Echo = input.Echo
|
config.Echo = input.Echo
|
||||||
@ -187,7 +149,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
input.Messages[i].StringContent = pp.Text
|
input.Messages[i].StringContent = pp.Text
|
||||||
} else if pp.Type == "image_url" {
|
} else if pp.Type == "image_url" {
|
||||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||||
base64, err := getBase64Image(pp.ImageURL.URL)
|
base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||||
// set a placeholder for each image
|
// set a placeholder for each image
|
||||||
|
@ -21,14 +21,13 @@ func notFoundHandler(c *fiber.Ctx) error {
|
|||||||
// Check if the request accepts JSON
|
// Check if the request accepts JSON
|
||||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||||
// The client expects a JSON response
|
// The client expects a JSON response
|
||||||
c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
|
return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
|
||||||
Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound},
|
Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// The client expects an HTML response
|
// The client expects an HTML response
|
||||||
c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{})
|
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{})
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func renderEngine() *fiberhtml.Engine {
|
func renderEngine() *fiberhtml.Engine {
|
||||||
|
@ -112,7 +112,10 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||||||
|
|
||||||
if options.LibPath != "" {
|
if options.LibPath != "" {
|
||||||
// If there is a lib directory, set LD_LIBRARY_PATH to include it
|
// If there is a lib directory, set LD_LIBRARY_PATH to include it
|
||||||
library.LoadExternal(options.LibPath)
|
err := library.LoadExternal(options.LibPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("LibPath", options.LibPath).Msg("Error while loading external libraries")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// turn off any process that was started by GRPC if the context is canceled
|
// turn off any process that was started by GRPC if the context is canceled
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package library
|
package library
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -17,14 +18,17 @@ import (
|
|||||||
var skipLibraryPath = os.Getenv("LOCALAI_SKIP_LIBRARY_PATH") != ""
|
var skipLibraryPath = os.Getenv("LOCALAI_SKIP_LIBRARY_PATH") != ""
|
||||||
|
|
||||||
// LoadExtractedLibs loads the extracted libraries from the asset dir
|
// LoadExtractedLibs loads the extracted libraries from the asset dir
|
||||||
func LoadExtractedLibs(dir string) {
|
func LoadExtractedLibs(dir string) error {
|
||||||
|
// Skip this if LOCALAI_SKIP_LIBRARY_PATH is set
|
||||||
if skipLibraryPath {
|
if skipLibraryPath {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error = nil
|
||||||
for _, libDir := range []string{filepath.Join(dir, "backend-assets", "lib"), filepath.Join(dir, "lib")} {
|
for _, libDir := range []string{filepath.Join(dir, "backend-assets", "lib"), filepath.Join(dir, "lib")} {
|
||||||
LoadExternal(libDir)
|
err = errors.Join(err, LoadExternal(libDir))
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadLDSO checks if there is a ld.so in the asset dir and if so, prefixes the grpc process with it.
|
// LoadLDSO checks if there is a ld.so in the asset dir and if so, prefixes the grpc process with it.
|
||||||
@ -57,9 +61,10 @@ func LoadLDSO(assetDir string, args []string, grpcProcess string) ([]string, str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LoadExternal sets the LD_LIBRARY_PATH to include the given directory
|
// LoadExternal sets the LD_LIBRARY_PATH to include the given directory
|
||||||
func LoadExternal(dir string) {
|
func LoadExternal(dir string) error {
|
||||||
|
// Skip this if LOCALAI_SKIP_LIBRARY_PATH is set
|
||||||
if skipLibraryPath {
|
if skipLibraryPath {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
lpathVar := "LD_LIBRARY_PATH"
|
lpathVar := "LD_LIBRARY_PATH"
|
||||||
@ -67,6 +72,7 @@ func LoadExternal(dir string) {
|
|||||||
lpathVar = "DYLD_FALLBACK_LIBRARY_PATH" // should it be DYLD_LIBRARY_PATH ?
|
lpathVar = "DYLD_FALLBACK_LIBRARY_PATH" // should it be DYLD_LIBRARY_PATH ?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var setErr error = nil
|
||||||
if _, err := os.Stat(dir); err == nil {
|
if _, err := os.Stat(dir); err == nil {
|
||||||
ldLibraryPath := os.Getenv(lpathVar)
|
ldLibraryPath := os.Getenv(lpathVar)
|
||||||
if ldLibraryPath == "" {
|
if ldLibraryPath == "" {
|
||||||
@ -74,6 +80,7 @@ func LoadExternal(dir string) {
|
|||||||
} else {
|
} else {
|
||||||
ldLibraryPath = fmt.Sprintf("%s:%s", ldLibraryPath, dir)
|
ldLibraryPath = fmt.Sprintf("%s:%s", ldLibraryPath, dir)
|
||||||
}
|
}
|
||||||
os.Setenv(lpathVar, ldLibraryPath)
|
setErr = errors.Join(setErr, os.Setenv(lpathVar, ldLibraryPath))
|
||||||
}
|
}
|
||||||
|
return setErr
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package model_test
|
package model_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
|
||||||
. "github.com/mudler/LocalAI/pkg/model"
|
. "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
@ -44,7 +43,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||||||
"user": {
|
"user": {
|
||||||
"template": llama3,
|
"template": llama3,
|
||||||
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
"expected": "<|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||||
"data": model.ChatMessageTemplateData{
|
"data": ChatMessageTemplateData{
|
||||||
SystemPrompt: "",
|
SystemPrompt: "",
|
||||||
Role: "user",
|
Role: "user",
|
||||||
RoleName: "user",
|
RoleName: "user",
|
||||||
@ -59,7 +58,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||||||
"assistant": {
|
"assistant": {
|
||||||
"template": llama3,
|
"template": llama3,
|
||||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|>",
|
||||||
"data": model.ChatMessageTemplateData{
|
"data": ChatMessageTemplateData{
|
||||||
SystemPrompt: "",
|
SystemPrompt: "",
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
RoleName: "assistant",
|
RoleName: "assistant",
|
||||||
@ -74,7 +73,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||||||
"function_call": {
|
"function_call": {
|
||||||
"template": llama3,
|
"template": llama3,
|
||||||
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
"expected": "<|start_header_id|>assistant<|end_header_id|>\n\nFunction call:\n{\"function\":\"test\"}<|eot_id|>",
|
||||||
"data": model.ChatMessageTemplateData{
|
"data": ChatMessageTemplateData{
|
||||||
SystemPrompt: "",
|
SystemPrompt: "",
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
RoleName: "assistant",
|
RoleName: "assistant",
|
||||||
@ -89,7 +88,7 @@ var llama3TestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||||||
"function_response": {
|
"function_response": {
|
||||||
"template": llama3,
|
"template": llama3,
|
||||||
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
"expected": "<|start_header_id|>tool<|end_header_id|>\n\nFunction response:\nResponse from tool<|eot_id|>",
|
||||||
"data": model.ChatMessageTemplateData{
|
"data": ChatMessageTemplateData{
|
||||||
SystemPrompt: "",
|
SystemPrompt: "",
|
||||||
Role: "tool",
|
Role: "tool",
|
||||||
RoleName: "tool",
|
RoleName: "tool",
|
||||||
@ -107,7 +106,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||||||
"user": {
|
"user": {
|
||||||
"template": chatML,
|
"template": chatML,
|
||||||
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||||
"data": model.ChatMessageTemplateData{
|
"data": ChatMessageTemplateData{
|
||||||
SystemPrompt: "",
|
SystemPrompt: "",
|
||||||
Role: "user",
|
Role: "user",
|
||||||
RoleName: "user",
|
RoleName: "user",
|
||||||
@ -122,7 +121,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||||||
"assistant": {
|
"assistant": {
|
||||||
"template": chatML,
|
"template": chatML,
|
||||||
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...<|im_end|>",
|
||||||
"data": model.ChatMessageTemplateData{
|
"data": ChatMessageTemplateData{
|
||||||
SystemPrompt: "",
|
SystemPrompt: "",
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
RoleName: "assistant",
|
RoleName: "assistant",
|
||||||
@ -137,7 +136,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||||||
"function_call": {
|
"function_call": {
|
||||||
"template": chatML,
|
"template": chatML,
|
||||||
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call><|im_end|>",
|
||||||
"data": model.ChatMessageTemplateData{
|
"data": ChatMessageTemplateData{
|
||||||
SystemPrompt: "",
|
SystemPrompt: "",
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
RoleName: "assistant",
|
RoleName: "assistant",
|
||||||
@ -152,7 +151,7 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in
|
|||||||
"function_response": {
|
"function_response": {
|
||||||
"template": chatML,
|
"template": chatML,
|
||||||
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response><|im_end|>",
|
||||||
"data": model.ChatMessageTemplateData{
|
"data": ChatMessageTemplateData{
|
||||||
SystemPrompt: "",
|
SystemPrompt: "",
|
||||||
Role: "tool",
|
Role: "tool",
|
||||||
RoleName: "tool",
|
RoleName: "tool",
|
||||||
@ -175,7 +174,7 @@ var _ = Describe("Templates", func() {
|
|||||||
for key := range chatMLTestMatch {
|
for key := range chatMLTestMatch {
|
||||||
foo := chatMLTestMatch[key]
|
foo := chatMLTestMatch[key]
|
||||||
It("renders correctly `"+key+"`", func() {
|
It("renders correctly `"+key+"`", func() {
|
||||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData))
|
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
})
|
})
|
||||||
@ -189,7 +188,7 @@ var _ = Describe("Templates", func() {
|
|||||||
for key := range llama3TestMatch {
|
for key := range llama3TestMatch {
|
||||||
foo := llama3TestMatch[key]
|
foo := llama3TestMatch[key]
|
||||||
It("renders correctly `"+key+"`", func() {
|
It("renders correctly `"+key+"`", func() {
|
||||||
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData))
|
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(ChatMessageTemplateData))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(templated).To(Equal(foo["expected"]), templated)
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
})
|
})
|
||||||
|
@ -103,7 +103,10 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
|
|||||||
c := make(chan os.Signal, 1)
|
c := make(chan os.Signal, 1)
|
||||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||||
<-c
|
<-c
|
||||||
grpcControlProcess.Stop()
|
err := grpcControlProcess.Stop()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("error while shutting down grpc process")
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -42,9 +42,12 @@ func GetImageURLAsBase64(s string) (string, error) {
|
|||||||
return encoded, nil
|
return encoded, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
// if the string instead is prefixed with "data:image/...;base64,", drop it
|
||||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
|
||||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
for _, prefix := range dropPrefix {
|
||||||
|
if strings.HasPrefix(s, prefix) {
|
||||||
|
return strings.ReplaceAll(s, prefix, ""), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return "", fmt.Errorf("not valid string")
|
return "", fmt.Errorf("not valid string")
|
||||||
}
|
}
|
||||||
|
@ -7,13 +7,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("utils/base64 tests", func() {
|
var _ = Describe("utils/base64 tests", func() {
|
||||||
It("GetImageURLAsBase64 can strip data url prefixes", func() {
|
It("GetImageURLAsBase64 can strip jpeg data url prefixes", func() {
|
||||||
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
|
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
|
||||||
input := ""
|
input := ""
|
||||||
b64, err := GetImageURLAsBase64(input)
|
b64, err := GetImageURLAsBase64(input)
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(b64).To(Equal("FOO"))
|
Expect(b64).To(Equal("FOO"))
|
||||||
})
|
})
|
||||||
|
It("GetImageURLAsBase64 can strip png data url prefixes", func() {
|
||||||
|
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
|
||||||
|
input := ""
|
||||||
|
b64, err := GetImageURLAsBase64(input)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
Expect(b64).To(Equal("BAR"))
|
||||||
|
})
|
||||||
It("GetImageURLAsBase64 returns an error for bogus data", func() {
|
It("GetImageURLAsBase64 returns an error for bogus data", func() {
|
||||||
input := "FOO"
|
input := "FOO"
|
||||||
b64, err := GetImageURLAsBase64(input)
|
b64, err := GetImageURLAsBase64(input)
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
package integration_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Integration Tests involving reflection in liue of code generation", func() {
|
|
||||||
Context("config.TemplateConfig and model.TemplateType must stay in sync", func() {
|
|
||||||
|
|
||||||
ttc := reflect.TypeOf(config.TemplateConfig{})
|
|
||||||
|
|
||||||
It("TemplateConfig and TemplateType should have the same number of valid values", func() {
|
|
||||||
const lastValidTemplateType = model.IntegrationTestTemplate - 1
|
|
||||||
Expect(lastValidTemplateType).To(Equal(ttc.NumField()))
|
|
||||||
})
|
|
||||||
|
|
||||||
})
|
|
||||||
})
|
|
Loading…
Reference in New Issue
Block a user