This commit is contained in:
Saifeddine ALOUI 2024-05-08 01:14:07 +02:00
parent d76a96cbb8
commit 8ca192ea64

View File

@ -24,6 +24,7 @@ import string
import json
from enum import Enum
import asyncio
from datetime import datetime
def _generate_id(length=10):
@ -549,13 +550,32 @@ async def ollama_chat_completion(request: ChatGenerationRequest):
return {"status":False,"error":str(ex)}
class CompletionGenerationRequest(BaseModel):
model: Optional[str] = ""
prompt: str = ""
model: Optional[str] = None
prompt: str
max_tokens: Optional[int] = -1
stream: Optional[bool] = False
temperature: Optional[float] = -1
mirostat: Optional[int] = None
mirostat_eta: Optional[float] = None
mirostat_tau: Optional[float] = None
num_ctx: Optional[int] = None
repeat_last_n: Optional[int] = -1
repeat_penalty: Optional[float] = -1
seed: Optional[int] = None
stop: Optional[str] = None
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
top_k: Optional[int] = -1
top_p: Optional[float] = -1
format: Optional[str] = None # Added as per new request
system: Optional[str] = None # Added as per new request
template: Optional[str] = None # Added as per new request
context: Optional[str] = None # Added as per new request
raw: Optional[bool] = None # Added as per new request
keep_alive: Optional[str] = None # Added as per new request
@router.options("/api/generate")
@router.post("/api/generate")
async def ollama_generate(request: CompletionGenerationRequest):
"""
@ -565,15 +585,27 @@ async def ollama_generate(request: CompletionGenerationRequest):
:return: A JSON response with the status of the operation.
"""
try:
start_time = time.perf_counter_ns()
ASCIIColors.cyan("> Ollama Server emulator: Received request")
text = request.prompt
n_predict = request.max_tokens if request.max_tokens>=0 else elf_server.config.max_n_predict
temperature = request.temperature if request.temperature>=0 else elf_server.config.temperature
# top_k = request.top_k if request.top_k>=0 else elf_server.config.top_k
# top_p = request.top_p if request.top_p>=0 else elf_server.config.top_p
# repeat_last_n = request.repeat_last_n if request.repeat_last_n>=0 else elf_server.config.repeat_last_n
# repeat_penalty = request.repeat_penalty if request.repeat_penalty>=0 else elf_server.config.repeat_penalty
request.max_tokens = request.max_tokens if request.max_tokens>0 else elf_server.config.ctx_size
n_predict = request.max_tokens if request.max_tokens>0 else elf_server.config.max_n_predict
temperature = request.temperature if request.temperature>0 else elf_server.config.temperature
top_k = request.top_k if request.top_k>0 else elf_server.config.top_k
top_p = request.top_p if request.top_p>0 else elf_server.config.top_p
repeat_last_n = request.repeat_last_n if request.repeat_last_n>0 else elf_server.config.repeat_last_n
repeat_penalty = request.repeat_penalty if request.repeat_penalty>0 else elf_server.config.repeat_penalty
stream = request.stream
created_at = datetime.now().isoformat()
response_data = {
"model": request.model if request.model is not None else "llama3",
"created_at": created_at,
"response": "",
"done": False,
"context": [1, 2, 3], # Placeholder for actual context
}
ASCIIColors.cyan("> Processing ...")
if elf_server.binding is not None:
if stream:
output = {"text":""}
@ -595,17 +627,20 @@ async def ollama_generate(request: CompletionGenerationRequest):
callback=callback,
temperature=temperature,
))
ASCIIColors.success("> Streaming ...")
return StreamingResponse(generate_chunks())
else:
output = {"text":""}
def callback(chunk, chunk_type:MSG_TYPE=MSG_TYPE.MSG_TYPE_CHUNK):
if chunk is None:
return
# Yield each chunk of data
output["text"] += chunk
antiprompt = detect_antiprompt(output["text"])
if antiprompt:
ASCIIColors.warning(f"\n{antiprompt} detected. Stopping generation")
output["text"] = remove_text_from_string(output["text"],antiprompt)
ASCIIColors.success("Done")
return False
else:
return True
@ -615,8 +650,18 @@ async def ollama_generate(request: CompletionGenerationRequest):
callback=callback,
temperature=request.temperature if request.temperature>=0 else elf_server.config.temperature
)
return output["text"]
ASCIIColors.success("> Done")
response_data["total_duration"] = time.perf_counter_ns() - start_time
response_data["load_duration"] = 0
response_data["prompt_eval_count"] = len(request.prompt.split())
response_data["prompt_eval_duration"] = time.perf_counter_ns() - start_time
response_data["eval_count"] = len(elf_server.binding.tokenize(output["text"])) # Simulated number of tokens in the response
response_data["eval_duration"] = time.perf_counter_ns() - start_time
response_data["response"] = output["text"]
response_data["done"] = True
return response_data
else:
ASCIIColors.error("> Failed")
return None
except Exception as ex:
trace_exception(ex)
@ -733,114 +778,6 @@ async def ollama_completion(request: CompletionGenerationRequest):
return {"status":False,"error":str(ex)}
@router.post("/api/generate")
async def ollama_chat(request: CompletionGenerationRequest):
"""
Executes Python code and returns the output.
:param request: The HTTP request object.
:return: A JSON response with the status of the operation.
"""
try:
reception_manager=RECEPTION_MANAGER()
prompt = request.prompt
n_predict = request.max_tokens if request.max_tokens>=0 else elf_server.config.max_n_predict
temperature = request.temperature if request.temperature>=0 else elf_server.config.temperature
# top_k = request.top_k if request.top_k>=0 else elf_server.config.top_k
# top_p = request.top_p if request.top_p>=0 else elf_server.config.top_p
# repeat_last_n = request.repeat_last_n if request.repeat_last_n>=0 else elf_server.config.repeat_last_n
# repeat_penalty = request.repeat_penalty if request.repeat_penalty>=0 else elf_server.config.repeat_penalty
stream = request.stream
headers = { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive',}
if elf_server.binding is not None:
if stream:
new_output={"new_values":[]}
async def generate_chunks():
lk = threading.Lock()
def callback(chunk, chunk_type:MSG_TYPE=MSG_TYPE.MSG_TYPE_CHUNK):
if elf_server.cancel_gen:
return False
if chunk is None:
return
rx = reception_manager.new_chunk(chunk)
if rx.status!=ROLE_CHANGE_DECISION.MOVE_ON:
if rx.status==ROLE_CHANGE_DECISION.PROGRESSING:
return True
elif rx.status==ROLE_CHANGE_DECISION.ROLE_CHANGED:
return False
else:
chunk = chunk + rx.value
# Yield each chunk of data
lk.acquire()
try:
new_output["new_values"].append(reception_manager.chunk)
lk.release()
return True
except Exception as ex:
trace_exception(ex)
lk.release()
return False
def chunks_builder():
if request.model in elf_server.binding.list_models() and elf_server.binding.model_name!=request.model:
elf_server.binding.build_model(request.model)
elf_server.binding.generate(
prompt,
n_predict,
callback=callback,
temperature=temperature or elf_server.config.temperature
)
reception_manager.done = True
thread = threading.Thread(target=chunks_builder)
thread.start()
current_index = 0
while (not reception_manager.done and elf_server.cancel_gen == False):
while (not reception_manager.done and len(new_output["new_values"])==0):
time.sleep(0.001)
lk.acquire()
for i in range(len(new_output["new_values"])):
current_index += 1
yield (json.dumps({"response":new_output["new_values"][i]})+"\n").encode("utf-8")
new_output["new_values"]=[]
lk.release()
elf_server.cancel_gen = False
return StreamingResponse(generate_chunks(), media_type="application/json", headers=headers)
else:
def callback(chunk, chunk_type:MSG_TYPE=MSG_TYPE.MSG_TYPE_CHUNK):
# Yield each chunk of data
if chunk is None:
return True
rx = reception_manager.new_chunk(chunk)
if rx.status!=ROLE_CHANGE_DECISION.MOVE_ON:
if rx.status==ROLE_CHANGE_DECISION.PROGRESSING:
return True
elif rx.status==ROLE_CHANGE_DECISION.ROLE_CHANGED:
return False
else:
chunk = chunk + rx.value
return True
elf_server.binding.generate(
prompt,
n_predict,
callback=callback,
temperature=request.temperature or elf_server.config.temperature
)
return json.dumps(reception_manager.reception_buffer).encode("utf-8")
except Exception as ex:
trace_exception(ex)
elf_server.error(ex)
return {"status":False,"error":str(ex)}
@router.post("/v1/completions")
async def v1_completion(request: CompletionGenerationRequest):
"""