feat(python-grpc): allow to set max workers with PYTHON_GRPC_MAX_WORKERS (#1081)

**Description**

this allows to customize the maximum number of grpc workers for python
backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-09-19 21:30:39 +02:00 committed by GitHub
parent 453e9c5da9
commit bdf3f95346
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 34 additions and 8 deletions

5
.env
View File

@ -62,3 +62,8 @@ MODELS_PATH=/models
### Huggingface cache for models ### Huggingface cache for models
# HUGGINGFACE_HUB_CACHE=/usr/local/huggingface # HUGGINGFACE_HUB_CACHE=/usr/local/huggingface
### Python backends GRPC max workers
### Default number of workers for GRPC Python backends.
### This actually controls wether a backend can process multiple requests or not.
# PYTHON_GRPC_MAX_WORKERS=1

View File

@ -15,6 +15,9 @@ from transformers import TextGenerationPipeline
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods # Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context): def Health(self, request, context):
@ -77,7 +80,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@ -15,6 +15,9 @@ from scipy.io.wavfile import write as write_wav
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods # Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context): def Health(self, request, context):
@ -51,7 +54,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=True) return backend_pb2.Result(success=True)
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@ -26,6 +26,9 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
COMPEL=os.environ.get("COMPEL", "1") == "1" COMPEL=os.environ.get("COMPEL", "1") == "1"
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1" CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287 # https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
def sc(self, clip_input, images) : return images, [False for i in images] def sc(self, clip_input, images) : return images, [False for i in images]
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values # edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
@ -346,7 +349,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(message="Model loaded successfully", success=True) return backend_pb2.Result(message="Model loaded successfully", success=True)
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@ -19,6 +19,9 @@ from exllama.tokenizer import ExLlamaTokenizer
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods # Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def generate(self,prompt, max_new_tokens): def generate(self,prompt, max_new_tokens):
@ -110,7 +113,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@ -12,6 +12,9 @@ from sentence_transformers import SentenceTransformer
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods # Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context): def Health(self, request, context):
@ -34,7 +37,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@ -16,6 +16,9 @@ from utils.prompt_making import make_prompt
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods # Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context): def Health(self, request, context):
@ -65,7 +68,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=True) return backend_pb2.Result(success=True)
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()

View File

@ -14,6 +14,9 @@ from vllm import LLM, SamplingParams
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods # Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer): class BackendServicer(backend_pb2_grpc.BackendServicer):
def generate(self,prompt, max_new_tokens): def generate(self,prompt, max_new_tokens):
@ -70,7 +73,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return self.Predict(request, context) return self.Predict(request, context)
def serve(address): def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address) server.add_insecure_port(address)
server.start() server.start()