fix: vllm - use AsyncLLMEngine to allow true streaming mode (#1749)

* fix: use vllm AsyncLLMEngine to bring true stream

Current vLLM implementation uses the LLMEngine, which was designed for offline batch inference, which results in the streaming mode outputing all blobs at once at the end of the inference.

This PR reworks the gRPC server to use asyncio and gRPC.aio, in combination with vLLM's AsyncLLMEngine to bring true stream mode.

This PR also passes more parameters to vLLM during inference (presence_penalty, frequency_penalty, stop, ignore_eos, seed, ...).

* Remove unused import
This commit is contained in:
Ludovic Leroux 2024-02-24 05:48:45 -05:00 committed by GitHub
parent ff88c390bb
commit 0135e1e3b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
import asyncio
from concurrent import futures
import time
import argparse
import signal
import sys
@ -10,7 +10,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -79,16 +82,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
backend_pb2.Result: The load model result.
"""
engine_args = AsyncEngineArgs(
model=request.Model,
)
if request.Quantization != "":
engine_args.quantization = request.Quantization
try:
if request.Quantization != "":
self.llm = LLM(model=request.Model, quantization=request.Quantization)
else:
self.llm = LLM(model=request.Model)
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
def Predict(self, request, context):
async def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters.
@ -99,24 +106,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
backend_pb2.Reply: The predict result.
"""
if request.TopP == 0:
request.TopP = 0.9
gen = self._predict(request, context, streaming=False)
res = await gen.__anext__()
return res
max_tokens = 200
if request.Tokens > 0:
max_tokens = request.Tokens
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
outputs = self.llm.generate([request.Prompt], sampling_params)
generated_text = outputs[0].outputs[0].text
# Remove prompt from response if present
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
def PredictStream(self, request, context):
async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results.
@ -127,30 +121,84 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
backend_pb2.Result: The predict stream result.
"""
yield self.Predict(request, context)
iterations = self._predict(request, context, streaming=True)
try:
async for iteration in iterations:
yield iteration
finally:
await iterations.aclose()
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
async def _predict(self, request, context, streaming=False):
# Build sampling parameters
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
if request.TopP != 0:
sampling_params.top_p = request.TopP
if request.Tokens > 0:
sampling_params.max_tokens = request.Tokens
if request.Temperature != 0:
sampling_params.temperature = request.Temperature
if request.TopK != 0:
sampling_params.top_k = request.TopK
if request.PresencePenalty != 0:
sampling_params.presence_penalty = request.PresencePenalty
if request.FrequencyPenalty != 0:
sampling_params.frequency_penalty = request.FrequencyPenalty
if request.StopPrompts:
sampling_params.stop = request.StopPrompts
if request.IgnoreEOS:
sampling_params.ignore_eos = request.IgnoreEOS
if request.Seed != 0:
sampling_params.seed = request.Seed
# Generate text
request_id = random_uuid()
outputs = self.llm.generate(request.Prompt, sampling_params, request_id)
# Stream the results
generated_text = ""
try:
async for request_output in outputs:
iteration_text = request_output.outputs[0].text
if streaming:
# Remove text already sent as vllm concatenates the text from previous yields
delta_iteration_text = iteration_text.removeprefix(generated_text)
# Send the partial result
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
# Keep track of text generated
generated_text = iteration_text
finally:
await outputs.aclose()
# If streaming, we already sent everything
if streaming:
return
# Sending the final generated text
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
async def serve(address):
# Start asyncio gRPC server
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address
server.add_insecure_port(address)
server.start()
# Gracefully shutdown the server on SIGTERM or SIGINT
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(
sig, lambda: asyncio.ensure_future(server.stop(5))
)
# Start the server
await server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
# Wait for the server to be terminated
await server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
@ -159,4 +207,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
serve(args.addr)
asyncio.run(serve(args.addr))