diff --git a/lollms/server/endpoints/lollms_generator.py b/lollms/server/endpoints/lollms_generator.py index 884ef70..09487cb 100644 --- a/lollms/server/endpoints/lollms_generator.py +++ b/lollms/server/endpoints/lollms_generator.py @@ -2,7 +2,9 @@ from fastapi import APIRouter from lollms.server.elf_server import LOLLMSElfServer from pydantic import BaseModel from starlette.responses import StreamingResponse - +from lollms.types import MSG_TYPE +from lollms.utilities import detect_antiprompt, remove_text_from_string +from ascii_colors import ASCIIColors class GenerateRequest(BaseModel): text: str n_predict: int = 1024 @@ -19,16 +21,35 @@ def generate(request_data: GenerateRequest): if elf_server.binding is not None: if stream: + output = {"text":""} def generate_chunks(): - def callback(chunk): + def callback(chunk, chunk_type:MSG_TYPE=MSG_TYPE.MSG_TYPE_CHUNK): # Yield each chunk of data - yield chunk - - elf_server.binding.generate(text, n_predict, callback=callback) + output["text"] += chunk + antiprompt = detect_antiprompt(output["text"]) + if antiprompt: + ASCIIColors.warning(f"\nDetected hallucination with antiprompt: {antiprompt}") + output["text"] = remove_text_from_string(output["text"],antiprompt) + return False + else: + yield chunk + return True + return iter(elf_server.binding.generate(text, n_predict, callback=callback)) return StreamingResponse(generate_chunks()) else: - output = elf_server.binding.generate(text, n_predict) - return output + output = {"text":""} + def callback(chunk, chunk_type:MSG_TYPE=MSG_TYPE.MSG_TYPE_CHUNK): + # Yield each chunk of data + output["text"] += chunk + antiprompt = detect_antiprompt(output["text"]) + if antiprompt: + ASCIIColors.warning(f"\nDetected hallucination with antiprompt: {antiprompt}") + output["text"] = remove_text_from_string(output["text"],antiprompt) + return False + else: + return True + elf_server.binding.generate(text, n_predict, callback=callback) + return output["text"] else: return None