mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-24 23:06:42 +00:00
feat: Token Stream support for Transformer, fix: missing package for OpenVINO (#1908)
* Streaming working * Small fix for regression on CUDA and XPU * use pip version of optimum[openvino] * Update backend/python/transformers/transformers_server.py Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> * Token streaming support fix optimum[openvino] package in install.sh * Token Streaming support --------- Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
e7cbe32601
commit
8210ffcb6c
@ -25,7 +25,7 @@ if [ -d "/opt/intel" ]; then
|
|||||||
# Intel GPU: If the directory exists, we assume we are using the intel image
|
# Intel GPU: If the directory exists, we assume we are using the intel image
|
||||||
# (no conda env)
|
# (no conda env)
|
||||||
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed
|
pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed optimum[openvino]
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
||||||
|
@ -9,6 +9,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
@ -205,17 +206,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
print("Embeddings:", sentence_embeddings, file=sys.stderr)
|
print("Embeddings:", sentence_embeddings, file=sys.stderr)
|
||||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
|
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
|
||||||
|
|
||||||
def Predict(self, request, context, streaming=False):
|
async def _predict(self, request, context, streaming=False):
|
||||||
"""
|
|
||||||
Generates text based on the given prompt and sampling parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The predict request.
|
|
||||||
context: The gRPC context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
backend_pb2.Reply: The predict result.
|
|
||||||
"""
|
|
||||||
set_seed(request.Seed)
|
set_seed(request.Seed)
|
||||||
if request.TopP == 0:
|
if request.TopP == 0:
|
||||||
request.TopP = 0.9
|
request.TopP = 0.9
|
||||||
@ -248,10 +239,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
thread=Thread(target=self.model.generate, kwargs=config)
|
thread=Thread(target=self.model.generate, kwargs=config)
|
||||||
thread.start()
|
thread.start()
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
|
try:
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
generated_text += new_text
|
generated_text += new_text
|
||||||
yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8'))
|
yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8'))
|
||||||
|
finally:
|
||||||
|
thread.join()
|
||||||
else:
|
else:
|
||||||
|
if XPU and self.OV == False:
|
||||||
outputs = self.model.generate(inputs["input_ids"],
|
outputs = self.model.generate(inputs["input_ids"],
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
temperature=request.Temperature,
|
temperature=request.Temperature,
|
||||||
@ -259,10 +254,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
top_k=request.TopK,
|
top_k=request.TopK,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
pad_token=self.tokenizer.eos_token_id)
|
pad_token=self.tokenizer.eos_token_id)
|
||||||
|
else:
|
||||||
|
outputs = self.model.generate(inputs["input_ids"],
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
temperature=request.Temperature,
|
||||||
|
top_p=request.TopP,
|
||||||
|
top_k=request.TopK,
|
||||||
|
do_sample=True,
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
pad_token_id=self.tokenizer.eos_token_id)
|
||||||
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
|
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
|
||||||
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
if streaming:
|
||||||
|
return
|
||||||
|
|
||||||
|
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||||
|
|
||||||
|
async def Predict(self, request, context):
|
||||||
|
"""
|
||||||
|
Generates text based on the given prompt and sampling parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The predict request.
|
||||||
|
context: The gRPC context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
backend_pb2.Reply: The predict result.
|
||||||
|
"""
|
||||||
|
gen = self._predict(request, context, streaming=False)
|
||||||
|
res = await gen.__anext__()
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def PredictStream(self, request, context):
|
||||||
"""
|
"""
|
||||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||||
|
|
||||||
@ -273,33 +297,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Result: The predict stream result.
|
backend_pb2.Result: The predict stream result.
|
||||||
"""
|
"""
|
||||||
iterations = self.Predict(request, context, streaming=True)
|
iterations = self._predict(request, context, streaming=True)
|
||||||
for iteration in iterations:
|
|
||||||
yield iteration
|
|
||||||
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
async for iteration in iterations:
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
yield iteration
|
||||||
except KeyboardInterrupt:
|
finally:
|
||||||
server.stop(0)
|
await iterations.aclose()
|
||||||
|
|
||||||
|
async def serve(address):
|
||||||
|
# Start asyncio gRPC server
|
||||||
|
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||||
|
# Add the servicer to the server
|
||||||
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
# Bind the server to the address
|
||||||
|
server.add_insecure_port(address)
|
||||||
|
|
||||||
|
# Gracefully shutdown the server on SIGTERM or SIGINT
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
loop.add_signal_handler(
|
||||||
|
sig, lambda: asyncio.ensure_future(server.stop(5))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the server
|
||||||
|
await server.start()
|
||||||
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||||
|
# Wait for the server to be terminated
|
||||||
|
await server.wait_for_termination()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||||
@ -308,4 +332,4 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
serve(args.addr)
|
asyncio.run(serve(args.addr))
|
Loading…
Reference in New Issue
Block a user