mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-13 08:19:57 +00:00
68fc014c6d
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
261 lines
9.4 KiB
Python
261 lines
9.4 KiB
Python
#!/usr/bin/env python3
|
|
import asyncio
|
|
from concurrent import futures
|
|
import argparse
|
|
import signal
|
|
import sys
|
|
import os
|
|
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
import grpc
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.utils import random_uuid
|
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
|
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
|
|
# Implement the BackendServicer class with the service methods
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
"""
|
|
A gRPC servicer that implements the Backend service defined in backend.proto.
|
|
"""
|
|
def generate(self,prompt, max_new_tokens):
|
|
"""
|
|
Generates text based on the given prompt and maximum number of new tokens.
|
|
|
|
Args:
|
|
prompt (str): The prompt to generate text from.
|
|
max_new_tokens (int): The maximum number of new tokens to generate.
|
|
|
|
Returns:
|
|
str: The generated text.
|
|
"""
|
|
self.generator.end_beam_search()
|
|
|
|
# Tokenizing the input
|
|
ids = self.generator.tokenizer.encode(prompt)
|
|
|
|
self.generator.gen_begin_reuse(ids)
|
|
initial_len = self.generator.sequence[0].shape[0]
|
|
has_leading_space = False
|
|
decoded_text = ''
|
|
for i in range(max_new_tokens):
|
|
token = self.generator.gen_single_token()
|
|
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
|
has_leading_space = True
|
|
|
|
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
|
if has_leading_space:
|
|
decoded_text = ' ' + decoded_text
|
|
|
|
if token.item() == self.generator.tokenizer.eos_token_id:
|
|
break
|
|
return decoded_text
|
|
|
|
def Health(self, request, context):
|
|
"""
|
|
Returns a health check message.
|
|
|
|
Args:
|
|
request: The health check request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Reply: The health check reply.
|
|
"""
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
|
|
async def LoadModel(self, request, context):
|
|
"""
|
|
Loads a language model.
|
|
|
|
Args:
|
|
request: The load model request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Result: The load model result.
|
|
"""
|
|
engine_args = AsyncEngineArgs(
|
|
model=request.Model,
|
|
)
|
|
|
|
if request.Quantization != "":
|
|
engine_args.quantization = request.Quantization
|
|
if request.GPUMemoryUtilization != 0:
|
|
engine_args.gpu_memory_utilization = request.GPUMemoryUtilization
|
|
if request.TrustRemoteCode:
|
|
engine_args.trust_remote_code = request.TrustRemoteCode
|
|
if request.EnforceEager:
|
|
engine_args.enforce_eager = request.EnforceEager
|
|
if request.TensorParallelSize:
|
|
engine_args.tensor_parallel_size = request.TensorParallelSize
|
|
if request.SwapSpace != 0:
|
|
engine_args.swap_space = request.SwapSpace
|
|
if request.MaxModelLen != 0:
|
|
engine_args.max_model_len = request.MaxModelLen
|
|
|
|
try:
|
|
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
|
|
except Exception as err:
|
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
|
|
try:
|
|
engine_model_config = await self.llm.get_model_config()
|
|
self.tokenizer = get_tokenizer(
|
|
engine_model_config.tokenizer,
|
|
tokenizer_mode=engine_model_config.tokenizer_mode,
|
|
trust_remote_code=engine_model_config.trust_remote_code,
|
|
truncation_side="left",
|
|
)
|
|
except Exception as err:
|
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
|
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
|
|
async def Predict(self, request, context):
|
|
"""
|
|
Generates text based on the given prompt and sampling parameters.
|
|
|
|
Args:
|
|
request: The predict request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Reply: The predict result.
|
|
"""
|
|
gen = self._predict(request, context, streaming=False)
|
|
res = await gen.__anext__()
|
|
return res
|
|
|
|
def Embedding(self, request, context):
|
|
"""
|
|
A gRPC method that calculates embeddings for a given sentence.
|
|
|
|
Args:
|
|
request: An EmbeddingRequest object that contains the request parameters.
|
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
|
|
Returns:
|
|
An EmbeddingResult object that contains the calculated embeddings.
|
|
"""
|
|
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
|
outputs = self.model.encode(request.Embeddings)
|
|
# Check if we have one result at least
|
|
if len(outputs) == 0:
|
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
|
context.set_details("No embeddings were calculated.")
|
|
return backend_pb2.EmbeddingResult()
|
|
return backend_pb2.EmbeddingResult(embeddings=outputs[0].outputs.embedding)
|
|
|
|
async def PredictStream(self, request, context):
|
|
"""
|
|
Generates text based on the given prompt and sampling parameters, and streams the results.
|
|
|
|
Args:
|
|
request: The predict stream request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Result: The predict stream result.
|
|
"""
|
|
iterations = self._predict(request, context, streaming=True)
|
|
try:
|
|
async for iteration in iterations:
|
|
yield iteration
|
|
finally:
|
|
await iterations.aclose()
|
|
|
|
async def _predict(self, request, context, streaming=False):
|
|
|
|
# Build sampling parameters
|
|
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
|
if request.TopP != 0:
|
|
sampling_params.top_p = request.TopP
|
|
if request.Tokens > 0:
|
|
sampling_params.max_tokens = request.Tokens
|
|
if request.Temperature != 0:
|
|
sampling_params.temperature = request.Temperature
|
|
if request.TopK != 0:
|
|
sampling_params.top_k = request.TopK
|
|
if request.PresencePenalty != 0:
|
|
sampling_params.presence_penalty = request.PresencePenalty
|
|
if request.FrequencyPenalty != 0:
|
|
sampling_params.frequency_penalty = request.FrequencyPenalty
|
|
if request.StopPrompts:
|
|
sampling_params.stop = request.StopPrompts
|
|
if request.IgnoreEOS:
|
|
sampling_params.ignore_eos = request.IgnoreEOS
|
|
if request.Seed != 0:
|
|
sampling_params.seed = request.Seed
|
|
|
|
prompt = request.Prompt
|
|
|
|
# If tokenizer template is enabled and messages are provided instead of prompt apply the tokenizer template
|
|
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
|
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
# Generate text
|
|
request_id = random_uuid()
|
|
outputs = self.llm.generate(prompt, sampling_params, request_id)
|
|
|
|
# Stream the results
|
|
generated_text = ""
|
|
try:
|
|
async for request_output in outputs:
|
|
iteration_text = request_output.outputs[0].text
|
|
|
|
if streaming:
|
|
# Remove text already sent as vllm concatenates the text from previous yields
|
|
delta_iteration_text = iteration_text.removeprefix(generated_text)
|
|
# Send the partial result
|
|
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
|
|
|
|
# Keep track of text generated
|
|
generated_text = iteration_text
|
|
finally:
|
|
await outputs.aclose()
|
|
|
|
# If streaming, we already sent everything
|
|
if streaming:
|
|
return
|
|
|
|
# Sending the final generated text
|
|
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
|
|
|
async def serve(address):
|
|
# Start asyncio gRPC server
|
|
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
# Add the servicer to the server
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
# Bind the server to the address
|
|
server.add_insecure_port(address)
|
|
|
|
# Gracefully shutdown the server on SIGTERM or SIGINT
|
|
loop = asyncio.get_event_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(
|
|
sig, lambda: asyncio.ensure_future(server.stop(5))
|
|
)
|
|
|
|
# Start the server
|
|
await server.start()
|
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
# Wait for the server to be terminated
|
|
await server.wait_for_termination()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
parser.add_argument(
|
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
asyncio.run(serve(args.addr)) |