enhanced generation

This commit is contained in:
Saifeddine ALOUI 2024-01-03 19:55:02 +01:00
parent c97a61a9e2
commit 52532df832

View File

@ -2,7 +2,9 @@ from fastapi import APIRouter
from lollms.server.elf_server import LOLLMSElfServer
from pydantic import BaseModel
from starlette.responses import StreamingResponse
from lollms.types import MSG_TYPE
from lollms.utilities import detect_antiprompt, remove_text_from_string
from ascii_colors import ASCIIColors
class GenerateRequest(BaseModel):
text: str
n_predict: int = 1024
@ -19,16 +21,35 @@ def generate(request_data: GenerateRequest):
if elf_server.binding is not None:
if stream:
output = {"text":""}
def generate_chunks():
def callback(chunk):
def callback(chunk, chunk_type:MSG_TYPE=MSG_TYPE.MSG_TYPE_CHUNK):
# Yield each chunk of data
yield chunk
elf_server.binding.generate(text, n_predict, callback=callback)
output["text"] += chunk
antiprompt = detect_antiprompt(output["text"])
if antiprompt:
ASCIIColors.warning(f"\nDetected hallucination with antiprompt: {antiprompt}")
output["text"] = remove_text_from_string(output["text"],antiprompt)
return False
else:
yield chunk
return True
return iter(elf_server.binding.generate(text, n_predict, callback=callback))
return StreamingResponse(generate_chunks())
else:
output = elf_server.binding.generate(text, n_predict)
return output
output = {"text":""}
def callback(chunk, chunk_type:MSG_TYPE=MSG_TYPE.MSG_TYPE_CHUNK):
# Yield each chunk of data
output["text"] += chunk
antiprompt = detect_antiprompt(output["text"])
if antiprompt:
ASCIIColors.warning(f"\nDetected hallucination with antiprompt: {antiprompt}")
output["text"] = remove_text_from_string(output["text"],antiprompt)
return False
else:
return True
elf_server.binding.generate(text, n_predict, callback=callback)
return output["text"]
else:
return None