mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-16 15:32:53 +00:00
WIP - improve start and end of speech detection
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
30e3c47598
commit
9a0982066f
@ -497,158 +497,165 @@ type VADState int
|
|||||||
const (
|
const (
|
||||||
StateSilence VADState = iota
|
StateSilence VADState = iota
|
||||||
StateSpeaking
|
StateSpeaking
|
||||||
StateTrailingSilence
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// handle VAD (Voice Activity Detection)
|
const (
|
||||||
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) {
|
// tune these thresholds to taste
|
||||||
|
SpeechFramesThreshold = 3 // must see X consecutive speech results to confirm "start"
|
||||||
|
SilenceFramesThreshold = 5 // must see X consecutive silence results to confirm "end"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleVAD is a goroutine that listens for audio data from the client,
|
||||||
|
// runs VAD on the audio data, and commits utterances to the conversation
|
||||||
|
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||||
vadContext, cancel := context.WithCancel(context.Background())
|
vadContext, cancel := context.WithCancel(context.Background())
|
||||||
//var startListening time.Time
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-done
|
<-done
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
vadState := VADState(StateSilence)
|
ticker := time.NewTicker(300 * time.Millisecond)
|
||||||
segments := []*proto.VADSegment{}
|
defer ticker.Stop()
|
||||||
timeListening := time.Now()
|
|
||||||
|
var (
|
||||||
|
lastSegmentCount int
|
||||||
|
timeOfLastNewSeg time.Time
|
||||||
|
speaking bool
|
||||||
|
)
|
||||||
|
|
||||||
// Implement VAD logic here
|
|
||||||
// For brevity, this is a placeholder
|
|
||||||
// When VAD detects end of speech, generate a response
|
|
||||||
// TODO: use session.ModelInterface to handle VAD and cut audio and detect when to process that
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
return
|
return
|
||||||
default:
|
case <-ticker.C:
|
||||||
// Check if there's audio data to process
|
// 1) Copy the entire buffer
|
||||||
session.AudioBufferLock.Lock()
|
session.AudioBufferLock.Lock()
|
||||||
|
allAudio := make([]byte, len(session.InputAudioBuffer))
|
||||||
|
copy(allAudio, session.InputAudioBuffer)
|
||||||
|
session.AudioBufferLock.Unlock()
|
||||||
|
|
||||||
if len(session.InputAudioBuffer) > 0 {
|
// 2) If there's no audio at all, just continue
|
||||||
|
if len(allAudio) == 0 {
|
||||||
if vadState == StateTrailingSilence {
|
continue
|
||||||
log.Debug().Msgf("VAD detected speech that we can process")
|
|
||||||
|
|
||||||
// Commit the audio buffer as a conversation item
|
|
||||||
item := &Item{
|
|
||||||
ID: generateItemID(),
|
|
||||||
Object: "realtime.item",
|
|
||||||
Type: "message",
|
|
||||||
Status: "completed",
|
|
||||||
Role: "user",
|
|
||||||
Content: []ConversationContent{
|
|
||||||
{
|
|
||||||
Type: "input_audio",
|
|
||||||
Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add item to conversation
|
|
||||||
conversation.Lock.Lock()
|
|
||||||
conversation.Items = append(conversation.Items, item)
|
|
||||||
conversation.Lock.Unlock()
|
|
||||||
|
|
||||||
// Reset InputAudioBuffer
|
|
||||||
session.InputAudioBuffer = nil
|
|
||||||
session.AudioBufferLock.Unlock()
|
|
||||||
|
|
||||||
// Send item.created event
|
|
||||||
sendEvent(c, OutgoingMessage{
|
|
||||||
Type: "conversation.item.created",
|
|
||||||
Item: item,
|
|
||||||
})
|
|
||||||
|
|
||||||
vadState = StateSilence
|
|
||||||
segments = []*proto.VADSegment{}
|
|
||||||
// Generate a response
|
|
||||||
generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
|
|
||||||
|
|
||||||
// Resample from 24kHz to 16kHz
|
|
||||||
// adata = sound.ResampleInt16(adata, 24000, 16000)
|
|
||||||
|
|
||||||
soundIntBuffer := &audio.IntBuffer{
|
|
||||||
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
|
|
||||||
}
|
|
||||||
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
|
|
||||||
|
|
||||||
/* if len(adata) < 16000 {
|
|
||||||
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
|
|
||||||
session.AudioBufferLock.Unlock()
|
|
||||||
continue
|
|
||||||
} */
|
|
||||||
float32Data := soundIntBuffer.AsFloat32Buffer().Data
|
|
||||||
|
|
||||||
// TODO: testing wav decoding
|
|
||||||
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
|
|
||||||
// buf, err := dec.FullPCMBuffer()
|
|
||||||
// if err != nil {
|
|
||||||
// //log.Error().Msgf("failed to process audio: %s", err.Error())
|
|
||||||
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
|
||||||
// session.AudioBufferLock.Unlock()
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
|
|
||||||
//float32Data = buf.AsFloat32Buffer().Data
|
|
||||||
|
|
||||||
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
|
|
||||||
Audio: float32Data,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("failed to process audio: %s", err.Error())
|
|
||||||
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
|
||||||
session.AudioBufferLock.Unlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Segments) == 0 {
|
|
||||||
log.Debug().Msg("VAD detected no speech activity")
|
|
||||||
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
|
|
||||||
if len(session.InputAudioBuffer) > 16000 {
|
|
||||||
session.InputAudioBuffer = nil
|
|
||||||
segments = []*proto.VADSegment{}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
|
|
||||||
} else if (len(resp.Segments) != len(segments)) && vadState == StateSpeaking {
|
|
||||||
// We have new segments, but we are still speaking
|
|
||||||
// We need to wait for the trailing silence
|
|
||||||
|
|
||||||
segments = resp.Segments
|
|
||||||
|
|
||||||
} else if (len(resp.Segments) == len(segments)) && vadState == StateSpeaking {
|
|
||||||
// We have the same number of segments, but we are still speaking
|
|
||||||
// We need to check if we are in this state for long enough, update the timer
|
|
||||||
|
|
||||||
// Check if we have been listening for too long
|
|
||||||
if time.Since(timeListening) > sendToVADDelay {
|
|
||||||
vadState = StateTrailingSilence
|
|
||||||
} else {
|
|
||||||
|
|
||||||
timeListening = timeListening.Add(time.Since(timeListening))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debug().Msg("VAD detected speech activity")
|
|
||||||
vadState = StateSpeaking
|
|
||||||
segments = resp.Segments
|
|
||||||
}
|
|
||||||
|
|
||||||
session.AudioBufferLock.Unlock()
|
|
||||||
} else {
|
|
||||||
session.AudioBufferLock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3) Run VAD on the entire audio so far
|
||||||
|
segments, err := runVAD(vadContext, session, allAudio)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("failed to process audio: %s", err.Error())
|
||||||
|
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
||||||
|
// handle or log error, continue
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
segCount := len(segments)
|
||||||
|
|
||||||
|
if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
|
||||||
|
// no speech detected, and we haven't seen a new segment in > 1s
|
||||||
|
// clean up input
|
||||||
|
session.AudioBufferLock.Lock()
|
||||||
|
session.InputAudioBuffer = nil
|
||||||
|
session.AudioBufferLock.Unlock()
|
||||||
|
log.Debug().Msgf("Detected silence for a while, clearing audio buffer")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) If we see more segments than before => "new speech"
|
||||||
|
if segCount > lastSegmentCount {
|
||||||
|
speaking = true
|
||||||
|
lastSegmentCount = segCount
|
||||||
|
timeOfLastNewSeg = time.Now()
|
||||||
|
log.Debug().Msgf("Detected new speech segment")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) If speaking, but we haven't seen a new segment in > 1s => finalize
|
||||||
|
if speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
|
||||||
|
log.Debug().Msgf("Detected end of speech segment")
|
||||||
|
// user has presumably stopped talking
|
||||||
|
commitUtterance(allAudio, cfg, evaluator, session, conv, c)
|
||||||
|
// reset state
|
||||||
|
speaking = false
|
||||||
|
lastSegmentCount = 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func commitUtterance(utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||||
|
if len(utt) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Commit logic: create item, broadcast item.created, etc.
|
||||||
|
item := &Item{
|
||||||
|
ID: generateItemID(),
|
||||||
|
Object: "realtime.item",
|
||||||
|
Type: "message",
|
||||||
|
Status: "completed",
|
||||||
|
Role: "user",
|
||||||
|
Content: []ConversationContent{
|
||||||
|
{
|
||||||
|
Type: "input_audio",
|
||||||
|
Audio: base64.StdEncoding.EncodeToString(utt),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conv.Lock.Lock()
|
||||||
|
conv.Items = append(conv.Items, item)
|
||||||
|
conv.Lock.Unlock()
|
||||||
|
|
||||||
|
sendEvent(c, OutgoingMessage{
|
||||||
|
Type: "conversation.item.created",
|
||||||
|
Item: item,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Optionally trigger the response generation
|
||||||
|
generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runVAD is a helper that calls your model's VAD method, returning
|
||||||
|
// true if it detects speech, false if it detects silence
|
||||||
|
func runVAD(ctx context.Context, session *Session, chunk []byte) ([]*proto.VADSegment, error) {
|
||||||
|
|
||||||
|
adata := sound.BytesToInt16sLE(chunk)
|
||||||
|
|
||||||
|
// Resample from 24kHz to 16kHz
|
||||||
|
// adata = sound.ResampleInt16(adata, 24000, 16000)
|
||||||
|
|
||||||
|
soundIntBuffer := &audio.IntBuffer{
|
||||||
|
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
|
||||||
|
}
|
||||||
|
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
|
||||||
|
|
||||||
|
/* if len(adata) < 16000 {
|
||||||
|
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
|
||||||
|
session.AudioBufferLock.Unlock()
|
||||||
|
continue
|
||||||
|
} */
|
||||||
|
float32Data := soundIntBuffer.AsFloat32Buffer().Data
|
||||||
|
|
||||||
|
resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{
|
||||||
|
Audio: float32Data,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: testing wav decoding
|
||||||
|
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
|
||||||
|
// buf, err := dec.FullPCMBuffer()
|
||||||
|
// if err != nil {
|
||||||
|
// //log.Error().Msgf("failed to process audio: %s", err.Error())
|
||||||
|
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
||||||
|
// session.AudioBufferLock.Unlock()
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
|
||||||
|
//float32Data = buf.AsFloat32Buffer().Data
|
||||||
|
|
||||||
|
// If resp.Segments is empty => no speech
|
||||||
|
return resp.Segments, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Function to generate a response based on the conversation
|
// Function to generate a response based on the conversation
|
||||||
func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user