mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-03 03:16:40 +00:00
Merge branch 'master' into default_miro
This commit is contained in:
commit
3a1727a4fe
@ -32,18 +32,22 @@ config_remote() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Setup special .ssh files
|
# Setup special .ssh files
|
||||||
#
|
# Prints out lines of text to make things pretty
|
||||||
# Param 1: bash array, filenames relative to the customization directory that should be copied to ~/.ssh
|
# Param 1: bash array, filenames relative to the customization directory that should be copied to ~/.ssh
|
||||||
setup_ssh() {
|
setup_ssh() {
|
||||||
|
echo "starting ~/.ssh directory setup..."
|
||||||
|
mkdir -p "${HOME}.ssh"
|
||||||
|
chmod 0700 "${HOME}/.ssh"
|
||||||
|
echo "-----"
|
||||||
local files=("$@")
|
local files=("$@")
|
||||||
for file in "${files[@]}"; then
|
for file in "${files[@]}" ; do
|
||||||
local cfile="/devcontainer-customization/${file}"
|
local cfile="/devcontainer-customization/${file}"
|
||||||
local hfile="~/.ssh/${file}"
|
local hfile="${HOME}/.ssh/${file}"
|
||||||
if [ ! -f "${hfile}" ]; then
|
if [ ! -f "${hfile}" ]; then
|
||||||
echo "copying ${file}"
|
echo "copying \"${file}\""
|
||||||
cp "${cfile}" "${hfile}"
|
cp "${cfile}" "${hfile}"
|
||||||
chmod 600 "${hfile}"
|
chmod 600 "${hfile}"
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
ls ~/.ssh
|
echo "~/.ssh directory setup complete!"
|
||||||
}
|
}
|
||||||
|
2
.github/workflows/bump_deps.yaml
vendored
2
.github/workflows/bump_deps.yaml
vendored
@ -56,7 +56,7 @@ jobs:
|
|||||||
rm -rfv ${{ matrix.variable }}_message.txt
|
rm -rfv ${{ matrix.variable }}_message.txt
|
||||||
rm -rfv ${{ matrix.variable }}_commit.txt
|
rm -rfv ${{ matrix.variable }}_commit.txt
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
uses: peter-evans/create-pull-request@v6
|
uses: peter-evans/create-pull-request@v7
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
||||||
push-to-fork: ci-forks/LocalAI
|
push-to-fork: ci-forks/LocalAI
|
||||||
|
2
.github/workflows/bump_docs.yaml
vendored
2
.github/workflows/bump_docs.yaml
vendored
@ -17,7 +17,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
bash .github/bump_docs.sh ${{ matrix.repository }}
|
bash .github/bump_docs.sh ${{ matrix.repository }}
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
uses: peter-evans/create-pull-request@v6
|
uses: peter-evans/create-pull-request@v7
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
||||||
push-to-fork: ci-forks/LocalAI
|
push-to-fork: ci-forks/LocalAI
|
||||||
|
2
.github/workflows/checksum_checker.yaml
vendored
2
.github/workflows/checksum_checker.yaml
vendored
@ -36,7 +36,7 @@ jobs:
|
|||||||
sudo chmod 777 /hf_cache
|
sudo chmod 777 /hf_cache
|
||||||
bash .github/checksum_checker.sh gallery/index.yaml
|
bash .github/checksum_checker.sh gallery/index.yaml
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
uses: peter-evans/create-pull-request@v6
|
uses: peter-evans/create-pull-request@v7
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
||||||
push-to-fork: ci-forks/LocalAI
|
push-to-fork: ci-forks/LocalAI
|
||||||
|
6
.github/workflows/release.yaml
vendored
6
.github/workflows/release.yaml
vendored
@ -294,7 +294,7 @@ jobs:
|
|||||||
export C_INCLUDE_PATH=/usr/local/include
|
export C_INCLUDE_PATH=/usr/local/include
|
||||||
export CPLUS_INCLUDE_PATH=/usr/local/include
|
export CPLUS_INCLUDE_PATH=/usr/local/include
|
||||||
export PATH=$PATH:$GOPATH/bin
|
export PATH=$PATH:$GOPATH/bin
|
||||||
|
export SKIP_GRPC_BACKEND=backend-assets/grpc/whisper
|
||||||
make dist
|
make dist
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
@ -327,7 +327,7 @@ jobs:
|
|||||||
cache: false
|
cache: false
|
||||||
- name: Dependencies
|
- name: Dependencies
|
||||||
run: |
|
run: |
|
||||||
brew install protobuf grpc
|
brew install protobuf grpc libomp llvm
|
||||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
|
||||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
|
||||||
- name: Build
|
- name: Build
|
||||||
@ -336,7 +336,7 @@ jobs:
|
|||||||
export C_INCLUDE_PATH=/usr/local/include
|
export C_INCLUDE_PATH=/usr/local/include
|
||||||
export CPLUS_INCLUDE_PATH=/usr/local/include
|
export CPLUS_INCLUDE_PATH=/usr/local/include
|
||||||
export PATH=$PATH:$GOPATH/bin
|
export PATH=$PATH:$GOPATH/bin
|
||||||
|
export CC=/opt/homebrew/opt/llvm/bin/clang
|
||||||
make dist
|
make dist
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
|
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@ -18,7 +18,7 @@ jobs:
|
|||||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||||
- name: Run Gosec Security Scanner
|
- name: Run Gosec Security Scanner
|
||||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||||
uses: securego/gosec@master
|
uses: securego/gosec@v2.21.2
|
||||||
with:
|
with:
|
||||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||||
|
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@ -214,12 +214,13 @@ jobs:
|
|||||||
run: go version
|
run: go version
|
||||||
- name: Dependencies
|
- name: Dependencies
|
||||||
run: |
|
run: |
|
||||||
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc
|
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm
|
||||||
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
export C_INCLUDE_PATH=/usr/local/include
|
export C_INCLUDE_PATH=/usr/local/include
|
||||||
export CPLUS_INCLUDE_PATH=/usr/local/include
|
export CPLUS_INCLUDE_PATH=/usr/local/include
|
||||||
|
export CC=/opt/homebrew/opt/llvm/bin/clang
|
||||||
# Used to run the newer GNUMake version from brew that supports --output-sync
|
# Used to run the newer GNUMake version from brew that supports --output-sync
|
||||||
export PATH="/opt/homebrew/opt/make/libexec/gnubin:$PATH"
|
export PATH="/opt/homebrew/opt/make/libexec/gnubin:$PATH"
|
||||||
BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test
|
BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test
|
||||||
|
2
.github/workflows/update_swagger.yaml
vendored
2
.github/workflows/update_swagger.yaml
vendored
@ -25,7 +25,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
make protogen-go swagger
|
make protogen-go swagger
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
uses: peter-evans/create-pull-request@v6
|
uses: peter-evans/create-pull-request@v7
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
||||||
push-to-fork: ci-forks/LocalAI
|
push-to-fork: ci-forks/LocalAI
|
||||||
|
39
Dockerfile
39
Dockerfile
@ -13,7 +13,7 @@ ARG TARGETARCH
|
|||||||
ARG TARGETVARIANT
|
ARG TARGETVARIANT
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,rerankers:/build/backend/python/rerankers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,openvoice:/build/backend/python/openvoice/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh"
|
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,rerankers:/build/backend/python/rerankers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,openvoice:/build/backend/python/openvoice/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh"
|
||||||
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
@ -263,14 +263,20 @@ EOT
|
|||||||
# In most cases, builder is the image you should be using - however, this can save build time if one just needs to copy backend-assets/grpc/stablediffusion and nothing else.
|
# In most cases, builder is the image you should be using - however, this can save build time if one just needs to copy backend-assets/grpc/stablediffusion and nothing else.
|
||||||
FROM builder-base AS builder-sd
|
FROM builder-base AS builder-sd
|
||||||
|
|
||||||
COPY . .
|
# stablediffusion does not tolerate a newer version of abseil, copy only over enough elements to build it
|
||||||
COPY .git .
|
COPY Makefile .
|
||||||
|
COPY go.mod .
|
||||||
|
COPY go.sum .
|
||||||
|
COPY backend/backend.proto ./backend/backend.proto
|
||||||
|
COPY backend/go/image/stablediffusion ./backend/go/image/stablediffusion
|
||||||
|
COPY pkg/grpc ./pkg/grpc
|
||||||
|
COPY pkg/stablediffusion ./pkg/stablediffusion
|
||||||
|
RUN git init
|
||||||
|
RUN make sources/go-stable-diffusion
|
||||||
|
RUN touch prepare-sources
|
||||||
|
|
||||||
RUN make prepare
|
# Actually build the backend
|
||||||
|
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make backend-assets/grpc/stablediffusion
|
||||||
|
|
||||||
# stablediffusion does not tolerate a newer version of abseil, build it first
|
|
||||||
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
|
||||||
|
|
||||||
###################################
|
###################################
|
||||||
###################################
|
###################################
|
||||||
@ -285,8 +291,20 @@ COPY --from=grpc /opt/grpc /usr/local
|
|||||||
# Rebuild with defaults backends
|
# Rebuild with defaults backends
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
COPY .git .
|
||||||
|
|
||||||
|
RUN make prepare
|
||||||
|
|
||||||
## Build the binary
|
## Build the binary
|
||||||
RUN make build
|
## If it's CUDA, we want to skip some of the llama-compat backends to save space
|
||||||
|
## We only leave the most CPU-optimized variant and the fallback for the cublas build
|
||||||
|
## (both will use CUDA for the actual computation)
|
||||||
|
RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
|
||||||
|
SKIP_GRPC_BACKEND="backend-assets/grpc/llama-cpp-avx backend-assets/grpc/llama-cpp-avx2" make build; \
|
||||||
|
else \
|
||||||
|
make build; \
|
||||||
|
fi
|
||||||
|
|
||||||
RUN if [ ! -d "/build/sources/go-piper/piper-phonemize/pi/lib/" ]; then \
|
RUN if [ ! -d "/build/sources/go-piper/piper-phonemize/pi/lib/" ]; then \
|
||||||
mkdir -p /build/sources/go-piper/piper-phonemize/pi/lib/ \
|
mkdir -p /build/sources/go-piper/piper-phonemize/pi/lib/ \
|
||||||
@ -400,9 +418,6 @@ RUN if [[ ( "${EXTRA_BACKENDS}" =~ "coqui" || -z "${EXTRA_BACKENDS}" ) && "$IMAG
|
|||||||
; fi && \
|
; fi && \
|
||||||
if [[ ( "${EXTRA_BACKENDS}" =~ "transformers-musicgen" || -z "${EXTRA_BACKENDS}" ) && "$IMAGE_TYPE" == "extras" ]]; then \
|
if [[ ( "${EXTRA_BACKENDS}" =~ "transformers-musicgen" || -z "${EXTRA_BACKENDS}" ) && "$IMAGE_TYPE" == "extras" ]]; then \
|
||||||
make -C backend/python/transformers-musicgen \
|
make -C backend/python/transformers-musicgen \
|
||||||
; fi && \
|
|
||||||
if [[ ( "${EXTRA_BACKENDS}" =~ "exllama1" || -z "${EXTRA_BACKENDS}" ) && "$IMAGE_TYPE" == "extras" ]]; then \
|
|
||||||
make -C backend/python/exllama \
|
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
RUN if [[ ( "${EXTRA_BACKENDS}" =~ "vall-e-x" || -z "${EXTRA_BACKENDS}" ) && "$IMAGE_TYPE" == "extras" ]]; then \
|
RUN if [[ ( "${EXTRA_BACKENDS}" =~ "vall-e-x" || -z "${EXTRA_BACKENDS}" ) && "$IMAGE_TYPE" == "extras" ]]; then \
|
||||||
|
19
Makefile
19
Makefile
@ -8,7 +8,7 @@ DETECT_LIBS?=true
|
|||||||
# llama.cpp versions
|
# llama.cpp versions
|
||||||
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
|
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
|
||||||
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
|
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
|
||||||
CPPLLAMA_VERSION?=2f3c1466ff46a2413b0e363a5005c46538186ee6
|
CPPLLAMA_VERSION?=23e0d70bacaaca1429d365a44aa9e7434f17823b
|
||||||
|
|
||||||
# go-rwkv version
|
# go-rwkv version
|
||||||
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
||||||
@ -16,7 +16,7 @@ RWKV_VERSION?=661e7ae26d442f5cfebd2a0881b44e8c55949ec6
|
|||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
|
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
|
||||||
WHISPER_CPP_VERSION?=d65786ea540a5aef21f67cacfa6f134097727780
|
WHISPER_CPP_VERSION?=049b3a0e53c8a8e4c4576c06a1a4fccf0063a73f
|
||||||
|
|
||||||
# bert.cpp version
|
# bert.cpp version
|
||||||
BERT_REPO?=https://github.com/go-skynet/go-bert.cpp
|
BERT_REPO?=https://github.com/go-skynet/go-bert.cpp
|
||||||
@ -534,10 +534,10 @@ protogen-go-clean:
|
|||||||
$(RM) bin/*
|
$(RM) bin/*
|
||||||
|
|
||||||
.PHONY: protogen-python
|
.PHONY: protogen-python
|
||||||
protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen rerankers-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen openvoice-protogen
|
protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama2-protogen mamba-protogen rerankers-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen openvoice-protogen
|
||||||
|
|
||||||
.PHONY: protogen-python-clean
|
.PHONY: protogen-python-clean
|
||||||
protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean sentencetransformers-protogen-clean rerankers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean openvoice-protogen-clean
|
protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama2-protogen-clean mamba-protogen-clean sentencetransformers-protogen-clean rerankers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean openvoice-protogen-clean
|
||||||
|
|
||||||
.PHONY: autogptq-protogen
|
.PHONY: autogptq-protogen
|
||||||
autogptq-protogen:
|
autogptq-protogen:
|
||||||
@ -571,14 +571,6 @@ diffusers-protogen:
|
|||||||
diffusers-protogen-clean:
|
diffusers-protogen-clean:
|
||||||
$(MAKE) -C backend/python/diffusers protogen-clean
|
$(MAKE) -C backend/python/diffusers protogen-clean
|
||||||
|
|
||||||
.PHONY: exllama-protogen
|
|
||||||
exllama-protogen:
|
|
||||||
$(MAKE) -C backend/python/exllama protogen
|
|
||||||
|
|
||||||
.PHONY: exllama-protogen-clean
|
|
||||||
exllama-protogen-clean:
|
|
||||||
$(MAKE) -C backend/python/exllama protogen-clean
|
|
||||||
|
|
||||||
.PHONY: exllama2-protogen
|
.PHONY: exllama2-protogen
|
||||||
exllama2-protogen:
|
exllama2-protogen:
|
||||||
$(MAKE) -C backend/python/exllama2 protogen
|
$(MAKE) -C backend/python/exllama2 protogen
|
||||||
@ -675,7 +667,6 @@ prepare-extra-conda-environments: protogen-python
|
|||||||
$(MAKE) -C backend/python/parler-tts
|
$(MAKE) -C backend/python/parler-tts
|
||||||
$(MAKE) -C backend/python/vall-e-x
|
$(MAKE) -C backend/python/vall-e-x
|
||||||
$(MAKE) -C backend/python/openvoice
|
$(MAKE) -C backend/python/openvoice
|
||||||
$(MAKE) -C backend/python/exllama
|
|
||||||
$(MAKE) -C backend/python/exllama2
|
$(MAKE) -C backend/python/exllama2
|
||||||
|
|
||||||
prepare-test-extra: protogen-python
|
prepare-test-extra: protogen-python
|
||||||
@ -846,7 +837,7 @@ endif
|
|||||||
|
|
||||||
backend-assets/grpc/whisper: sources/whisper.cpp sources/whisper.cpp/libwhisper.a backend-assets/grpc
|
backend-assets/grpc/whisper: sources/whisper.cpp sources/whisper.cpp/libwhisper.a backend-assets/grpc
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="$(CURDIR)/sources/whisper.cpp/include:$(CURDIR)/sources/whisper.cpp/ggml/include" LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \
|
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="$(CURDIR)/sources/whisper.cpp/include:$(CURDIR)/sources/whisper.cpp/ggml/include" LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/whisper
|
||||||
ifneq ($(UPX),)
|
ifneq ($(UPX),)
|
||||||
$(UPX) backend-assets/grpc/whisper
|
$(UPX) backend-assets/grpc/whisper
|
||||||
endif
|
endif
|
||||||
|
@ -40,7 +40,7 @@
|
|||||||
|
|
||||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||||
>
|
>
|
||||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [📣 News](https://localai.io/basics/news/) [ 🛫 Examples ](https://github.com/go-skynet/LocalAI/tree/master/examples/) [ 🖼️ Models ](https://localai.io/models/) [ 🚀 Roadmap ](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🥽 Demo](https://demo.localai.io) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/go-skynet/LocalAI/tree/master/examples/)
|
||||||
|
|
||||||
[![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai)
|
[![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai)
|
||||||
|
|
||||||
@ -72,6 +72,7 @@ docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-cpu
|
|||||||
|
|
||||||
[Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
[Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
||||||
|
|
||||||
|
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
|
||||||
- July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723
|
- July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723
|
||||||
- June 2024: 🆕 You can browse now the model gallery without LocalAI! Check out https://models.localai.io
|
- June 2024: 🆕 You can browse now the model gallery without LocalAI! Check out https://models.localai.io
|
||||||
- June 2024: Support for models from OCI registries: https://github.com/mudler/LocalAI/pull/2628
|
- June 2024: Support for models from OCI registries: https://github.com/mudler/LocalAI/pull/2628
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
name: stablediffusion
|
name: stablediffusion
|
||||||
parameters:
|
parameters:
|
||||||
model: runwayml/stable-diffusion-v1-5
|
model: Lykon/dreamshaper-8
|
||||||
backend: diffusers
|
backend: diffusers
|
||||||
step: 25
|
step: 25
|
||||||
f16: true
|
f16: true
|
||||||
|
@ -16,6 +16,7 @@ service Backend {
|
|||||||
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
||||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||||
rpc TTS(TTSRequest) returns (Result) {}
|
rpc TTS(TTSRequest) returns (Result) {}
|
||||||
|
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||||
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
|
||||||
rpc Status(HealthMessage) returns (StatusResponse) {}
|
rpc Status(HealthMessage) returns (StatusResponse) {}
|
||||||
|
|
||||||
@ -270,6 +271,17 @@ message TTSRequest {
|
|||||||
optional string language = 5;
|
optional string language = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message SoundGenerationRequest {
|
||||||
|
string text = 1;
|
||||||
|
string model = 2;
|
||||||
|
string dst = 3;
|
||||||
|
optional float duration = 4;
|
||||||
|
optional float temperature = 5;
|
||||||
|
optional bool sample = 6;
|
||||||
|
optional string src = 7;
|
||||||
|
optional int32 src_divisor = 8;
|
||||||
|
}
|
||||||
|
|
||||||
message TokenizationResponse {
|
message TokenizationResponse {
|
||||||
int32 length = 1;
|
int32 length = 1;
|
||||||
repeated int32 tokens = 2;
|
repeated int32 tokens = 2;
|
||||||
|
@ -13,15 +13,15 @@
|
|||||||
#include <getopt.h>
|
#include <getopt.h>
|
||||||
#include "clip.h"
|
#include "clip.h"
|
||||||
#include "llava.h"
|
#include "llava.h"
|
||||||
|
#include "log.h"
|
||||||
#include "stb_image.h"
|
#include "stb_image.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
|
||||||
#include "backend.pb.h"
|
#include "backend.pb.h"
|
||||||
#include "backend.grpc.pb.h"
|
#include "backend.grpc.pb.h"
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
#include "sampling.h"
|
||||||
// include std::regex
|
// include std::regex
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
@ -203,8 +203,8 @@ struct llama_client_slot
|
|||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
struct llama_sampling_params sparams;
|
struct gpt_sampler_params sparams;
|
||||||
llama_sampling_context *ctx_sampling = nullptr;
|
gpt_sampler *ctx_sampling = nullptr;
|
||||||
|
|
||||||
int32_t ga_i = 0; // group-attention state
|
int32_t ga_i = 0; // group-attention state
|
||||||
int32_t ga_n = 1; // group-attention factor
|
int32_t ga_n = 1; // group-attention factor
|
||||||
@ -449,7 +449,7 @@ struct llama_server_context
|
|||||||
LOG_INFO("Multi Modal Mode Enabled", {});
|
LOG_INFO("Multi Modal Mode Enabled", {});
|
||||||
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
|
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
|
||||||
if(clp_ctx == nullptr) {
|
if(clp_ctx == nullptr) {
|
||||||
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
|
LOG_ERR("unable to load clip model: %s", params.mmproj.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -463,7 +463,7 @@ struct llama_server_context
|
|||||||
ctx = llama_init.context;
|
ctx = llama_init.context;
|
||||||
if (model == nullptr)
|
if (model == nullptr)
|
||||||
{
|
{
|
||||||
LOG_ERROR("unable to load model", {{"model", params.model}});
|
LOG_ERR("unable to load model: %s", params.model.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -471,7 +471,7 @@ struct llama_server_context
|
|||||||
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
|
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
|
||||||
const int n_embd_llm = llama_n_embd(model);
|
const int n_embd_llm = llama_n_embd(model);
|
||||||
if (n_embd_clip != n_embd_llm) {
|
if (n_embd_clip != n_embd_llm) {
|
||||||
LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm);
|
LOG("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return false;
|
return false;
|
||||||
@ -490,7 +490,7 @@ struct llama_server_context
|
|||||||
std::vector<char> buf(1);
|
std::vector<char> buf(1);
|
||||||
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
|
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
LOG_ERR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", __func__);
|
||||||
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
|
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -619,7 +619,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
|
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
|
||||||
slot_params default_params;
|
slot_params default_params;
|
||||||
llama_sampling_params default_sparams;
|
gpt_sampler_params default_sparams;
|
||||||
|
|
||||||
slot->params.stream = json_value(data, "stream", false);
|
slot->params.stream = json_value(data, "stream", false);
|
||||||
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||||
@ -628,7 +628,7 @@ struct llama_server_context
|
|||||||
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||||
slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||||
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||||
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
slot->sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||||
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
slot->sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
slot->sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||||
slot->sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
slot->sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||||
@ -641,7 +641,7 @@ struct llama_server_context
|
|||||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
slot->sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
@ -665,6 +665,7 @@ struct llama_server_context
|
|||||||
slot->params.input_prefix = "";
|
slot->params.input_prefix = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (data.count("input_suffix") != 0)
|
if (data.count("input_suffix") != 0)
|
||||||
{
|
{
|
||||||
slot->params.input_suffix = data["input_suffix"];
|
slot->params.input_suffix = data["input_suffix"];
|
||||||
@ -683,6 +684,10 @@ struct llama_server_context
|
|||||||
slot->prompt = "";
|
slot->prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (json_value(data, "ignore_eos", false)) {
|
||||||
|
slot->sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
||||||
|
}
|
||||||
|
/*
|
||||||
slot->sparams.penalty_prompt_tokens.clear();
|
slot->sparams.penalty_prompt_tokens.clear();
|
||||||
slot->sparams.use_penalty_prompt_tokens = false;
|
slot->sparams.use_penalty_prompt_tokens = false;
|
||||||
const auto &penalty_prompt = data.find("penalty_prompt");
|
const auto &penalty_prompt = data.find("penalty_prompt");
|
||||||
@ -718,14 +723,10 @@ struct llama_server_context
|
|||||||
slot->sparams.use_penalty_prompt_tokens = true;
|
slot->sparams.use_penalty_prompt_tokens = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
slot->sparams.logit_bias.clear();
|
slot->sparams.logit_bias.clear();
|
||||||
|
|
||||||
if (json_value(data, "ignore_eos", false))
|
|
||||||
{
|
|
||||||
slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto &logit_bias = data.find("logit_bias");
|
const auto &logit_bias = data.find("logit_bias");
|
||||||
if (logit_bias != data.end() && logit_bias->is_array())
|
if (logit_bias != data.end() && logit_bias->is_array())
|
||||||
{
|
{
|
||||||
@ -753,7 +754,7 @@ struct llama_server_context
|
|||||||
llama_token tok = el[0].get<llama_token>();
|
llama_token tok = el[0].get<llama_token>();
|
||||||
if (tok >= 0 && tok < n_vocab)
|
if (tok >= 0 && tok < n_vocab)
|
||||||
{
|
{
|
||||||
slot->sparams.logit_bias[tok] = bias;
|
slot->sparams.logit_bias.push_back({tok, bias});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (el[0].is_string())
|
else if (el[0].is_string())
|
||||||
@ -761,7 +762,7 @@ struct llama_server_context
|
|||||||
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
||||||
for (auto tok : toks)
|
for (auto tok : toks)
|
||||||
{
|
{
|
||||||
slot->sparams.logit_bias[tok] = bias;
|
slot->sparams.logit_bias.push_back({tok, bias});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -782,24 +783,22 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto &samplers_sequence = data.find("samplers");
|
const auto & samplers = data.find("samplers");
|
||||||
if (samplers_sequence != data.end() && samplers_sequence->is_array())
|
if (samplers != data.end() && samplers->is_array()) {
|
||||||
{
|
|
||||||
std::vector<std::string> sampler_names;
|
std::vector<std::string> sampler_names;
|
||||||
for (const auto &sampler_name : *samplers_sequence)
|
for (const auto & name : *samplers) {
|
||||||
{
|
if (name.is_string()) {
|
||||||
if (sampler_name.is_string())
|
sampler_names.emplace_back(name);
|
||||||
{
|
}
|
||||||
sampler_names.emplace_back(sampler_name);
|
|
||||||
}
|
}
|
||||||
}
|
slot->sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
|
||||||
slot->sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
slot->sparams.samplers_sequence = default_sparams.samplers_sequence;
|
slot->sparams.samplers = default_sparams.samplers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (multimodal)
|
if (multimodal)
|
||||||
{
|
{
|
||||||
const auto &images_data = data.find("image_data");
|
const auto &images_data = data.find("image_data");
|
||||||
@ -814,10 +813,11 @@ struct llama_server_context
|
|||||||
img_sl.img_data = clip_image_u8_init();
|
img_sl.img_data = clip_image_u8_init();
|
||||||
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
|
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
|
||||||
{
|
{
|
||||||
LOG_ERROR("failed to load image", {
|
LOG_ERR("%s: failed to load image, slot_id: %d, img_sl_id: %d",
|
||||||
{"slot_id", slot->id},
|
__func__,
|
||||||
{"img_sl_id", img_sl.id}
|
slot->id,
|
||||||
});
|
img_sl.id
|
||||||
|
);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
LOG_VERBOSE("image loaded", {
|
LOG_VERBOSE("image loaded", {
|
||||||
@ -855,12 +855,12 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!found) {
|
if (!found) {
|
||||||
LOG_TEE("ERROR: Image with id: %i, not found.\n", img_id);
|
LOG("ERROR: Image with id: %i, not found.\n", img_id);
|
||||||
slot->images.clear();
|
slot->images.clear();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} catch (const std::invalid_argument& e) {
|
} catch (const std::invalid_argument& e) {
|
||||||
LOG_TEE("Invalid image number id in prompt\n");
|
LOG("Invalid image number id in prompt\n");
|
||||||
slot->images.clear();
|
slot->images.clear();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -875,10 +875,10 @@ struct llama_server_context
|
|||||||
|
|
||||||
if (slot->ctx_sampling != nullptr)
|
if (slot->ctx_sampling != nullptr)
|
||||||
{
|
{
|
||||||
llama_sampling_free(slot->ctx_sampling);
|
gpt_sampler_free(slot->ctx_sampling);
|
||||||
}
|
}
|
||||||
slot->ctx_sampling = llama_sampling_init(slot->sparams);
|
slot->ctx_sampling = gpt_sampler_init(model, slot->sparams);
|
||||||
llama_set_rng_seed(ctx, slot->params.seed);
|
//llama_set_rng_seed(ctx, slot->params.seed);
|
||||||
slot->command = LOAD_PROMPT;
|
slot->command = LOAD_PROMPT;
|
||||||
|
|
||||||
all_slots_are_idle = false;
|
all_slots_are_idle = false;
|
||||||
@ -888,7 +888,7 @@ struct llama_server_context
|
|||||||
{"task_id", slot->task_id},
|
{"task_id", slot->task_id},
|
||||||
});
|
});
|
||||||
|
|
||||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(slot->sparams).c_str());
|
// LOG("sampling: \n%s\n", llama_sampling_print(slot->sparams).c_str());
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -928,7 +928,7 @@ struct llama_server_context
|
|||||||
};
|
};
|
||||||
if (llama_decode(ctx, batch_view) != 0)
|
if (llama_decode(ctx, batch_view) != 0)
|
||||||
{
|
{
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
LOG("%s: llama_decode() failed\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -940,7 +940,7 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("system prompt updated\n");
|
LOG("system prompt updated\n");
|
||||||
system_need_update = false;
|
system_need_update = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1006,11 +1006,13 @@ struct llama_server_context
|
|||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
|
|
||||||
|
/*
|
||||||
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
|
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
|
||||||
{
|
{
|
||||||
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
||||||
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
// check if there is incomplete UTF-8 character at the end
|
// check if there is incomplete UTF-8 character at the end
|
||||||
bool incomplete = false;
|
bool incomplete = false;
|
||||||
@ -1119,8 +1121,8 @@ struct llama_server_context
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llava_image_embed_make_with_clip_img(clp_ctx, params.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) {
|
if (!llava_image_embed_make_with_clip_img(clp_ctx, params.cpuparams.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) {
|
||||||
LOG_TEE("Error processing the given image");
|
LOG("Error processing the given image");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1132,7 +1134,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
void send_error(task_server& task, const std::string &error)
|
void send_error(task_server& task, const std::string &error)
|
||||||
{
|
{
|
||||||
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
|
LOG("task %i - error: %s\n", task.id, error.c_str());
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = task.id;
|
res.id = task.id;
|
||||||
res.multitask_id = task.multitask_id;
|
res.multitask_id = task.multitask_id;
|
||||||
@ -1144,13 +1146,11 @@ struct llama_server_context
|
|||||||
|
|
||||||
json get_formated_generation(llama_client_slot &slot)
|
json get_formated_generation(llama_client_slot &slot)
|
||||||
{
|
{
|
||||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
std::vector<std::string> samplers;
|
||||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
|
samplers.reserve(slot.sparams.samplers.size());
|
||||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
for (const auto & sampler : slot.sparams.samplers)
|
||||||
std::vector<std::string> samplers_sequence;
|
|
||||||
for (const auto &sampler_type : slot.sparams.samplers_sequence)
|
|
||||||
{
|
{
|
||||||
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
|
samplers.emplace_back(gpt_sampler_type_to_str(sampler));
|
||||||
}
|
}
|
||||||
|
|
||||||
return json {
|
return json {
|
||||||
@ -1165,13 +1165,11 @@ struct llama_server_context
|
|||||||
{"top_p", slot.sparams.top_p},
|
{"top_p", slot.sparams.top_p},
|
||||||
{"min_p", slot.sparams.min_p},
|
{"min_p", slot.sparams.min_p},
|
||||||
{"tfs_z", slot.sparams.tfs_z},
|
{"tfs_z", slot.sparams.tfs_z},
|
||||||
{"typical_p", slot.sparams.typical_p},
|
{"typical_p", slot.sparams.typ_p},
|
||||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||||
{"presence_penalty", slot.sparams.penalty_present},
|
{"presence_penalty", slot.sparams.penalty_present},
|
||||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||||
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
|
||||||
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
|
||||||
{"mirostat", slot.sparams.mirostat},
|
{"mirostat", slot.sparams.mirostat},
|
||||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||||
@ -1179,13 +1177,13 @@ struct llama_server_context
|
|||||||
{"stop", slot.params.antiprompt},
|
{"stop", slot.params.antiprompt},
|
||||||
{"n_predict", slot.params.n_predict},
|
{"n_predict", slot.params.n_predict},
|
||||||
{"n_keep", params.n_keep},
|
{"n_keep", params.n_keep},
|
||||||
{"ignore_eos", ignore_eos},
|
{"ignore_eos", slot.sparams.ignore_eos},
|
||||||
{"stream", slot.params.stream},
|
{"stream", slot.params.stream},
|
||||||
{"logit_bias", slot.sparams.logit_bias},
|
// {"logit_bias", slot.sparams.logit_bias},
|
||||||
{"n_probs", slot.sparams.n_probs},
|
{"n_probs", slot.sparams.n_probs},
|
||||||
{"min_keep", slot.sparams.min_keep},
|
{"min_keep", slot.sparams.min_keep},
|
||||||
{"grammar", slot.sparams.grammar},
|
{"grammar", slot.sparams.grammar},
|
||||||
{"samplers", samplers_sequence}
|
{"samplers", samplers}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1375,7 +1373,7 @@ struct llama_server_context
|
|||||||
};
|
};
|
||||||
if (llama_decode(ctx, batch_view))
|
if (llama_decode(ctx, batch_view))
|
||||||
{
|
{
|
||||||
LOG_TEE("%s : failed to eval\n", __func__);
|
LOG("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1393,7 +1391,7 @@ struct llama_server_context
|
|||||||
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
|
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
|
||||||
if (llama_decode(ctx, batch_img))
|
if (llama_decode(ctx, batch_img))
|
||||||
{
|
{
|
||||||
LOG_TEE("%s : failed to eval image\n", __func__);
|
LOG("%s : failed to eval image\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
slot.n_past += n_eval;
|
slot.n_past += n_eval;
|
||||||
@ -1576,7 +1574,7 @@ struct llama_server_context
|
|||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.truncated = false;
|
slot.truncated = false;
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
LOG_TEE("Context exhausted. Slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
|
LOG("Context exhausted. Slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
// END LOCALAI changes
|
// END LOCALAI changes
|
||||||
@ -1714,7 +1712,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
if (!slot.params.cache_prompt)
|
if (!slot.params.cache_prompt)
|
||||||
{
|
{
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
gpt_sampler_reset(slot.ctx_sampling);
|
||||||
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_past_se = 0;
|
slot.n_past_se = 0;
|
||||||
@ -1726,7 +1724,7 @@ struct llama_server_context
|
|||||||
// push the prompt into the sampling context (do not apply grammar)
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
for (auto &token : prompt_tokens)
|
for (auto &token : prompt_tokens)
|
||||||
{
|
{
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
|
gpt_sampler_accept(slot.ctx_sampling, token, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
@ -1824,10 +1822,11 @@ struct llama_server_context
|
|||||||
|
|
||||||
if (has_images && !ingest_images(slot, n_batch))
|
if (has_images && !ingest_images(slot, n_batch))
|
||||||
{
|
{
|
||||||
LOG_ERROR("failed processing images", {
|
LOG_ERR("%s: failed processing images Slot id : %d, Task id: %d",
|
||||||
"slot_id", slot.id,
|
__func__,
|
||||||
"task_id", slot.task_id,
|
slot.id,
|
||||||
});
|
slot.task_id
|
||||||
|
);
|
||||||
// FIXME @phymbert: to be properly tested
|
// FIXME @phymbert: to be properly tested
|
||||||
// early returning without changing the slot state will block the slot for ever
|
// early returning without changing the slot state will block the slot for ever
|
||||||
// no one at the moment is checking the return value
|
// no one at the moment is checking the return value
|
||||||
@ -1867,10 +1866,10 @@ struct llama_server_context
|
|||||||
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
||||||
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
||||||
|
|
||||||
LOG_TEE("\n");
|
LOG("\n");
|
||||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
||||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||||
|
|
||||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
||||||
@ -1880,7 +1879,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
slot.ga_i += slot.ga_w / slot.ga_n;
|
slot.ga_i += slot.ga_w / slot.ga_n;
|
||||||
|
|
||||||
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
||||||
}
|
}
|
||||||
slot.n_past_se += n_tokens;
|
slot.n_past_se += n_tokens;
|
||||||
}
|
}
|
||||||
@ -1905,11 +1904,11 @@ struct llama_server_context
|
|||||||
if (n_batch == 1 || ret < 0)
|
if (n_batch == 1 || ret < 0)
|
||||||
{
|
{
|
||||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||||
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
LOG("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
|
LOG("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
|
||||||
|
|
||||||
// retry with half the batch size to try to find a free slot in the KV cache
|
// retry with half the batch size to try to find a free slot in the KV cache
|
||||||
n_batch /= 2;
|
n_batch /= 2;
|
||||||
@ -1934,9 +1933,9 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
const llama_token id = gpt_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
|
||||||
|
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
gpt_sampler_accept(slot.ctx_sampling, id, true);
|
||||||
|
|
||||||
slot.n_decoded += 1;
|
slot.n_decoded += 1;
|
||||||
if (slot.n_decoded == 1)
|
if (slot.n_decoded == 1)
|
||||||
@ -1946,19 +1945,14 @@ struct llama_server_context
|
|||||||
metrics.on_prompt_eval(slot);
|
metrics.on_prompt_eval(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
|
||||||
result.tok = id;
|
result.tok = id;
|
||||||
|
const auto * cur_p = gpt_sampler_get_candidates(slot.ctx_sampling);
|
||||||
|
|
||||||
const int32_t n_probs = slot.sparams.n_probs;
|
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
|
||||||
if (slot.sparams.temp <= 0 && n_probs > 0)
|
result.probs.push_back({
|
||||||
{
|
cur_p->data[i].id,
|
||||||
// for llama_sample_token_greedy we need to sort candidates
|
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||||
llama_sample_softmax(ctx, &cur_p);
|
});
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
|
|
||||||
{
|
|
||||||
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!process_token(result, slot))
|
if (!process_token(result, slot))
|
||||||
@ -2210,7 +2204,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||||||
params.model_alias = request->modelfile();
|
params.model_alias = request->modelfile();
|
||||||
params.n_ctx = request->contextsize();
|
params.n_ctx = request->contextsize();
|
||||||
//params.memory_f16 = request->f16memory();
|
//params.memory_f16 = request->f16memory();
|
||||||
params.n_threads = request->threads();
|
params.cpuparams.n_threads = request->threads();
|
||||||
params.n_gpu_layers = request->ngpulayers();
|
params.n_gpu_layers = request->ngpulayers();
|
||||||
params.n_batch = request->nbatch();
|
params.n_batch = request->nbatch();
|
||||||
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
|
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
|
||||||
|
13
backend/cpp/llama/patches/01-llava.patch
Normal file
13
backend/cpp/llama/patches/01-llava.patch
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
|
||||||
|
index 342042ff..224db9b5 100644
|
||||||
|
--- a/examples/llava/clip.cpp
|
||||||
|
+++ b/examples/llava/clip.cpp
|
||||||
|
@@ -2419,7 +2419,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
|
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||||
|
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||||
|
for (int i = 0; i < num_patches; i++) {
|
||||||
|
- patches_data[i] = i + 1;
|
||||||
|
+ patches_data[i] = i;
|
||||||
|
}
|
||||||
|
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||||
|
free(patches_data);
|
@ -1,5 +1,12 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
## Patches
|
||||||
|
## Apply patches from the `patches` directory
|
||||||
|
for patch in $(ls patches); do
|
||||||
|
echo "Applying patch $patch"
|
||||||
|
patch -d llama.cpp/ -p1 < patches/$patch
|
||||||
|
done
|
||||||
|
|
||||||
cp -r CMakeLists.txt llama.cpp/examples/grpc-server/
|
cp -r CMakeLists.txt llama.cpp/examples/grpc-server/
|
||||||
cp -r grpc-server.cpp llama.cpp/examples/grpc-server/
|
cp -r grpc-server.cpp llama.cpp/examples/grpc-server/
|
||||||
cp -rfv json.hpp llama.cpp/examples/grpc-server/
|
cp -rfv json.hpp llama.cpp/examples/grpc-server/
|
||||||
|
@ -481,30 +481,3 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
|
|||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// random string / id
|
|
||||||
//
|
|
||||||
|
|
||||||
static std::string random_string()
|
|
||||||
{
|
|
||||||
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
|
|
||||||
|
|
||||||
std::random_device rd;
|
|
||||||
std::mt19937 generator(rd());
|
|
||||||
|
|
||||||
std::string result(32, ' ');
|
|
||||||
|
|
||||||
for (int i = 0; i < 32; ++i) {
|
|
||||||
result[i] = str[generator() % str.size()];
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string gen_chatcmplid()
|
|
||||||
{
|
|
||||||
std::stringstream chatcmplid;
|
|
||||||
chatcmplid << "chatcmpl-" << random_string();
|
|
||||||
return chatcmplid.str();
|
|
||||||
}
|
|
@ -1,104 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
|
||||||
"github.com/go-audio/wav"
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ffmpegCommand(args []string) (string, error) {
|
|
||||||
cmd := exec.Command("ffmpeg", args...) // Constrain this to ffmpeg to permit security scanner to see that the command is safe.
|
|
||||||
cmd.Env = os.Environ()
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
return string(out), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// AudioToWav converts audio to wav for transcribe.
|
|
||||||
// TODO: use https://github.com/mccoyst/ogg?
|
|
||||||
func audioToWav(src, dst string) error {
|
|
||||||
commandArgs := []string{"-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst}
|
|
||||||
out, err := ffmpegCommand(commandArgs)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error: %w out: %s", err, out)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Transcript(model whisper.Model, audiopath, language string, translate bool, threads uint) (schema.TranscriptionResult, error) {
|
|
||||||
res := schema.TranscriptionResult{}
|
|
||||||
|
|
||||||
dir, err := os.MkdirTemp("", "whisper")
|
|
||||||
if err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(dir)
|
|
||||||
|
|
||||||
convertedPath := filepath.Join(dir, "converted.wav")
|
|
||||||
|
|
||||||
if err := audioToWav(audiopath, convertedPath); err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open samples
|
|
||||||
fh, err := os.Open(convertedPath)
|
|
||||||
if err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
defer fh.Close()
|
|
||||||
|
|
||||||
// Read samples
|
|
||||||
d := wav.NewDecoder(fh)
|
|
||||||
buf, err := d.FullPCMBuffer()
|
|
||||||
if err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
data := buf.AsFloat32Buffer().Data
|
|
||||||
|
|
||||||
// Process samples
|
|
||||||
context, err := model.NewContext()
|
|
||||||
if err != nil {
|
|
||||||
return res, err
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
context.SetThreads(threads)
|
|
||||||
|
|
||||||
if language != "" {
|
|
||||||
context.SetLanguage(language)
|
|
||||||
} else {
|
|
||||||
context.SetLanguage("auto")
|
|
||||||
}
|
|
||||||
|
|
||||||
if translate {
|
|
||||||
context.SetTranslate(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := context.Process(data, nil, nil); err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
s, err := context.NextSegment()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
var tokens []int
|
|
||||||
for _, t := range s.Tokens {
|
|
||||||
tokens = append(tokens, t.Id)
|
|
||||||
}
|
|
||||||
|
|
||||||
segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
|
|
||||||
res.Segments = append(res.Segments, segment)
|
|
||||||
|
|
||||||
res.Text += s.Text
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
@ -1,26 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// This is a wrapper to statisfy the GRPC service interface
|
|
||||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
|
||||||
import (
|
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Whisper struct {
|
|
||||||
base.SingleThread
|
|
||||||
whisper whisper.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
|
|
||||||
// Note: the Model here is a path to a directory containing the model files
|
|
||||||
w, err := whisper.New(opts.ModelFile)
|
|
||||||
sd.whisper = w
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) {
|
|
||||||
return Transcript(sd.whisper, opts.Dst, opts.Language, opts.Translate, uint(opts.Threads))
|
|
||||||
}
|
|
105
backend/go/transcribe/whisper/whisper.go
Normal file
105
backend/go/transcribe/whisper/whisper.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
"github.com/go-audio/wav"
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Whisper struct {
|
||||||
|
base.SingleThread
|
||||||
|
whisper whisper.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
|
||||||
|
// Note: the Model here is a path to a directory containing the model files
|
||||||
|
w, err := whisper.New(opts.ModelFile)
|
||||||
|
sd.whisper = w
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
|
|
||||||
|
dir, err := os.MkdirTemp("", "whisper")
|
||||||
|
if err != nil {
|
||||||
|
return pb.TranscriptResult{}, err
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(dir)
|
||||||
|
|
||||||
|
convertedPath := filepath.Join(dir, "converted.wav")
|
||||||
|
|
||||||
|
if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
|
||||||
|
return pb.TranscriptResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open samples
|
||||||
|
fh, err := os.Open(convertedPath)
|
||||||
|
if err != nil {
|
||||||
|
return pb.TranscriptResult{}, err
|
||||||
|
}
|
||||||
|
defer fh.Close()
|
||||||
|
|
||||||
|
// Read samples
|
||||||
|
d := wav.NewDecoder(fh)
|
||||||
|
buf, err := d.FullPCMBuffer()
|
||||||
|
if err != nil {
|
||||||
|
return pb.TranscriptResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := buf.AsFloat32Buffer().Data
|
||||||
|
|
||||||
|
// Process samples
|
||||||
|
context, err := sd.whisper.NewContext()
|
||||||
|
if err != nil {
|
||||||
|
return pb.TranscriptResult{}, err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
context.SetThreads(uint(opts.Threads))
|
||||||
|
|
||||||
|
if opts.Language != "" {
|
||||||
|
context.SetLanguage(opts.Language)
|
||||||
|
} else {
|
||||||
|
context.SetLanguage("auto")
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.Translate {
|
||||||
|
context.SetTranslate(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := context.Process(data, nil, nil); err != nil {
|
||||||
|
return pb.TranscriptResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
segments := []*pb.TranscriptSegment{}
|
||||||
|
text := ""
|
||||||
|
for {
|
||||||
|
s, err := context.NextSegment()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens []int32
|
||||||
|
for _, t := range s.Tokens {
|
||||||
|
tokens = append(tokens, int32(t.Id))
|
||||||
|
}
|
||||||
|
|
||||||
|
segment := &pb.TranscriptSegment{Id: int32(s.Num), Text: s.Text, Start: int64(s.Start), End: int64(s.End), Tokens: tokens}
|
||||||
|
segments = append(segments, segment)
|
||||||
|
|
||||||
|
text += s.Text
|
||||||
|
}
|
||||||
|
|
||||||
|
return pb.TranscriptResult{
|
||||||
|
Segments: segments,
|
||||||
|
Text: text,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
}
|
@ -2,4 +2,4 @@
|
|||||||
intel-extension-for-pytorch
|
intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
@ -1,6 +1,6 @@
|
|||||||
accelerate
|
accelerate
|
||||||
auto-gptq==0.7.1
|
auto-gptq==0.7.1
|
||||||
grpcio==1.65.4
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
transformers
|
transformers
|
@ -3,6 +3,6 @@ intel-extension-for-pytorch
|
|||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==70.3.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
@ -1,4 +1,4 @@
|
|||||||
bark==0.1.5
|
bark==0.1.5
|
||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
@ -1,2 +1,2 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
@ -3,6 +3,6 @@ intel-extension-for-pytorch
|
|||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
@ -1,4 +1,4 @@
|
|||||||
TTS==0.22.0
|
TTS==0.22.0
|
||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
@ -168,7 +168,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if request.CFGScale != 0:
|
if request.CFGScale != 0:
|
||||||
self.cfg_scale = request.CFGScale
|
self.cfg_scale = request.CFGScale
|
||||||
|
|
||||||
clipmodel = "runwayml/stable-diffusion-v1-5"
|
clipmodel = "Lykon/dreamshaper-8"
|
||||||
if request.CLIPModel != "":
|
if request.CLIPModel != "":
|
||||||
clipmodel = request.CLIPModel
|
clipmodel = request.CLIPModel
|
||||||
clipsubfolder = "text_encoder"
|
clipsubfolder = "text_encoder"
|
||||||
|
@ -3,7 +3,7 @@ intel-extension-for-pytorch
|
|||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==70.3.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||||
diffusers
|
diffusers
|
||||||
opencv-python
|
opencv-python
|
||||||
transformers
|
transformers
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
setuptools
|
setuptools
|
||||||
grpcio==1.65.4
|
grpcio==1.66.1
|
||||||
pillow
|
pillow
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
|
@ -53,7 +53,7 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
self.setUp()
|
self.setUp()
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8"))
|
||||||
self.assertTrue(response.success)
|
self.assertTrue(response.success)
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
self.assertEqual(response.message, "Model loaded successfully")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@ -71,7 +71,7 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
self.setUp()
|
self.setUp()
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8"))
|
||||||
print(response.message)
|
print(response.message)
|
||||||
self.assertTrue(response.success)
|
self.assertTrue(response.success)
|
||||||
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
|
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
|
||||||
|
1
backend/python/exllama/.gitignore
vendored
1
backend/python/exllama/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
source
|
|
@ -1,25 +0,0 @@
|
|||||||
export CONDA_ENV_PATH = "exllama.yml"
|
|
||||||
|
|
||||||
.PHONY: exllama
|
|
||||||
exllama: protogen
|
|
||||||
bash install.sh ${CONDA_ENV_PATH}
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run: protogen
|
|
||||||
@echo "Running exllama..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "exllama run."
|
|
||||||
|
|
||||||
.PHONY: protogen
|
|
||||||
protogen: backend_pb2_grpc.py backend_pb2.py
|
|
||||||
|
|
||||||
.PHONY: protogen-clean
|
|
||||||
protogen-clean:
|
|
||||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
|
||||||
|
|
||||||
backend_pb2_grpc.py backend_pb2.py:
|
|
||||||
python3 -m grpc_tools.protoc -I../.. --python_out=. --grpc_python_out=. backend.proto
|
|
||||||
|
|
||||||
.PHONY: clean
|
|
||||||
clean: protogen-clean
|
|
||||||
$(RM) -r venv source __pycache__
|
|
@ -1,5 +0,0 @@
|
|||||||
# Creating a separate environment for the exllama project
|
|
||||||
|
|
||||||
```
|
|
||||||
make exllama
|
|
||||||
```
|
|
@ -1,159 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import grpc
|
|
||||||
from concurrent import futures
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
import argparse
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os, glob
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import version as torch_version
|
|
||||||
|
|
||||||
from source.tokenizer import ExLlamaTokenizer
|
|
||||||
from source.generator import ExLlamaGenerator
|
|
||||||
from source.model import ExLlama, ExLlamaCache, ExLlamaConfig
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
||||||
def generate(self,prompt, max_new_tokens):
|
|
||||||
self.generator.end_beam_search()
|
|
||||||
|
|
||||||
# Tokenizing the input
|
|
||||||
ids = self.generator.tokenizer.encode(prompt)
|
|
||||||
|
|
||||||
self.generator.gen_begin_reuse(ids)
|
|
||||||
initial_len = self.generator.sequence[0].shape[0]
|
|
||||||
has_leading_space = False
|
|
||||||
decoded_text = ''
|
|
||||||
for i in range(max_new_tokens):
|
|
||||||
token = self.generator.gen_single_token()
|
|
||||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
|
||||||
has_leading_space = True
|
|
||||||
|
|
||||||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
|
||||||
if has_leading_space:
|
|
||||||
decoded_text = ' ' + decoded_text
|
|
||||||
|
|
||||||
if token.item() == self.generator.tokenizer.eos_token_id:
|
|
||||||
break
|
|
||||||
return decoded_text
|
|
||||||
def Health(self, request, context):
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
try:
|
|
||||||
# https://github.com/turboderp/exllama/blob/master/example_cfg.py
|
|
||||||
model_directory = request.ModelFile
|
|
||||||
|
|
||||||
# Locate files we need within that directory
|
|
||||||
tokenizer_path = os.path.join(model_directory, "tokenizer.model")
|
|
||||||
model_config_path = os.path.join(model_directory, "config.json")
|
|
||||||
st_pattern = os.path.join(model_directory, "*.safetensors")
|
|
||||||
model_path = glob.glob(st_pattern)[0]
|
|
||||||
|
|
||||||
# Create config, model, tokenizer and generator
|
|
||||||
|
|
||||||
config = ExLlamaConfig(model_config_path) # create config from config.json
|
|
||||||
config.model_path = model_path # supply path to model weights file
|
|
||||||
if (request.ContextSize):
|
|
||||||
config.max_seq_len = request.ContextSize # override max sequence length
|
|
||||||
config.max_attention_size = request.ContextSize**2 # Should be set to context_size^2.
|
|
||||||
# https://github.com/turboderp/exllama/issues/220#issuecomment-1720324163
|
|
||||||
|
|
||||||
# Set Rope scaling.
|
|
||||||
if (request.RopeFreqScale):
|
|
||||||
# Alpha value for Rope scaling.
|
|
||||||
# Higher value increases context but adds perplexity.
|
|
||||||
# alpha_value and compress_pos_emb are mutually exclusive.
|
|
||||||
# https://github.com/turboderp/exllama/issues/115
|
|
||||||
config.alpha_value = request.RopeFreqScale
|
|
||||||
config.calculate_rotary_embedding_base()
|
|
||||||
|
|
||||||
model = ExLlama(config) # create ExLlama instance and load the weights
|
|
||||||
tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file
|
|
||||||
|
|
||||||
cache = ExLlamaCache(model, batch_size = 2) # create cache for inference
|
|
||||||
generator = ExLlamaGenerator(model, tokenizer, cache) # create generator
|
|
||||||
|
|
||||||
self.generator= generator
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.cache = cache
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
||||||
|
|
||||||
def Predict(self, request, context):
|
|
||||||
penalty = 1.15
|
|
||||||
if request.Penalty != 0.0:
|
|
||||||
penalty = request.Penalty
|
|
||||||
self.generator.settings.token_repetition_penalty_max = penalty
|
|
||||||
self.generator.settings.temperature = request.Temperature
|
|
||||||
self.generator.settings.top_k = request.TopK
|
|
||||||
self.generator.settings.top_p = request.TopP
|
|
||||||
|
|
||||||
tokens = 512
|
|
||||||
if request.Tokens != 0:
|
|
||||||
tokens = request.Tokens
|
|
||||||
|
|
||||||
if self.cache.batch_size == 1:
|
|
||||||
del self.cache
|
|
||||||
self.cache = ExLlamaCache(self.model, batch_size=2)
|
|
||||||
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
|
||||||
|
|
||||||
t = self.generate(request.Prompt, tokens)
|
|
||||||
|
|
||||||
# Remove prompt from response if present
|
|
||||||
if request.Prompt in t:
|
|
||||||
t = t.replace(request.Prompt, "")
|
|
||||||
|
|
||||||
return backend_pb2.Result(message=bytes(t, encoding='utf-8'))
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
# Implement PredictStream RPC
|
|
||||||
#for reply in some_data_generator():
|
|
||||||
# yield reply
|
|
||||||
# Not implemented yet
|
|
||||||
return self.Predict(request, context)
|
|
||||||
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
serve(args.addr)
|
|
@ -1,13 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
LIMIT_TARGETS="cublas"
|
|
||||||
|
|
||||||
source $(dirname $0)/../common/libbackend.sh
|
|
||||||
|
|
||||||
installRequirements
|
|
||||||
|
|
||||||
git clone https://github.com/turboderp/exllama $MY_DIR/source
|
|
||||||
uv pip install ${BUILD_ISOLATION_FLAG} --requirement ${MY_DIR}/source/requirements.txt
|
|
||||||
|
|
||||||
cp -v ./*py $MY_DIR/source/
|
|
@ -1,3 +0,0 @@
|
|||||||
transformers
|
|
||||||
accelerate
|
|
||||||
torch
|
|
@ -1,4 +0,0 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
|
||||||
torch
|
|
||||||
transformers
|
|
||||||
accelerate
|
|
@ -1,3 +0,0 @@
|
|||||||
torch
|
|
||||||
transformers
|
|
||||||
accelerate
|
|
@ -1,4 +0,0 @@
|
|||||||
grpcio==1.65.5
|
|
||||||
protobuf
|
|
||||||
certifi
|
|
||||||
setuptools
|
|
@ -1,7 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
LIMIT_TARGETS="cublas"
|
|
||||||
BACKEND_FILE="${MY_DIR}/source/backend.py"
|
|
||||||
|
|
||||||
source $(dirname $0)/../common/libbackend.sh
|
|
||||||
|
|
||||||
startBackend $@
|
|
@ -1,6 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
source $(dirname $0)/../common/libbackend.sh
|
|
||||||
|
|
||||||
runUnittests
|
|
@ -1,4 +1,4 @@
|
|||||||
grpcio==1.65.4
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
wheel
|
wheel
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
@ -2,7 +2,7 @@
|
|||||||
intel-extension-for-pytorch
|
intel-extension-for-pytorch
|
||||||
torch
|
torch
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
librosa==0.9.1
|
librosa==0.9.1
|
||||||
faster-whisper==1.0.3
|
faster-whisper==1.0.3
|
||||||
@ -15,7 +15,7 @@ unidecode==1.3.7
|
|||||||
whisper-timestamped==1.15.4
|
whisper-timestamped==1.15.4
|
||||||
openai
|
openai
|
||||||
python-dotenv
|
python-dotenv
|
||||||
pypinyin==0.50.0
|
pypinyin==0.53.0
|
||||||
cn2an==0.5.22
|
cn2an==0.5.22
|
||||||
jieba==0.42.1
|
jieba==0.42.1
|
||||||
gradio==4.38.1
|
gradio==4.38.1
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
librosa
|
librosa
|
||||||
faster-whisper
|
faster-whisper
|
||||||
|
@ -1 +1,3 @@
|
|||||||
git+https://github.com/huggingface/parler-tts.git@8e465f1b5fcd223478e07175cb40494d19ffbe17
|
git+https://github.com/huggingface/parler-tts.git@8e465f1b5fcd223478e07175cb40494d19ffbe17
|
||||||
|
llvmlite==0.43.0
|
||||||
|
numba==0.60.0
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||||
torch
|
torch==2.3.0+rocm6.0
|
||||||
torchaudio
|
torchaudio==2.3.0+rocm6.0
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
@ -3,6 +3,6 @@ intel-extension-for-pytorch
|
|||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
||||||
transformers
|
transformers
|
||||||
accelerate
|
accelerate
|
@ -1,4 +1,4 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
@ -5,4 +5,4 @@ accelerate
|
|||||||
torch
|
torch
|
||||||
rerankers[transformers]
|
rerankers[transformers]
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
@ -1,3 +1,3 @@
|
|||||||
grpcio==1.65.4
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
@ -55,7 +55,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
"""
|
"""
|
||||||
model_name = request.Model
|
model_name = request.Model
|
||||||
try:
|
try:
|
||||||
self.model = SentenceTransformer(model_name)
|
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
|
||||||
|
@ -2,5 +2,5 @@ torch
|
|||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
sentence-transformers==3.0.1
|
sentence-transformers==3.1.0
|
||||||
transformers
|
transformers
|
@ -1,5 +1,5 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
torch
|
torch
|
||||||
accelerate
|
accelerate
|
||||||
sentence-transformers==3.0.1
|
sentence-transformers==3.1.0
|
||||||
transformers
|
transformers
|
@ -1,4 +1,4 @@
|
|||||||
torch
|
torch
|
||||||
accelerate
|
accelerate
|
||||||
sentence-transformers==3.0.1
|
sentence-transformers==3.1.0
|
||||||
transformers
|
transformers
|
@ -1,5 +1,5 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||||
torch
|
torch
|
||||||
accelerate
|
accelerate
|
||||||
sentence-transformers==3.0.1
|
sentence-transformers==3.1.0
|
||||||
transformers
|
transformers
|
@ -4,5 +4,5 @@ torch
|
|||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
||||||
accelerate
|
accelerate
|
||||||
sentence-transformers==3.0.1
|
sentence-transformers==3.1.0
|
||||||
transformers
|
transformers
|
@ -1,3 +1,5 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
|
datasets
|
||||||
|
einops
|
@ -15,7 +15,7 @@ import backend_pb2_grpc
|
|||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
from scipy.io.wavfile import write as write_wav
|
from scipy.io import wavfile
|
||||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@ -63,6 +63,61 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
|
def SoundGeneration(self, request, context):
|
||||||
|
model_name = request.model
|
||||||
|
if model_name == "":
|
||||||
|
return backend_pb2.Result(success=False, message="request.model is required")
|
||||||
|
try:
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||||
|
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||||
|
inputs = None
|
||||||
|
if request.text == "":
|
||||||
|
inputs = self.model.get_unconditional_inputs(num_samples=1)
|
||||||
|
elif request.HasField('src'):
|
||||||
|
# TODO SECURITY CODE GOES HERE LOL
|
||||||
|
# WHO KNOWS IF THIS WORKS???
|
||||||
|
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')
|
||||||
|
|
||||||
|
if request.HasField('src_divisor'):
|
||||||
|
wsamples = wsamples[: len(wsamples) // request.src_divisor]
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
audio=wsamples,
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
text=[request.text],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[request.text],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = 256
|
||||||
|
if request.HasField('duration'):
|
||||||
|
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
|
||||||
|
guidance = 3.0
|
||||||
|
if request.HasField('temperature'):
|
||||||
|
guidance = request.temperature
|
||||||
|
dosample = True
|
||||||
|
if request.HasField('sample'):
|
||||||
|
dosample = request.sample
|
||||||
|
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
|
||||||
|
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
|
||||||
|
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||||
|
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||||
|
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
|
||||||
|
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
|
||||||
|
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
|
||||||
|
print(request, file=sys.stderr)
|
||||||
|
except Exception as err:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
return backend_pb2.Result(success=True)
|
||||||
|
|
||||||
|
|
||||||
|
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||||
def TTS(self, request, context):
|
def TTS(self, request, context):
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
if model_name == "":
|
if model_name == "":
|
||||||
@ -75,8 +130,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
padding=True,
|
padding=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
tokens = 256
|
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
|
||||||
# TODO get tokens from request?
|
|
||||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
||||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
||||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
||||||
|
@ -4,4 +4,4 @@ transformers
|
|||||||
accelerate
|
accelerate
|
||||||
torch
|
torch
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
@ -1,4 +1,4 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
scipy==1.14.0
|
scipy==1.14.0
|
||||||
certifi
|
certifi
|
@ -63,7 +63,7 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
|
|
||||||
def test_tts(self):
|
def test_tts(self):
|
||||||
"""
|
"""
|
||||||
This method tests if the embeddings are generated successfully
|
This method tests if TTS is generated successfully
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.setUp()
|
self.setUp()
|
||||||
@ -79,3 +79,22 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
self.fail("TTS service failed")
|
self.fail("TTS service failed")
|
||||||
finally:
|
finally:
|
||||||
self.tearDown()
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_sound_generation(self):
|
||||||
|
"""
|
||||||
|
This method tests if SoundGeneration is generated successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
sg_request = backend_pb2.SoundGenerationRequest(text="80s TV news production music hit for tonight's biggest story")
|
||||||
|
sg_response = stub.SoundGeneration(sg_request)
|
||||||
|
self.assertIsNotNone(sg_response)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("SoundGeneration service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
@ -1,4 +1,4 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
|
@ -4,4 +4,4 @@ accelerate
|
|||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
@ -1,3 +1,3 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
@ -135,6 +135,26 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
res = await gen.__anext__()
|
res = await gen.__anext__()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def Embedding(self, request, context):
|
||||||
|
"""
|
||||||
|
A gRPC method that calculates embeddings for a given sentence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: An EmbeddingRequest object that contains the request parameters.
|
||||||
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An EmbeddingResult object that contains the calculated embeddings.
|
||||||
|
"""
|
||||||
|
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||||
|
outputs = self.model.encode(request.Embeddings)
|
||||||
|
# Check if we have one result at least
|
||||||
|
if len(outputs) == 0:
|
||||||
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||||
|
context.set_details("No embeddings were calculated.")
|
||||||
|
return backend_pb2.EmbeddingResult()
|
||||||
|
return backend_pb2.EmbeddingResult(embeddings=outputs[0].outputs.embedding)
|
||||||
|
|
||||||
async def PredictStream(self, request, context):
|
async def PredictStream(self, request, context):
|
||||||
"""
|
"""
|
||||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||||
|
@ -4,4 +4,4 @@ accelerate
|
|||||||
torch
|
torch
|
||||||
transformers
|
transformers
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools==70.3.0 # https://github.com/mudler/LocalAI/issues/2406
|
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
|
@ -1,4 +1,4 @@
|
|||||||
grpcio==1.65.5
|
grpcio==1.66.1
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
setuptools
|
setuptools
|
@ -74,3 +74,26 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
self.fail("text service failed")
|
self.fail("text service failed")
|
||||||
finally:
|
finally:
|
||||||
self.tearDown()
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_embedding(self):
|
||||||
|
"""
|
||||||
|
This method tests if the embeddings are generated successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
|
||||||
|
embedding_response = stub.Embedding(embedding_request)
|
||||||
|
self.assertIsNotNone(embedding_response.embeddings)
|
||||||
|
# assert that is a list of floats
|
||||||
|
self.assertIsInstance(embedding_response.embeddings, list)
|
||||||
|
# assert that the list is not empty
|
||||||
|
self.assertTrue(len(embedding_response.embeddings) > 0)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("Embedding service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
13
core/backend/backend_suite_test.go
Normal file
13
core/backend/backend_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package backend_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackend(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Backend test suite")
|
||||||
|
}
|
@ -9,6 +9,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
|
||||||
@ -87,7 +89,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
|||||||
case string:
|
case string:
|
||||||
protoMessages[i].Content = ct
|
protoMessages[i].Content = ct
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct)
|
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -181,13 +183,37 @@ func Finetune(config config.BackendConfig, input, prediction string) string {
|
|||||||
mu.Lock()
|
mu.Lock()
|
||||||
reg, ok := cutstrings[c]
|
reg, ok := cutstrings[c]
|
||||||
if !ok {
|
if !ok {
|
||||||
cutstrings[c] = regexp.MustCompile(c)
|
r, err := regexp.Compile(c)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("failed to compile regex")
|
||||||
|
}
|
||||||
|
cutstrings[c] = r
|
||||||
reg = cutstrings[c]
|
reg = cutstrings[c]
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
prediction = reg.ReplaceAllString(prediction, "")
|
prediction = reg.ReplaceAllString(prediction, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extract results from the response which can be for instance inside XML tags
|
||||||
|
var predResult string
|
||||||
|
for _, r := range config.ExtractRegex {
|
||||||
|
mu.Lock()
|
||||||
|
reg, ok := cutstrings[r]
|
||||||
|
if !ok {
|
||||||
|
regex, err := regexp.Compile(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("failed to compile regex")
|
||||||
|
}
|
||||||
|
cutstrings[r] = regex
|
||||||
|
reg = regex
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
predResult += reg.FindString(prediction)
|
||||||
|
}
|
||||||
|
if predResult != "" {
|
||||||
|
prediction = predResult
|
||||||
|
}
|
||||||
|
|
||||||
for _, c := range config.TrimSpace {
|
for _, c := range config.TrimSpace {
|
||||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
||||||
}
|
}
|
||||||
|
109
core/backend/llm_test.go
Normal file
109
core/backend/llm_test.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package backend_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/mudler/LocalAI/core/backend"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("LLM tests", func() {
|
||||||
|
Context("Finetune LLM output", func() {
|
||||||
|
var (
|
||||||
|
testConfig config.BackendConfig
|
||||||
|
input string
|
||||||
|
prediction string
|
||||||
|
result string
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
testConfig = config.BackendConfig{
|
||||||
|
PredictionOptions: schema.PredictionOptions{
|
||||||
|
Echo: false,
|
||||||
|
},
|
||||||
|
LLMConfig: config.LLMConfig{
|
||||||
|
Cutstrings: []string{`<.*?>`}, // Example regex for removing XML tags
|
||||||
|
ExtractRegex: []string{`<result>(.*?)</result>`}, // Example regex to extract from tags
|
||||||
|
TrimSpace: []string{" ", "\n"},
|
||||||
|
TrimSuffix: []string{".", "!"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when echo is enabled", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
testConfig.Echo = true
|
||||||
|
input = "Hello"
|
||||||
|
prediction = "World"
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should prepend input to prediction", func() {
|
||||||
|
result = Finetune(testConfig, input, prediction)
|
||||||
|
Expect(result).To(Equal("HelloWorld"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when echo is disabled", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
testConfig.Echo = false
|
||||||
|
input = "Hello"
|
||||||
|
prediction = "World"
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should not modify the prediction with input", func() {
|
||||||
|
result = Finetune(testConfig, input, prediction)
|
||||||
|
Expect(result).To(Equal("World"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when cutstrings regex is applied", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
input = ""
|
||||||
|
prediction = "<div>Hello</div> World"
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should remove substrings matching cutstrings regex", func() {
|
||||||
|
result = Finetune(testConfig, input, prediction)
|
||||||
|
Expect(result).To(Equal("Hello World"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when extract regex is applied", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
input = ""
|
||||||
|
prediction = "<response><result>42</result></response>"
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should extract substrings matching the extract regex", func() {
|
||||||
|
result = Finetune(testConfig, input, prediction)
|
||||||
|
Expect(result).To(Equal("42"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when trimming spaces", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
input = ""
|
||||||
|
prediction = " Hello World "
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should trim spaces from the prediction", func() {
|
||||||
|
result = Finetune(testConfig, input, prediction)
|
||||||
|
Expect(result).To(Equal("Hello World"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when trimming suffixes", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
input = ""
|
||||||
|
prediction = "Hello World."
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should trim suffixes from the prediction", func() {
|
||||||
|
result = Finetune(testConfig, input, prediction)
|
||||||
|
Expect(result).To(Equal("Hello World"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
74
core/backend/soundgeneration.go
Normal file
74
core/backend/soundgeneration.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SoundGeneration(
|
||||||
|
backend string,
|
||||||
|
modelFile string,
|
||||||
|
text string,
|
||||||
|
duration *float32,
|
||||||
|
temperature *float32,
|
||||||
|
doSample *bool,
|
||||||
|
sourceFile *string,
|
||||||
|
sourceDivisor *int32,
|
||||||
|
loader *model.ModelLoader,
|
||||||
|
appConfig *config.ApplicationConfig,
|
||||||
|
backendConfig config.BackendConfig,
|
||||||
|
) (string, *proto.Result, error) {
|
||||||
|
if backend == "" {
|
||||||
|
return "", nil, fmt.Errorf("backend is a required parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(backendConfig)
|
||||||
|
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
|
||||||
|
model.WithBackendString(backend),
|
||||||
|
model.WithModel(modelFile),
|
||||||
|
model.WithContext(appConfig.Context),
|
||||||
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
|
})
|
||||||
|
|
||||||
|
soundGenModel, err := loader.BackendLoader(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if soundGenModel == nil {
|
||||||
|
return "", nil, fmt.Errorf("could not load sound generation model")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav")
|
||||||
|
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
||||||
|
|
||||||
|
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
|
||||||
|
Text: text,
|
||||||
|
Model: modelFile,
|
||||||
|
Dst: filePath,
|
||||||
|
Sample: doSample,
|
||||||
|
Duration: duration,
|
||||||
|
Temperature: temperature,
|
||||||
|
Src: sourceFile,
|
||||||
|
SrcDivisor: sourceDivisor,
|
||||||
|
})
|
||||||
|
|
||||||
|
// return RPC error if any
|
||||||
|
if !res.Success {
|
||||||
|
return "", nil, fmt.Errorf(res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filePath, res, err
|
||||||
|
}
|
@ -3,12 +3,13 @@ package backend
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||||
@ -21,19 +22,40 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
|
|||||||
model.WithAssetDir(appConfig.AssetsDestination),
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
})
|
})
|
||||||
|
|
||||||
whisperModel, err := ml.BackendLoader(opts...)
|
transcriptionModel, err := ml.BackendLoader(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if whisperModel == nil {
|
if transcriptionModel == nil {
|
||||||
return nil, fmt.Errorf("could not load whisper model")
|
return nil, fmt.Errorf("could not load transcription model")
|
||||||
}
|
}
|
||||||
|
|
||||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||||
Dst: audio,
|
Dst: audio,
|
||||||
Language: language,
|
Language: language,
|
||||||
Translate: translate,
|
Translate: translate,
|
||||||
Threads: uint32(*backendConfig.Threads),
|
Threads: uint32(*backendConfig.Threads),
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tr := &schema.TranscriptionResult{
|
||||||
|
Text: r.Text,
|
||||||
|
}
|
||||||
|
for _, s := range r.Segments {
|
||||||
|
var tks []int
|
||||||
|
for _, t := range s.Tokens {
|
||||||
|
tks = append(tks, int(t))
|
||||||
|
}
|
||||||
|
tr.Segments = append(tr.Segments,
|
||||||
|
schema.Segment{
|
||||||
|
Text: s.Text,
|
||||||
|
Id: int(s.Id),
|
||||||
|
Start: time.Duration(s.Start),
|
||||||
|
End: time.Duration(s.End),
|
||||||
|
Tokens: tks,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return tr, err
|
||||||
}
|
}
|
||||||
|
@ -9,31 +9,15 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
|
||||||
counter := 1
|
|
||||||
fileName := baseName + ext
|
|
||||||
|
|
||||||
for {
|
|
||||||
filePath := filepath.Join(dir, fileName)
|
|
||||||
_, err := os.Stat(filePath)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return fileName
|
|
||||||
}
|
|
||||||
|
|
||||||
counter++
|
|
||||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelTTS(
|
func ModelTTS(
|
||||||
backend,
|
backend,
|
||||||
text,
|
text,
|
||||||
modelFile,
|
modelFile,
|
||||||
voice ,
|
voice,
|
||||||
language string,
|
language string,
|
||||||
loader *model.ModelLoader,
|
loader *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
@ -66,7 +50,7 @@ func ModelTTS(
|
|||||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
|
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
|
||||||
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
||||||
|
|
||||||
// If the model file is not empty, we pass it joined with the model path
|
// If the model file is not empty, we pass it joined with the model path
|
||||||
@ -88,12 +72,15 @@ func ModelTTS(
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||||
Text: text,
|
Text: text,
|
||||||
Model: modelPath,
|
Model: modelPath,
|
||||||
Voice: voice,
|
Voice: voice,
|
||||||
Dst: filePath,
|
Dst: filePath,
|
||||||
Language: &language,
|
Language: &language,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// return RPC error if any
|
// return RPC error if any
|
||||||
if !res.Success {
|
if !res.Success {
|
||||||
|
80
core/cli/api/p2p.go
Normal file
80
core/cli/api/p2p.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package cli_api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/p2p"
|
||||||
|
"github.com/mudler/edgevpn/pkg/node"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StartP2PStack(ctx context.Context, address, token, networkID string, federated bool) error {
|
||||||
|
var n *node.Node
|
||||||
|
// Here we are avoiding creating multiple nodes:
|
||||||
|
// - if the federated mode is enabled, we create a federated node and expose a service
|
||||||
|
// - exposing a service creates a node with specific options, and we don't want to create another node
|
||||||
|
|
||||||
|
// If the federated mode is enabled, we expose a service to the local instance running
|
||||||
|
// at r.Address
|
||||||
|
if federated {
|
||||||
|
_, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Here a new node is created and started
|
||||||
|
// and a service is exposed by the node
|
||||||
|
node, err := p2p.ExposeService(ctx, "localhost", port, token, p2p.NetworkID(networkID, p2p.FederatedID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p2p.ServiceDiscoverer(ctx, node, token, p2p.NetworkID(networkID, p2p.FederatedID), nil, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
n = node
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the p2p mode is enabled, we start the service discovery
|
||||||
|
if token != "" {
|
||||||
|
// If a node wasn't created previously, create it
|
||||||
|
if n == nil {
|
||||||
|
node, err := p2p.NewNode(token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = node.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting new node: %w", err)
|
||||||
|
}
|
||||||
|
n = node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach a ServiceDiscoverer to the p2p node
|
||||||
|
log.Info().Msg("Starting P2P server discovery...")
|
||||||
|
if err := p2p.ServiceDiscoverer(ctx, n, token, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node p2p.NodeData) {
|
||||||
|
var tunnelAddresses []string
|
||||||
|
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) {
|
||||||
|
if v.IsOnline() {
|
||||||
|
tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
|
||||||
|
} else {
|
||||||
|
log.Info().Msgf("Node %s is offline", v.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tunnelEnvVar := strings.Join(tunnelAddresses, ",")
|
||||||
|
|
||||||
|
os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
|
||||||
|
log.Debug().Msgf("setting LLAMACPP_GRPC_SERVERS to %s", tunnelEnvVar)
|
||||||
|
}, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -8,12 +8,13 @@ import (
|
|||||||
var CLI struct {
|
var CLI struct {
|
||||||
cliContext.Context `embed:""`
|
cliContext.Context `embed:""`
|
||||||
|
|
||||||
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
|
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
|
||||||
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
|
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
|
||||||
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
|
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
|
||||||
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
|
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
|
||||||
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
|
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
|
||||||
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
|
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
|
||||||
Util UtilCMD `cmd:"" help:"Utility commands"`
|
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
|
||||||
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
|
Util UtilCMD `cmd:"" help:"Utility commands"`
|
||||||
|
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
|
||||||
}
|
}
|
||||||
|
@ -3,11 +3,10 @@ package cli
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/http"
|
"github.com/mudler/LocalAI/core/http"
|
||||||
@ -42,29 +41,34 @@ type RunCMD struct {
|
|||||||
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
|
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
|
||||||
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" default:"512" help:"Default context size for models" group:"performance"`
|
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" default:"512" help:"Default context size for models" group:"performance"`
|
||||||
|
|
||||||
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
|
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
|
||||||
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
|
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
|
||||||
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
|
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
|
||||||
LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"`
|
LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"`
|
||||||
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
|
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
|
||||||
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
||||||
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
||||||
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
|
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
|
||||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||||
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
|
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
|
||||||
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
|
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
|
||||||
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
|
HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"`
|
||||||
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
|
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
|
||||||
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"`
|
Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"`
|
||||||
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
|
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
|
||||||
ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
|
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
|
||||||
EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"`
|
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
|
||||||
WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"`
|
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
|
||||||
EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"`
|
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"`
|
||||||
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
|
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
|
||||||
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
|
ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
|
||||||
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
|
EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"`
|
||||||
|
WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"`
|
||||||
|
EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"`
|
||||||
|
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
|
||||||
|
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
|
||||||
|
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||||
@ -96,6 +100,9 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
|
config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
|
||||||
config.WithOpaqueErrors(r.OpaqueErrors),
|
config.WithOpaqueErrors(r.OpaqueErrors),
|
||||||
config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
|
config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
|
||||||
|
config.WithSubtleKeyComparison(r.UseSubtleKeyComparison),
|
||||||
|
config.WithDisableApiKeyRequirementForHttpGet(r.DisableApiKeyRequirementForHttpGet),
|
||||||
|
config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
|
||||||
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
|
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +114,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
// IF no token is provided, and p2p is enabled,
|
// IF no token is provided, and p2p is enabled,
|
||||||
// we generate one and wait for the user to pick up the token (this is for interactive)
|
// we generate one and wait for the user to pick up the token (this is for interactive)
|
||||||
log.Info().Msg("No token provided, generating one")
|
log.Info().Msg("No token provided, generating one")
|
||||||
token = p2p.GenerateToken()
|
token = p2p.GenerateToken(r.Peer2PeerDHTInterval, r.Peer2PeerOTPInterval)
|
||||||
log.Info().Msg("Generated Token:")
|
log.Info().Msg("Generated Token:")
|
||||||
fmt.Println(token)
|
fmt.Println(token)
|
||||||
|
|
||||||
@ -115,45 +122,12 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
fmt.Printf("export TOKEN=\"%s\"\nlocal-ai worker p2p-llama-cpp-rpc\n", token)
|
fmt.Printf("export TOKEN=\"%s\"\nlocal-ai worker p2p-llama-cpp-rpc\n", token)
|
||||||
}
|
}
|
||||||
opts = append(opts, config.WithP2PToken(token))
|
opts = append(opts, config.WithP2PToken(token))
|
||||||
|
|
||||||
node, err := p2p.NewNode(token)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Msg("Starting P2P server discovery...")
|
|
||||||
if err := p2p.ServiceDiscoverer(context.Background(), node, token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID), func(serviceID string, node p2p.NodeData) {
|
|
||||||
var tunnelAddresses []string
|
|
||||||
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) {
|
|
||||||
if v.IsOnline() {
|
|
||||||
tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
|
|
||||||
} else {
|
|
||||||
log.Info().Msgf("Node %s is offline", v.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tunnelEnvVar := strings.Join(tunnelAddresses, ",")
|
|
||||||
|
|
||||||
os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
|
|
||||||
log.Debug().Msgf("setting LLAMACPP_GRPC_SERVERS to %s", tunnelEnvVar)
|
|
||||||
}, true); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Federated {
|
backgroundCtx := context.Background()
|
||||||
_, port, err := net.SplitHostPort(r.Address)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fedCtx := context.Background()
|
|
||||||
node, err := p2p.ExposeService(fedCtx, "localhost", port, token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.FederatedID))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := p2p.ServiceDiscoverer(fedCtx, node, token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.FederatedID), nil, false); err != nil {
|
if err := cli_api.StartP2PStack(backgroundCtx, r.Address, token, r.Peer2PeerNetworkID, r.Federated); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
idleWatchDog := r.EnableWatchdogIdle
|
idleWatchDog := r.EnableWatchdogIdle
|
||||||
|
110
core/cli/soundgeneration.go
Normal file
110
core/cli/soundgeneration.go
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SoundGenerationCMD struct {
|
||||||
|
Text []string `arg:""`
|
||||||
|
|
||||||
|
Backend string `short:"b" required:"" help:"Backend to run the SoundGeneration model"`
|
||||||
|
Model string `short:"m" required:"" help:"Model name to run the SoundGeneration"`
|
||||||
|
Duration string `short:"d" help:"If specified, the length of audio to generate in seconds"`
|
||||||
|
Temperature string `short:"t" help:"If specified, the temperature of the generation"`
|
||||||
|
InputFile string `short:"i" help:"If specified, the input file to condition generation upon"`
|
||||||
|
InputFileSampleDivisor string `short:"f" help:"If InputFile and this divisor is specified, the first portion of the sample file will be used"`
|
||||||
|
DoSample bool `short:"s" default:"true" help:"Enables sampling from the model. Better quality at the cost of speed. Defaults to enabled."`
|
||||||
|
OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"`
|
||||||
|
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||||
|
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
|
||||||
|
ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToFloat32Ptr(input string) *float32 {
|
||||||
|
f, err := strconv.ParseFloat(input, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
f2 := float32(f)
|
||||||
|
return &f2
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToInt32Ptr(input string) *int32 {
|
||||||
|
i, err := strconv.ParseInt(input, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
i2 := int32(i)
|
||||||
|
return &i2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||||
|
outputFile := t.OutputFile
|
||||||
|
outputDir := t.BackendAssetsPath
|
||||||
|
if outputFile != "" {
|
||||||
|
outputDir = filepath.Dir(outputFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
text := strings.Join(t.Text, " ")
|
||||||
|
|
||||||
|
externalBackends := make(map[string]string)
|
||||||
|
// split ":" to get backend name and the uri
|
||||||
|
for _, v := range t.ExternalGRPCBackends {
|
||||||
|
backend := v[:strings.IndexByte(v, ':')]
|
||||||
|
uri := v[strings.IndexByte(v, ':')+1:]
|
||||||
|
externalBackends[backend] = uri
|
||||||
|
fmt.Printf("TMP externalBackends[%q]=%q\n\n", backend, uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &config.ApplicationConfig{
|
||||||
|
ModelPath: t.ModelsPath,
|
||||||
|
Context: context.Background(),
|
||||||
|
AudioDir: outputDir,
|
||||||
|
AssetsDestination: t.BackendAssetsPath,
|
||||||
|
ExternalGRPCBackends: externalBackends,
|
||||||
|
}
|
||||||
|
ml := model.NewModelLoader(opts.ModelPath)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := ml.StopAllGRPC()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("unable to stop all grpc processes")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
options := config.BackendConfig{}
|
||||||
|
options.SetDefaults()
|
||||||
|
|
||||||
|
var inputFile *string
|
||||||
|
if t.InputFile != "" {
|
||||||
|
inputFile = &t.InputFile
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath, _, err := backend.SoundGeneration(t.Backend, t.Model, text,
|
||||||
|
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
|
||||||
|
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if outputFile != "" {
|
||||||
|
if err := os.Rename(filePath, outputFile); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("Generate file %s\n", outputFile)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Generate file %s\n", filePath)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -2,6 +2,7 @@ package worker
|
|||||||
|
|
||||||
type WorkerFlags struct {
|
type WorkerFlags struct {
|
||||||
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
|
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
|
||||||
|
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Worker struct {
|
type Worker struct {
|
||||||
|
@ -3,6 +3,7 @@ package worker
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
@ -12,7 +13,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type LLamaCPP struct {
|
type LLamaCPP struct {
|
||||||
Args []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
|
|
||||||
WorkerFlags `embed:""`
|
WorkerFlags `embed:""`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,9 +34,8 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
|
|||||||
"llama-cpp-rpc-server",
|
"llama-cpp-rpc-server",
|
||||||
)
|
)
|
||||||
|
|
||||||
args := os.Args[4:]
|
args := strings.Split(r.ExtraLLamaCPPArgs, " ")
|
||||||
args, grpcProcess = library.LoadLDSO(r.BackendAssetsPath, args, grpcProcess)
|
args, grpcProcess = library.LoadLDSO(r.BackendAssetsPath, args, grpcProcess)
|
||||||
|
|
||||||
args = append([]string{grpcProcess}, args...)
|
args = append([]string{grpcProcess}, args...)
|
||||||
return syscall.Exec(
|
return syscall.Exec(
|
||||||
grpcProcess,
|
grpcProcess,
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
@ -20,12 +21,11 @@ import (
|
|||||||
|
|
||||||
type P2P struct {
|
type P2P struct {
|
||||||
WorkerFlags `embed:""`
|
WorkerFlags `embed:""`
|
||||||
Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"`
|
Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"`
|
||||||
NoRunner bool `env:"LOCALAI_NO_RUNNER,NO_RUNNER" help:"Do not start the llama-cpp-rpc-server"`
|
NoRunner bool `env:"LOCALAI_NO_RUNNER,NO_RUNNER" help:"Do not start the llama-cpp-rpc-server"`
|
||||||
RunnerAddress string `env:"LOCALAI_RUNNER_ADDRESS,RUNNER_ADDRESS" help:"Address of the llama-cpp-rpc-server"`
|
RunnerAddress string `env:"LOCALAI_RUNNER_ADDRESS,RUNNER_ADDRESS" help:"Address of the llama-cpp-rpc-server"`
|
||||||
RunnerPort string `env:"LOCALAI_RUNNER_PORT,RUNNER_PORT" help:"Port of the llama-cpp-rpc-server"`
|
RunnerPort string `env:"LOCALAI_RUNNER_PORT,RUNNER_PORT" help:"Port of the llama-cpp-rpc-server"`
|
||||||
ExtraLLamaCPPArgs []string `env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
|
||||||
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *P2P) Run(ctx *cliContext.Context) error {
|
func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||||
@ -65,44 +65,42 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info().Msgf("You need to start llama-cpp-rpc-server on '%s:%s'", address, p)
|
log.Info().Msgf("You need to start llama-cpp-rpc-server on '%s:%s'", address, p)
|
||||||
|
} else {
|
||||||
|
// Start llama.cpp directly from the version we have pre-packaged
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port)
|
||||||
|
|
||||||
return nil
|
grpcProcess := assets.ResolvePath(
|
||||||
}
|
r.BackendAssetsPath,
|
||||||
|
"util",
|
||||||
|
"llama-cpp-rpc-server",
|
||||||
|
)
|
||||||
|
extraArgs := strings.Split(r.ExtraLLamaCPPArgs, " ")
|
||||||
|
args := append([]string{"--host", address, "--port", fmt.Sprint(port)}, extraArgs...)
|
||||||
|
args, grpcProcess = library.LoadLDSO(r.BackendAssetsPath, args, grpcProcess)
|
||||||
|
|
||||||
// Start llama.cpp directly from the version we have pre-packaged
|
cmd := exec.Command(
|
||||||
go func() {
|
grpcProcess, args...,
|
||||||
for {
|
)
|
||||||
log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port)
|
|
||||||
|
|
||||||
grpcProcess := assets.ResolvePath(
|
cmd.Env = os.Environ()
|
||||||
r.BackendAssetsPath,
|
|
||||||
"util",
|
|
||||||
"llama-cpp-rpc-server",
|
|
||||||
)
|
|
||||||
|
|
||||||
args := append([]string{"--host", address, "--port", fmt.Sprint(port)}, r.ExtraLLamaCPPArgs...)
|
cmd.Stderr = os.Stdout
|
||||||
args, grpcProcess = library.LoadLDSO(r.BackendAssetsPath, args, grpcProcess)
|
cmd.Stdout = os.Stdout
|
||||||
|
|
||||||
cmd := exec.Command(
|
if err := cmd.Start(); err != nil {
|
||||||
grpcProcess, args...,
|
log.Error().Any("grpcProcess", grpcProcess).Any("args", args).Err(err).Msg("Failed to start llama-cpp-rpc-server")
|
||||||
)
|
}
|
||||||
|
|
||||||
cmd.Env = os.Environ()
|
cmd.Wait()
|
||||||
|
|
||||||
cmd.Stderr = os.Stdout
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
log.Error().Any("grpcProcess", grpcProcess).Any("args", args).Err(err).Msg("Failed to start llama-cpp-rpc-server")
|
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
cmd.Wait()
|
_, err = p2p.ExposeService(context.Background(), address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = p2p.ExposeService(context.Background(), address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"regexp"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||||
@ -16,7 +17,6 @@ type ApplicationConfig struct {
|
|||||||
ModelPath string
|
ModelPath string
|
||||||
LibPath string
|
LibPath string
|
||||||
UploadLimitMB, Threads, ContextSize int
|
UploadLimitMB, Threads, ContextSize int
|
||||||
DisableWebUI bool
|
|
||||||
F16 bool
|
F16 bool
|
||||||
Debug bool
|
Debug bool
|
||||||
ImageDir string
|
ImageDir string
|
||||||
@ -31,11 +31,17 @@ type ApplicationConfig struct {
|
|||||||
PreloadModelsFromPath string
|
PreloadModelsFromPath string
|
||||||
CORSAllowOrigins string
|
CORSAllowOrigins string
|
||||||
ApiKeys []string
|
ApiKeys []string
|
||||||
EnforcePredownloadScans bool
|
|
||||||
OpaqueErrors bool
|
|
||||||
P2PToken string
|
P2PToken string
|
||||||
P2PNetworkID string
|
P2PNetworkID string
|
||||||
|
|
||||||
|
DisableWebUI bool
|
||||||
|
EnforcePredownloadScans bool
|
||||||
|
OpaqueErrors bool
|
||||||
|
UseSubtleKeyComparison bool
|
||||||
|
DisableApiKeyRequirementForHttpGet bool
|
||||||
|
HttpGetExemptedEndpoints []*regexp.Regexp
|
||||||
|
DisableGalleryEndpoint bool
|
||||||
|
|
||||||
ModelLibraryURL string
|
ModelLibraryURL string
|
||||||
|
|
||||||
Galleries []Gallery
|
Galleries []Gallery
|
||||||
@ -57,8 +63,6 @@ type ApplicationConfig struct {
|
|||||||
ModelsURL []string
|
ModelsURL []string
|
||||||
|
|
||||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||||
|
|
||||||
DisableGalleryEndpoint bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppOption func(*ApplicationConfig)
|
type AppOption func(*ApplicationConfig)
|
||||||
@ -327,6 +331,32 @@ func WithOpaqueErrors(opaque bool) AppOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithSubtleKeyComparison(subtle bool) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.UseSubtleKeyComparison = subtle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.DisableApiKeyRequirementForHttpGet = required
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHttpGetExemptedEndpoints(endpoints []string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.HttpGetExemptedEndpoints = []*regexp.Regexp{}
|
||||||
|
for _, epr := range endpoints {
|
||||||
|
r, err := regexp.Compile(epr)
|
||||||
|
if err == nil && r != nil {
|
||||||
|
o.HttpGetExemptedEndpoints = append(o.HttpGetExemptedEndpoints, r)
|
||||||
|
} else {
|
||||||
|
log.Warn().Err(err).Str("regex", epr).Msg("Error while compiling HTTP Get Exemption regex, skipping this entry.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
|
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
|
||||||
// Some options defined at the application level are going to be passed as defaults for
|
// Some options defined at the application level are going to be passed as defaults for
|
||||||
// all the configuration for the models.
|
// all the configuration for the models.
|
||||||
|
@ -126,6 +126,7 @@ type LLMConfig struct {
|
|||||||
Grammar string `yaml:"grammar"`
|
Grammar string `yaml:"grammar"`
|
||||||
StopWords []string `yaml:"stopwords"`
|
StopWords []string `yaml:"stopwords"`
|
||||||
Cutstrings []string `yaml:"cutstrings"`
|
Cutstrings []string `yaml:"cutstrings"`
|
||||||
|
ExtractRegex []string `yaml:"extract_regex"`
|
||||||
TrimSpace []string `yaml:"trimspace"`
|
TrimSpace []string `yaml:"trimspace"`
|
||||||
TrimSuffix []string `yaml:"trimsuffix"`
|
TrimSuffix []string `yaml:"trimsuffix"`
|
||||||
|
|
||||||
|
@ -3,13 +3,15 @@ package http
|
|||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
|
"github.com/dave-gray101/v2keyauth"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||||
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
"github.com/mudler/LocalAI/core/http/routes"
|
"github.com/mudler/LocalAI/core/http/routes"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
@ -137,37 +139,14 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
kaConfig, err := middleware.GetKeyAuthConfig(appConfig)
|
||||||
auth := func(c *fiber.Ctx) error {
|
if err != nil || kaConfig == nil {
|
||||||
if len(appConfig.ApiKeys) == 0 {
|
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(appConfig.ApiKeys) == 0 {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
authHeader := readAuthHeader(c)
|
|
||||||
if authHeader == "" {
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it's a bearer token
|
|
||||||
authHeaderParts := strings.Split(authHeader, " ")
|
|
||||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := authHeaderParts[1]
|
|
||||||
for _, key := range appConfig.ApiKeys {
|
|
||||||
if apiKey == key {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||||
|
app.Use(v2keyauth.New(*kaConfig))
|
||||||
|
|
||||||
if appConfig.CORS {
|
if appConfig.CORS {
|
||||||
var c func(ctx *fiber.Ctx) error
|
var c func(ctx *fiber.Ctx) error
|
||||||
if appConfig.CORSAllowOrigins == "" {
|
if appConfig.CORSAllowOrigins == "" {
|
||||||
@ -192,13 +171,13 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||||||
galleryService := services.NewGalleryService(appConfig)
|
galleryService := services.NewGalleryService(appConfig)
|
||||||
galleryService.Start(appConfig.Context, cl)
|
galleryService.Start(appConfig.Context, cl)
|
||||||
|
|
||||||
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth)
|
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
|
||||||
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService, auth)
|
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
|
||||||
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
|
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
|
||||||
if !appConfig.DisableWebUI {
|
if !appConfig.DisableWebUI {
|
||||||
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth)
|
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
|
||||||
}
|
}
|
||||||
routes.RegisterJINARoutes(app, cl, ml, appConfig, auth)
|
routes.RegisterJINARoutes(app, cl, ml, appConfig)
|
||||||
|
|
||||||
httpFS := http.FS(embedDirStatic)
|
httpFS := http.FS(embedDirStatic)
|
||||||
|
|
||||||
|
@ -772,6 +772,17 @@ var _ = Describe("API test", func() {
|
|||||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error:"))
|
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error:"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("shows the external backend", func() {
|
||||||
|
// do an http request to the /system endpoint
|
||||||
|
resp, err := http.Get("http://127.0.0.1:9090/system")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(200))
|
||||||
|
dat, err := io.ReadAll(resp.Body)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(string(dat)).To(ContainSubstring("huggingface"))
|
||||||
|
Expect(string(dat)).To(ContainSubstring("llama-cpp"))
|
||||||
|
})
|
||||||
|
|
||||||
It("transcribes audio", func() {
|
It("transcribes audio", func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
Skip("test supported only on linux")
|
Skip("test supported only on linux")
|
||||||
|
65
core/http/endpoints/elevenlabs/soundgeneration.go
Normal file
65
core/http/endpoints/elevenlabs/soundgeneration.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package elevenlabs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation
|
||||||
|
// @Summary Generates audio from the input text.
|
||||||
|
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||||
|
// @Success 200 {string} binary "Response"
|
||||||
|
// @Router /v1/sound-generation [post]
|
||||||
|
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(schema.ElevenLabsSoundGenerationRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||||
|
config.LoadOptionDebug(appConfig.Debug),
|
||||||
|
config.LoadOptionThreads(appConfig.Threads),
|
||||||
|
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||||
|
config.LoadOptionF16(appConfig.F16),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID")
|
||||||
|
} else {
|
||||||
|
if input.ModelID != "" {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
} else {
|
||||||
|
modelFile = cfg.Model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
|
||||||
|
|
||||||
|
if input.Duration != nil {
|
||||||
|
log.Debug().Float32("duration", *input.Duration).Msg("duration set")
|
||||||
|
}
|
||||||
|
if input.Temperature != nil {
|
||||||
|
log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Support uploading files?
|
||||||
|
filePath, _, err := backend.SoundGeneration(cfg.Backend, modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Download(filePath)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
29
core/http/endpoints/localai/system.go
Normal file
29
core/http/endpoints/localai/system.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SystemInformations returns the system informations
|
||||||
|
// @Summary Show the LocalAI instance information
|
||||||
|
// @Success 200 {object} schema.SystemInformationResponse "Response"
|
||||||
|
// @Router /system [get]
|
||||||
|
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
availableBackends, err := ml.ListAvailableBackends(appConfig.AssetsDestination)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for b := range appConfig.ExternalGRPCBackends {
|
||||||
|
availableBackends = append(availableBackends, b)
|
||||||
|
}
|
||||||
|
return c.JSON(
|
||||||
|
schema.SystemInformationResponse{
|
||||||
|
Backends: availableBackends,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
@ -25,9 +25,8 @@ import (
|
|||||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
// @Router /v1/chat/completions [post]
|
// @Router /v1/chat/completions [post]
|
||||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
textContentToReturn := ""
|
var id, textContentToReturn string
|
||||||
id := uuid.New().String()
|
var created int
|
||||||
created := int(time.Now().Unix())
|
|
||||||
|
|
||||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
initialMessage := schema.OpenAIResponse{
|
initialMessage := schema.OpenAIResponse{
|
||||||
@ -69,9 +68,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
|
|
||||||
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
|
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
|
||||||
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
|
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||||
results := functions.ParseFunctionCall(result, config.FunctionsConfig)
|
functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig)
|
||||||
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
|
log.Debug().Msgf("Text content to return: %s", textContentToReturn)
|
||||||
noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0
|
noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case noActionToRun:
|
case noActionToRun:
|
||||||
@ -84,7 +83,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
}
|
}
|
||||||
responses <- initialMessage
|
responses <- initialMessage
|
||||||
|
|
||||||
result, err := handleQuestion(config, req, ml, startupOptions, results, result, prompt)
|
result, err := handleQuestion(config, req, ml, startupOptions, functionResults, result, prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("error handling question")
|
log.Error().Err(err).Msg("error handling question")
|
||||||
return
|
return
|
||||||
@ -106,7 +105,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
responses <- resp
|
responses <- resp
|
||||||
|
|
||||||
default:
|
default:
|
||||||
for i, ss := range results {
|
for i, ss := range functionResults {
|
||||||
name, args := ss.Name, ss.Arguments
|
name, args := ss.Name, ss.Arguments
|
||||||
|
|
||||||
initialMessage := schema.OpenAIResponse{
|
initialMessage := schema.OpenAIResponse{
|
||||||
@ -159,6 +158,10 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
textContentToReturn = ""
|
||||||
|
id = uuid.New().String()
|
||||||
|
created = int(time.Now().Unix())
|
||||||
|
|
||||||
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
|
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
93
core/http/middleware/auth.go
Normal file
93
core/http/middleware/auth.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/dave-gray101/v2keyauth"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
|
||||||
|
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
|
||||||
|
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
|
||||||
|
|
||||||
|
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
|
||||||
|
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key"}, keyauth.ConfigDefault.AuthScheme)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &v2keyauth.Config{
|
||||||
|
CustomKeyLookup: customLookup,
|
||||||
|
Next: getApiKeyRequiredFilterFunction(applicationConfig),
|
||||||
|
Validator: getApiKeyValidationFunction(applicationConfig),
|
||||||
|
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
|
||||||
|
AuthScheme: "Bearer",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
|
||||||
|
return func(ctx *fiber.Ctx, err error) error {
|
||||||
|
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
|
||||||
|
if len(applicationConfig.ApiKeys) == 0 {
|
||||||
|
return ctx.Next() // if no keys are set up, any error we get here is not an error.
|
||||||
|
}
|
||||||
|
if applicationConfig.OpaqueErrors {
|
||||||
|
return ctx.SendStatus(403)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if applicationConfig.OpaqueErrors {
|
||||||
|
return ctx.SendStatus(500)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
|
||||||
|
|
||||||
|
if applicationConfig.UseSubtleKeyComparison {
|
||||||
|
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
|
||||||
|
if len(applicationConfig.ApiKeys) == 0 {
|
||||||
|
return true, nil // If no keys are setup, accept everything
|
||||||
|
}
|
||||||
|
for _, validKey := range applicationConfig.ApiKeys {
|
||||||
|
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, v2keyauth.ErrMissingOrMalformedAPIKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
|
||||||
|
if len(applicationConfig.ApiKeys) == 0 {
|
||||||
|
return true, nil // If no keys are setup, accept everything
|
||||||
|
}
|
||||||
|
for _, validKey := range applicationConfig.ApiKeys {
|
||||||
|
if apiKey == validKey {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, v2keyauth.ErrMissingOrMalformedAPIKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
|
||||||
|
if applicationConfig.DisableApiKeyRequirementForHttpGet {
|
||||||
|
return func(c *fiber.Ctx) bool {
|
||||||
|
if c.Method() != "GET" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
|
||||||
|
if rx.MatchString(c.Path()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(c *fiber.Ctx) bool { return false }
|
||||||
|
}
|
@ -10,10 +10,11 @@ import (
|
|||||||
func RegisterElevenLabsRoutes(app *fiber.App,
|
func RegisterElevenLabsRoutes(app *fiber.App,
|
||||||
cl *config.BackendConfigLoader,
|
cl *config.BackendConfigLoader,
|
||||||
ml *model.ModelLoader,
|
ml *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig) {
|
||||||
auth func(*fiber.Ctx) error) {
|
|
||||||
|
|
||||||
// Elevenlabs
|
// Elevenlabs
|
||||||
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -11,8 +11,7 @@ import (
|
|||||||
func RegisterJINARoutes(app *fiber.App,
|
func RegisterJINARoutes(app *fiber.App,
|
||||||
cl *config.BackendConfigLoader,
|
cl *config.BackendConfigLoader,
|
||||||
ml *model.ModelLoader,
|
ml *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig) {
|
||||||
auth func(*fiber.Ctx) error) {
|
|
||||||
|
|
||||||
// POST endpoint to mimic the reranking
|
// POST endpoint to mimic the reranking
|
||||||
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))
|
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))
|
||||||
|
@ -15,33 +15,32 @@ func RegisterLocalAIRoutes(app *fiber.App,
|
|||||||
cl *config.BackendConfigLoader,
|
cl *config.BackendConfigLoader,
|
||||||
ml *model.ModelLoader,
|
ml *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
galleryService *services.GalleryService,
|
galleryService *services.GalleryService) {
|
||||||
auth func(*fiber.Ctx) error) {
|
|
||||||
|
|
||||||
app.Get("/swagger/*", swagger.HandlerDefault) // default
|
app.Get("/swagger/*", swagger.HandlerDefault) // default
|
||||||
|
|
||||||
// LocalAI API endpoints
|
// LocalAI API endpoints
|
||||||
if !appConfig.DisableGalleryEndpoint {
|
if !appConfig.DisableGalleryEndpoint {
|
||||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||||
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||||
app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||||
|
|
||||||
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||||
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||||
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
|
app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
|
||||||
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
||||||
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
|
app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||||
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
|
app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// Stores
|
// Stores
|
||||||
sl := model.NewModelLoader("")
|
sl := model.NewModelLoader("")
|
||||||
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
|
app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
|
app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
|
app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
|
||||||
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
|
app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
|
||||||
|
|
||||||
// Kubernetes health checks
|
// Kubernetes health checks
|
||||||
ok := func(c *fiber.Ctx) error {
|
ok := func(c *fiber.Ctx) error {
|
||||||
@ -51,23 +50,25 @@ func RegisterLocalAIRoutes(app *fiber.App,
|
|||||||
app.Get("/healthz", ok)
|
app.Get("/healthz", ok)
|
||||||
app.Get("/readyz", ok)
|
app.Get("/readyz", ok)
|
||||||
|
|
||||||
app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
|
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||||
|
|
||||||
// Experimental Backend Statistics Module
|
// Experimental Backend Statistics Module
|
||||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||||
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService))
|
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
|
||||||
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService))
|
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
|
||||||
|
|
||||||
// p2p
|
// p2p
|
||||||
if p2p.IsP2PEnabled() {
|
if p2p.IsP2PEnabled() {
|
||||||
app.Get("/api/p2p", auth, localai.ShowP2PNodes(appConfig))
|
app.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
|
||||||
app.Get("/api/p2p/token", auth, localai.ShowP2PToken(appConfig))
|
app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
app.Get("/version", func(c *fiber.Ctx) error {
|
||||||
return c.JSON(struct {
|
return c.JSON(struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}{Version: internal.PrintableVersion()})
|
}{Version: internal.PrintableVersion()})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
app.Get("/system", auth, localai.SystemInformations(ml, appConfig))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -11,66 +11,65 @@ import (
|
|||||||
func RegisterOpenAIRoutes(app *fiber.App,
|
func RegisterOpenAIRoutes(app *fiber.App,
|
||||||
cl *config.BackendConfigLoader,
|
cl *config.BackendConfigLoader,
|
||||||
ml *model.ModelLoader,
|
ml *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig) {
|
||||||
auth func(*fiber.Ctx) error) {
|
|
||||||
// openAI compatible API endpoint
|
// openAI compatible API endpoint
|
||||||
|
|
||||||
// chat
|
// chat
|
||||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// edit
|
// edit
|
||||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// assistant
|
// assistant
|
||||||
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
||||||
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
||||||
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||||
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||||
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
||||||
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||||
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||||
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// files
|
// files
|
||||||
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig))
|
||||||
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
app.Get("/files", openai.ListFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
|
||||||
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
|
||||||
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||||
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||||
|
|
||||||
// completion
|
// completion
|
||||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// embeddings
|
// embeddings
|
||||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// audio
|
// audio
|
||||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
|
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// images
|
// images
|
||||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
|
app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
if appConfig.ImageDir != "" {
|
if appConfig.ImageDir != "" {
|
||||||
app.Static("/generated-images", appConfig.ImageDir)
|
app.Static("/generated-images", appConfig.ImageDir)
|
||||||
@ -81,6 +80,6 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List models
|
// List models
|
||||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
|
app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml))
|
||||||
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
|
app.Get("/models", openai.ListModelsEndpoint(cl, ml))
|
||||||
}
|
}
|
||||||
|
@ -59,8 +59,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
cl *config.BackendConfigLoader,
|
cl *config.BackendConfigLoader,
|
||||||
ml *model.ModelLoader,
|
ml *model.ModelLoader,
|
||||||
appConfig *config.ApplicationConfig,
|
appConfig *config.ApplicationConfig,
|
||||||
galleryService *services.GalleryService,
|
galleryService *services.GalleryService) {
|
||||||
auth func(*fiber.Ctx) error) {
|
|
||||||
|
|
||||||
// keeps the state of models that are being installed from the UI
|
// keeps the state of models that are being installed from the UI
|
||||||
var processingModels = NewModelOpCache()
|
var processingModels = NewModelOpCache()
|
||||||
@ -85,10 +84,10 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
return processingModelsData, taskTypes
|
return processingModelsData, taskTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Get("/", auth, localai.WelcomeEndpoint(appConfig, cl, ml, modelStatus))
|
app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, modelStatus))
|
||||||
|
|
||||||
if p2p.IsP2PEnabled() {
|
if p2p.IsP2PEnabled() {
|
||||||
app.Get("/p2p", auth, func(c *fiber.Ctx) error {
|
app.Get("/p2p", func(c *fiber.Ctx) error {
|
||||||
summary := fiber.Map{
|
summary := fiber.Map{
|
||||||
"Title": "LocalAI - P2P dashboard",
|
"Title": "LocalAI - P2P dashboard",
|
||||||
"Version": internal.PrintableVersion(),
|
"Version": internal.PrintableVersion(),
|
||||||
@ -104,17 +103,17 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
})
|
})
|
||||||
|
|
||||||
/* show nodes live! */
|
/* show nodes live! */
|
||||||
app.Get("/p2p/ui/workers", auth, func(c *fiber.Ctx) error {
|
app.Get("/p2p/ui/workers", func(c *fiber.Ctx) error {
|
||||||
return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))))
|
return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))))
|
||||||
})
|
})
|
||||||
app.Get("/p2p/ui/workers-federation", auth, func(c *fiber.Ctx) error {
|
app.Get("/p2p/ui/workers-federation", func(c *fiber.Ctx) error {
|
||||||
return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))))
|
return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))))
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/p2p/ui/workers-stats", auth, func(c *fiber.Ctx) error {
|
app.Get("/p2p/ui/workers-stats", func(c *fiber.Ctx) error {
|
||||||
return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))))
|
return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))))
|
||||||
})
|
})
|
||||||
app.Get("/p2p/ui/workers-federation-stats", auth, func(c *fiber.Ctx) error {
|
app.Get("/p2p/ui/workers-federation-stats", func(c *fiber.Ctx) error {
|
||||||
return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))))
|
return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -122,7 +121,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
if !appConfig.DisableGalleryEndpoint {
|
if !appConfig.DisableGalleryEndpoint {
|
||||||
|
|
||||||
// Show the Models page (all models)
|
// Show the Models page (all models)
|
||||||
app.Get("/browse", auth, func(c *fiber.Ctx) error {
|
app.Get("/browse", func(c *fiber.Ctx) error {
|
||||||
term := c.Query("term")
|
term := c.Query("term")
|
||||||
|
|
||||||
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
|
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
|
||||||
@ -167,7 +166,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
|
|
||||||
// Show the models, filtered from the user input
|
// Show the models, filtered from the user input
|
||||||
// https://htmx.org/examples/active-search/
|
// https://htmx.org/examples/active-search/
|
||||||
app.Post("/browse/search/models", auth, func(c *fiber.Ctx) error {
|
app.Post("/browse/search/models", func(c *fiber.Ctx) error {
|
||||||
form := struct {
|
form := struct {
|
||||||
Search string `form:"search"`
|
Search string `form:"search"`
|
||||||
}{}
|
}{}
|
||||||
@ -188,7 +187,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
|
|
||||||
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
|
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
|
||||||
// https://htmx.org/examples/progress-bar/
|
// https://htmx.org/examples/progress-bar/
|
||||||
app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error {
|
app.Post("/browse/install/model/:id", func(c *fiber.Ctx) error {
|
||||||
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
|
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
|
||||||
log.Debug().Msgf("UI job submitted to install : %+v\n", galleryID)
|
log.Debug().Msgf("UI job submitted to install : %+v\n", galleryID)
|
||||||
|
|
||||||
@ -215,7 +214,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
|
|
||||||
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
|
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
|
||||||
// https://htmx.org/examples/progress-bar/
|
// https://htmx.org/examples/progress-bar/
|
||||||
app.Post("/browse/delete/model/:id", auth, func(c *fiber.Ctx) error {
|
app.Post("/browse/delete/model/:id", func(c *fiber.Ctx) error {
|
||||||
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
|
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
|
||||||
log.Debug().Msgf("UI job submitted to delete : %+v\n", galleryID)
|
log.Debug().Msgf("UI job submitted to delete : %+v\n", galleryID)
|
||||||
var galleryName = galleryID
|
var galleryName = galleryID
|
||||||
@ -255,7 +254,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
// Display the job current progress status
|
// Display the job current progress status
|
||||||
// If the job is done, we trigger the /browse/job/:uid route
|
// If the job is done, we trigger the /browse/job/:uid route
|
||||||
// https://htmx.org/examples/progress-bar/
|
// https://htmx.org/examples/progress-bar/
|
||||||
app.Get("/browse/job/progress/:uid", auth, func(c *fiber.Ctx) error {
|
app.Get("/browse/job/progress/:uid", func(c *fiber.Ctx) error {
|
||||||
jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests!
|
jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests!
|
||||||
|
|
||||||
status := galleryService.GetStatus(jobUID)
|
status := galleryService.GetStatus(jobUID)
|
||||||
@ -279,7 +278,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
|
|
||||||
// this route is hit when the job is done, and we display the
|
// this route is hit when the job is done, and we display the
|
||||||
// final state (for now just displays "Installation completed")
|
// final state (for now just displays "Installation completed")
|
||||||
app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error {
|
app.Get("/browse/job/:uid", func(c *fiber.Ctx) error {
|
||||||
jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests!
|
jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests!
|
||||||
|
|
||||||
status := galleryService.GetStatus(jobUID)
|
status := galleryService.GetStatus(jobUID)
|
||||||
@ -303,7 +302,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Show the Chat page
|
// Show the Chat page
|
||||||
app.Get("/chat/:model", auth, func(c *fiber.Ctx) error {
|
app.Get("/chat/:model", func(c *fiber.Ctx) error {
|
||||||
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
||||||
|
|
||||||
summary := fiber.Map{
|
summary := fiber.Map{
|
||||||
@ -318,7 +317,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
return c.Render("views/chat", summary)
|
return c.Render("views/chat", summary)
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/talk/", auth, func(c *fiber.Ctx) error {
|
app.Get("/talk/", func(c *fiber.Ctx) error {
|
||||||
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
||||||
|
|
||||||
if len(backendConfigs) == 0 {
|
if len(backendConfigs) == 0 {
|
||||||
@ -338,7 +337,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
return c.Render("views/talk", summary)
|
return c.Render("views/talk", summary)
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/chat/", auth, func(c *fiber.Ctx) error {
|
app.Get("/chat/", func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
backendConfigs, _ := services.ListModels(cl, ml, "", true)
|
||||||
|
|
||||||
@ -359,7 +358,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
return c.Render("views/chat", summary)
|
return c.Render("views/chat", summary)
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/text2image/:model", auth, func(c *fiber.Ctx) error {
|
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
|
||||||
backendConfigs := cl.GetAllBackendConfigs()
|
backendConfigs := cl.GetAllBackendConfigs()
|
||||||
|
|
||||||
summary := fiber.Map{
|
summary := fiber.Map{
|
||||||
@ -374,7 +373,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
return c.Render("views/text2image", summary)
|
return c.Render("views/text2image", summary)
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/text2image/", auth, func(c *fiber.Ctx) error {
|
app.Get("/text2image/", func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
backendConfigs := cl.GetAllBackendConfigs()
|
backendConfigs := cl.GetAllBackendConfigs()
|
||||||
|
|
||||||
@ -395,7 +394,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
return c.Render("views/text2image", summary)
|
return c.Render("views/text2image", summary)
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/tts/:model", auth, func(c *fiber.Ctx) error {
|
app.Get("/tts/:model", func(c *fiber.Ctx) error {
|
||||||
backendConfigs := cl.GetAllBackendConfigs()
|
backendConfigs := cl.GetAllBackendConfigs()
|
||||||
|
|
||||||
summary := fiber.Map{
|
summary := fiber.Map{
|
||||||
@ -410,7 +409,7 @@ func RegisterUIRoutes(app *fiber.App,
|
|||||||
return c.Render("views/tts", summary)
|
return c.Render("views/tts", summary)
|
||||||
})
|
})
|
||||||
|
|
||||||
app.Get("/tts/", auth, func(c *fiber.Ctx) error {
|
app.Get("/tts/", func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
backendConfigs := cl.GetAllBackendConfigs()
|
backendConfigs := cl.GetAllBackendConfigs()
|
||||||
|
|
||||||
|
@ -6,11 +6,7 @@
|
|||||||
rel="stylesheet"
|
rel="stylesheet"
|
||||||
href="/static/assets/highlightjs.css"
|
href="/static/assets/highlightjs.css"
|
||||||
/>
|
/>
|
||||||
<script defer src="/static/assets/anime.min.js"></script>
|
<script defer src="/static/assets/highlightjs.js"></script>
|
||||||
<script
|
|
||||||
defer
|
|
||||||
src="/static/assets/highlightjs.js"
|
|
||||||
></script>
|
|
||||||
<script
|
<script
|
||||||
defer
|
defer
|
||||||
src="/static/assets/alpine.js"
|
src="/static/assets/alpine.js"
|
||||||
|
@ -28,9 +28,15 @@ import (
|
|||||||
"github.com/mudler/edgevpn/pkg/logger"
|
"github.com/mudler/edgevpn/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func generateNewConnectionData() *node.YAMLConnectionConfig {
|
func generateNewConnectionData(DHTInterval, OTPInterval int) *node.YAMLConnectionConfig {
|
||||||
maxMessSize := 20 << 20 // 20MB
|
maxMessSize := 20 << 20 // 20MB
|
||||||
keyLength := 43
|
keyLength := 43
|
||||||
|
if DHTInterval == 0 {
|
||||||
|
DHTInterval = 360
|
||||||
|
}
|
||||||
|
if OTPInterval == 0 {
|
||||||
|
OTPInterval = 9000
|
||||||
|
}
|
||||||
|
|
||||||
return &node.YAMLConnectionConfig{
|
return &node.YAMLConnectionConfig{
|
||||||
MaxMessageSize: maxMessSize,
|
MaxMessageSize: maxMessSize,
|
||||||
@ -40,21 +46,21 @@ func generateNewConnectionData() *node.YAMLConnectionConfig {
|
|||||||
OTP: node.OTP{
|
OTP: node.OTP{
|
||||||
DHT: node.OTPConfig{
|
DHT: node.OTPConfig{
|
||||||
Key: eutils.RandStringRunes(keyLength),
|
Key: eutils.RandStringRunes(keyLength),
|
||||||
Interval: 120,
|
Interval: DHTInterval,
|
||||||
Length: keyLength,
|
Length: keyLength,
|
||||||
},
|
},
|
||||||
Crypto: node.OTPConfig{
|
Crypto: node.OTPConfig{
|
||||||
Key: eutils.RandStringRunes(keyLength),
|
Key: eutils.RandStringRunes(keyLength),
|
||||||
Interval: 9000,
|
Interval: OTPInterval,
|
||||||
Length: keyLength,
|
Length: keyLength,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateToken() string {
|
func GenerateToken(DHTInterval, OTPInterval int) string {
|
||||||
// Generates a new config and exit
|
// Generates a new config and exit
|
||||||
return generateNewConnectionData().Base64()
|
return generateNewConnectionData(DHTInterval, OTPInterval).Base64()
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsP2PEnabled() bool {
|
func IsP2PEnabled() bool {
|
||||||
@ -202,13 +208,9 @@ func ServiceDiscoverer(ctx context.Context, n *node.Node, token, servicesID stri
|
|||||||
func discoveryTunnels(ctx context.Context, n *node.Node, token, servicesID string, allocate bool) (chan NodeData, error) {
|
func discoveryTunnels(ctx context.Context, n *node.Node, token, servicesID string, allocate bool) (chan NodeData, error) {
|
||||||
tunnels := make(chan NodeData)
|
tunnels := make(chan NodeData)
|
||||||
|
|
||||||
err := n.Start(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("creating a new node: %w", err)
|
|
||||||
}
|
|
||||||
ledger, err := n.Ledger()
|
ledger, err := n.Ledger()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating a new node: %w", err)
|
return nil, fmt.Errorf("getting the ledger: %w", err)
|
||||||
}
|
}
|
||||||
// get new services, allocate and return to the channel
|
// get new services, allocate and return to the channel
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/mudler/edgevpn/pkg/node"
|
"github.com/mudler/edgevpn/pkg/node"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GenerateToken() string {
|
func GenerateToken(DHTInterval, OTPInterval int) string {
|
||||||
return "not implemented"
|
return "not implemented"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user