feat: Token Stream support for Transformer, fix: missing package for OpenVINO (#1908)

* Streaming working

* Small fix for regression on CUDA and XPU

* use pip version of optimum[openvino]

* Update backend/python/transformers/transformers_server.py

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

* Token streaming support

fix optimum[openvino] package in install.sh

* Token Streaming support

---------

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
fakezeta 2024-03-27 17:50:35 +01:00 committed by GitHub
parent e7cbe32601
commit 8210ffcb6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 48 deletions

View File

@ -25,7 +25,7 @@ if [ -d "/opt/intel" ]; then
# Intel GPU: If the directory exists, we assume we are using the intel image # Intel GPU: If the directory exists, we assume we are using the intel image
# (no conda env) # (no conda env)
# https://github.com/intel/intel-extension-for-pytorch/issues/538 # https://github.com/intel/intel-extension-for-pytorch/issues/538
pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed optimum[openvino]
fi fi
if [ "$PIP_CACHE_PURGE" = true ] ; then if [ "$PIP_CACHE_PURGE" = true ] ; then

View File

@ -9,6 +9,7 @@ import signal
import sys import sys
import os import os
from threading import Thread from threading import Thread
import asyncio
import time import time
import backend_pb2 import backend_pb2
@ -205,17 +206,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
print("Embeddings:", sentence_embeddings, file=sys.stderr) print("Embeddings:", sentence_embeddings, file=sys.stderr)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0]) return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
def Predict(self, request, context, streaming=False): async def _predict(self, request, context, streaming=False):
"""
Generates text based on the given prompt and sampling parameters.
Args:
request: The predict request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The predict result.
"""
set_seed(request.Seed) set_seed(request.Seed)
if request.TopP == 0: if request.TopP == 0:
request.TopP = 0.9 request.TopP = 0.9
@ -248,10 +239,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
thread=Thread(target=self.model.generate, kwargs=config) thread=Thread(target=self.model.generate, kwargs=config)
thread.start() thread.start()
generated_text = "" generated_text = ""
try:
for new_text in streamer: for new_text in streamer:
generated_text += new_text generated_text += new_text
yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8')) yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8'))
finally:
thread.join()
else: else:
if XPU and self.OV == False:
outputs = self.model.generate(inputs["input_ids"], outputs = self.model.generate(inputs["input_ids"],
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
temperature=request.Temperature, temperature=request.Temperature,
@ -259,10 +254,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
top_k=request.TopK, top_k=request.TopK,
do_sample=True, do_sample=True,
pad_token=self.tokenizer.eos_token_id) pad_token=self.tokenizer.eos_token_id)
else:
outputs = self.model.generate(inputs["input_ids"],
max_new_tokens=max_tokens,
temperature=request.Temperature,
top_p=request.TopP,
top_k=request.TopK,
do_sample=True,
attention_mask=inputs["attention_mask"],
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id)
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
def PredictStream(self, request, context): if streaming:
return
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
async def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters.
Args:
request: The predict request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The predict result.
"""
gen = self._predict(request, context, streaming=False)
res = await gen.__anext__()
return res
async def PredictStream(self, request, context):
""" """
Generates text based on the given prompt and sampling parameters, and streams the results. Generates text based on the given prompt and sampling parameters, and streams the results.
@ -273,33 +297,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns: Returns:
backend_pb2.Result: The predict stream result. backend_pb2.Result: The predict stream result.
""" """
iterations = self.Predict(request, context, streaming=True) iterations = self._predict(request, context, streaming=True)
for iteration in iterations:
yield iteration
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try: try:
while True: async for iteration in iterations:
time.sleep(_ONE_DAY_IN_SECONDS) yield iteration
except KeyboardInterrupt: finally:
server.stop(0) await iterations.aclose()
async def serve(address):
# Start asyncio gRPC server
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address
server.add_insecure_port(address)
# Gracefully shutdown the server on SIGTERM or SIGINT
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(
sig, lambda: asyncio.ensure_future(server.stop(5))
)
# Start the server
await server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Wait for the server to be terminated
await server.wait_for_termination()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.") parser = argparse.ArgumentParser(description="Run the gRPC server.")
@ -308,4 +332,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
serve(args.addr) asyncio.run(serve(args.addr))