fix: untangle pkg/grpc and core/schema for Transcription (#3419)

untangle pkg/grpc and core/schema in Transcribe

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-09-02 09:48:53 -04:00 committed by GitHub
parent 1655411ccd
commit c2804c42fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 162 additions and 186 deletions

View File

@ -846,7 +846,7 @@ endif
backend-assets/grpc/whisper: sources/whisper.cpp sources/whisper.cpp/libwhisper.a backend-assets/grpc backend-assets/grpc/whisper: sources/whisper.cpp sources/whisper.cpp/libwhisper.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="$(CURDIR)/sources/whisper.cpp/include:$(CURDIR)/sources/whisper.cpp/ggml/include" LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \ CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="$(CURDIR)/sources/whisper.cpp/include:$(CURDIR)/sources/whisper.cpp/ggml/include" LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/whisper
ifneq ($(UPX),) ifneq ($(UPX),)
$(UPX) backend-assets/grpc/whisper $(UPX) backend-assets/grpc/whisper
endif endif

View File

@ -1,104 +0,0 @@
package main
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/core/schema"
)
func ffmpegCommand(args []string) (string, error) {
cmd := exec.Command("ffmpeg", args...) // Constrain this to ffmpeg to permit security scanner to see that the command is safe.
cmd.Env = os.Environ()
out, err := cmd.CombinedOutput()
return string(out), err
}
// AudioToWav converts audio to wav for transcribe.
// TODO: use https://github.com/mccoyst/ogg?
func audioToWav(src, dst string) error {
commandArgs := []string{"-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst}
out, err := ffmpegCommand(commandArgs)
if err != nil {
return fmt.Errorf("error: %w out: %s", err, out)
}
return nil
}
func Transcript(model whisper.Model, audiopath, language string, translate bool, threads uint) (schema.TranscriptionResult, error) {
res := schema.TranscriptionResult{}
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return res, err
}
defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav")
if err := audioToWav(audiopath, convertedPath); err != nil {
return res, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return res, err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return res, err
}
data := buf.AsFloat32Buffer().Data
// Process samples
context, err := model.NewContext()
if err != nil {
return res, err
}
context.SetThreads(threads)
if language != "" {
context.SetLanguage(language)
} else {
context.SetLanguage("auto")
}
if translate {
context.SetTranslate(true)
}
if err := context.Process(data, nil, nil); err != nil {
return res, err
}
for {
s, err := context.NextSegment()
if err != nil {
break
}
var tokens []int
for _, t := range s.Tokens {
tokens = append(tokens, t.Id)
}
segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
res.Segments = append(res.Segments, segment)
res.Text += s.Text
}
return res, nil
}

View File

@ -1,26 +0,0 @@
package main
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
type Whisper struct {
base.SingleThread
whisper whisper.Model
}
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
// Note: the Model here is a path to a directory containing the model files
w, err := whisper.New(opts.ModelFile)
sd.whisper = w
return err
}
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) {
return Transcript(sd.whisper, opts.Dst, opts.Language, opts.Translate, uint(opts.Threads))
}

View File

@ -0,0 +1,105 @@
package main
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"os"
"path/filepath"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
)
type Whisper struct {
base.SingleThread
whisper whisper.Model
}
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
// Note: the Model here is a path to a directory containing the model files
w, err := whisper.New(opts.ModelFile)
sd.whisper = w
return err
}
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err
}
defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav")
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
return pb.TranscriptResult{}, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return pb.TranscriptResult{}, err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return pb.TranscriptResult{}, err
}
data := buf.AsFloat32Buffer().Data
// Process samples
context, err := sd.whisper.NewContext()
if err != nil {
return pb.TranscriptResult{}, err
}
context.SetThreads(uint(opts.Threads))
if opts.Language != "" {
context.SetLanguage(opts.Language)
} else {
context.SetLanguage("auto")
}
if opts.Translate {
context.SetTranslate(true)
}
if err := context.Process(data, nil, nil); err != nil {
return pb.TranscriptResult{}, err
}
segments := []*pb.TranscriptSegment{}
text := ""
for {
s, err := context.NextSegment()
if err != nil {
break
}
var tokens []int32
for _, t := range s.Tokens {
tokens = append(tokens, int32(t.Id))
}
segment := &pb.TranscriptSegment{Id: int32(s.Num), Text: s.Text, Start: int64(s.Start), End: int64(s.End), Tokens: tokens}
segments = append(segments, segment)
text += s.Text
}
return pb.TranscriptResult{
Segments: segments,
Text: text,
}, nil
}

View File

@ -3,6 +3,7 @@ package backend
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
@ -30,10 +31,31 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
return nil, fmt.Errorf("could not load whisper model") return nil, fmt.Errorf("could not load whisper model")
} }
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ r, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio, Dst: audio,
Language: language, Language: language,
Translate: translate, Translate: translate,
Threads: uint32(*backendConfig.Threads), Threads: uint32(*backendConfig.Threads),
}) })
if err != nil {
return nil, err
}
tr := &schema.TranscriptionResult{
Text: r.Text,
}
for _, s := range r.Segments {
var tks []int
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tr.Segments = append(tr.Segments,
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
return tr, err
} }

View File

@ -3,7 +3,6 @@ package grpc
import ( import (
"context" "context"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto" pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -42,7 +41,7 @@ type Backend interface {
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
Status(ctx context.Context) (*pb.StatusResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error)

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto" pb "github.com/mudler/LocalAI/pkg/grpc/proto"
gopsutil "github.com/shirou/gopsutil/v3/process" gopsutil "github.com/shirou/gopsutil/v3/process"
) )
@ -53,8 +52,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented") return fmt.Errorf("unimplemented")
} }
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) { func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
return schema.TranscriptionResult{}, fmt.Errorf("unimplemented") return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
} }
func (llm *Base) TTS(*pb.TTSRequest) error { func (llm *Base) TTS(*pb.TTSRequest) error {

View File

@ -7,7 +7,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto" pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@ -228,7 +227,7 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ
return client.SoundGeneration(ctx, in, opts...) return client.SoundGeneration(ctx, in, opts...)
} }
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) {
if !c.parallel { if !c.parallel {
c.opMutex.Lock() c.opMutex.Lock()
defer c.opMutex.Unlock() defer c.opMutex.Unlock()
@ -243,27 +242,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
} }
defer conn.Close() defer conn.Close()
client := pb.NewBackendClient(conn) client := pb.NewBackendClient(conn)
res, err := client.AudioTranscription(ctx, in, opts...) return client.AudioTranscription(ctx, in, opts...)
if err != nil {
return nil, err
}
tresult := &schema.TranscriptionResult{}
for _, s := range res.Segments {
tks := []int{}
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tresult.Segments = append(tresult.Segments,
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
tresult.Text = res.Text
return tresult, err
} }
func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {

View File

@ -2,9 +2,7 @@ package grpc
import ( import (
"context" "context"
"time"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto" pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -57,28 +55,8 @@ func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerati
return e.s.SoundGeneration(ctx, in) return e.s.SoundGeneration(ctx, in)
} }
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) {
r, err := e.s.AudioTranscription(ctx, in) return e.s.AudioTranscription(ctx, in)
if err != nil {
return nil, err
}
tr := &schema.TranscriptionResult{}
for _, s := range r.Segments {
var tks []int
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tr.Segments = append(tr.Segments,
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
tr.Text = r.Text
return tr, err
} }
func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {

View File

@ -1,7 +1,6 @@
package grpc package grpc
import ( import (
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto" pb "github.com/mudler/LocalAI/pkg/grpc/proto"
) )
@ -15,7 +14,7 @@ type LLM interface {
Load(*pb.ModelOptions) error Load(*pb.ModelOptions) error
Embeddings(*pb.PredictOptions) ([]float32, error) Embeddings(*pb.PredictOptions) ([]float32, error)
GenerateImage(*pb.GenerateImageRequest) error GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
TTS(*pb.TTSRequest) error TTS(*pb.TTSRequest) error
SoundGeneration(*pb.SoundGenerationRequest) error SoundGeneration(*pb.SoundGenerationRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)

25
pkg/utils/ffmpeg.go Normal file
View File

@ -0,0 +1,25 @@
package utils
import (
"fmt"
"os"
"os/exec"
)
func ffmpegCommand(args []string) (string, error) {
cmd := exec.Command("ffmpeg", args...) // Constrain this to ffmpeg to permit security scanner to see that the command is safe.
cmd.Env = os.Environ()
out, err := cmd.CombinedOutput()
return string(out), err
}
// AudioToWav converts audio to wav for transcribe.
// TODO: use https://github.com/mccoyst/ogg?
func AudioToWav(src, dst string) error {
commandArgs := []string{"-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst}
out, err := ffmpegCommand(commandArgs)
if err != nil {
return fmt.Errorf("error: %w out: %s", err, out)
}
return nil
}