From c2804c42fe43da1b5d16feb8683bbb3c401e840f Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 2 Sep 2024 09:48:53 -0400 Subject: [PATCH] fix: untangle pkg/grpc and core/schema for Transcription (#3419) untangle pkg/grpc and core/schema in Transcribe Signed-off-by: Dave Lee --- Makefile | 2 +- backend/go/transcribe/transcript.go | 104 ------------------- backend/go/transcribe/whisper.go | 26 ----- backend/go/transcribe/{ => whisper}/main.go | 0 backend/go/transcribe/whisper/whisper.go | 105 ++++++++++++++++++++ core/backend/transcript.go | 24 ++++- pkg/grpc/backend.go | 3 +- pkg/grpc/base/base.go | 5 +- pkg/grpc/client.go | 25 +---- pkg/grpc/embed.go | 26 +---- pkg/grpc/interface.go | 3 +- pkg/utils/ffmpeg.go | 25 +++++ 12 files changed, 162 insertions(+), 186 deletions(-) delete mode 100644 backend/go/transcribe/transcript.go delete mode 100644 backend/go/transcribe/whisper.go rename backend/go/transcribe/{ => whisper}/main.go (100%) create mode 100644 backend/go/transcribe/whisper/whisper.go create mode 100644 pkg/utils/ffmpeg.go diff --git a/Makefile b/Makefile index be80d875..a360fe88 100644 --- a/Makefile +++ b/Makefile @@ -846,7 +846,7 @@ endif backend-assets/grpc/whisper: sources/whisper.cpp sources/whisper.cpp/libwhisper.a backend-assets/grpc CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="$(CURDIR)/sources/whisper.cpp/include:$(CURDIR)/sources/whisper.cpp/ggml/include" LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \ - $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/whisper ifneq ($(UPX),) $(UPX) backend-assets/grpc/whisper endif diff --git a/backend/go/transcribe/transcript.go b/backend/go/transcribe/transcript.go deleted file mode 100644 index 6831167f..00000000 --- a/backend/go/transcribe/transcript.go +++ /dev/null @@ -1,104 +0,0 @@ -package main - -import ( - "fmt" - "os" - "os/exec" - "path/filepath" - - "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - "github.com/go-audio/wav" - "github.com/mudler/LocalAI/core/schema" -) - -func ffmpegCommand(args []string) (string, error) { - cmd := exec.Command("ffmpeg", args...) // Constrain this to ffmpeg to permit security scanner to see that the command is safe. - cmd.Env = os.Environ() - out, err := cmd.CombinedOutput() - return string(out), err -} - -// AudioToWav converts audio to wav for transcribe. -// TODO: use https://github.com/mccoyst/ogg? -func audioToWav(src, dst string) error { - commandArgs := []string{"-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} - out, err := ffmpegCommand(commandArgs) - if err != nil { - return fmt.Errorf("error: %w out: %s", err, out) - } - return nil -} - -func Transcript(model whisper.Model, audiopath, language string, translate bool, threads uint) (schema.TranscriptionResult, error) { - res := schema.TranscriptionResult{} - - dir, err := os.MkdirTemp("", "whisper") - if err != nil { - return res, err - } - defer os.RemoveAll(dir) - - convertedPath := filepath.Join(dir, "converted.wav") - - if err := audioToWav(audiopath, convertedPath); err != nil { - return res, err - } - - // Open samples - fh, err := os.Open(convertedPath) - if err != nil { - return res, err - } - defer fh.Close() - - // Read samples - d := wav.NewDecoder(fh) - buf, err := d.FullPCMBuffer() - if err != nil { - return res, err - } - - data := buf.AsFloat32Buffer().Data - - // Process samples - context, err := model.NewContext() - if err != nil { - return res, err - - } - - context.SetThreads(threads) - - if language != "" { - context.SetLanguage(language) - } else { - context.SetLanguage("auto") - } - - if translate { - context.SetTranslate(true) - } - - if err := context.Process(data, nil, nil); err != nil { - return res, err - } - - for { - s, err := context.NextSegment() - if err != nil { - break - } - - var tokens []int - for _, t := range s.Tokens { - tokens = append(tokens, t.Id) - } - - segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens} - res.Segments = append(res.Segments, segment) - - res.Text += s.Text - } - - return res, nil -} diff --git a/backend/go/transcribe/whisper.go b/backend/go/transcribe/whisper.go deleted file mode 100644 index 61ae98e9..00000000 --- a/backend/go/transcribe/whisper.go +++ /dev/null @@ -1,26 +0,0 @@ -package main - -// This is a wrapper to statisfy the GRPC service interface -// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) -import ( - "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/pkg/grpc/base" - pb "github.com/mudler/LocalAI/pkg/grpc/proto" -) - -type Whisper struct { - base.SingleThread - whisper whisper.Model -} - -func (sd *Whisper) Load(opts *pb.ModelOptions) error { - // Note: the Model here is a path to a directory containing the model files - w, err := whisper.New(opts.ModelFile) - sd.whisper = w - return err -} - -func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) { - return Transcript(sd.whisper, opts.Dst, opts.Language, opts.Translate, uint(opts.Threads)) -} diff --git a/backend/go/transcribe/main.go b/backend/go/transcribe/whisper/main.go similarity index 100% rename from backend/go/transcribe/main.go rename to backend/go/transcribe/whisper/main.go diff --git a/backend/go/transcribe/whisper/whisper.go b/backend/go/transcribe/whisper/whisper.go new file mode 100644 index 00000000..63416bb3 --- /dev/null +++ b/backend/go/transcribe/whisper/whisper.go @@ -0,0 +1,105 @@ +package main + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "os" + "path/filepath" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-audio/wav" + "github.com/mudler/LocalAI/pkg/grpc/base" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/utils" +) + +type Whisper struct { + base.SingleThread + whisper whisper.Model +} + +func (sd *Whisper) Load(opts *pb.ModelOptions) error { + // Note: the Model here is a path to a directory containing the model files + w, err := whisper.New(opts.ModelFile) + sd.whisper = w + return err +} + +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { + + dir, err := os.MkdirTemp("", "whisper") + if err != nil { + return pb.TranscriptResult{}, err + } + defer os.RemoveAll(dir) + + convertedPath := filepath.Join(dir, "converted.wav") + + if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil { + return pb.TranscriptResult{}, err + } + + // Open samples + fh, err := os.Open(convertedPath) + if err != nil { + return pb.TranscriptResult{}, err + } + defer fh.Close() + + // Read samples + d := wav.NewDecoder(fh) + buf, err := d.FullPCMBuffer() + if err != nil { + return pb.TranscriptResult{}, err + } + + data := buf.AsFloat32Buffer().Data + + // Process samples + context, err := sd.whisper.NewContext() + if err != nil { + return pb.TranscriptResult{}, err + + } + + context.SetThreads(uint(opts.Threads)) + + if opts.Language != "" { + context.SetLanguage(opts.Language) + } else { + context.SetLanguage("auto") + } + + if opts.Translate { + context.SetTranslate(true) + } + + if err := context.Process(data, nil, nil); err != nil { + return pb.TranscriptResult{}, err + } + + segments := []*pb.TranscriptSegment{} + text := "" + for { + s, err := context.NextSegment() + if err != nil { + break + } + + var tokens []int32 + for _, t := range s.Tokens { + tokens = append(tokens, int32(t.Id)) + } + + segment := &pb.TranscriptSegment{Id: int32(s.Num), Text: s.Text, Start: int64(s.Start), End: int64(s.End), Tokens: tokens} + segments = append(segments, segment) + + text += s.Text + } + + return pb.TranscriptResult{ + Segments: segments, + Text: text, + }, nil + +} diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 0980288f..ed3e24a5 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -3,6 +3,7 @@ package backend import ( "context" "fmt" + "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" @@ -30,10 +31,31 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL return nil, fmt.Errorf("could not load whisper model") } - return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ + r, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ Dst: audio, Language: language, Translate: translate, Threads: uint32(*backendConfig.Threads), }) + if err != nil { + return nil, err + } + tr := &schema.TranscriptionResult{ + Text: r.Text, + } + for _, s := range r.Segments { + var tks []int + for _, t := range s.Tokens { + tks = append(tks, int(t)) + } + tr.Segments = append(tr.Segments, + schema.Segment{ + Text: s.Text, + Id: int(s.Id), + Start: time.Duration(s.Start), + End: time.Duration(s.End), + Tokens: tks, + }) + } + return tr, err } diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 3821678c..85c9e5bc 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -3,7 +3,6 @@ package grpc import ( "context" - "github.com/mudler/LocalAI/core/schema" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" ) @@ -42,7 +41,7 @@ type Backend interface { GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) - AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) + AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 21dd1578..95dca561 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -6,7 +6,6 @@ import ( "fmt" "os" - "github.com/mudler/LocalAI/core/schema" pb "github.com/mudler/LocalAI/pkg/grpc/proto" gopsutil "github.com/shirou/gopsutil/v3/process" ) @@ -53,8 +52,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) { - return schema.TranscriptionResult{}, fmt.Errorf("unimplemented") +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) { + return pb.TranscriptResult{}, fmt.Errorf("unimplemented") } func (llm *Base) TTS(*pb.TTSRequest) error { diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index b654e9c9..032c9c00 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -7,7 +7,6 @@ import ( "sync" "time" - "github.com/mudler/LocalAI/core/schema" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -228,7 +227,7 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ return client.SoundGeneration(ctx, in, opts...) } -func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -243,27 +242,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques } defer conn.Close() client := pb.NewBackendClient(conn) - res, err := client.AudioTranscription(ctx, in, opts...) - if err != nil { - return nil, err - } - tresult := &schema.TranscriptionResult{} - for _, s := range res.Segments { - tks := []int{} - for _, t := range s.Tokens { - tks = append(tks, int(t)) - } - tresult.Segments = append(tresult.Segments, - schema.Segment{ - Text: s.Text, - Id: int(s.Id), - Start: time.Duration(s.Start), - End: time.Duration(s.End), - Tokens: tks, - }) - } - tresult.Text = res.Text - return tresult, err + return client.AudioTranscription(ctx, in, opts...) } func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 67d83e27..3155ff59 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -2,9 +2,7 @@ package grpc import ( "context" - "time" - "github.com/mudler/LocalAI/core/schema" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -57,28 +55,8 @@ func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerati return e.s.SoundGeneration(ctx, in) } -func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { - r, err := e.s.AudioTranscription(ctx, in) - if err != nil { - return nil, err - } - tr := &schema.TranscriptionResult{} - for _, s := range r.Segments { - var tks []int - for _, t := range s.Tokens { - tks = append(tks, int(t)) - } - tr.Segments = append(tr.Segments, - schema.Segment{ - Text: s.Text, - Id: int(s.Id), - Start: time.Duration(s.Start), - End: time.Duration(s.End), - Tokens: tks, - }) - } - tr.Text = r.Text - return tr, err +func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) { + return e.s.AudioTranscription(ctx, in) } func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 731dcd5b..97b958cc 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -1,7 +1,6 @@ package grpc import ( - "github.com/mudler/LocalAI/core/schema" pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) @@ -15,7 +14,7 @@ type LLM interface { Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error - AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) + AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) TTS(*pb.TTSRequest) error SoundGeneration(*pb.SoundGenerationRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) diff --git a/pkg/utils/ffmpeg.go b/pkg/utils/ffmpeg.go new file mode 100644 index 00000000..16656d8e --- /dev/null +++ b/pkg/utils/ffmpeg.go @@ -0,0 +1,25 @@ +package utils + +import ( + "fmt" + "os" + "os/exec" +) + +func ffmpegCommand(args []string) (string, error) { + cmd := exec.Command("ffmpeg", args...) // Constrain this to ffmpeg to permit security scanner to see that the command is safe. + cmd.Env = os.Environ() + out, err := cmd.CombinedOutput() + return string(out), err +} + +// AudioToWav converts audio to wav for transcribe. +// TODO: use https://github.com/mccoyst/ogg? +func AudioToWav(src, dst string) error { + commandArgs := []string{"-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} + out, err := ffmpegCommand(commandArgs) + if err != nil { + return fmt.Errorf("error: %w out: %s", err, out) + } + return nil +}