syntax = "proto3";

option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto";
option java_multiple_files = true;
option java_package = "io.skynet.localai.backend";
option java_outer_classname = "LocalAIBackend";

package backend;

service Backend {
  rpc Health(HealthMessage) returns (Reply) {}
  rpc Predict(PredictOptions) returns (Reply) {}
  rpc LoadModel(ModelOptions) returns (Result) {}
  rpc PredictStream(PredictOptions) returns (stream Reply) {}
  rpc Embedding(PredictOptions) returns (EmbeddingResult) {}
  rpc GenerateImage(GenerateImageRequest) returns (Result) {}
  rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
  rpc TTS(TTSRequest) returns (Result) {}
  rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
  rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
  rpc Status(HealthMessage) returns (StatusResponse) {}

  rpc StoresSet(StoresSetOptions) returns (Result) {}
  rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
  rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
  rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}

  rpc Rerank(RerankRequest) returns (RerankResult) {}

  rpc GetMetrics(MetricsRequest) returns (MetricsResponse);

  rpc VAD(VADRequest) returns (VADResponse) {}
}

// Define the empty request
message MetricsRequest {}

message MetricsResponse {
  int32 slot_id = 1;
  string prompt_json_for_slot = 2;  // Stores the prompt as a JSON string.
  float tokens_per_second = 3;
  int32 tokens_generated = 4;
  int32 prompt_tokens_processed = 5;
}

message RerankRequest {
  string query = 1;
  repeated string documents = 2;
  int32 top_n = 3;
}

message RerankResult {
  Usage usage = 1;
  repeated DocumentResult results = 2;
}

message Usage {
  int32 total_tokens = 1;
  int32 prompt_tokens = 2;
}

message DocumentResult {
  int32 index = 1;
  string text = 2;
  float relevance_score = 3;
}

message StoresKey {
  repeated float Floats = 1;
}

message StoresValue {
  bytes Bytes = 1;
}

message StoresSetOptions {
  repeated StoresKey Keys = 1;
  repeated StoresValue Values = 2;
}

message StoresDeleteOptions {
  repeated StoresKey Keys = 1;
}

message StoresGetOptions {
  repeated StoresKey Keys = 1;
}

message StoresGetResult {
  repeated StoresKey Keys = 1;
  repeated StoresValue Values = 2;
}

message StoresFindOptions {
  StoresKey Key = 1;
  int32 TopK = 2;
}

message StoresFindResult {
  repeated StoresKey Keys = 1;
  repeated StoresValue Values = 2;
  repeated float Similarities = 3;
}

message HealthMessage {}

// The request message containing the user's name.
message PredictOptions {
  string Prompt = 1;
  int32 Seed = 2;
  int32 Threads = 3;
  int32 Tokens = 4;
  int32 TopK = 5;
  int32 Repeat = 6;
  int32 Batch = 7;
  int32 NKeep = 8;
  float Temperature = 9;
  float Penalty = 10;
  bool F16KV = 11;
  bool DebugMode = 12;
  repeated string StopPrompts = 13;
  bool IgnoreEOS = 14;
  float TailFreeSamplingZ = 15;
  float TypicalP = 16;
  float FrequencyPenalty = 17;
  float PresencePenalty = 18;
  int32 Mirostat = 19;
  float MirostatETA = 20;
  float MirostatTAU = 21;
  bool PenalizeNL = 22;
  string LogitBias = 23;
  bool MLock = 25;
  bool MMap = 26;
  bool PromptCacheAll = 27;
  bool PromptCacheRO = 28;
  string Grammar = 29;
  string MainGPU = 30;
  string TensorSplit = 31;
  float TopP = 32;
  string PromptCachePath = 33;
  bool Debug = 34;
  repeated int32 EmbeddingTokens = 35;
  string Embeddings = 36;
  float RopeFreqBase = 37;
  float RopeFreqScale = 38;
  float NegativePromptScale = 39;
  string NegativePrompt = 40;
  int32 NDraft = 41;
  repeated string Images = 42;
  bool UseTokenizerTemplate = 43;
  repeated Message Messages = 44;
  repeated string Videos = 45;
  repeated string Audios = 46;
  string CorrelationId = 47;
}

// The response message containing the result
message Reply {
  bytes message = 1;
  int32 tokens = 2;
  int32 prompt_tokens = 3;
}

message ModelOptions {
  string Model = 1;
  int32 ContextSize = 2;
  int32 Seed = 3;
  int32 NBatch = 4;
  bool F16Memory = 5;
  bool MLock = 6;
  bool MMap = 7;
  bool VocabOnly = 8;
  bool LowVRAM = 9;
  bool Embeddings = 10;
  bool NUMA = 11;
  int32 NGPULayers = 12;
  string MainGPU = 13;
  string TensorSplit = 14;
  int32 Threads = 15;
  string LibrarySearchPath = 16;
  float RopeFreqBase = 17;
  float RopeFreqScale = 18;
  float RMSNormEps = 19;
  int32 NGQA = 20;
  string ModelFile = 21;

  // AutoGPTQ
  string Device = 22;
  bool UseTriton = 23;
  string ModelBaseName = 24;
  bool UseFastTokenizer = 25;

  // Diffusers
  string PipelineType = 26;
  string SchedulerType = 27;
  bool CUDA = 28;
  float CFGScale = 29;
  bool IMG2IMG = 30;
  string CLIPModel = 31;
  string CLIPSubfolder = 32;
  int32 CLIPSkip = 33;
  string ControlNet = 48;

  string Tokenizer = 34;

  // LLM (llama.cpp)
  string LoraBase = 35;
  string LoraAdapter = 36;
  float LoraScale = 42;

  bool NoMulMatQ = 37;
  string DraftModel = 39;

  string AudioPath = 38;

  // vllm
  string Quantization = 40;
  float  GPUMemoryUtilization = 50;
  bool   TrustRemoteCode = 51;
  bool   EnforceEager = 52;
  int32  SwapSpace = 53;
  int32  MaxModelLen = 54;
  int32  TensorParallelSize = 55;
  string LoadFormat = 58;

  string MMProj = 41;

  string RopeScaling = 43;
  float YarnExtFactor = 44;
  float YarnAttnFactor = 45;
  float YarnBetaFast = 46;
  float YarnBetaSlow = 47;

  string Type = 49;

  bool FlashAttention = 56;
  bool NoKVOffload = 57;

  string ModelPath = 59;

  repeated string LoraAdapters = 60;
  repeated float LoraScales = 61;

  repeated string Options = 62;

  string CacheTypeKey = 63;
  string CacheTypeValue = 64;
}

message Result {
  string message = 1;
  bool success = 2;
}

message EmbeddingResult {
  repeated float embeddings = 1;
}

message TranscriptRequest {
  string dst = 2;
  string language = 3;
  uint32 threads = 4;
  bool translate = 5;
}

message TranscriptResult {
  repeated TranscriptSegment segments = 1;
  string text = 2;
}

message TranscriptSegment {
  int32 id = 1;
  int64 start = 2;
  int64 end = 3;
  string text = 4;
  repeated int32 tokens = 5;
}

message GenerateImageRequest {
  int32 height = 1;
  int32 width = 2;
  int32 mode = 3;
  int32 step = 4;
  int32 seed = 5;
  string positive_prompt = 6;
  string negative_prompt = 7;
  string dst = 8;
  string src = 9;

  // Diffusers
  string EnableParameters = 10;
  int32 CLIPSkip = 11;
}

message TTSRequest {
  string text = 1;
  string model = 2;
  string dst = 3;
  string voice = 4;
  optional string language = 5;
}

message VADRequest {
  repeated float audio = 1;
}

message VADSegment {
  float start = 1;
  float end = 2;
}

message VADResponse {
  repeated VADSegment segments = 1;
}

message SoundGenerationRequest {
  string text = 1;
  string model = 2;
  string dst = 3;
  optional float duration = 4;
  optional float temperature = 5;
  optional bool sample = 6;
  optional string src = 7;
  optional int32 src_divisor = 8;
}

message TokenizationResponse {
  int32 length = 1;
  repeated int32 tokens = 2;
}

message MemoryUsageData {
  uint64 total = 1;
  map<string, uint64> breakdown = 2;
}

message StatusResponse {
  enum State {
    UNINITIALIZED = 0;
    BUSY = 1;
    READY = 2;
    ERROR = -1;
  }
  State state = 1;
  MemoryUsageData memory = 2;
}

message Message {
  string role = 1;
  string content = 2;
}