mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(bark-cpp): add new bark.cpp backend (#4287)
* feat(bark-cpp): add new bark.cpp backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * build on linux only for now Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * track bark.cpp in CI bumps Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop old entries from bumper Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * No need to test rwkv specifically, now part of llama.cpp Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
0d6c3a7d57
commit
58ff47de26
16
.github/workflows/bump_deps.yaml
vendored
16
.github/workflows/bump_deps.yaml
vendored
@ -12,24 +12,12 @@ jobs:
|
|||||||
- repository: "ggerganov/llama.cpp"
|
- repository: "ggerganov/llama.cpp"
|
||||||
variable: "CPPLLAMA_VERSION"
|
variable: "CPPLLAMA_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
- repository: "go-skynet/go-ggml-transformers.cpp"
|
|
||||||
variable: "GOGGMLTRANSFORMERS_VERSION"
|
|
||||||
branch: "master"
|
|
||||||
- repository: "donomii/go-rwkv.cpp"
|
|
||||||
variable: "RWKV_VERSION"
|
|
||||||
branch: "main"
|
|
||||||
- repository: "ggerganov/whisper.cpp"
|
- repository: "ggerganov/whisper.cpp"
|
||||||
variable: "WHISPER_CPP_VERSION"
|
variable: "WHISPER_CPP_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
- repository: "go-skynet/go-bert.cpp"
|
- repository: "PABannier/bark.cpp"
|
||||||
variable: "BERT_VERSION"
|
variable: "BARKCPP_VERSION"
|
||||||
branch: "master"
|
|
||||||
- repository: "go-skynet/bloomz.cpp"
|
|
||||||
variable: "BLOOMZ_VERSION"
|
|
||||||
branch: "main"
|
branch: "main"
|
||||||
- repository: "mudler/go-ggllm.cpp"
|
|
||||||
variable: "GOGGLLM_VERSION"
|
|
||||||
branch: "master"
|
|
||||||
- repository: "mudler/go-stable-diffusion"
|
- repository: "mudler/go-stable-diffusion"
|
||||||
variable: "STABLEDIFFUSION_VERSION"
|
variable: "STABLEDIFFUSION_VERSION"
|
||||||
branch: "master"
|
branch: "master"
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
|||||||
/sources/
|
/sources/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.a
|
*.a
|
||||||
|
*.o
|
||||||
get-sources
|
get-sources
|
||||||
prepare-sources
|
prepare-sources
|
||||||
/backend/cpp/llama/grpc-server
|
/backend/cpp/llama/grpc-server
|
||||||
|
37
Makefile
37
Makefile
@ -26,6 +26,10 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
|
|||||||
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
|
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
|
||||||
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
|
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
|
||||||
|
|
||||||
|
# bark.cpp
|
||||||
|
BARKCPP_REPO?=https://github.com/PABannier/bark.cpp.git
|
||||||
|
BARKCPP_VERSION?=v1.0.0
|
||||||
|
|
||||||
ONNX_VERSION?=1.20.0
|
ONNX_VERSION?=1.20.0
|
||||||
ONNX_ARCH?=x64
|
ONNX_ARCH?=x64
|
||||||
ONNX_OS?=linux
|
ONNX_OS?=linux
|
||||||
@ -201,6 +205,13 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-ggml
|
|||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-cpp-grpc
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-cpp-grpc
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
|
ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
||||||
|
|
||||||
|
ifeq ($(ONNX_OS),linux)
|
||||||
|
ifeq ($(ONNX_ARCH),x64)
|
||||||
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/bark-cpp
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
|
||||||
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
|
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
|
||||||
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
||||||
@ -233,6 +244,22 @@ sources/go-llama.cpp:
|
|||||||
git checkout $(GOLLAMA_VERSION) && \
|
git checkout $(GOLLAMA_VERSION) && \
|
||||||
git submodule update --init --recursive --depth 1 --single-branch
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
sources/bark.cpp:
|
||||||
|
git clone --recursive https://github.com/PABannier/bark.cpp.git sources/bark.cpp && \
|
||||||
|
cd sources/bark.cpp && \
|
||||||
|
git checkout $(BARKCPP_VERSION) && \
|
||||||
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
|
sources/bark.cpp/build/libbark.a: sources/bark.cpp
|
||||||
|
cd sources/bark.cpp && \
|
||||||
|
mkdir build && \
|
||||||
|
cd build && \
|
||||||
|
cmake $(CMAKE_ARGS) .. && \
|
||||||
|
cmake --build . --config Release
|
||||||
|
|
||||||
|
backend/go/bark/libbark.a: sources/bark.cpp/build/libbark.a
|
||||||
|
$(MAKE) -C backend/go/bark libbark.a
|
||||||
|
|
||||||
sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
|
sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
|
||||||
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
|
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
|
||||||
|
|
||||||
@ -302,7 +329,7 @@ sources/whisper.cpp:
|
|||||||
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
|
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
|
||||||
cd sources/whisper.cpp && $(MAKE) libwhisper.a libggml.a
|
cd sources/whisper.cpp && $(MAKE) libwhisper.a libggml.a
|
||||||
|
|
||||||
get-sources: sources/go-llama.cpp sources/go-piper sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
|
get-sources: sources/go-llama.cpp sources/go-piper sources/bark.cpp sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
|
||||||
|
|
||||||
replace:
|
replace:
|
||||||
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
|
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
|
||||||
@ -343,6 +370,7 @@ clean: ## Remove build related file
|
|||||||
rm -rf release/
|
rm -rf release/
|
||||||
rm -rf backend-assets/*
|
rm -rf backend-assets/*
|
||||||
$(MAKE) -C backend/cpp/grpc clean
|
$(MAKE) -C backend/cpp/grpc clean
|
||||||
|
$(MAKE) -C backend/go/bark clean
|
||||||
$(MAKE) -C backend/cpp/llama clean
|
$(MAKE) -C backend/cpp/llama clean
|
||||||
rm -rf backend/cpp/llama-* || true
|
rm -rf backend/cpp/llama-* || true
|
||||||
$(MAKE) dropreplace
|
$(MAKE) dropreplace
|
||||||
@ -792,6 +820,13 @@ ifneq ($(UPX),)
|
|||||||
$(UPX) backend-assets/grpc/llama-ggml
|
$(UPX) backend-assets/grpc/llama-ggml
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
backend-assets/grpc/bark-cpp: backend/go/bark/libbark.a backend-assets/grpc
|
||||||
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/bark/ LIBRARY_PATH=$(CURDIR)/backend/go/bark/ \
|
||||||
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bark-cpp ./backend/go/bark/
|
||||||
|
ifneq ($(UPX),)
|
||||||
|
$(UPX) backend-assets/grpc/bark-cpp
|
||||||
|
endif
|
||||||
|
|
||||||
backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
|
backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
|
||||||
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
|
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
|
||||||
|
25
backend/go/bark/Makefile
Normal file
25
backend/go/bark/Makefile
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
INCLUDE_PATH := $(abspath ./)
|
||||||
|
LIBRARY_PATH := $(abspath ./)
|
||||||
|
|
||||||
|
AR?=ar
|
||||||
|
|
||||||
|
BUILD_TYPE?=
|
||||||
|
# keep standard at C11 and C++11
|
||||||
|
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../sources/bark.cpp/examples -I$(INCLUDE_PATH)/../../../sources/bark.cpp/spm-headers -I$(INCLUDE_PATH)/../../../sources/bark.cpp -O3 -DNDEBUG -std=c++17 -fPIC
|
||||||
|
LDFLAGS = -L$(LIBRARY_PATH) -L$(LIBRARY_PATH)/../../../sources/bark.cpp/build/examples -lbark -lstdc++ -lm
|
||||||
|
|
||||||
|
# warnings
|
||||||
|
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
|
||||||
|
|
||||||
|
gobark.o:
|
||||||
|
$(CXX) $(CXXFLAGS) gobark.cpp -o gobark.o -c $(LDFLAGS)
|
||||||
|
|
||||||
|
libbark.a: gobark.o
|
||||||
|
cp $(INCLUDE_PATH)/../../../sources/bark.cpp/build/libbark.a ./
|
||||||
|
$(AR) rcs libbark.a gobark.o
|
||||||
|
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml.c.o
|
||||||
|
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-alloc.c.o
|
||||||
|
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-backend.c.o
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f gobark.o libbark.a
|
85
backend/go/bark/gobark.cpp
Normal file
85
backend/go/bark/gobark.cpp
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "bark.h"
|
||||||
|
#include "gobark.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
struct bark_context *c;
|
||||||
|
|
||||||
|
void bark_print_progress_callback(struct bark_context *bctx, enum bark_encoding_step step, int progress, void *user_data) {
|
||||||
|
if (step == bark_encoding_step::SEMANTIC) {
|
||||||
|
printf("\rGenerating semantic tokens... %d%%", progress);
|
||||||
|
} else if (step == bark_encoding_step::COARSE) {
|
||||||
|
printf("\rGenerating coarse tokens... %d%%", progress);
|
||||||
|
} else if (step == bark_encoding_step::FINE) {
|
||||||
|
printf("\rGenerating fine tokens... %d%%", progress);
|
||||||
|
}
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
int load_model(char *model) {
|
||||||
|
// initialize bark context
|
||||||
|
struct bark_context_params ctx_params = bark_context_default_params();
|
||||||
|
bark_params params;
|
||||||
|
|
||||||
|
params.model_path = model;
|
||||||
|
|
||||||
|
// ctx_params.verbosity = verbosity;
|
||||||
|
ctx_params.progress_callback = bark_print_progress_callback;
|
||||||
|
ctx_params.progress_callback_user_data = nullptr;
|
||||||
|
|
||||||
|
struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed);
|
||||||
|
if (!bctx) {
|
||||||
|
fprintf(stderr, "%s: Could not load model\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
c = bctx;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int tts(char *text,int threads, char *dst ) {
|
||||||
|
|
||||||
|
ggml_time_init();
|
||||||
|
const int64_t t_main_start_us = ggml_time_us();
|
||||||
|
|
||||||
|
// generate audio
|
||||||
|
if (!bark_generate_audio(c, text, threads)) {
|
||||||
|
fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float *audio_data = bark_get_audio_data(c);
|
||||||
|
if (audio_data == NULL) {
|
||||||
|
fprintf(stderr, "%s: Could not get audio data\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int audio_arr_size = bark_get_audio_data_size(c);
|
||||||
|
|
||||||
|
std::vector<float> audio_arr(audio_data, audio_data + audio_arr_size);
|
||||||
|
|
||||||
|
write_wav_on_disk(audio_arr, dst);
|
||||||
|
|
||||||
|
// report timing
|
||||||
|
{
|
||||||
|
const int64_t t_main_end_us = ggml_time_us();
|
||||||
|
const int64_t t_load_us = bark_get_load_time(c);
|
||||||
|
const int64_t t_eval_us = bark_get_eval_time(c);
|
||||||
|
|
||||||
|
printf("\n\n");
|
||||||
|
printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
|
||||||
|
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us / 1000.0f);
|
||||||
|
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int unload() {
|
||||||
|
bark_free(c);
|
||||||
|
}
|
||||||
|
|
52
backend/go/bark/gobark.go
Normal file
52
backend/go/bark/gobark.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// #cgo CXXFLAGS: -I${SRCDIR}/../../../sources/bark.cpp/ -I${SRCDIR}/../../../sources/bark.cpp/encodec.cpp -I${SRCDIR}/../../../sources/bark.cpp/examples -I${SRCDIR}/../../../sources/bark.cpp/spm-headers
|
||||||
|
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../sources/bark.cpp/build/examples -L${SRCDIR}/../../../sources/bark.cpp/build/encodec.cpp/ -lbark -lencodec -lcommon
|
||||||
|
// #include <gobark.h>
|
||||||
|
// #include <stdlib.h>
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bark struct {
|
||||||
|
base.SingleThread
|
||||||
|
threads int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *Bark) Load(opts *pb.ModelOptions) error {
|
||||||
|
|
||||||
|
sd.threads = int(opts.Threads)
|
||||||
|
|
||||||
|
modelFile := C.CString(opts.ModelFile)
|
||||||
|
defer C.free(unsafe.Pointer(modelFile))
|
||||||
|
|
||||||
|
ret := C.load_model(modelFile)
|
||||||
|
if ret != 0 {
|
||||||
|
return fmt.Errorf("inference failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *Bark) TTS(opts *pb.TTSRequest) error {
|
||||||
|
t := C.CString(opts.Text)
|
||||||
|
defer C.free(unsafe.Pointer(t))
|
||||||
|
|
||||||
|
dst := C.CString(opts.Dst)
|
||||||
|
defer C.free(unsafe.Pointer(dst))
|
||||||
|
|
||||||
|
threads := C.int(sd.threads)
|
||||||
|
|
||||||
|
ret := C.tts(t, threads, dst)
|
||||||
|
if ret != 0 {
|
||||||
|
return fmt.Errorf("inference failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
8
backend/go/bark/gobark.h
Normal file
8
backend/go/bark/gobark.h
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
int load_model(char *model);
|
||||||
|
int tts(char *text,int threads, char *dst );
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
20
backend/go/bark/main.go
Normal file
20
backend/go/bark/main.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
|
||||||
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &Bark{}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
@ -5,14 +5,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
. "github.com/mudler/LocalAI/core/http"
|
. "github.com/mudler/LocalAI/core/http"
|
||||||
@ -913,71 +911,6 @@ var _ = Describe("API test", func() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("backends", func() {
|
|
||||||
It("runs rwkv completion", func() {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
Skip("test supported only on linux")
|
|
||||||
}
|
|
||||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
|
||||||
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
|
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
|
|
||||||
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
tokens := 0
|
|
||||||
text := ""
|
|
||||||
for {
|
|
||||||
response, err := stream.Recv()
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
text += response.Choices[0].Text
|
|
||||||
tokens++
|
|
||||||
}
|
|
||||||
Expect(text).ToNot(BeEmpty())
|
|
||||||
Expect(text).To(ContainSubstring("five"))
|
|
||||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
|
||||||
})
|
|
||||||
It("runs rwkv chat completion", func() {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
Skip("test supported only on linux")
|
|
||||||
}
|
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
|
||||||
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
|
||||||
Expect(strings.ToLower(resp.Choices[0].Message.Content)).To(Or(ContainSubstring("sure"), ContainSubstring("five"), ContainSubstring("5")))
|
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
tokens := 0
|
|
||||||
text := ""
|
|
||||||
for {
|
|
||||||
response, err := stream.Recv()
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
text += response.Choices[0].Delta.Content
|
|
||||||
tokens++
|
|
||||||
}
|
|
||||||
Expect(text).ToNot(BeEmpty())
|
|
||||||
Expect(strings.ToLower(text)).To(Or(ContainSubstring("sure"), ContainSubstring("five")))
|
|
||||||
|
|
||||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// See tests/integration/stores_test
|
// See tests/integration/stores_test
|
||||||
Context("Stores", Label("stores"), func() {
|
Context("Stores", Label("stores"), func() {
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user