mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-21 21:47:51 +00:00
9dddd1134d
* Fix python header comments for some extra gRPC backends When a Python script is to be executed directly via exec(3), either the platform knows how to execute the file itself (i.e. special configuration is necessary) or the first line contains a shebang (#!) specifying the interpreter to run it (similar to shell scripts). The shebang MUST be on the first line for the script to work on all platforms, so any header comments need to be in the lines following it. Otherwise executing these scripts as extra backends will yield an "exec format error" message. Changes: * Move introductory comments below the shebang line * Change header comment in transformers.py to refer to the correct python module Signed-off-by: Marcus Köhler <khler.marcus@gmail.com> * Make header comment in ttsbark.py more specific Signed-off-by: Marcus Köhler <khler.marcus@gmail.com> --------- Signed-off-by: Marcus Köhler <khler.marcus@gmail.com>
115 lines
3.8 KiB
Python
Executable File
115 lines
3.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Extra gRPC server for HuggingFace SentenceTransformer models.
|
|
"""
|
|
from concurrent import futures
|
|
|
|
import argparse
|
|
import signal
|
|
import sys
|
|
import os
|
|
|
|
import time
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
import grpc
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
|
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
|
|
# Implement the BackendServicer class with the service methods
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
"""
|
|
A gRPC servicer for the backend service.
|
|
|
|
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
|
"""
|
|
def Health(self, request, context):
|
|
"""
|
|
A gRPC method that returns the health status of the backend service.
|
|
|
|
Args:
|
|
request: A HealthRequest object that contains the request parameters.
|
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
|
|
Returns:
|
|
A Reply object that contains the health status of the backend service.
|
|
"""
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
|
|
def LoadModel(self, request, context):
|
|
"""
|
|
A gRPC method that loads a model into memory.
|
|
|
|
Args:
|
|
request: A LoadModelRequest object that contains the request parameters.
|
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
|
|
Returns:
|
|
A Result object that contains the result of the LoadModel operation.
|
|
"""
|
|
model_name = request.Model
|
|
try:
|
|
self.model = SentenceTransformer(model_name)
|
|
except Exception as err:
|
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
|
|
# Implement your logic here for the LoadModel service
|
|
# Replace this with your desired response
|
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
|
|
def Embedding(self, request, context):
|
|
"""
|
|
A gRPC method that calculates embeddings for a given sentence.
|
|
|
|
Args:
|
|
request: An EmbeddingRequest object that contains the request parameters.
|
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
|
|
Returns:
|
|
An EmbeddingResult object that contains the calculated embeddings.
|
|
"""
|
|
# Implement your logic here for the Embedding service
|
|
# Replace this with your desired response
|
|
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
|
sentence_embeddings = self.model.encode(request.Embeddings)
|
|
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
|
|
|
|
|
|
def serve(address):
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
server.add_insecure_port(address)
|
|
server.start()
|
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
|
|
# Define the signal handler function
|
|
def signal_handler(sig, frame):
|
|
print("Received termination signal. Shutting down...")
|
|
server.stop(0)
|
|
sys.exit(0)
|
|
|
|
# Set the signal handlers for SIGINT and SIGTERM
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
try:
|
|
while True:
|
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
except KeyboardInterrupt:
|
|
server.stop(0)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
parser.add_argument(
|
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
serve(args.addr)
|