gRPC client stubs

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-11-18 19:12:27 +01:00
parent 064a2a4f11
commit 41366efa66
4 changed files with 41 additions and 6 deletions

View File

@ -159,7 +159,7 @@ message Reply {
bytes message = 1;
int32 tokens = 2;
int32 prompt_tokens = 3;
string audio_output = 4;
bytes audio = 5;
}
message ModelOptions {

View File

@ -120,6 +120,8 @@ var sessionLock sync.Mutex
// TODO: implement interface as we start to define usages
type Model interface {
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error)
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
}
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
@ -800,7 +802,17 @@ func processAudioResponse(session *Session, audioData []byte) (string, []byte, *
// 4. Convert the response text to speech (audio)
//
// Placeholder implementation:
// TODO: use session.ModelInterface...
// TODO: template eventual messages, like chat.go
reply, err := session.ModelInterface.Predict(context.Background(), &proto.PredictOptions{
Prompt: "What's the weather in New York?",
})
if err != nil {
return "", nil, nil, err
}
generatedAudio := reply.Audio
transcribedText := "What's the weather in New York?"
var functionCall *FunctionCall
@ -819,9 +831,6 @@ func processAudioResponse(session *Session, audioData []byte) (string, []byte, *
// Generate a response
generatedText := "This is a response to your speech input."
generatedAudio := []byte{} // Generate audio bytes from the generatedText
// TODO: Implement actual transcription and TTS
return generatedText, generatedAudio, nil, nil
}

View File

@ -13,6 +13,11 @@ import (
"google.golang.org/grpc"
)
var (
_ Model = new(wrappedModel)
_ Model = new(anyToAnyModel)
)
// wrappedModel represent a model which does not support Any-to-Any operations
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
@ -47,6 +52,27 @@ func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...g
return m.VADClient.VAD(ctx, in)
}
func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
// sound.BufferAsWAV(audioData, "audio.wav")
return m.LLMClient.Predict(ctx, in)
}
func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
return m.LLMClient.PredictStream(ctx, in, f)
}
func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
return m.LLMClient.Predict(ctx, in)
}
func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
return m.LLMClient.PredictStream(ctx, in, f)
}
// returns and loads either a wrapped model or a model that support audio-to-audio
func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {

View File

@ -35,9 +35,9 @@ type Backend interface {
IsBusy() bool
HealthCheck(ctx context.Context) (bool, error)
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)