Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-11-15 21:49:14 +01:00
parent d999991f39
commit 9849d2e823
4 changed files with 86 additions and 19 deletions

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
"time"
"github.com/go-audio/audio" "github.com/go-audio/audio"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
@ -187,7 +188,6 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
log.Error().Msgf("read: %s", err.Error()) log.Error().Msgf("read: %s", err.Error())
break break
} }
log.Printf("recv: %s", msg)
// Parse the incoming message // Parse the incoming message
var incomingMsg IncomingMessage var incomingMsg IncomingMessage
@ -199,6 +199,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
switch incomingMsg.Type { switch incomingMsg.Type {
case "session.update": case "session.update":
log.Printf("recv: %s", msg)
// Update session configurations // Update session configurations
var sessionUpdate Session var sessionUpdate Session
if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil { if err := json.Unmarshal(incomingMsg.Session, &sessionUpdate); err != nil {
@ -258,6 +260,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
session.AudioBufferLock.Unlock() session.AudioBufferLock.Unlock()
case "input_audio_buffer.commit": case "input_audio_buffer.commit":
log.Printf("recv: %s", msg)
// Commit the audio buffer to the conversation as a new item // Commit the audio buffer to the conversation as a new item
item := &Item{ item := &Item{
ID: generateItemID(), ID: generateItemID(),
@ -290,6 +294,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
}) })
case "conversation.item.create": case "conversation.item.create":
log.Printf("recv: %s", msg)
// Handle creating new conversation items // Handle creating new conversation items
var item Item var item Item
if err := json.Unmarshal(incomingMsg.Item, &item); err != nil { if err := json.Unmarshal(incomingMsg.Item, &item); err != nil {
@ -315,10 +321,14 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
}) })
case "conversation.item.delete": case "conversation.item.delete":
log.Printf("recv: %s", msg)
// Handle deleting conversation items // Handle deleting conversation items
// Implement deletion logic as needed // Implement deletion logic as needed
case "response.create": case "response.create":
log.Printf("recv: %s", msg)
// Handle generating a response // Handle generating a response
var responseCreate ResponseCreate var responseCreate ResponseCreate
if len(incomingMsg.Response) > 0 { if len(incomingMsg.Response) > 0 {
@ -342,6 +352,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
}() }()
case "conversation.item.update": case "conversation.item.update":
log.Printf("recv: %s", msg)
// Handle function_call_output from the client // Handle function_call_output from the client
var item Item var item Item
if err := json.Unmarshal(incomingMsg.Item, &item); err != nil { if err := json.Unmarshal(incomingMsg.Item, &item); err != nil {
@ -366,6 +378,8 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
}) })
case "response.cancel": case "response.cancel":
log.Printf("recv: %s", msg)
// Handle cancellation of ongoing responses // Handle cancellation of ongoing responses
// Implement cancellation logic as needed // Implement cancellation logic as needed
@ -443,12 +457,19 @@ func updateSession(session *Session, update *Session, cl *config.BackendConfigLo
return nil return nil
} }
const (
minMicVolume = 450
sendToVADDelay = time.Second
maxWhisperSegmentDuration = time.Second * 25
)
// Placeholder function to handle VAD (Voice Activity Detection) // Placeholder function to handle VAD (Voice Activity Detection)
// https://github.com/snakers4/silero-vad/tree/master/examples/go // https://github.com/snakers4/silero-vad/tree/master/examples/go
// XXX: use session.ModelInterface for VAD or hook directly VAD runtime here? // XXX: use session.ModelInterface for VAD or hook directly VAD runtime here?
func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) { func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn, done chan struct{}) {
vadContext, cancel := context.WithCancel(context.Background()) vadContext, cancel := context.WithCancel(context.Background())
//var startListening time.Time
go func() { go func() {
<-done <-done
@ -466,7 +487,7 @@ func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn,
default: default:
// Check if there's audio data to process // Check if there's audio data to process
session.AudioBufferLock.Lock() session.AudioBufferLock.Lock()
if len(session.InputAudioBuffer) > 0 { if len(session.InputAudioBuffer) > 16000 {
adata := sound.BytesToInt16sLE(session.InputAudioBuffer) adata := sound.BytesToInt16sLE(session.InputAudioBuffer)
@ -475,37 +496,77 @@ func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn,
} }
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata) soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)
/* if len(adata) < 16000 {
log.Debug().Msgf("audio length too small %d", len(session.InputAudioBuffer))
session.AudioBufferLock.Unlock()
continue
} */
float32Data := soundIntBuffer.AsFloat32Buffer().Data
resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{ resp, err := session.ModelInterface.VAD(vadContext, &proto.VADRequest{
Audio: soundIntBuffer.AsFloat32Buffer().Data, Audio: float32Data,
}) })
if err != nil { if err != nil {
log.Error().Msgf("failed to process audio: %s", err.Error()) log.Error().Msgf("failed to process audio: %s", err.Error())
sendError(c, "processing_error", "Failed to process audio", "", "") sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
session.AudioBufferLock.Unlock() session.AudioBufferLock.Unlock()
continue continue
} }
speechStart, speechEnd := float32(0), float32(0) speechStart, speechEnd := float32(0), float32(0)
/*
volume := sound.CalculateRMS16(adata)
if volume > minMicVolume {
startListening = time.Now()
}
if time.Since(startListening) < sendToVADDelay && time.Since(startListening) < maxWhisperSegmentDuration {
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
session.AudioBufferLock.Unlock()
log.Debug().Msg("speech is ongoing")
continue
}
*/
if len(resp.Segments) == 0 {
log.Debug().Msg("VAD detected no speech activity")
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
session.InputAudioBuffer = nil
log.Debug().Msgf("audio length(after) %d", len(session.InputAudioBuffer))
session.AudioBufferLock.Unlock()
continue
}
log.Debug().Msgf("VAD detected %d segments", len(resp.Segments))
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
speechStart = resp.Segments[0].Start
log.Debug().Msgf("speech starts at %0.2fs", speechStart)
for _, s := range resp.Segments { for _, s := range resp.Segments {
log.Debug().Msgf("speech starts at %0.2fs", s.Start)
speechStart = s.Start
if s.End > 0 { if s.End > 0 {
log.Debug().Msgf("speech ends at %0.2fs", s.End) log.Debug().Msgf("speech ends at %0.2fs", s.End)
speechEnd = s.End speechEnd = s.End
} else {
continue
} }
} }
if speechEnd == 0 && speechStart != 0 { if speechEnd == 0 {
log.Debug().Msgf("audio length %d", len(session.InputAudioBuffer))
session.AudioBufferLock.Unlock() session.AudioBufferLock.Unlock()
log.Debug().Msg("speech is ongoing") log.Debug().Msg("speech is ongoing, no end found ?")
continue continue
} }
// Handle when input is too long without a voice activity (reset the buffer) // Handle when input is too long without a voice activity (reset the buffer)
if speechStart == 0 && speechEnd == 0 { if speechStart == 0 && speechEnd == 0 {
log.Debug().Msg("VAD detected no speech activity") // log.Debug().Msg("VAD detected no speech activity")
session.InputAudioBuffer = nil session.InputAudioBuffer = nil
session.AudioBufferLock.Unlock() session.AudioBufferLock.Unlock()
continue continue

1
go.mod
View File

@ -111,6 +111,7 @@ require (
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect github.com/nikolalohinski/gonja/v2 v2.3.2 // indirect
github.com/pion/datachannel v1.5.10 // indirect github.com/pion/datachannel v1.5.10 // indirect
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e // indirect
github.com/pion/dtls/v2 v2.2.12 // indirect github.com/pion/dtls/v2 v2.2.12 // indirect
github.com/pion/ice/v2 v2.3.37 // indirect github.com/pion/ice/v2 v2.3.37 // indirect
github.com/pion/interceptor v0.1.37 // indirect github.com/pion/interceptor v0.1.37 // indirect

View File

@ -5,14 +5,6 @@ import (
"math" "math"
) )
func BytesToFloat32Array(aBytes []byte) []float32 {
aArr := make([]float32, 3)
for i := 0; i < 3; i++ {
aArr[i] = BytesFloat32(aBytes[i*4:])
}
return aArr
}
func BytesFloat32(bytes []byte) float32 { func BytesFloat32(bytes []byte) float32 {
bits := binary.LittleEndian.Uint32(bytes) bits := binary.LittleEndian.Uint32(bytes)
float := math.Float32frombits(bits) float := math.Float32frombits(bits)

View File

@ -1,5 +1,7 @@
package sound package sound
import "math"
/* /*
MIT License MIT License
@ -8,6 +10,17 @@ Copyright (c) 2024 Xbozon
*/ */
// calculateRMS16 calculates the root mean square of the audio buffer for int16 samples.
func CalculateRMS16(buffer []int16) float64 {
var sumSquares float64
for _, sample := range buffer {
val := float64(sample) // Convert int16 to float64 for calculation
sumSquares += val * val
}
meanSquares := sumSquares / float64(len(buffer))
return math.Sqrt(meanSquares)
}
func ResampleInt16(input []int16, inputRate, outputRate int) []int16 { func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
// Calculate the resampling ratio // Calculate the resampling ratio
ratio := float64(inputRate) / float64(outputRate) ratio := float64(inputRate) / float64(outputRate)