more robust approach
Some checks failed
Security Scan / tests (push) Has been cancelled

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-01-10 16:22:50 +01:00
parent 72b2883757
commit 4565b87e5c

View File

@ -1,14 +1,18 @@
package openai
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/go-audio/wav"
"github.com/go-audio/audio"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
@ -488,21 +492,8 @@ func updateSession(session *Session, update *Session, cl *config.BackendConfigLo
}
const (
minMicVolume = 450
sendToVADDelay = time.Second
)
type VADState int
const (
StateSilence VADState = iota
StateSpeaking
)
const (
// tune these thresholds to taste
SpeechFramesThreshold = 3 // must see X consecutive speech results to confirm "start"
SilenceFramesThreshold = 5 // must see X consecutive silence results to confirm "end"
sendToVADDelay = 2 * time.Second
silenceThreshold = 2 * time.Second
)
// handleVAD is a goroutine that listens for audio data from the client,
@ -534,14 +525,18 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
copy(allAudio, session.InputAudioBuffer)
session.AudioBufferLock.Unlock()
// 2) If there's no audio at all, just continue
if len(allAudio) == 0 {
// 2) If there's no audio at all, or just too small samples, just continue
if len(allAudio) == 0 || len(allAudio) < 32000 {
continue
}
// 3) Run VAD on the entire audio so far
segments, err := runVAD(vadContext, session, allAudio)
if err != nil {
if err.Error() == "unexpected speech end" {
log.Debug().Msg("VAD cancelled")
continue
}
log.Error().Msgf("failed to process audio: %s", err.Error())
sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "")
// handle or log error, continue
@ -550,7 +545,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
segCount := len(segments)
if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
if len(segments) == 0 && !speaking && time.Since(timeOfLastNewSeg) > silenceThreshold {
// no speech detected, and we haven't seen a new segment in > 1s
// clean up input
session.AudioBufferLock.Lock()
@ -569,8 +564,11 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
}
// 5) If speaking, but we haven't seen a new segment in > 1s => finalize
if speaking && time.Since(timeOfLastNewSeg) > 1*time.Second {
if speaking && time.Since(timeOfLastNewSeg) > sendToVADDelay {
log.Debug().Msgf("Detected end of speech segment")
session.AudioBufferLock.Lock()
session.InputAudioBuffer = nil
session.AudioBufferLock.Unlock()
// user has presumably stopped talking
commitUtterance(allAudio, cfg, evaluator, session, conv, c)
// reset state
@ -608,18 +606,38 @@ func commitUtterance(utt []byte, cfg *config.BackendConfig, evaluator *templates
Item: item,
})
// Optionally trigger the response generation
// save chunk to disk
f, err := os.CreateTemp("", "audio-*.wav")
if err != nil {
log.Error().Msgf("failed to create temp file: %s", err.Error())
return
}
defer f.Close()
//defer os.Remove(f.Name())
log.Debug().Msgf("Writing to %s\n", f.Name())
f.Write(utt)
f.Sync()
// trigger the response generation
generateResponse(cfg, evaluator, session, conv, ResponseCreate{}, c, websocket.TextMessage)
}
// runVAD is a helper that calls your model's VAD method, returning
// runVAD is a helper that calls the model's VAD method, returning
// true if it detects speech, false if it detects silence
func runVAD(ctx context.Context, session *Session, chunk []byte) ([]*proto.VADSegment, error) {
adata := sound.BytesToInt16sLE(chunk)
// Resample from 24kHz to 16kHz
// adata = sound.ResampleInt16(adata, 24000, 16000)
adata = sound.ResampleInt16(adata, 24000, 16000)
dec := wav.NewDecoder(bytes.NewReader(chunk))
dur, err := dec.Duration()
if err != nil {
fmt.Printf("failed to get duration: %s\n", err)
}
fmt.Printf("duration: %s\n", dur)
soundIntBuffer := &audio.IntBuffer{
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},