mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-23 06:22:23 +00:00
feat(transformers): add embeddings with Automodel (#1308)
* Update huggingface.py Switch SentenceTransformer for AutoModel in order to set trust_remote_code needed to use the encode method with embeddings models like jinai-v2 Signed-off-by: Lucas Hänke de Cansino <lhc@next-boss.eu> * feat(transformers): split in separate backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Lucas Hänke de Cansino <lhc@next-boss.eu> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Lucas Hänke de Cansino <lhc@next-boss.eu>
This commit is contained in:
parent
ff9afdb0fe
commit
92cbc4d516
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -78,7 +78,7 @@ jobs:
|
|||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
||||||
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
sudo rm -rfv /usr/bin/conda || true
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/huggingface
|
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers
|
||||||
|
|
||||||
# Pre-build piper before we start tests in order to have shared libraries in place
|
# Pre-build piper before we start tests in order to have shared libraries in place
|
||||||
make sources/go-piper && \
|
make sources/go-piper && \
|
||||||
|
@ -12,7 +12,7 @@ ARG TARGETARCH
|
|||||||
ARG TARGETVARIANT
|
ARG TARGETVARIANT
|
||||||
|
|
||||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||||
ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/backend/python/huggingface/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh"
|
ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh"
|
||||||
ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]'
|
ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]'
|
||||||
ARG GO_TAGS="stablediffusion tts"
|
ARG GO_TAGS="stablediffusion tts"
|
||||||
|
|
||||||
@ -169,7 +169,10 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
|||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
|
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/huggingface \
|
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \
|
||||||
|
; fi
|
||||||
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
|
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/vall-e-x \
|
PATH=$PATH:/opt/conda/bin make -C backend/python/vall-e-x \
|
||||||
|
20
Makefile
20
Makefile
@ -296,7 +296,7 @@ test: prepare test-models/testmodel grpcs
|
|||||||
@echo 'Running tests'
|
@echo 'Running tests'
|
||||||
export GO_TAGS="tts stablediffusion"
|
export GO_TAGS="tts stablediffusion"
|
||||||
$(MAKE) prepare-test
|
$(MAKE) prepare-test
|
||||||
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/huggingface/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r ./api ./pkg
|
||||||
$(MAKE) test-gpt4all
|
$(MAKE) test-gpt4all
|
||||||
$(MAKE) test-llama
|
$(MAKE) test-llama
|
||||||
@ -367,13 +367,14 @@ protogen-go:
|
|||||||
backend/backend.proto
|
backend/backend.proto
|
||||||
|
|
||||||
protogen-python:
|
protogen-python:
|
||||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=backend/python/huggingface/ --grpc_python_out=backend/python/huggingface/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/sentencetransformers/ --grpc_python_out=backend/python/sentencetransformers/ backend/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=backend/python/autogptq/ --grpc_python_out=backend/python/autogptq/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/transformers/ --grpc_python_out=backend/python/transformers/ backend/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=backend/python/exllama/ --grpc_python_out=backend/python/exllama/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/autogptq/ --grpc_python_out=backend/python/autogptq/ backend/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=backend/python/bark/ --grpc_python_out=backend/python/bark/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/exllama/ --grpc_python_out=backend/python/exllama/ backend/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=backend/python/diffusers/ --grpc_python_out=backend/python/diffusers/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/bark/ --grpc_python_out=backend/python/bark/ backend/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=backend/python/vall-e-x/ --grpc_python_out=backend/python/vall-e-x/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/diffusers/ --grpc_python_out=backend/python/diffusers/ backend/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=backend/python/vllm/ --grpc_python_out=backend/python/vllm/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vall-e-x/ --grpc_python_out=backend/python/vall-e-x/ backend/backend.proto
|
||||||
|
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vllm/ --grpc_python_out=backend/python/vllm/ backend/backend.proto
|
||||||
|
|
||||||
## GRPC
|
## GRPC
|
||||||
# Note: it is duplicated in the Dockerfile
|
# Note: it is duplicated in the Dockerfile
|
||||||
@ -382,7 +383,8 @@ prepare-extra-conda-environments:
|
|||||||
$(MAKE) -C backend/python/bark
|
$(MAKE) -C backend/python/bark
|
||||||
$(MAKE) -C backend/python/diffusers
|
$(MAKE) -C backend/python/diffusers
|
||||||
$(MAKE) -C backend/python/vllm
|
$(MAKE) -C backend/python/vllm
|
||||||
$(MAKE) -C backend/python/huggingface
|
$(MAKE) -C backend/python/sentencetransformers
|
||||||
|
$(MAKE) -C backend/python/transformers
|
||||||
$(MAKE) -C backend/python/vall-e-x
|
$(MAKE) -C backend/python/vall-e-x
|
||||||
$(MAKE) -C backend/python/exllama
|
$(MAKE) -C backend/python/exllama
|
||||||
|
|
||||||
|
@ -704,7 +704,7 @@ var _ = Describe("API test", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
Context("External gRPC calls", func() {
|
Context("External gRPC calls", func() {
|
||||||
It("calculate embeddings with huggingface", func() {
|
It("calculate embeddings with sentencetransformers", func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
Skip("test supported only on linux")
|
Skip("test supported only on linux")
|
||||||
}
|
}
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
# Creating a separate environment for the huggingface project
|
|
||||||
|
|
||||||
```
|
|
||||||
make huggingface
|
|
||||||
```
|
|
18
backend/python/sentencetransformers/Makefile
Normal file
18
backend/python/sentencetransformers/Makefile
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
.PONY: sentencetransformers
|
||||||
|
sentencetransformers:
|
||||||
|
@echo "Creating virtual environment..."
|
||||||
|
@conda env create --name sentencetransformers --file sentencetransformers.yml
|
||||||
|
@echo "Virtual environment created."
|
||||||
|
|
||||||
|
.PONY: run
|
||||||
|
run:
|
||||||
|
@echo "Running sentencetransformers..."
|
||||||
|
bash run.sh
|
||||||
|
@echo "sentencetransformers run."
|
||||||
|
|
||||||
|
# It is not working well by using command line. It only6 works with IDE like VSCode.
|
||||||
|
.PONY: test
|
||||||
|
test:
|
||||||
|
@echo "Testing sentencetransformers..."
|
||||||
|
bash test.sh
|
||||||
|
@echo "sentencetransformers tested."
|
5
backend/python/sentencetransformers/README.md
Normal file
5
backend/python/sentencetransformers/README.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Creating a separate environment for the sentencetransformers project
|
||||||
|
|
||||||
|
```
|
||||||
|
make sentencetransformers
|
||||||
|
```
|
14
backend/python/sentencetransformers/run.sh
Executable file
14
backend/python/sentencetransformers/run.sh
Executable file
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
##
|
||||||
|
## A bash script wrapper that runs the sentencetransformers server with conda
|
||||||
|
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
|
# Activate conda environment
|
||||||
|
source activate sentencetransformers
|
||||||
|
|
||||||
|
# get the directory where the bash script is located
|
||||||
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|
||||||
|
python $DIR/sentencetransformers.py $@
|
77
backend/python/sentencetransformers/sentencetransformers.yml
Normal file
77
backend/python/sentencetransformers/sentencetransformers.yml
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
name: sentencetransformers
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=main
|
||||||
|
- _openmp_mutex=5.1=1_gnu
|
||||||
|
- bzip2=1.0.8=h7b6447c_0
|
||||||
|
- ca-certificates=2023.08.22=h06a4308_0
|
||||||
|
- ld_impl_linux-64=2.38=h1181459_1
|
||||||
|
- libffi=3.4.4=h6a678d5_0
|
||||||
|
- libgcc-ng=11.2.0=h1234567_1
|
||||||
|
- libgomp=11.2.0=h1234567_1
|
||||||
|
- libstdcxx-ng=11.2.0=h1234567_1
|
||||||
|
- libuuid=1.41.5=h5eee18b_0
|
||||||
|
- ncurses=6.4=h6a678d5_0
|
||||||
|
- openssl=3.0.11=h7f8727e_2
|
||||||
|
- pip=23.2.1=py311h06a4308_0
|
||||||
|
- python=3.11.5=h955ad1f_0
|
||||||
|
- readline=8.2=h5eee18b_0
|
||||||
|
- setuptools=68.0.0=py311h06a4308_0
|
||||||
|
- sqlite=3.41.2=h5eee18b_0
|
||||||
|
- tk=8.6.12=h1ccaba5_0
|
||||||
|
- tzdata=2023c=h04d1e81_0
|
||||||
|
- wheel=0.41.2=py311h06a4308_0
|
||||||
|
- xz=5.4.2=h5eee18b_0
|
||||||
|
- zlib=1.2.13=h5eee18b_0
|
||||||
|
- pip:
|
||||||
|
- certifi==2023.7.22
|
||||||
|
- charset-normalizer==3.3.0
|
||||||
|
- click==8.1.7
|
||||||
|
- filelock==3.12.4
|
||||||
|
- fsspec==2023.9.2
|
||||||
|
- grpcio==1.59.0
|
||||||
|
- huggingface-hub==0.17.3
|
||||||
|
- idna==3.4
|
||||||
|
- install==1.3.5
|
||||||
|
- jinja2==3.1.2
|
||||||
|
- joblib==1.3.2
|
||||||
|
- markupsafe==2.1.3
|
||||||
|
- mpmath==1.3.0
|
||||||
|
- networkx==3.1
|
||||||
|
- nltk==3.8.1
|
||||||
|
- numpy==1.26.0
|
||||||
|
- nvidia-cublas-cu12==12.1.3.1
|
||||||
|
- nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
- nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
- nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
- nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
- nvidia-cufft-cu12==11.0.2.54
|
||||||
|
- nvidia-curand-cu12==10.3.2.106
|
||||||
|
- nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
- nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
- nvidia-nccl-cu12==2.18.1
|
||||||
|
- nvidia-nvjitlink-cu12==12.2.140
|
||||||
|
- nvidia-nvtx-cu12==12.1.105
|
||||||
|
- packaging==23.2
|
||||||
|
- pillow==10.0.1
|
||||||
|
- protobuf==4.24.4
|
||||||
|
- pyyaml==6.0.1
|
||||||
|
- regex==2023.10.3
|
||||||
|
- requests==2.31.0
|
||||||
|
- safetensors==0.4.0
|
||||||
|
- scikit-learn==1.3.1
|
||||||
|
- scipy==1.11.3
|
||||||
|
- sentence-transformers==2.2.2
|
||||||
|
- sentencepiece==0.1.99
|
||||||
|
- sympy==1.12
|
||||||
|
- threadpoolctl==3.2.0
|
||||||
|
- tokenizers==0.14.1
|
||||||
|
- torch==2.1.0
|
||||||
|
- torchvision==0.16.0
|
||||||
|
- tqdm==4.66.1
|
||||||
|
- transformers==4.34.0
|
||||||
|
- triton==2.1.0
|
||||||
|
- typing-extensions==4.8.0
|
||||||
|
- urllib3==2.0.6
|
||||||
|
prefix: /opt/conda/envs/sentencetransformers
|
11
backend/python/sentencetransformers/test.sh
Normal file
11
backend/python/sentencetransformers/test.sh
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
##
|
||||||
|
## A bash script wrapper that runs the sentencetransformers server with conda
|
||||||
|
|
||||||
|
# Activate conda environment
|
||||||
|
source activate sentencetransformers
|
||||||
|
|
||||||
|
# get the directory where the bash script is located
|
||||||
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|
||||||
|
python -m unittest $DIR/test_sentencetransformers.py
|
@ -0,0 +1,81 @@
|
|||||||
|
"""
|
||||||
|
A test script to test the gRPC service
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackendServicer(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
TestBackendServicer is the class that tests the gRPC service
|
||||||
|
"""
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
This method sets up the gRPC service by starting the server
|
||||||
|
"""
|
||||||
|
self.service = subprocess.Popen(["python3", "sentencetransformers.py", "--addr", "localhost:50051"])
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
"""
|
||||||
|
This method tears down the gRPC service by terminating the server
|
||||||
|
"""
|
||||||
|
self.service.terminate()
|
||||||
|
self.service.wait()
|
||||||
|
|
||||||
|
def test_server_startup(self):
|
||||||
|
"""
|
||||||
|
This method tests if the server starts up successfully
|
||||||
|
"""
|
||||||
|
time.sleep(2)
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.Health(backend_pb2.HealthMessage())
|
||||||
|
self.assertEqual(response.message, b'OK')
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("Server failed to start")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_load_model(self):
|
||||||
|
"""
|
||||||
|
This method tests if the model is loaded successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
self.assertEqual(response.message, "Model loaded successfully")
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("LoadModel service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_embedding(self):
|
||||||
|
"""
|
||||||
|
This method tests if the embeddings are generated successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
|
||||||
|
embedding_response = stub.Embedding(embedding_request)
|
||||||
|
self.assertIsNotNone(embedding_response.embeddings)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("Embedding service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
@ -1,18 +1,18 @@
|
|||||||
.PONY: huggingface
|
.PONY: transformers
|
||||||
huggingface:
|
transformers:
|
||||||
@echo "Creating virtual environment..."
|
@echo "Creating virtual environment..."
|
||||||
@conda env create --name huggingface --file huggingface.yml
|
@conda env create --name transformers --file transformers.yml
|
||||||
@echo "Virtual environment created."
|
@echo "Virtual environment created."
|
||||||
|
|
||||||
.PONY: run
|
.PONY: run
|
||||||
run:
|
run:
|
||||||
@echo "Running huggingface..."
|
@echo "Running transformers..."
|
||||||
bash run.sh
|
bash run.sh
|
||||||
@echo "huggingface run."
|
@echo "transformers run."
|
||||||
|
|
||||||
# It is not working well by using command line. It only6 works with IDE like VSCode.
|
# It is not working well by using command line. It only6 works with IDE like VSCode.
|
||||||
.PONY: test
|
.PONY: test
|
||||||
test:
|
test:
|
||||||
@echo "Testing huggingface..."
|
@echo "Testing transformers..."
|
||||||
bash test.sh
|
bash test.sh
|
||||||
@echo "huggingface tested."
|
@echo "transformers tested."
|
5
backend/python/transformers/README.md
Normal file
5
backend/python/transformers/README.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Creating a separate environment for the transformers project
|
||||||
|
|
||||||
|
```
|
||||||
|
make transformers
|
||||||
|
```
|
61
backend/python/transformers/backend_pb2.py
Normal file
61
backend/python/transformers/backend_pb2.py
Normal file
File diff suppressed because one or more lines are too long
363
backend/python/transformers/backend_pb2_grpc.py
Normal file
363
backend/python/transformers/backend_pb2_grpc.py
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
import backend_pb2 as backend__pb2
|
||||||
|
|
||||||
|
|
||||||
|
class BackendStub(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.Health = channel.unary_unary(
|
||||||
|
'/backend.Backend/Health',
|
||||||
|
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.Reply.FromString,
|
||||||
|
)
|
||||||
|
self.Predict = channel.unary_unary(
|
||||||
|
'/backend.Backend/Predict',
|
||||||
|
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.Reply.FromString,
|
||||||
|
)
|
||||||
|
self.LoadModel = channel.unary_unary(
|
||||||
|
'/backend.Backend/LoadModel',
|
||||||
|
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.Result.FromString,
|
||||||
|
)
|
||||||
|
self.PredictStream = channel.unary_stream(
|
||||||
|
'/backend.Backend/PredictStream',
|
||||||
|
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.Reply.FromString,
|
||||||
|
)
|
||||||
|
self.Embedding = channel.unary_unary(
|
||||||
|
'/backend.Backend/Embedding',
|
||||||
|
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
||||||
|
)
|
||||||
|
self.GenerateImage = channel.unary_unary(
|
||||||
|
'/backend.Backend/GenerateImage',
|
||||||
|
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.Result.FromString,
|
||||||
|
)
|
||||||
|
self.AudioTranscription = channel.unary_unary(
|
||||||
|
'/backend.Backend/AudioTranscription',
|
||||||
|
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
||||||
|
)
|
||||||
|
self.TTS = channel.unary_unary(
|
||||||
|
'/backend.Backend/TTS',
|
||||||
|
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.Result.FromString,
|
||||||
|
)
|
||||||
|
self.TokenizeString = channel.unary_unary(
|
||||||
|
'/backend.Backend/TokenizeString',
|
||||||
|
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.TokenizationResponse.FromString,
|
||||||
|
)
|
||||||
|
self.Status = channel.unary_unary(
|
||||||
|
'/backend.Backend/Status',
|
||||||
|
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||||
|
response_deserializer=backend__pb2.StatusResponse.FromString,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackendServicer(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def Health(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Predict(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def PredictStream(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Embedding(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def GenerateImage(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def AudioTranscription(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def TTS(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def TokenizeString(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Status(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_BackendServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'Health': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Health,
|
||||||
|
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||||
|
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||||
|
),
|
||||||
|
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Predict,
|
||||||
|
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||||
|
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||||
|
),
|
||||||
|
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.LoadModel,
|
||||||
|
request_deserializer=backend__pb2.ModelOptions.FromString,
|
||||||
|
response_serializer=backend__pb2.Result.SerializeToString,
|
||||||
|
),
|
||||||
|
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
||||||
|
servicer.PredictStream,
|
||||||
|
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||||
|
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||||
|
),
|
||||||
|
'Embedding': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Embedding,
|
||||||
|
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||||
|
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
||||||
|
),
|
||||||
|
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.GenerateImage,
|
||||||
|
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
||||||
|
response_serializer=backend__pb2.Result.SerializeToString,
|
||||||
|
),
|
||||||
|
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.AudioTranscription,
|
||||||
|
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
||||||
|
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
||||||
|
),
|
||||||
|
'TTS': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.TTS,
|
||||||
|
request_deserializer=backend__pb2.TTSRequest.FromString,
|
||||||
|
response_serializer=backend__pb2.Result.SerializeToString,
|
||||||
|
),
|
||||||
|
'TokenizeString': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.TokenizeString,
|
||||||
|
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||||
|
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
'Status': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Status,
|
||||||
|
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||||
|
response_serializer=backend__pb2.StatusResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'backend.Backend', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class Backend(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Health(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
||||||
|
backend__pb2.HealthMessage.SerializeToString,
|
||||||
|
backend__pb2.Reply.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Predict(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
||||||
|
backend__pb2.PredictOptions.SerializeToString,
|
||||||
|
backend__pb2.Reply.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def LoadModel(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
||||||
|
backend__pb2.ModelOptions.SerializeToString,
|
||||||
|
backend__pb2.Result.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def PredictStream(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
||||||
|
backend__pb2.PredictOptions.SerializeToString,
|
||||||
|
backend__pb2.Reply.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Embedding(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
||||||
|
backend__pb2.PredictOptions.SerializeToString,
|
||||||
|
backend__pb2.EmbeddingResult.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def GenerateImage(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
||||||
|
backend__pb2.GenerateImageRequest.SerializeToString,
|
||||||
|
backend__pb2.Result.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def AudioTranscription(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
||||||
|
backend__pb2.TranscriptRequest.SerializeToString,
|
||||||
|
backend__pb2.TranscriptResult.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def TTS(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
||||||
|
backend__pb2.TTSRequest.SerializeToString,
|
||||||
|
backend__pb2.Result.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def TokenizeString(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
|
||||||
|
backend__pb2.PredictOptions.SerializeToString,
|
||||||
|
backend__pb2.TokenizationResponse.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Status(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
|
||||||
|
backend__pb2.HealthMessage.SerializeToString,
|
||||||
|
backend__pb2.StatusResponse.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
@ -1,14 +1,14 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
##
|
##
|
||||||
## A bash script wrapper that runs the huggingface server with conda
|
## A bash script wrapper that runs the transformers server with conda
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
# Activate conda environment
|
# Activate conda environment
|
||||||
source activate huggingface
|
source activate transformers
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
# get the directory where the bash script is located
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|
||||||
python $DIR/huggingface.py $@
|
python $DIR/transformers.py $@
|
@ -1,11 +1,11 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
##
|
##
|
||||||
## A bash script wrapper that runs the huggingface server with conda
|
## A bash script wrapper that runs the transformers server with conda
|
||||||
|
|
||||||
# Activate conda environment
|
# Activate conda environment
|
||||||
source activate huggingface
|
source activate transformers
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
# get the directory where the bash script is located
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|
||||||
python -m unittest $DIR/test_huggingface.py
|
python -m unittest $DIR/test_transformers.py
|
@ -18,7 +18,7 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
This method sets up the gRPC service by starting the server
|
This method sets up the gRPC service by starting the server
|
||||||
"""
|
"""
|
||||||
self.service = subprocess.Popen(["python3", "huggingface.py", "--addr", "localhost:50051"])
|
self.service = subprocess.Popen(["python3", "transformers.py", "--addr", "localhost:50051"])
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
"""
|
"""
|
114
backend/python/transformers/transformers.py
Executable file
114
backend/python/transformers/transformers.py
Executable file
@ -0,0 +1,114 @@
|
|||||||
|
"""
|
||||||
|
Extra gRPC server for HuggingFace SentenceTransformer models.
|
||||||
|
"""
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import time
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
|
|
||||||
|
# Implement the BackendServicer class with the service methods
|
||||||
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
"""
|
||||||
|
A gRPC servicer for the backend service.
|
||||||
|
|
||||||
|
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
||||||
|
"""
|
||||||
|
def Health(self, request, context):
|
||||||
|
"""
|
||||||
|
A gRPC method that returns the health status of the backend service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: A HealthRequest object that contains the request parameters.
|
||||||
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Reply object that contains the health status of the backend service.
|
||||||
|
"""
|
||||||
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||||
|
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
"""
|
||||||
|
A gRPC method that loads a model into memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: A LoadModelRequest object that contains the request parameters.
|
||||||
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Result object that contains the result of the LoadModel operation.
|
||||||
|
"""
|
||||||
|
model_name = request.Model
|
||||||
|
try:
|
||||||
|
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
|
||||||
|
except Exception as err:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
|
||||||
|
# Implement your logic here for the LoadModel service
|
||||||
|
# Replace this with your desired response
|
||||||
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
|
def Embedding(self, request, context):
|
||||||
|
"""
|
||||||
|
A gRPC method that calculates embeddings for a given sentence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: An EmbeddingRequest object that contains the request parameters.
|
||||||
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An EmbeddingResult object that contains the calculated embeddings.
|
||||||
|
"""
|
||||||
|
# Implement your logic here for the Embedding service
|
||||||
|
# Replace this with your desired response
|
||||||
|
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||||
|
sentence_embeddings = self.model.encode(request.Embeddings)
|
||||||
|
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def serve(address):
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||||
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
server.add_insecure_port(address)
|
||||||
|
server.start()
|
||||||
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||||
|
|
||||||
|
# Define the signal handler function
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
print("Received termination signal. Shutting down...")
|
||||||
|
server.stop(0)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Set the signal handlers for SIGINT and SIGTERM
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
serve(args.addr)
|
@ -1,4 +1,4 @@
|
|||||||
name: huggingface
|
name: transformers
|
||||||
channels:
|
channels:
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
@ -74,4 +74,4 @@ dependencies:
|
|||||||
- triton==2.1.0
|
- triton==2.1.0
|
||||||
- typing-extensions==4.8.0
|
- typing-extensions==4.8.0
|
||||||
- urllib3==2.0.6
|
- urllib3==2.0.6
|
||||||
prefix: /opt/conda/envs/huggingface
|
prefix: /opt/conda/envs/transformers
|
Loading…
Reference in New Issue
Block a user