mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-19 20:57:54 +00:00
Initial implementation of upload files api. (#1703)
* Initial implementation of upload files api. * Move sanitize method to utils. * Save uploaded data to uploads folder. * Avoid loop if we do not have a purpose. * Minor cleanup of api and fix bug where deleting duplicate filename cause error. * Revert defer of saving config * Moved creation of directory to startup. * Make file names unique when storing on disk. * Add test for files api. * Update dependencies.
This commit is contained in:
parent
c72808f18b
commit
01205fd4c0
16
api/api.go
16
api/api.go
@ -223,8 +223,12 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||||||
// Make sure directories exists
|
// Make sure directories exists
|
||||||
os.MkdirAll(options.ImageDir, 0755)
|
os.MkdirAll(options.ImageDir, 0755)
|
||||||
os.MkdirAll(options.AudioDir, 0755)
|
os.MkdirAll(options.AudioDir, 0755)
|
||||||
|
os.MkdirAll(options.UploadDir, 0755)
|
||||||
os.MkdirAll(options.Loader.ModelPath, 0755)
|
os.MkdirAll(options.Loader.ModelPath, 0755)
|
||||||
|
|
||||||
|
// Load upload json
|
||||||
|
openai.LoadUploadConfig(options.UploadDir)
|
||||||
|
|
||||||
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
|
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
|
||||||
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
||||||
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
||||||
@ -244,6 +248,18 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
|
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
|
||||||
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
|
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
|
||||||
|
|
||||||
|
// files
|
||||||
|
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options))
|
||||||
|
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options))
|
||||||
|
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options))
|
||||||
|
app.Get("/files", auth, openai.ListFilesEndpoint(cl, options))
|
||||||
|
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
|
||||||
|
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
|
||||||
|
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
|
||||||
|
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
|
||||||
|
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
|
||||||
|
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
|
||||||
|
|
||||||
// completion
|
// completion
|
||||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
|
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
|
||||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
|
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
|
||||||
|
207
api/openai/files.go
Normal file
207
api/openai/files.go
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var uploadedFiles []File
|
||||||
|
|
||||||
|
// File represents the structure of a file object from the OpenAI API.
|
||||||
|
type File struct {
|
||||||
|
ID string `json:"id"` // Unique identifier for the file
|
||||||
|
Object string `json:"object"` // Type of the object (e.g., "file")
|
||||||
|
Bytes int `json:"bytes"` // Size of the file in bytes
|
||||||
|
CreatedAt time.Time `json:"created_at"` // The time at which the file was created
|
||||||
|
Filename string `json:"filename"` // The name of the file
|
||||||
|
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveUploadConfig(uploadDir string) {
|
||||||
|
file, err := json.MarshalIndent(uploadedFiles, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(filepath.Join(uploadDir, "uploadedFiles.json"), file, 0644)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to save uploadedFiles to file: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadUploadConfig(uploadPath string) {
|
||||||
|
file, err := os.ReadFile(filepath.Join(uploadPath, "uploadedFiles.json"))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to read file: %s", err)
|
||||||
|
} else {
|
||||||
|
err = json.Unmarshal(file, &uploadedFiles)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
||||||
|
func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the file size
|
||||||
|
if file.Size > int64(o.UploadLimitMB*1024*1024) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB))
|
||||||
|
}
|
||||||
|
|
||||||
|
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
|
||||||
|
if purpose == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the filename to prevent directory traversal
|
||||||
|
filename := utils.SanitizeFileName(file.Filename)
|
||||||
|
|
||||||
|
savePath := filepath.Join(o.UploadDir, filename)
|
||||||
|
|
||||||
|
// Check if file already exists
|
||||||
|
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.SaveFile(file, savePath)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
f := File{
|
||||||
|
ID: fmt.Sprintf("file-%d", time.Now().Unix()),
|
||||||
|
Object: "file",
|
||||||
|
Bytes: int(file.Size),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Filename: file.Filename,
|
||||||
|
Purpose: purpose,
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadedFiles = append(uploadedFiles, f)
|
||||||
|
saveUploadConfig(o.UploadDir)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
||||||
|
func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
type ListFiles struct {
|
||||||
|
Data []File
|
||||||
|
Object string
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
var listFiles ListFiles
|
||||||
|
|
||||||
|
purpose := c.Query("purpose")
|
||||||
|
if purpose == "" {
|
||||||
|
listFiles.Data = uploadedFiles
|
||||||
|
} else {
|
||||||
|
for _, f := range uploadedFiles {
|
||||||
|
if purpose == f.Purpose {
|
||||||
|
listFiles.Data = append(listFiles.Data, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
listFiles.Object = "list"
|
||||||
|
return c.Status(fiber.StatusOK).JSON(listFiles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFileFromRequest(c *fiber.Ctx) (*File, error) {
|
||||||
|
id := c.Params("file_id")
|
||||||
|
if id == "" {
|
||||||
|
return nil, fmt.Errorf("file_id parameter is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range uploadedFiles {
|
||||||
|
if id == f.ID {
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to find file id %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
|
||||||
|
func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
|
||||||
|
func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
type DeleteStatus struct {
|
||||||
|
Id string
|
||||||
|
Object string
|
||||||
|
Deleted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Remove(filepath.Join(o.UploadDir, file.Filename))
|
||||||
|
if err != nil {
|
||||||
|
// If the file doesn't exist then we should just continue to remove it
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove upload from list
|
||||||
|
for i, f := range uploadedFiles {
|
||||||
|
if f.ID == file.ID {
|
||||||
|
uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
saveUploadConfig(o.UploadDir)
|
||||||
|
return c.JSON(DeleteStatus{
|
||||||
|
Id: file.ID,
|
||||||
|
Object: "file",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||||
|
func GetFilesContentsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename))
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Send(fileContents)
|
||||||
|
}
|
||||||
|
}
|
286
api/openai/files_test.go
Normal file
286
api/openai/files_test.go
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
utils2 "github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ListFiles struct {
|
||||||
|
Data []File
|
||||||
|
Object string
|
||||||
|
}
|
||||||
|
|
||||||
|
func startUpApp() (app *fiber.App, option *options.Option, loader *config.ConfigLoader) {
|
||||||
|
// Preparing the mocked objects
|
||||||
|
loader = &config.ConfigLoader{}
|
||||||
|
|
||||||
|
option = &options.Option{
|
||||||
|
UploadLimitMB: 10,
|
||||||
|
UploadDir: "test_dir",
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.RemoveAll(option.UploadDir)
|
||||||
|
|
||||||
|
app = fiber.New(fiber.Config{
|
||||||
|
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a Test Server
|
||||||
|
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||||
|
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||||
|
// Preparing the mocked objects
|
||||||
|
loader := &config.ConfigLoader{}
|
||||||
|
|
||||||
|
option := &options.Option{
|
||||||
|
UploadLimitMB: 10,
|
||||||
|
UploadDir: "test_dir",
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.RemoveAll(option.UploadDir)
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a Test Server
|
||||||
|
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||||
|
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||||
|
|
||||||
|
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
|
||||||
|
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
|
||||||
|
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
|
||||||
|
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||||
|
fmt.Println(f1)
|
||||||
|
fmt.Printf("ERror: %v", err)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "File already exists")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
|
||||||
|
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
// Check if file exists in the disk
|
||||||
|
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt"))
|
||||||
|
_, err := os.Stat(filePath)
|
||||||
|
|
||||||
|
assert.False(t, os.IsNotExist(err))
|
||||||
|
assert.Equal(t, file.Bytes, 5242880)
|
||||||
|
assert.NotEmpty(t, file.CreatedAt)
|
||||||
|
assert.Equal(t, file.Filename, "test.txt")
|
||||||
|
assert.Equal(t, file.Purpose, "fine-tune")
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
if len(listFiles.Data) != len(uploadedFiles) {
|
||||||
|
t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
|
||||||
|
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
if len(listFiles.Data) != 1 {
|
||||||
|
t.Errorf("Expected 1 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
|
||||||
|
if len(listFiles.Data) != 0 {
|
||||||
|
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/files", nil)
|
||||||
|
resp, _ := app.Test(req)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var listFiles ListFiles
|
||||||
|
if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil {
|
||||||
|
t.Errorf("Failed to decode response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(listFiles.Data) != 0 {
|
||||||
|
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) {
|
||||||
|
var target string
|
||||||
|
if purpose != "" {
|
||||||
|
target = fmt.Sprintf("/files?purpose=%s", purpose)
|
||||||
|
} else {
|
||||||
|
target = "/files"
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest("GET", target, nil)
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||||
|
request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil)
|
||||||
|
return app.Test(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) (*http.Response, error) {
|
||||||
|
// Create a file that exceeds the limit
|
||||||
|
file := createTestFile(t, fileName, fileSize, o)
|
||||||
|
|
||||||
|
// Creating a new HTTP Request
|
||||||
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||||
|
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) File {
|
||||||
|
// Create a file that exceeds the limit
|
||||||
|
file := createTestFile(t, fileName, fileSize, o)
|
||||||
|
|
||||||
|
// Creating a new HTTP Request
|
||||||
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||||
|
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
f := responseToFile(t, resp)
|
||||||
|
|
||||||
|
id := f.ID
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_, err := CallFilesDeleteEndpoint(t, app, id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||||
|
target := fmt.Sprintf("/files/%s", fileId)
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to create multi-part file
|
||||||
|
func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) {
|
||||||
|
body := new(strings.Builder)
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
file, _ := os.Open(filePath)
|
||||||
|
defer file.Close()
|
||||||
|
part, _ := writer.CreateFormFile(tag, filepath.Base(filePath))
|
||||||
|
io.Copy(part, file)
|
||||||
|
|
||||||
|
if purpose != "" {
|
||||||
|
_ = writer.WriteField("purpose", purpose)
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Close()
|
||||||
|
return strings.NewReader(body.String()), writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to create test files
|
||||||
|
func createTestFile(t *testing.T, name string, sizeMB int, option *options.Option) *os.File {
|
||||||
|
err := os.MkdirAll(option.UploadDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
t.Fatalf("Error MKDIR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, _ := os.Create(name)
|
||||||
|
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(name)
|
||||||
|
os.RemoveAll(option.UploadDir)
|
||||||
|
})
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func bodyToString(resp *http.Response, t *testing.T) string {
|
||||||
|
return string(bodyToByteArray(resp, t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func bodyToByteArray(resp *http.Response, t *testing.T) []byte {
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return bodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseToFile(t *testing.T, resp *http.Response) File {
|
||||||
|
var file File
|
||||||
|
responseToString := bodyToString(resp, t)
|
||||||
|
|
||||||
|
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decode response: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseToListFile(t *testing.T, resp *http.Response) ListFiles {
|
||||||
|
var listFiles ListFiles
|
||||||
|
responseToString := bodyToString(resp, t)
|
||||||
|
|
||||||
|
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to decode response: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return listFiles
|
||||||
|
}
|
@ -21,6 +21,7 @@ type Option struct {
|
|||||||
Debug, DisableMessage bool
|
Debug, DisableMessage bool
|
||||||
ImageDir string
|
ImageDir string
|
||||||
AudioDir string
|
AudioDir string
|
||||||
|
UploadDir string
|
||||||
CORS bool
|
CORS bool
|
||||||
PreloadJSONModels string
|
PreloadJSONModels string
|
||||||
PreloadModelsFromPath string
|
PreloadModelsFromPath string
|
||||||
@ -249,6 +250,12 @@ func WithImageDir(imageDir string) AppOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithUploadDir(uploadDir string) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
o.UploadDir = uploadDir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithApiKeys(apiKeys []string) AppOption {
|
func WithApiKeys(apiKeys []string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.ApiKeys = apiKeys
|
o.ApiKeys = apiKeys
|
||||||
|
4
go.mod
4
go.mod
@ -8,7 +8,6 @@ require (
|
|||||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
|
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
|
||||||
github.com/go-audio/wav v1.1.0
|
github.com/go-audio/wav v1.1.0
|
||||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
|
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
|
||||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e
|
|
||||||
github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428
|
github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428
|
||||||
github.com/gofiber/fiber/v2 v2.50.0
|
github.com/gofiber/fiber/v2 v2.50.0
|
||||||
github.com/google/uuid v1.3.1
|
github.com/google/uuid v1.3.1
|
||||||
@ -28,6 +27,7 @@ require (
|
|||||||
github.com/rs/zerolog v1.31.0
|
github.com/rs/zerolog v1.31.0
|
||||||
github.com/sashabaranov/go-openai v1.16.0
|
github.com/sashabaranov/go-openai v1.16.0
|
||||||
github.com/schollz/progressbar/v3 v3.13.1
|
github.com/schollz/progressbar/v3 v3.13.1
|
||||||
|
github.com/stretchr/testify v1.8.4
|
||||||
github.com/tmc/langchaingo v0.0.0-20231019140956-c636b3da7701
|
github.com/tmc/langchaingo v0.0.0-20231019140956-c636b3da7701
|
||||||
github.com/urfave/cli/v2 v2.25.7
|
github.com/urfave/cli/v2 v2.25.7
|
||||||
github.com/valyala/fasthttp v1.50.0
|
github.com/valyala/fasthttp v1.50.0
|
||||||
@ -55,6 +55,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dlclark/regexp2 v1.8.1 // indirect
|
github.com/dlclark/regexp2 v1.8.1 // indirect
|
||||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect
|
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
@ -68,6 +69,7 @@ require (
|
|||||||
github.com/nwaples/rardecode v1.1.0 // indirect
|
github.com/nwaples/rardecode v1.1.0 // indirect
|
||||||
github.com/pierrec/lz4/v4 v4.1.2 // indirect
|
github.com/pierrec/lz4/v4 v4.1.2 // indirect
|
||||||
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
|
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
|
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
|
||||||
github.com/prometheus/common v0.44.0 // indirect
|
github.com/prometheus/common v0.44.0 // indirect
|
||||||
github.com/prometheus/procfs v0.11.1 // indirect
|
github.com/prometheus/procfs v0.11.1 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -43,8 +43,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
|||||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 h1:yXvc7QfGtoZ51tUW/YVjoTwAfh8HG88XU7UOrbNlz5Y=
|
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 h1:yXvc7QfGtoZ51tUW/YVjoTwAfh8HG88XU7UOrbNlz5Y=
|
||||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1/go.mod h1:fYjkCDRzC+oRLHSjQoajmYK6AmeJnmEanV27CClAcDc=
|
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1/go.mod h1:fYjkCDRzC+oRLHSjQoajmYK6AmeJnmEanV27CClAcDc=
|
||||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e h1:4reMY29i1eOZaRaSTMPNyXI7X8RMNxCTfDDBXYzrbr0=
|
|
||||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A=
|
|
||||||
github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428 h1:WYjkXL0Nw7dN2uDBMVCWQ8xLavrIhjF/DLczuh5L9TY=
|
github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428 h1:WYjkXL0Nw7dN2uDBMVCWQ8xLavrIhjF/DLczuh5L9TY=
|
||||||
github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428/go.mod h1:iub0ugfTnflE3rcIuqV2pQSo15nEw3GLW/utm5gyERo=
|
github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428/go.mod h1:iub0ugfTnflE3rcIuqV2pQSo15nEw3GLW/utm5gyERo=
|
||||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||||
|
7
main.go
7
main.go
@ -142,6 +142,12 @@ func main() {
|
|||||||
EnvVars: []string{"AUDIO_PATH"},
|
EnvVars: []string{"AUDIO_PATH"},
|
||||||
Value: "/tmp/generated/audio",
|
Value: "/tmp/generated/audio",
|
||||||
},
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "upload-path",
|
||||||
|
Usage: "Path to store uploads from files api",
|
||||||
|
EnvVars: []string{"UPLOAD_PATH"},
|
||||||
|
Value: "/tmp/localai/upload",
|
||||||
|
},
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
Name: "backend-assets-path",
|
Name: "backend-assets-path",
|
||||||
Usage: "Path used to extract libraries that are required by some of the backends in runtime.",
|
Usage: "Path used to extract libraries that are required by some of the backends in runtime.",
|
||||||
@ -227,6 +233,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
|||||||
options.WithDebug(ctx.Bool("debug")),
|
options.WithDebug(ctx.Bool("debug")),
|
||||||
options.WithImageDir(ctx.String("image-path")),
|
options.WithImageDir(ctx.String("image-path")),
|
||||||
options.WithAudioDir(ctx.String("audio-path")),
|
options.WithAudioDir(ctx.String("audio-path")),
|
||||||
|
options.WithUploadDir(ctx.String("upload-path")),
|
||||||
options.WithF16(ctx.Bool("f16")),
|
options.WithF16(ctx.Bool("f16")),
|
||||||
options.WithStringGalleries(ctx.String("galleries")),
|
options.WithStringGalleries(ctx.String("galleries")),
|
||||||
options.WithModelLibraryURL(ctx.String("remote-library")),
|
options.WithModelLibraryURL(ctx.String("remote-library")),
|
||||||
|
@ -3,6 +3,7 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func inTrustedRoot(path string, trustedRoot string) error {
|
func inTrustedRoot(path string, trustedRoot string) error {
|
||||||
@ -20,3 +21,14 @@ func VerifyPath(path, basePath string) error {
|
|||||||
c := filepath.Clean(filepath.Join(basePath, path))
|
c := filepath.Clean(filepath.Join(basePath, path))
|
||||||
return inTrustedRoot(c, filepath.Clean(basePath))
|
return inTrustedRoot(c, filepath.Clean(basePath))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeFileName sanitizes the given filename
|
||||||
|
func SanitizeFileName(fileName string) string {
|
||||||
|
// filepath.Clean to clean the path
|
||||||
|
cleanName := filepath.Clean(fileName)
|
||||||
|
// filepath.Base to ensure we only get the final element, not any directory path
|
||||||
|
baseName := filepath.Base(cleanName)
|
||||||
|
// Replace any remaining tricky characters that might have survived cleaning
|
||||||
|
safeName := strings.ReplaceAll(baseName, "..", "")
|
||||||
|
return safeName
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user