mirror of
https://github.com/mudler/LocalAI.git
synced 2025-02-09 20:31:09 +00:00
Improve audio detection
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
33a352825f
commit
120208c396
@ -13,7 +13,6 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
@ -138,6 +137,8 @@ func registerRealtime(application *application.Application) func(c *websocket.Co
|
||||
model = "gpt-4o"
|
||||
}
|
||||
|
||||
log.Info().Msgf("New session with model: %s", model)
|
||||
|
||||
sessionID := generateSessionID()
|
||||
session := &Session{
|
||||
ID: sessionID,
|
||||
@ -487,9 +488,16 @@ func updateSession(session *Session, update *Session, cl *config.BackendConfigLo
|
||||
}
|
||||
|
||||
const (
|
||||
minMicVolume = 450
|
||||
sendToVADDelay = time.Second
|
||||
maxWhisperSegmentDuration = time.Second * 15
|
||||
minMicVolume = 450
|
||||
sendToVADDelay = time.Second
|
||||
)
|
||||
|
||||
type VADState int
|
||||
|
||||
const (
|
||||
StateSilence VADState = iota
|
||||
StateSpeaking
|
||||
StateTrailingSilence
|
||||
)
|
||||
|
||||
// handle VAD (Voice Activity Detection)
|
||||
@ -503,7 +511,8 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
cancel()
|
||||
}()
|
||||
|
||||
audioDetected := false
|
||||
vadState := VADState(StateSilence)
|
||||
segments := []*proto.VADSegment{}
|
||||
timeListening := time.Now()
|
||||
|
||||
// Implement VAD logic here
|
||||
@ -520,15 +529,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
|
||||
if len(session.InputAudioBuffer) > 0 {
|
||||
|
||||
if audioDetected && time.Since(timeListening) < maxWhisperSegmentDuration {
|
||||
log.Debug().Msgf("VAD detected speech, but still listening")
|
||||
// audioDetected = false
|
||||
// keep listening
|
||||
session.AudioBufferLock.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if audioDetected {
|
||||
if vadState == StateTrailingSilence {
|
||||
log.Debug().Msgf("VAD detected speech that we can process")
|
||||
|
||||
// Commit the audio buffer as a conversation item
|
||||
@ -561,7 +562,8 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
Item: item,
|
||||
})
|
||||
|
||||
audioDetected = false
|
||||
vadState = StateSilence
|
||||
segments = []*proto.VADSegment{}
|
||||
// Generate a response
|
||||
generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage)
|
||||
continue
|
||||
@ -570,7 +572,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
|
||||
|
||||
// Resample from 24kHz to 16kHz
|
||||
adata = sound.ResampleInt16(adata, 24000, 16000)
|
||||
// adata = sound.ResampleInt16(adata, 24000, 16000)
|
||||
|
||||
soundIntBuffer := &audio.IntBuffer{
|
||||
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
|
||||
@ -582,9 +584,20 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
session.AudioBufferLock.Unlock()
|
||||
continue
|
||||
} */
|
||||
|
||||
float32Data := soundIntBuffer.AsFloat32Buffer().Data
|
||||
|
||||
// TODO: testing wav decoding
|
||||
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
|
||||
// buf, err := dec.FullPCMBuffer()
|
||||
// if err != nil {
|
||||
// //log.Error().Msgf("failed to process audio: %s", err.Error())
|
||||
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
||||
// session.AudioBufferLock.Unlock()
|
||||
// continue
|
||||
// }
|
||||
|
||||
//float32Data = buf.AsFloat32Buffer().Data
|
||||
|
||||
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
|
||||
Audio: float32Data,
|
||||
})
|
||||
@ -598,20 +611,34 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
if len(resp.Segments) == 0 {
|
||||
log.Debug().Msg("VAD detected no speech activity")
|
||||
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
|
||||
|
||||
if !audioDetected {
|
||||
if len(session.InputAudioBuffer) > 16000 {
|
||||
session.InputAudioBuffer = nil
|
||||
segments = []*proto.VADSegment{}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
|
||||
} else if (len(resp.Segments) != len(segments)) && vadState == StateSpeaking {
|
||||
// We have new segments, but we are still speaking
|
||||
// We need to wait for the trailing silence
|
||||
|
||||
session.AudioBufferLock.Unlock()
|
||||
continue
|
||||
}
|
||||
segments = resp.Segments
|
||||
|
||||
if !audioDetected {
|
||||
timeListening = time.Now()
|
||||
} else if (len(resp.Segments) == len(segments)) && vadState == StateSpeaking {
|
||||
// We have the same number of segments, but we are still speaking
|
||||
// We need to check if we are in this state for long enough, update the timer
|
||||
|
||||
// Check if we have been listening for too long
|
||||
if time.Since(timeListening) > sendToVADDelay {
|
||||
vadState = StateTrailingSilence
|
||||
} else {
|
||||
|
||||
timeListening = timeListening.Add(time.Since(timeListening))
|
||||
}
|
||||
} else {
|
||||
log.Debug().Msg("VAD detected speech activity")
|
||||
vadState = StateSpeaking
|
||||
segments = resp.Segments
|
||||
}
|
||||
audioDetected = true
|
||||
|
||||
session.AudioBufferLock.Unlock()
|
||||
} else {
|
||||
@ -843,101 +870,104 @@ func processTextResponse(config *config.BackendConfig, session *Session, prompt
|
||||
// Replace this with actual model inference logic using session.Model and prompt
|
||||
// For example, the model might return a special token or JSON indicating a function call
|
||||
|
||||
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
|
||||
/*
|
||||
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if !shouldUseFn {
|
||||
// no function is called, just reply and use stop as finish reason
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
return
|
||||
}
|
||||
|
||||
textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
|
||||
s = functions.CleanupLLMResult(s, config.FunctionsConfig)
|
||||
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
|
||||
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
|
||||
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error handling question")
|
||||
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if !shouldUseFn {
|
||||
// no function is called, just reply and use stop as finish reason
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
return
|
||||
}
|
||||
*c = append(*c, schema.Choice{
|
||||
Message: &schema.Message{Role: "assistant", Content: &result}})
|
||||
default:
|
||||
toolChoice := schema.Choice{
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
toolChoice.FinishReason = "tool_calls"
|
||||
}
|
||||
textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
|
||||
s = functions.CleanupLLMResult(s, config.FunctionsConfig)
|
||||
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
|
||||
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
|
||||
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error handling question")
|
||||
return
|
||||
}
|
||||
*c = append(*c, schema.Choice{
|
||||
Message: &schema.Message{Role: "assistant", Content: &result}})
|
||||
default:
|
||||
toolChoice := schema.Choice{
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
|
||||
for _, ss := range results {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
if len(input.Tools) > 0 {
|
||||
// If we are using tools, we condense the function calls into
|
||||
// a single response choice with all the tools
|
||||
toolChoice.Message.Content = textContentToReturn
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
toolChoice.FinishReason = "tool_calls"
|
||||
}
|
||||
|
||||
for _, ss := range results {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
if len(input.Tools) > 0 {
|
||||
// If we are using tools, we condense the function calls into
|
||||
// a single response choice with all the tools
|
||||
toolChoice.Message.Content = textContentToReturn
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// otherwise we return more choices directly
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: "function_call",
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
)
|
||||
} else {
|
||||
// otherwise we return more choices directly
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: "function_call",
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
// we need to append our result if we are using tools
|
||||
*c = append(*c, toolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
// we need to append our result if we are using tools
|
||||
*c = append(*c, toolChoice)
|
||||
}
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "chat.completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "chat.completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
*/
|
||||
|
||||
// TODO: use session.ModelInterface...
|
||||
// Simulate a function call
|
||||
|
Loading…
x
Reference in New Issue
Block a user