From 01205fd4c0e606f7d8e68b23bd3c74fb4a032b3e Mon Sep 17 00:00:00 2001 From: Steven Christou <1302212+christ66@users.noreply.github.com> Date: Sun, 18 Feb 2024 02:12:02 -0800 Subject: [PATCH] Initial implementation of upload files api. (#1703) * Initial implementation of upload files api. * Move sanitize method to utils. * Save uploaded data to uploads folder. * Avoid loop if we do not have a purpose. * Minor cleanup of api and fix bug where deleting duplicate filename cause error. * Revert defer of saving config * Moved creation of directory to startup. * Make file names unique when storing on disk. * Add test for files api. * Update dependencies. --- api/api.go | 16 +++ api/openai/files.go | 207 ++++++++++++++++++++++++++++ api/openai/files_test.go | 286 +++++++++++++++++++++++++++++++++++++++ api/options/options.go | 7 + go.mod | 4 +- go.sum | 2 - main.go | 7 + pkg/utils/path.go | 12 ++ 8 files changed, 538 insertions(+), 3 deletions(-) create mode 100644 api/openai/files.go create mode 100644 api/openai/files_test.go diff --git a/api/api.go b/api/api.go index 946204d2..4442421e 100644 --- a/api/api.go +++ b/api/api.go @@ -223,8 +223,12 @@ func App(opts ...options.AppOption) (*fiber.App, error) { // Make sure directories exists os.MkdirAll(options.ImageDir, 0755) os.MkdirAll(options.AudioDir, 0755) + os.MkdirAll(options.UploadDir, 0755) os.MkdirAll(options.Loader.ModelPath, 0755) + // Load upload json + openai.LoadUploadConfig(options.UploadDir) + modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) @@ -244,6 +248,18 @@ func App(opts ...options.AppOption) (*fiber.App, error) { app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) app.Post("/edits", auth, openai.EditEndpoint(cl, options)) + // files + app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options)) + app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options)) + app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options)) + app.Get("/files", auth, openai.ListFilesEndpoint(cl, options)) + app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) + app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) + app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) + app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) + app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) + app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) + // completion app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) diff --git a/api/openai/files.go b/api/openai/files.go new file mode 100644 index 00000000..f19e79d8 --- /dev/null +++ b/api/openai/files.go @@ -0,0 +1,207 @@ +package openai + +import ( + "encoding/json" + "errors" + "fmt" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "os" + "path/filepath" + "time" +) + +var uploadedFiles []File + +// File represents the structure of a file object from the OpenAI API. +type File struct { + ID string `json:"id"` // Unique identifier for the file + Object string `json:"object"` // Type of the object (e.g., "file") + Bytes int `json:"bytes"` // Size of the file in bytes + CreatedAt time.Time `json:"created_at"` // The time at which the file was created + Filename string `json:"filename"` // The name of the file + Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.) +} + +func saveUploadConfig(uploadDir string) { + file, err := json.MarshalIndent(uploadedFiles, "", " ") + if err != nil { + log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err) + } + + err = os.WriteFile(filepath.Join(uploadDir, "uploadedFiles.json"), file, 0644) + if err != nil { + log.Error().Msgf("Failed to save uploadedFiles to file: %s", err) + } +} + +func LoadUploadConfig(uploadPath string) { + file, err := os.ReadFile(filepath.Join(uploadPath, "uploadedFiles.json")) + if err != nil { + log.Error().Msgf("Failed to read file: %s", err) + } else { + err = json.Unmarshal(file, &uploadedFiles) + if err != nil { + log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err) + } + } +} + +// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create +func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + file, err := c.FormFile("file") + if err != nil { + return err + } + + // Check the file size + if file.Size > int64(o.UploadLimitMB*1024*1024) { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB)) + } + + purpose := c.FormValue("purpose", "") //TODO put in purpose dirs + if purpose == "" { + return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined") + } + + // Sanitize the filename to prevent directory traversal + filename := utils.SanitizeFileName(file.Filename) + + savePath := filepath.Join(o.UploadDir, filename) + + // Check if file already exists + if _, err := os.Stat(savePath); !os.IsNotExist(err) { + return c.Status(fiber.StatusBadRequest).SendString("File already exists") + } + + err = c.SaveFile(file, savePath) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error()) + } + + f := File{ + ID: fmt.Sprintf("file-%d", time.Now().Unix()), + Object: "file", + Bytes: int(file.Size), + CreatedAt: time.Now(), + Filename: file.Filename, + Purpose: purpose, + } + + uploadedFiles = append(uploadedFiles, f) + saveUploadConfig(o.UploadDir) + return c.Status(fiber.StatusOK).JSON(f) + } +} + +// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list +func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + type ListFiles struct { + Data []File + Object string + } + + return func(c *fiber.Ctx) error { + var listFiles ListFiles + + purpose := c.Query("purpose") + if purpose == "" { + listFiles.Data = uploadedFiles + } else { + for _, f := range uploadedFiles { + if purpose == f.Purpose { + listFiles.Data = append(listFiles.Data, f) + } + } + } + listFiles.Object = "list" + return c.Status(fiber.StatusOK).JSON(listFiles) + } +} + +func getFileFromRequest(c *fiber.Ctx) (*File, error) { + id := c.Params("file_id") + if id == "" { + return nil, fmt.Errorf("file_id parameter is required") + } + + for _, f := range uploadedFiles { + if id == f.ID { + return &f, nil + } + } + + return nil, fmt.Errorf("unable to find file id %s", id) +} + +// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve +func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + file, err := getFileFromRequest(c) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + return c.JSON(file) + } +} + +// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete +func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + type DeleteStatus struct { + Id string + Object string + Deleted bool + } + + return func(c *fiber.Ctx) error { + file, err := getFileFromRequest(c) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + err = os.Remove(filepath.Join(o.UploadDir, file.Filename)) + if err != nil { + // If the file doesn't exist then we should just continue to remove it + if !errors.Is(err, os.ErrNotExist) { + return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err)) + } + } + + // Remove upload from list + for i, f := range uploadedFiles { + if f.ID == file.ID { + uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...) + break + } + } + + saveUploadConfig(o.UploadDir) + return c.JSON(DeleteStatus{ + Id: file.ID, + Object: "file", + Deleted: true, + }) + } +} + +// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents +func GetFilesContentsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + file, err := getFileFromRequest(c) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename)) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + return c.Send(fileContents) + } +} diff --git a/api/openai/files_test.go b/api/openai/files_test.go new file mode 100644 index 00000000..cb111b4a --- /dev/null +++ b/api/openai/files_test.go @@ -0,0 +1,286 @@ +package openai + +import ( + "encoding/json" + "fmt" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + utils2 "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + + "testing" +) + +type ListFiles struct { + Data []File + Object string +} + +func startUpApp() (app *fiber.App, option *options.Option, loader *config.ConfigLoader) { + // Preparing the mocked objects + loader = &config.ConfigLoader{} + + option = &options.Option{ + UploadLimitMB: 10, + UploadDir: "test_dir", + } + + _ = os.RemoveAll(option.UploadDir) + + app = fiber.New(fiber.Config{ + BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB. + }) + + // Create a Test Server + app.Post("/files", UploadFilesEndpoint(loader, option)) + app.Get("/files", ListFilesEndpoint(loader, option)) + app.Get("/files/:file_id", GetFilesEndpoint(loader, option)) + app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option)) + app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option)) + + return +} + +func TestUploadFileExceedSizeLimit(t *testing.T) { + // Preparing the mocked objects + loader := &config.ConfigLoader{} + + option := &options.Option{ + UploadLimitMB: 10, + UploadDir: "test_dir", + } + + _ = os.RemoveAll(option.UploadDir) + + app := fiber.New(fiber.Config{ + BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB. + }) + + // Create a Test Server + app.Post("/files", UploadFilesEndpoint(loader, option)) + app.Get("/files", ListFilesEndpoint(loader, option)) + app.Get("/files/:file_id", GetFilesEndpoint(loader, option)) + app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option)) + app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option)) + + t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) { + resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option) + assert.NoError(t, err) + + assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode) + assert.Contains(t, bodyToString(resp, t), "exceeds upload limit") + }) + t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) { + resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option) + + assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode) + assert.Contains(t, bodyToString(resp, t), "Purpose is not defined") + }) + t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) { + f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option) + + resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option) + fmt.Println(f1) + fmt.Printf("ERror: %v", err) + + assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode) + assert.Contains(t, bodyToString(resp, t), "File already exists") + }) + t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) { + file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option) + + // Check if file exists in the disk + filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt")) + _, err := os.Stat(filePath) + + assert.False(t, os.IsNotExist(err)) + assert.Equal(t, file.Bytes, 5242880) + assert.NotEmpty(t, file.CreatedAt) + assert.Equal(t, file.Filename, "test.txt") + assert.Equal(t, file.Purpose, "fine-tune") + }) + t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) { + resp, err := CallListFilesEndpoint(t, app, "") + assert.NoError(t, err) + + assert.Equal(t, 200, resp.StatusCode) + + listFiles := responseToListFile(t, resp) + if len(listFiles.Data) != len(uploadedFiles) { + t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data)) + } + }) + t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) { + _ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option) + + resp, err := CallListFilesEndpoint(t, app, "fine-tune") + assert.NoError(t, err) + + listFiles := responseToListFile(t, resp) + if len(listFiles.Data) != 1 { + t.Errorf("Expected 1 file, got %v files", len(listFiles.Data)) + } + }) + t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) { + resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune") + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + listFiles := responseToListFile(t, resp) + + if len(listFiles.Data) != 0 { + t.Errorf("Expected 0 file, got %v files", len(listFiles.Data)) + } + }) + t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) { + req := httptest.NewRequest("GET", "/files", nil) + resp, _ := app.Test(req) + assert.Equal(t, 200, resp.StatusCode) + + var listFiles ListFiles + if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil { + t.Errorf("Failed to decode response: %v", err) + return + } + + if len(listFiles.Data) != 0 { + t.Errorf("Expected 0 file, got %v files", len(listFiles.Data)) + } + }) +} + +func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) { + var target string + if purpose != "" { + target = fmt.Sprintf("/files?purpose=%s", purpose) + } else { + target = "/files" + } + req := httptest.NewRequest("GET", target, nil) + return app.Test(req) +} + +func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) { + request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil) + return app.Test(request) +} + +func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) (*http.Response, error) { + // Create a file that exceeds the limit + file := createTestFile(t, fileName, fileSize, o) + + // Creating a new HTTP Request + body, writer := newMultipartFile(file.Name(), tag, purpose) + + req := httptest.NewRequest(http.MethodPost, "/files", body) + req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType()) + return app.Test(req) +} + +func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) File { + // Create a file that exceeds the limit + file := createTestFile(t, fileName, fileSize, o) + + // Creating a new HTTP Request + body, writer := newMultipartFile(file.Name(), tag, purpose) + + req := httptest.NewRequest(http.MethodPost, "/files", body) + req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType()) + resp, err := app.Test(req) + assert.NoError(t, err) + f := responseToFile(t, resp) + + id := f.ID + t.Cleanup(func() { + _, err := CallFilesDeleteEndpoint(t, app, id) + assert.NoError(t, err) + }) + + return f + +} + +func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) { + target := fmt.Sprintf("/files/%s", fileId) + req := httptest.NewRequest(http.MethodDelete, target, nil) + return app.Test(req) +} + +// Helper to create multi-part file +func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) { + body := new(strings.Builder) + writer := multipart.NewWriter(body) + file, _ := os.Open(filePath) + defer file.Close() + part, _ := writer.CreateFormFile(tag, filepath.Base(filePath)) + io.Copy(part, file) + + if purpose != "" { + _ = writer.WriteField("purpose", purpose) + } + + writer.Close() + return strings.NewReader(body.String()), writer +} + +// Helper to create test files +func createTestFile(t *testing.T, name string, sizeMB int, option *options.Option) *os.File { + err := os.MkdirAll(option.UploadDir, 0755) + if err != nil { + + t.Fatalf("Error MKDIR: %v", err) + } + + file, _ := os.Create(name) + file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File + + t.Cleanup(func() { + os.Remove(name) + os.RemoveAll(option.UploadDir) + }) + return file +} + +func bodyToString(resp *http.Response, t *testing.T) string { + return string(bodyToByteArray(resp, t)) +} + +func bodyToByteArray(resp *http.Response, t *testing.T) []byte { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + return bodyBytes +} + +func responseToFile(t *testing.T, resp *http.Response) File { + var file File + responseToString := bodyToString(resp, t) + + err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file) + if err != nil { + t.Errorf("Failed to decode response: %s", err) + } + + return file +} + +func responseToListFile(t *testing.T, resp *http.Response) ListFiles { + var listFiles ListFiles + responseToString := bodyToString(resp, t) + + err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles) + if err != nil { + fmt.Printf("Failed to decode response: %s", err) + } + + return listFiles +} diff --git a/api/options/options.go b/api/options/options.go index 8c066584..72aea1a3 100644 --- a/api/options/options.go +++ b/api/options/options.go @@ -21,6 +21,7 @@ type Option struct { Debug, DisableMessage bool ImageDir string AudioDir string + UploadDir string CORS bool PreloadJSONModels string PreloadModelsFromPath string @@ -249,6 +250,12 @@ func WithImageDir(imageDir string) AppOption { } } +func WithUploadDir(uploadDir string) AppOption { + return func(o *Option) { + o.UploadDir = uploadDir + } +} + func WithApiKeys(apiKeys []string) AppOption { return func(o *Option) { o.ApiKeys = apiKeys diff --git a/go.mod b/go.mod index 250a2361..bbd787b5 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e github.com/go-audio/wav v1.1.0 github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 - github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428 github.com/gofiber/fiber/v2 v2.50.0 github.com/google/uuid v1.3.1 @@ -28,6 +27,7 @@ require ( github.com/rs/zerolog v1.31.0 github.com/sashabaranov/go-openai v1.16.0 github.com/schollz/progressbar/v3 v3.13.1 + github.com/stretchr/testify v1.8.4 github.com/tmc/langchaingo v0.0.0-20231019140956-c636b3da7701 github.com/urfave/cli/v2 v2.25.7 github.com/valyala/fasthttp v1.50.0 @@ -55,6 +55,7 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.8.1 // indirect github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -68,6 +69,7 @@ require ( github.com/nwaples/rardecode v1.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pkoukk/tiktoken-go v0.1.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.11.1 // indirect diff --git a/go.sum b/go.sum index fc00bf6e..20dfbfb4 100644 --- a/go.sum +++ b/go.sum @@ -43,8 +43,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 h1:yXvc7QfGtoZ51tUW/YVjoTwAfh8HG88XU7UOrbNlz5Y= github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1/go.mod h1:fYjkCDRzC+oRLHSjQoajmYK6AmeJnmEanV27CClAcDc= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e h1:4reMY29i1eOZaRaSTMPNyXI7X8RMNxCTfDDBXYzrbr0= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428 h1:WYjkXL0Nw7dN2uDBMVCWQ8xLavrIhjF/DLczuh5L9TY= github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428/go.mod h1:iub0ugfTnflE3rcIuqV2pQSo15nEw3GLW/utm5gyERo= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= diff --git a/main.go b/main.go index edf70328..2636b402 100644 --- a/main.go +++ b/main.go @@ -142,6 +142,12 @@ func main() { EnvVars: []string{"AUDIO_PATH"}, Value: "/tmp/generated/audio", }, + &cli.StringFlag{ + Name: "upload-path", + Usage: "Path to store uploads from files api", + EnvVars: []string{"UPLOAD_PATH"}, + Value: "/tmp/localai/upload", + }, &cli.StringFlag{ Name: "backend-assets-path", Usage: "Path used to extract libraries that are required by some of the backends in runtime.", @@ -227,6 +233,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithDebug(ctx.Bool("debug")), options.WithImageDir(ctx.String("image-path")), options.WithAudioDir(ctx.String("audio-path")), + options.WithUploadDir(ctx.String("upload-path")), options.WithF16(ctx.Bool("f16")), options.WithStringGalleries(ctx.String("galleries")), options.WithModelLibraryURL(ctx.String("remote-library")), diff --git a/pkg/utils/path.go b/pkg/utils/path.go index 05481d2c..f95b0138 100644 --- a/pkg/utils/path.go +++ b/pkg/utils/path.go @@ -3,6 +3,7 @@ package utils import ( "fmt" "path/filepath" + "strings" ) func inTrustedRoot(path string, trustedRoot string) error { @@ -20,3 +21,14 @@ func VerifyPath(path, basePath string) error { c := filepath.Clean(filepath.Join(basePath, path)) return inTrustedRoot(c, filepath.Clean(basePath)) } + +// SanitizeFileName sanitizes the given filename +func SanitizeFileName(fileName string) string { + // filepath.Clean to clean the path + cleanName := filepath.Clean(fileName) + // filepath.Base to ensure we only get the final element, not any directory path + baseName := filepath.Base(cleanName) + // Replace any remaining tricky characters that might have survived cleaning + safeName := strings.ReplaceAll(baseName, "..", "") + return safeName +}