mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(rerankers): Add new backend, support jina rerankers API (#2121)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
e16658b7ec
commit
b664edde29
31
.github/workflows/test-extra.yml
vendored
31
.github/workflows/test-extra.yml
vendored
@ -74,6 +74,37 @@ jobs:
|
|||||||
make --jobs=5 --output-sync=target -C backend/python/sentencetransformers
|
make --jobs=5 --output-sync=target -C backend/python/sentencetransformers
|
||||||
make --jobs=5 --output-sync=target -C backend/python/sentencetransformers test
|
make --jobs=5 --output-sync=target -C backend/python/sentencetransformers test
|
||||||
|
|
||||||
|
|
||||||
|
tests-rerankers:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install build-essential ffmpeg
|
||||||
|
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
||||||
|
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
||||||
|
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
||||||
|
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
||||||
|
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
||||||
|
sudo apt-get update && \
|
||||||
|
sudo apt-get install -y conda
|
||||||
|
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
|
||||||
|
sudo apt-get install -y libopencv-dev
|
||||||
|
pip install --user grpcio-tools
|
||||||
|
|
||||||
|
sudo rm -rfv /usr/bin/conda || true
|
||||||
|
|
||||||
|
- name: Test rerankers
|
||||||
|
run: |
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
make --jobs=5 --output-sync=target -C backend/python/rerankers
|
||||||
|
make --jobs=5 --output-sync=target -C backend/python/rerankers test
|
||||||
|
|
||||||
tests-diffusers:
|
tests-diffusers:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
@ -16,7 +16,7 @@ ARG TARGETVARIANT
|
|||||||
|
|
||||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh"
|
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,rerankers:/build/backend/python/rerankers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh"
|
||||||
|
|
||||||
ARG GO_TAGS="stablediffusion tinydream tts"
|
ARG GO_TAGS="stablediffusion tinydream tts"
|
||||||
|
|
||||||
@ -259,6 +259,9 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
|||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
make -C backend/python/sentencetransformers \
|
make -C backend/python/sentencetransformers \
|
||||||
; fi
|
; fi
|
||||||
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
|
make -C backend/python/rerankers \
|
||||||
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
make -C backend/python/transformers \
|
make -C backend/python/transformers \
|
||||||
; fi
|
; fi
|
||||||
|
13
Makefile
13
Makefile
@ -437,10 +437,10 @@ protogen-go-clean:
|
|||||||
$(RM) bin/*
|
$(RM) bin/*
|
||||||
|
|
||||||
.PHONY: protogen-python
|
.PHONY: protogen-python
|
||||||
protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen petals-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen
|
protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen petals-protogen rerankers-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen
|
||||||
|
|
||||||
.PHONY: protogen-python-clean
|
.PHONY: protogen-python-clean
|
||||||
protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean petals-protogen-clean sentencetransformers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean
|
protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean petals-protogen-clean sentencetransformers-protogen-clean rerankers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean
|
||||||
|
|
||||||
.PHONY: autogptq-protogen
|
.PHONY: autogptq-protogen
|
||||||
autogptq-protogen:
|
autogptq-protogen:
|
||||||
@ -506,6 +506,14 @@ petals-protogen:
|
|||||||
petals-protogen-clean:
|
petals-protogen-clean:
|
||||||
$(MAKE) -C backend/python/petals protogen-clean
|
$(MAKE) -C backend/python/petals protogen-clean
|
||||||
|
|
||||||
|
.PHONY: rerankers-protogen
|
||||||
|
rerankers-protogen:
|
||||||
|
$(MAKE) -C backend/python/rerankers protogen
|
||||||
|
|
||||||
|
.PHONY: rerankers-protogen-clean
|
||||||
|
rerankers-protogen-clean:
|
||||||
|
$(MAKE) -C backend/python/rerankers protogen-clean
|
||||||
|
|
||||||
.PHONY: sentencetransformers-protogen
|
.PHONY: sentencetransformers-protogen
|
||||||
sentencetransformers-protogen:
|
sentencetransformers-protogen:
|
||||||
$(MAKE) -C backend/python/sentencetransformers protogen
|
$(MAKE) -C backend/python/sentencetransformers protogen
|
||||||
@ -564,6 +572,7 @@ prepare-extra-conda-environments: protogen-python
|
|||||||
$(MAKE) -C backend/python/vllm
|
$(MAKE) -C backend/python/vllm
|
||||||
$(MAKE) -C backend/python/mamba
|
$(MAKE) -C backend/python/mamba
|
||||||
$(MAKE) -C backend/python/sentencetransformers
|
$(MAKE) -C backend/python/sentencetransformers
|
||||||
|
$(MAKE) -C backend/python/rerankers
|
||||||
$(MAKE) -C backend/python/transformers
|
$(MAKE) -C backend/python/transformers
|
||||||
$(MAKE) -C backend/python/transformers-musicgen
|
$(MAKE) -C backend/python/transformers-musicgen
|
||||||
$(MAKE) -C backend/python/parler-tts
|
$(MAKE) -C backend/python/parler-tts
|
||||||
|
27
aio/cpu/rerank.yaml
Normal file
27
aio/cpu/rerank.yaml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
name: jina-reranker-v1-base-en
|
||||||
|
backend: rerankers
|
||||||
|
parameters:
|
||||||
|
model: cross-encoder
|
||||||
|
|
||||||
|
usage: |
|
||||||
|
You can test this model with curl like this:
|
||||||
|
|
||||||
|
curl http://localhost:8080/v1/rerank \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "jina-reranker-v1-base-en",
|
||||||
|
"query": "Organic skincare products for sensitive skin",
|
||||||
|
"documents": [
|
||||||
|
"Eco-friendly kitchenware for modern homes",
|
||||||
|
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||||
|
"Organic cotton baby clothes for sensitive skin",
|
||||||
|
"Natural organic skincare range for sensitive skin",
|
||||||
|
"Tech gadgets for smart homes: 2024 edition",
|
||||||
|
"Sustainable gardening tools and compost solutions",
|
||||||
|
"Sensitive skin-friendly facial cleansers and toners",
|
||||||
|
"Organic food wraps and storage solutions",
|
||||||
|
"All-natural pet food for dogs with allergies",
|
||||||
|
"Yoga mats made from recycled materials"
|
||||||
|
],
|
||||||
|
"top_n": 3
|
||||||
|
}'
|
@ -129,7 +129,7 @@ detect_gpu
|
|||||||
detect_gpu_size
|
detect_gpu_size
|
||||||
|
|
||||||
PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
|
PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
|
||||||
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}"
|
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}"
|
||||||
|
|
||||||
check_vars
|
check_vars
|
||||||
|
|
||||||
|
27
aio/gpu-8g/rerank.yaml
Normal file
27
aio/gpu-8g/rerank.yaml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
name: jina-reranker-v1-base-en
|
||||||
|
backend: rerankers
|
||||||
|
parameters:
|
||||||
|
model: cross-encoder
|
||||||
|
|
||||||
|
usage: |
|
||||||
|
You can test this model with curl like this:
|
||||||
|
|
||||||
|
curl http://localhost:8080/v1/rerank \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "jina-reranker-v1-base-en",
|
||||||
|
"query": "Organic skincare products for sensitive skin",
|
||||||
|
"documents": [
|
||||||
|
"Eco-friendly kitchenware for modern homes",
|
||||||
|
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||||
|
"Organic cotton baby clothes for sensitive skin",
|
||||||
|
"Natural organic skincare range for sensitive skin",
|
||||||
|
"Tech gadgets for smart homes: 2024 edition",
|
||||||
|
"Sustainable gardening tools and compost solutions",
|
||||||
|
"Sensitive skin-friendly facial cleansers and toners",
|
||||||
|
"Organic food wraps and storage solutions",
|
||||||
|
"All-natural pet food for dogs with allergies",
|
||||||
|
"Yoga mats made from recycled materials"
|
||||||
|
],
|
||||||
|
"top_n": 3
|
||||||
|
}'
|
27
aio/intel/rerank.yaml
Normal file
27
aio/intel/rerank.yaml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
name: jina-reranker-v1-base-en
|
||||||
|
backend: rerankers
|
||||||
|
parameters:
|
||||||
|
model: cross-encoder
|
||||||
|
|
||||||
|
usage: |
|
||||||
|
You can test this model with curl like this:
|
||||||
|
|
||||||
|
curl http://localhost:8080/v1/rerank \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "jina-reranker-v1-base-en",
|
||||||
|
"query": "Organic skincare products for sensitive skin",
|
||||||
|
"documents": [
|
||||||
|
"Eco-friendly kitchenware for modern homes",
|
||||||
|
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||||
|
"Organic cotton baby clothes for sensitive skin",
|
||||||
|
"Natural organic skincare range for sensitive skin",
|
||||||
|
"Tech gadgets for smart homes: 2024 edition",
|
||||||
|
"Sustainable gardening tools and compost solutions",
|
||||||
|
"Sensitive skin-friendly facial cleansers and toners",
|
||||||
|
"Organic food wraps and storage solutions",
|
||||||
|
"All-natural pet food for dogs with allergies",
|
||||||
|
"Yoga mats made from recycled materials"
|
||||||
|
],
|
||||||
|
"top_n": 3
|
||||||
|
}'
|
@ -23,6 +23,30 @@ service Backend {
|
|||||||
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
|
||||||
rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
|
rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
|
||||||
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
|
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
|
||||||
|
|
||||||
|
rpc Rerank(RerankRequest) returns (RerankResult) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message RerankRequest {
|
||||||
|
string query = 1;
|
||||||
|
repeated string documents = 2;
|
||||||
|
int32 top_n = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RerankResult {
|
||||||
|
Usage usage = 1;
|
||||||
|
repeated DocumentResult results = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Usage {
|
||||||
|
int32 total_tokens = 1;
|
||||||
|
int32 prompt_tokens = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DocumentResult {
|
||||||
|
int32 index = 1;
|
||||||
|
string text = 2;
|
||||||
|
float relevance_score = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoresKey {
|
message StoresKey {
|
||||||
|
@ -120,4 +120,6 @@ dependencies:
|
|||||||
- transformers>=4.38.2 # Updated Version
|
- transformers>=4.38.2 # Updated Version
|
||||||
- transformers_stream_generator==0.0.5
|
- transformers_stream_generator==0.0.5
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
|
- rerankers[transformers]
|
||||||
|
- pydantic
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
@ -108,4 +108,6 @@ dependencies:
|
|||||||
- transformers>=4.38.2 # Updated Version
|
- transformers>=4.38.2 # Updated Version
|
||||||
- transformers_stream_generator==0.0.5
|
- transformers_stream_generator==0.0.5
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
|
- rerankers[transformers]
|
||||||
|
- pydantic
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
@ -111,5 +111,7 @@ dependencies:
|
|||||||
- vllm>=0.4.0
|
- vllm>=0.4.0
|
||||||
- transformers>=4.38.2 # Updated Version
|
- transformers>=4.38.2 # Updated Version
|
||||||
- transformers_stream_generator==0.0.5
|
- transformers_stream_generator==0.0.5
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
|
- rerankers[transformers]
|
||||||
|
- pydantic
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
27
backend/python/rerankers/Makefile
Normal file
27
backend/python/rerankers/Makefile
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
.PHONY: rerankers
|
||||||
|
rerankers: protogen
|
||||||
|
$(MAKE) -C ../common-env/transformers
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: run
|
||||||
|
run: protogen
|
||||||
|
@echo "Running rerankers..."
|
||||||
|
bash run.sh
|
||||||
|
@echo "rerankers run."
|
||||||
|
|
||||||
|
# It is not working well by using command line. It only6 works with IDE like VSCode.
|
||||||
|
.PHONY: test
|
||||||
|
test: protogen
|
||||||
|
@echo "Testing rerankers..."
|
||||||
|
bash test.sh
|
||||||
|
@echo "rerankers tested."
|
||||||
|
|
||||||
|
.PHONY: protogen
|
||||||
|
protogen: backend_pb2_grpc.py backend_pb2.py
|
||||||
|
|
||||||
|
.PHONY: protogen-clean
|
||||||
|
protogen-clean:
|
||||||
|
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||||
|
|
||||||
|
backend_pb2_grpc.py backend_pb2.py:
|
||||||
|
python3 -m grpc_tools.protoc -I../.. --python_out=. --grpc_python_out=. backend.proto
|
5
backend/python/rerankers/README.md
Normal file
5
backend/python/rerankers/README.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Creating a separate environment for the reranker project
|
||||||
|
|
||||||
|
```
|
||||||
|
make reranker
|
||||||
|
```
|
123
backend/python/rerankers/reranker.py
Executable file
123
backend/python/rerankers/reranker.py
Executable file
@ -0,0 +1,123 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Extra gRPC server for Rerankers models.
|
||||||
|
"""
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import time
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from rerankers import Reranker
|
||||||
|
|
||||||
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
|
|
||||||
|
# Implement the BackendServicer class with the service methods
|
||||||
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
"""
|
||||||
|
A gRPC servicer for the backend service.
|
||||||
|
|
||||||
|
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
||||||
|
"""
|
||||||
|
def Health(self, request, context):
|
||||||
|
"""
|
||||||
|
A gRPC method that returns the health status of the backend service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: A HealthRequest object that contains the request parameters.
|
||||||
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Reply object that contains the health status of the backend service.
|
||||||
|
"""
|
||||||
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||||
|
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
"""
|
||||||
|
A gRPC method that loads a model into memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: A LoadModelRequest object that contains the request parameters.
|
||||||
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Result object that contains the result of the LoadModel operation.
|
||||||
|
"""
|
||||||
|
model_name = request.Model
|
||||||
|
try:
|
||||||
|
kwargs = {}
|
||||||
|
if request.Type != "":
|
||||||
|
kwargs['model_type'] = request.Type
|
||||||
|
if request.PipelineType != "": # Reuse the PipelineType field for language
|
||||||
|
kwargs['lang'] = request.PipelineType
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model = Reranker(model_name, **kwargs)
|
||||||
|
except Exception as err:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
|
||||||
|
# Implement your logic here for the LoadModel service
|
||||||
|
# Replace this with your desired response
|
||||||
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
|
def Rerank(self, request, context):
|
||||||
|
documents = []
|
||||||
|
for idx, doc in enumerate(request.documents):
|
||||||
|
documents.append(doc)
|
||||||
|
ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
|
||||||
|
# Prepare results to return
|
||||||
|
results = [
|
||||||
|
backend_pb2.DocumentResult(
|
||||||
|
index=res.doc_id,
|
||||||
|
text=res.text,
|
||||||
|
relevance_score=res.score
|
||||||
|
) for res in ranked_results.results
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate the usage and total tokens
|
||||||
|
# TODO: Implement the usage calculation with reranker
|
||||||
|
total_tokens = sum(len(doc.split()) for doc in request.documents) + len(request.query.split())
|
||||||
|
prompt_tokens = len(request.query.split())
|
||||||
|
usage = backend_pb2.Usage(total_tokens=total_tokens, prompt_tokens=prompt_tokens)
|
||||||
|
return backend_pb2.RerankResult(usage=usage, results=results)
|
||||||
|
|
||||||
|
def serve(address):
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||||
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
server.add_insecure_port(address)
|
||||||
|
server.start()
|
||||||
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||||
|
|
||||||
|
# Define the signal handler function
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
print("Received termination signal. Shutting down...")
|
||||||
|
server.stop(0)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Set the signal handlers for SIGINT and SIGTERM
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
serve(args.addr)
|
14
backend/python/rerankers/run.sh
Executable file
14
backend/python/rerankers/run.sh
Executable file
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
##
|
||||||
|
## A bash script wrapper that runs the reranker server with conda
|
||||||
|
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
|
# Activate conda environment
|
||||||
|
source activate transformers
|
||||||
|
|
||||||
|
# get the directory where the bash script is located
|
||||||
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|
||||||
|
python $DIR/reranker.py $@
|
11
backend/python/rerankers/test.sh
Executable file
11
backend/python/rerankers/test.sh
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
##
|
||||||
|
## A bash script wrapper that runs the reranker server with conda
|
||||||
|
|
||||||
|
# Activate conda environment
|
||||||
|
source activate transformers
|
||||||
|
|
||||||
|
# get the directory where the bash script is located
|
||||||
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|
||||||
|
python -m unittest $DIR/test_reranker.py
|
90
backend/python/rerankers/test_reranker.py
Executable file
90
backend/python/rerankers/test_reranker.py
Executable file
@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
A test script to test the gRPC service
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackendServicer(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
TestBackendServicer is the class that tests the gRPC service
|
||||||
|
"""
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
This method sets up the gRPC service by starting the server
|
||||||
|
"""
|
||||||
|
self.service = subprocess.Popen(["python3", "reranker.py", "--addr", "localhost:50051"])
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
"""
|
||||||
|
This method tears down the gRPC service by terminating the server
|
||||||
|
"""
|
||||||
|
self.service.kill()
|
||||||
|
self.service.wait()
|
||||||
|
|
||||||
|
def test_server_startup(self):
|
||||||
|
"""
|
||||||
|
This method tests if the server starts up successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.Health(backend_pb2.HealthMessage())
|
||||||
|
self.assertEqual(response.message, b'OK')
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("Server failed to start")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_load_model(self):
|
||||||
|
"""
|
||||||
|
This method tests if the model is loaded successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
self.assertEqual(response.message, "Model loaded successfully")
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("LoadModel service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_rerank(self):
|
||||||
|
"""
|
||||||
|
This method tests if the embeddings are generated successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
request = backend_pb2.RerankRequest(
|
||||||
|
query="I love you",
|
||||||
|
documents=["I hate you", "I really like you"],
|
||||||
|
top_n=2
|
||||||
|
)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
|
||||||
|
rerank_response = stub.Rerank(request)
|
||||||
|
print(rerank_response.results[0])
|
||||||
|
self.assertIsNotNone(rerank_response.results)
|
||||||
|
self.assertEqual(len(rerank_response.results), 2)
|
||||||
|
self.assertEqual(rerank_response.results[0].text, "I really like you")
|
||||||
|
self.assertEqual(rerank_response.results[1].text, "I hate you")
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("Reranker service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
39
core/backend/rerank.go
Normal file
39
core/backend/rerank.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
|
||||||
|
bb := backend
|
||||||
|
if bb == "" {
|
||||||
|
return nil, fmt.Errorf("backend is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(backendConfig)
|
||||||
|
|
||||||
|
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
|
||||||
|
model.WithBackendString(bb),
|
||||||
|
model.WithModel(modelFile),
|
||||||
|
model.WithContext(appConfig.Context),
|
||||||
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
|
})
|
||||||
|
rerankModel, err := loader.BackendLoader(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if rerankModel == nil {
|
||||||
|
return nil, fmt.Errorf("could not load rerank model")
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := rerankModel.Rerank(context.Background(), request)
|
||||||
|
|
||||||
|
return res, err
|
||||||
|
}
|
@ -194,6 +194,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||||||
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
|
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
|
||||||
routes.RegisterPagesRoutes(app, cl, ml, appConfig, auth)
|
routes.RegisterPagesRoutes(app, cl, ml, appConfig, auth)
|
||||||
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth)
|
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth)
|
||||||
|
routes.RegisterJINARoutes(app, cl, ml, appConfig, auth)
|
||||||
|
|
||||||
// Define a custom 404 handler
|
// Define a custom 404 handler
|
||||||
// Note: keep this at the bottom!
|
// Note: keep this at the bottom!
|
||||||
|
84
core/http/endpoints/jina/rerank.go
Normal file
84
core/http/endpoints/jina/rerank.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package jina
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
|
||||||
|
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
req := new(schema.JINARerankRequest)
|
||||||
|
if err := c.BodyParser(req); err != nil {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||||
|
"error": "Cannot parse JSON",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
input := new(schema.TTSRequest)
|
||||||
|
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.Model
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||||
|
config.LoadOptionDebug(appConfig.Debug),
|
||||||
|
config.LoadOptionThreads(appConfig.Threads),
|
||||||
|
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||||
|
config.LoadOptionF16(appConfig.F16),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.Model
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||||
|
} else {
|
||||||
|
modelFile = cfg.Model
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||||
|
|
||||||
|
if input.Backend != "" {
|
||||||
|
cfg.Backend = input.Backend
|
||||||
|
}
|
||||||
|
|
||||||
|
request := &proto.RerankRequest{
|
||||||
|
Query: req.Query,
|
||||||
|
TopN: int32(req.TopN),
|
||||||
|
Documents: req.Documents,
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := backend.Rerank(cfg.Backend, modelFile, request, ml, appConfig, *cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &schema.JINARerankResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range results.Results {
|
||||||
|
response.Results = append(response.Results, schema.JINADocumentResult{
|
||||||
|
Index: int(r.Index),
|
||||||
|
Document: schema.JINAText{Text: r.Text},
|
||||||
|
RelevanceScore: float64(r.RelevanceScore),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
|
||||||
|
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusOK).JSON(response)
|
||||||
|
}
|
||||||
|
}
|
19
core/http/routes/jina.go
Normal file
19
core/http/routes/jina.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/core/http/endpoints/jina"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RegisterJINARoutes(app *fiber.App,
|
||||||
|
cl *config.BackendConfigLoader,
|
||||||
|
ml *model.ModelLoader,
|
||||||
|
appConfig *config.ApplicationConfig,
|
||||||
|
auth func(*fiber.Ctx) error) {
|
||||||
|
|
||||||
|
// POST endpoint to mimic the reranking
|
||||||
|
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))
|
||||||
|
}
|
34
core/schema/jina.go
Normal file
34
core/schema/jina.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
// RerankRequest defines the structure of the request payload
|
||||||
|
type JINARerankRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
Documents []string `json:"documents"`
|
||||||
|
TopN int `json:"top_n"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DocumentResult represents a single document result
|
||||||
|
type JINADocumentResult struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Document JINAText `json:"document"`
|
||||||
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text holds the text of the document
|
||||||
|
type JINAText struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RerankResponse defines the structure of the response payload
|
||||||
|
type JINARerankResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage JINAUsageInfo `json:"usage"`
|
||||||
|
Results []JINADocumentResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageInfo holds information about usage of tokens
|
||||||
|
type JINAUsageInfo struct {
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
}
|
@ -49,4 +49,6 @@ type Backend interface {
|
|||||||
StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error)
|
StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error)
|
||||||
StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error)
|
StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error)
|
||||||
|
|
||||||
|
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
|
||||||
}
|
}
|
||||||
|
@ -355,3 +355,19 @@ func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts
|
|||||||
client := pb.NewBackendClient(conn)
|
client := pb.NewBackendClient(conn)
|
||||||
return client.StoresFind(ctx, in, opts...)
|
return client.StoresFind(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
return client.Rerank(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
@ -101,6 +101,10 @@ func (e *embedBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions,
|
|||||||
return e.s.StoresFind(ctx, in)
|
return e.s.StoresFind(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) {
|
||||||
|
return e.s.Rerank(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
type embedBackendServerStream struct {
|
type embedBackendServerStream struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
fn func(s []byte)
|
fn func(s []byte)
|
||||||
|
Loading…
Reference in New Issue
Block a user