feat: tokenization endpoint (#3710)

endpoint to access the tokenizer

Signed-off-by: shraddhazpy <shraddha@shraddhafive.in>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Co-authored-by: Dave <dave@gray101.com>
This commit is contained in:
Shraddha 2024-10-02 12:26:18 +05:30 committed by GitHub
parent 0965c6cd68
commit 5488fc3bc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 121 additions and 0 deletions

50
core/backend/tokenize.go Normal file
View File

@ -0,0 +1,50 @@
package backend
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc"
model "github.com/mudler/LocalAI/pkg/model"
)
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
modelFile := backendConfig.Model
grpcOpts := GRPCModelOpts(backendConfig)
var inferenceModel grpc.Backend
var err error
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
})
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return schema.TokenizeResponse{}, err
}
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
predictOptions.Prompt = s
// tokenize the string
resp, err := inferenceModel.TokenizeString(appConfig.Context, predictOptions)
if err != nil {
return schema.TokenizeResponse{}, err
}
return schema.TokenizeResponse{
Tokens: resp.Tokens,
}, nil
}

View File

@ -0,0 +1,58 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// TokenizeEndpoint exposes a REST API to tokenize the content
// @Summary Tokenize the input.
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TokenizeRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil {
return err
}
c.JSON(tokenResponse)
return nil
}
}

View File

@ -63,4 +63,7 @@ func RegisterLocalAIRoutes(app *fiber.App,
app.Get("/system", localai.SystemInformations(ml, appConfig))
// misc
app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
}

10
core/schema/tokenize.go Normal file
View File

@ -0,0 +1,10 @@
package schema
type TokenizeRequest struct {
Content string `json:"content"`
Model string `json:"model"`
}
type TokenizeResponse struct {
Tokens []int32 `json:"tokens"`
}