mirror of
https://github.com/mudler/LocalAI.git
synced 2025-02-04 18:22:16 +00:00
wip(vad)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
d999991f39
commit
9849d2e823
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-audio/audio"
|
"github.com/go-audio/audio"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
@ -187,7 +188,6 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||||||
log.Error().Msgf("read: %s", err.Error())
|
log.Error().Msgf("read: %s", err.Error())
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
log.Printf("recv: %s", msg)
|
|
||||||
|
|
||||||
// Parse the incoming message
|
// Parse the incoming message
|
||||||
var incomingMsg IncomingMessage
|
var incomingMsg IncomingMessage
|
||||||
@ -199,6 +199,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||||||
|
|
||||||
switch incomingMsg.Type {
|
switch incomingMsg.Type {
|
||||||
case "session.update":
|
case "session.update":
|
||||||
|
log.Printf("recv: %s", msg)
|
||||||
|
|
||||||
// Update session configurations
|
// Update session configurations
|
||||||
var sessionUpdate Session
|
var sessionUpdate Session
|
||||||
if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil {
|
if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil {
|
||||||
@ -258,6 +260,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||||||
session.AudioBufferLock.Unlock()
|
session.AudioBufferLock.Unlock()
|
||||||
|
|
||||||
case "input_audio_buffer.commit":
|
case "input_audio_buffer.commit":
|
||||||
|
log.Printf("recv: %s", msg)
|
||||||
|
|
||||||
// Commit the audio buffer to the conversation as a new item
|
// Commit the audio buffer to the conversation as a new item
|
||||||
item := &Item{
|
item := &Item{
|
||||||
ID: generateItemID(),
|
ID: generateItemID(),
|
||||||
@ -290,6 +294,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||||||
})
|
})
|
||||||
|
|
||||||
case "conversation.item.create":
|
case "conversation.item.create":
|
||||||
|
log.Printf("recv: %s", msg)
|
||||||
|
|
||||||
// Handle creating new conversation items
|
// Handle creating new conversation items
|
||||||
var item Item
|
var item Item
|
||||||
if err := json.Unmarshal(incomingMsg.Item, &item); err != nil {
|
if err := json.Unmarshal(incomingMsg.Item, &item); err != nil {
|
||||||
@ -315,10 +321,14 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||||||
})
|
})
|
||||||
|
|
||||||
case "conversation.item.delete":
|
case "conversation.item.delete":
|
||||||
|
log.Printf("recv: %s", msg)
|
||||||
|
|
||||||
// Handle deleting conversation items
|
// Handle deleting conversation items
|
||||||
// Implement deletion logic as needed
|
// Implement deletion logic as needed
|
||||||
|
|
||||||
case "response.create":
|
case "response.create":
|
||||||
|
log.Printf("recv: %s", msg)
|
||||||
|
|
||||||
// Handle generating a response
|
// Handle generating a response
|
||||||
var responseCreate ResponseCreate
|
var responseCreate ResponseCreate
|
||||||
if len(incomingMsg.Response) > 0 {
|
if len(incomingMsg.Response) > 0 {
|
||||||
@ -342,6 +352,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
case "conversation.item.update":
|
case "conversation.item.update":
|
||||||
|
log.Printf("recv: %s", msg)
|
||||||
|
|
||||||
// Handle function_call_output from the client
|
// Handle function_call_output from the client
|
||||||
var item Item
|
var item Item
|
||||||
if err := json.Unmarshal(incomingMsg.Item, &item); err != nil {
|
if err := json.Unmarshal(incomingMsg.Item, &item); err != nil {
|
||||||
@ -366,6 +378,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
|||||||
})
|
})
|
||||||
|
|
||||||
case "response.cancel":
|
case "response.cancel":
|
||||||
|
log.Printf("recv: %s", msg)
|
||||||
|
|
||||||
// Handle cancellation of ongoing responses
|
// Handle cancellation of ongoing responses
|
||||||
// Implement cancellation logic as needed
|
// Implement cancellation logic as needed
|
||||||
|
|
||||||
@ -443,12 +457,19 @@ func updateSession(session *Session, update *Session, cl *config.BackendConfigLo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minMicVolume = 450
|
||||||
|
sendToVADDelay = time.Second
|
||||||
|
maxWhisperSegmentDuration = time.Second * 25
|
||||||
|
)
|
||||||
|
|
||||||
// Placeholder function to handle VAD (Voice Activity Detection)
|
// Placeholder function to handle VAD (Voice Activity Detection)
|
||||||
// https://github.com/snakers4/silero-vad/tree/master/examples/go
|
// https://github.com/snakers4/silero-vad/tree/master/examples/go
|
||||||
// XXX: use session.ModelInterface for VAD or hook directly VAD runtime here?
|
// XXX: use session.ModelInterface for VAD or hook directly VAD runtime here?
|
||||||
func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) {
|
func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||||
|
|
||||||
vadContext, cancel := context.WithCancel(context.Background())
|
vadContext, cancel := context.WithCancel(context.Background())
|
||||||
|
//var startListening time.Time
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-done
|
<-done
|
||||||
@ -466,7 +487,7 @@ func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn,
|
|||||||
default:
|
default:
|
||||||
// Check if there's audio data to process
|
// Check if there's audio data to process
|
||||||
session.AudioBufferLock.Lock()
|
session.AudioBufferLock.Lock()
|
||||||
if len(session.InputAudioBuffer) > 0 {
|
if len(session.InputAudioBuffer) > 16000 {
|
||||||
|
|
||||||
adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
|
adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
|
||||||
|
|
||||||
@ -475,37 +496,77 @@ func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn,
|
|||||||
}
|
}
|
||||||
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
|
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
|
||||||
|
|
||||||
|
/* if len(adata) < 16000 {
|
||||||
|
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
|
||||||
|
session.AudioBufferLock.Unlock()
|
||||||
|
continue
|
||||||
|
} */
|
||||||
|
|
||||||
|
float32Data := soundIntBuffer.AsFloat32Buffer().Data
|
||||||
|
|
||||||
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
|
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
|
||||||
Audio: soundIntBuffer.AsFloat32Buffer().Data,
|
Audio: float32Data,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("failed to process audio: %s", err.Error())
|
log.Error().Msgf("failed to process audio: %s", err.Error())
|
||||||
sendError(c, "processing_error", "Failed to process audio", "", "")
|
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
|
||||||
session.AudioBufferLock.Unlock()
|
session.AudioBufferLock.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
speechStart, speechEnd := float32(0), float32(0)
|
speechStart, speechEnd := float32(0), float32(0)
|
||||||
|
|
||||||
|
/*
|
||||||
|
volume := sound.CalculateRMS16(adata)
|
||||||
|
if volume > minMicVolume {
|
||||||
|
startListening = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Since(startListening) < sendToVADDelay && time.Since(startListening) < maxWhisperSegmentDuration {
|
||||||
|
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
|
||||||
|
|
||||||
|
session.AudioBufferLock.Unlock()
|
||||||
|
log.Debug().Msg("speech is ongoing")
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
if len(resp.Segments) == 0 {
|
||||||
|
log.Debug().Msg("VAD detected no speech activity")
|
||||||
|
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
|
||||||
|
|
||||||
|
session.InputAudioBuffer = nil
|
||||||
|
log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
|
||||||
|
|
||||||
|
session.AudioBufferLock.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("VAD detected %d segments", len(resp.Segments))
|
||||||
|
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
|
||||||
|
|
||||||
|
speechStart = resp.Segments[0].Start
|
||||||
|
log.Debug().Msgf("speech starts at %0.2fs", speechStart)
|
||||||
|
|
||||||
for _, s := range resp.Segments {
|
for _, s := range resp.Segments {
|
||||||
log.Debug().Msgf("speech starts at %0.2fs", s.Start)
|
|
||||||
speechStart = s.Start
|
|
||||||
if s.End > 0 {
|
if s.End > 0 {
|
||||||
log.Debug().Msgf("speech ends at %0.2fs", s.End)
|
log.Debug().Msgf("speech ends at %0.2fs", s.End)
|
||||||
speechEnd = s.End
|
speechEnd = s.End
|
||||||
} else {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if speechEnd == 0 && speechStart != 0 {
|
if speechEnd == 0 {
|
||||||
|
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
|
||||||
|
|
||||||
session.AudioBufferLock.Unlock()
|
session.AudioBufferLock.Unlock()
|
||||||
log.Debug().Msg("speech is ongoing")
|
log.Debug().Msg("speech is ongoing, no end found ?")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle when input is too long without a voice activity (reset the buffer)
|
// Handle when input is too long without a voice activity (reset the buffer)
|
||||||
if speechStart == 0 && speechEnd == 0 {
|
if speechStart == 0 && speechEnd == 0 {
|
||||||
log.Debug().Msg("VAD detected no speech activity")
|
// log.Debug().Msg("VAD detected no speech activity")
|
||||||
session.InputAudioBuffer = nil
|
session.InputAudioBuffer = nil
|
||||||
session.AudioBufferLock.Unlock()
|
session.AudioBufferLock.Unlock()
|
||||||
continue
|
continue
|
||||||
|
1
go.mod
1
go.mod
@ -111,6 +111,7 @@ require (
|
|||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
|
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
|
||||||
github.com/pion/datachannel v1.5.10 // indirect
|
github.com/pion/datachannel v1.5.10 // indirect
|
||||||
|
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e // indirect
|
||||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||||
github.com/pion/ice/v2 v2.3.37 // indirect
|
github.com/pion/ice/v2 v2.3.37 // indirect
|
||||||
github.com/pion/interceptor v0.1.37 // indirect
|
github.com/pion/interceptor v0.1.37 // indirect
|
||||||
|
@ -5,14 +5,6 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BytesToFloat32Array(aBytes []byte) []float32 {
|
|
||||||
aArr := make([]float32, 3)
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
aArr[i] = BytesFloat32(aBytes[i*4:])
|
|
||||||
}
|
|
||||||
return aArr
|
|
||||||
}
|
|
||||||
|
|
||||||
func BytesFloat32(bytes []byte) float32 {
|
func BytesFloat32(bytes []byte) float32 {
|
||||||
bits := binary.LittleEndian.Uint32(bytes)
|
bits := binary.LittleEndian.Uint32(bytes)
|
||||||
float := math.Float32frombits(bits)
|
float := math.Float32frombits(bits)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package sound
|
package sound
|
||||||
|
|
||||||
|
import "math"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
||||||
MIT License
|
MIT License
|
||||||
@ -8,6 +10,17 @@ Copyright (c) 2024 Xbozon
|
|||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// calculateRMS16 calculates the root mean square of the audio buffer for int16 samples.
|
||||||
|
func CalculateRMS16(buffer []int16) float64 {
|
||||||
|
var sumSquares float64
|
||||||
|
for _, sample := range buffer {
|
||||||
|
val := float64(sample) // Convert int16 to float64 for calculation
|
||||||
|
sumSquares += val * val
|
||||||
|
}
|
||||||
|
meanSquares := sumSquares / float64(len(buffer))
|
||||||
|
return math.Sqrt(meanSquares)
|
||||||
|
}
|
||||||
|
|
||||||
func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
|
func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
|
||||||
// Calculate the resampling ratio
|
// Calculate the resampling ratio
|
||||||
ratio := float64(inputRate) / float64(outputRate)
|
ratio := float64(inputRate) / float64(outputRate)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user