mirror of
https://github.com/mudler/LocalAI.git
synced 2025-03-14 00:06:53 +00:00
WIP - improve start and end of speech detection
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
30e3c47598
commit
9a0982066f
@ -497,158 +497,165 @@ type VADState int
|
||||
const (
|
||||
StateSilence VADState = iota
|
||||
StateSpeaking
|
||||
StateTrailingSilence
|
||||
)
|
||||
|
||||
// handle VAD (Voice Activity Detection)
|
||||
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
const (
|
||||
// tune these thresholds to taste
|
||||
SpeechFramesThreshold = 3 // must see X consecutive speech results to confirm "start"
|
||||
SilenceFramesThreshold = 5 // must see X consecutive silence results to confirm "end"
|
||||
)
|
||||
|
||||
// handleVAD is a goroutine that listens for audio data from the client,
|
||||
// runs VAD on the audio data, and commits utterances to the conversation
|
||||
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
vadContext, cancel := context.WithCancel(context.Background())
|
||||
//var startListening time.Time
|
||||
|
||||
go func() {
|
||||
<-done
|
||||
cancel()
|
||||
}()
|
||||
|
||||
vadState := VADState(StateSilence)
|
||||
segments := []*proto.VADSegment{}
|
||||
timeListening := time.Now()
|
||||
ticker := time.NewTicker(300 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
var (
|
||||
lastSegmentCount int
|
||||
timeOfLastNewSeg time.Time
|
||||
speaking bool
|
||||
)
|
||||
|
||||
// Implement VAD logic here
|
||||
// For brevity, this is a placeholder
|
||||
// When VAD detects end of speech, generate a response
|
||||
// TODO: use session.ModelInterface to handle VAD and cut audio and detect when to process that
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
// Check if there's audio data to process
|
||||
case <-ticker.C:
|
||||
// 1) Copy the entire buffer
|
||||
session.AudioBufferLock.Lock()
|
||||
allAudio := make([]byte, len(session.InputAudioBuffer))
|
||||
copy(allAudio, session.InputAudioBuffer)
|
||||
session.AudioBufferLock.Unlock()
|
||||
|
||||
if len(session.InputAudioBuffer) > 0 {
|
||||
|
||||
if vadState == StateTrailingSilence {
|
||||
log.Debug().Msgf("VAD detected speech that we can process")
|
||||
|
||||
// Commit the audio buffer as a conversation item
|
||||
item := &Item{
|
||||
ID: generateItemID(),
|
||||
Object: "realtime.item",
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "user",
|
||||
Content: []ConversationContent{
|
||||
{
|
||||
Type: "input_audio",
|
||||
Audio: base64.StdEncoding.EncodeToString(session.InputAudioBuffer),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add item to conversation
|
||||
conversation.Lock.Lock()
|
||||
conversation.Items = append(conversation.Items, item)
|
||||
conversation.Lock.Unlock()
|
||||
|
||||
// Reset InputAudioBuffer
|
||||
session.InputAudioBuffer = nil
|
||||
session.AudioBufferLock.Unlock()
|
||||
|
||||
// Send item.created event
|
||||
sendEvent(c, OutgoingMessage{
|
||||
Type: "conversation.item.created",
|
||||
Item: item,
|
||||
})
|
||||
|
||||
vadState = StateSilence
|
||||
segments = []*proto.VADSegment{}
|
||||
// Generate a response
|
||||
generateResponse(cfg, evaluator, session, conversation, ResponseCreate{}, c, websocket.TextMessage)
|
||||
continue
|
||||
}
|
||||
|
||||
adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
|
||||
|
||||
// Resample from 24kHz to 16kHz
|
||||
// adata = sound.ResampleInt16(adata, 24000, 16000)
|
||||
|
||||
soundIntBuffer := &audio.IntBuffer{
|
||||
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
|
||||
}
|
||||
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
|
||||
|
||||
/* if len(adata) < 16000 {
|
||||
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
|
||||
session.AudioBufferLock.Unlock()
|
||||
continue
|
||||
} */
|
||||
float32Data := soundIntBuffer.AsFloat32Buffer().Data
|
||||
|
||||
// TODO: testing wav decoding
|
||||
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
|
||||
// buf, err := dec.FullPCMBuffer()
|
||||
// if err != nil {
|
||||
// //log.Error().Msgf("failed to process audio: %s", err.Error())
|
||||
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
||||
// session.AudioBufferLock.Unlock()
|
||||
// continue
|
||||
// }
|
||||
|
||||
//float32Data = buf.AsFloat32Buffer().Data
|
||||
|
||||
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
|
||||
Audio: float32Data,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Msgf("failed to process audio: %s", err.Error())
|
||||
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
||||
session.AudioBufferLock.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if len(resp.Segments) == 0 {
|
||||
log.Debug().Msg("VAD detected no speech activity")
|
||||
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
|
||||
if len(session.InputAudioBuffer) > 16000 {
|
||||
session.InputAudioBuffer = nil
|
||||
segments = []*proto.VADSegment{}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
|
||||
} else if (len(resp.Segments) != len(segments)) && vadState == StateSpeaking {
|
||||
// We have new segments, but we are still speaking
|
||||
// We need to wait for the trailing silence
|
||||
|
||||
segments = resp.Segments
|
||||
|
||||
} else if (len(resp.Segments) == len(segments)) && vadState == StateSpeaking {
|
||||
// We have the same number of segments, but we are still speaking
|
||||
// We need to check if we are in this state for long enough, update the timer
|
||||
|
||||
// Check if we have been listening for too long
|
||||
if time.Since(timeListening) > sendToVADDelay {
|
||||
vadState = StateTrailingSilence
|
||||
} else {
|
||||
|
||||
timeListening = timeListening.Add(time.Since(timeListening))
|
||||
}
|
||||
} else {
|
||||
log.Debug().Msg("VAD detected speech activity")
|
||||
vadState = StateSpeaking
|
||||
segments = resp.Segments
|
||||
}
|
||||
|
||||
session.AudioBufferLock.Unlock()
|
||||
} else {
|
||||
session.AudioBufferLock.Unlock()
|
||||
// 2) If there's no audio at all, just continue
|
||||
if len(allAudio) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 3) Run VAD on the entire audio so far
|
||||
segments, err := runVAD(vadContext, session, allAudio)
|
||||
if err != nil {
|
||||
log.Error().Msgf("failed to process audio: %s", err.Error())
|
||||
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
||||
// handle or log error, continue
|
||||
continue
|
||||
}
|
||||
|
||||
segCount := len(segments)
|
||||
|
||||
if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
|
||||
// no speech detected, and we haven't seen a new segment in > 1s
|
||||
// clean up input
|
||||
session.AudioBufferLock.Lock()
|
||||
session.InputAudioBuffer = nil
|
||||
session.AudioBufferLock.Unlock()
|
||||
log.Debug().Msgf("Detected silence for a while, clearing audio buffer")
|
||||
continue
|
||||
}
|
||||
|
||||
// 4) If we see more segments than before => "new speech"
|
||||
if segCount > lastSegmentCount {
|
||||
speaking = true
|
||||
lastSegmentCount = segCount
|
||||
timeOfLastNewSeg = time.Now()
|
||||
log.Debug().Msgf("Detected new speech segment")
|
||||
}
|
||||
|
||||
// 5) If speaking, but we haven't seen a new segment in > 1s => finalize
|
||||
if speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
|
||||
log.Debug().Msgf("Detected end of speech segment")
|
||||
// user has presumably stopped talking
|
||||
commitUtterance(allAudio, cfg, evaluator, session, conv, c)
|
||||
// reset state
|
||||
speaking = false
|
||||
lastSegmentCount = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func commitUtterance(utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
if len(utt) == 0 {
|
||||
return
|
||||
}
|
||||
// Commit logic: create item, broadcast item.created, etc.
|
||||
item := &Item{
|
||||
ID: generateItemID(),
|
||||
Object: "realtime.item",
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Role: "user",
|
||||
Content: []ConversationContent{
|
||||
{
|
||||
Type: "input_audio",
|
||||
Audio: base64.StdEncoding.EncodeToString(utt),
|
||||
},
|
||||
},
|
||||
}
|
||||
conv.Lock.Lock()
|
||||
conv.Items = append(conv.Items, item)
|
||||
conv.Lock.Unlock()
|
||||
|
||||
sendEvent(c, OutgoingMessage{
|
||||
Type: "conversation.item.created",
|
||||
Item: item,
|
||||
})
|
||||
|
||||
// Optionally trigger the response generation
|
||||
generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage)
|
||||
}
|
||||
|
||||
// runVAD is a helper that calls your model's VAD method, returning
|
||||
// true if it detects speech, false if it detects silence
|
||||
func runVAD(ctx context.Context, session *Session, chunk []byte) ([]*proto.VADSegment, error) {
|
||||
|
||||
adata := sound.BytesToInt16sLE(chunk)
|
||||
|
||||
// Resample from 24kHz to 16kHz
|
||||
// adata = sound.ResampleInt16(adata, 24000, 16000)
|
||||
|
||||
soundIntBuffer := &audio.IntBuffer{
|
||||
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
|
||||
}
|
||||
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
|
||||
|
||||
/* if len(adata) < 16000 {
|
||||
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
|
||||
session.AudioBufferLock.Unlock()
|
||||
continue
|
||||
} */
|
||||
float32Data := soundIntBuffer.AsFloat32Buffer().Data
|
||||
|
||||
resp, err := session.ModelInterface.VAD(ctx, &proto.VADRequest{
|
||||
Audio: float32Data,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: testing wav decoding
|
||||
// dec := wav.NewDecoder(bytes.NewReader(session.InputAudioBuffer))
|
||||
// buf, err := dec.FullPCMBuffer()
|
||||
// if err != nil {
|
||||
// //log.Error().Msgf("failed to process audio: %s", err.Error())
|
||||
// sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
||||
// session.AudioBufferLock.Unlock()
|
||||
// continue
|
||||
// }
|
||||
|
||||
//float32Data = buf.AsFloat32Buffer().Data
|
||||
|
||||
// If resp.Segments is empty => no speech
|
||||
return resp.Segments, nil
|
||||
}
|
||||
|
||||
// Function to generate a response based on the conversation
|
||||
func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user