mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Only send the queries required when num_generations increases, per prompt
This commit is contained in:
parent
8714d78c76
commit
8de28cdac0
@ -362,7 +362,7 @@ const PromptNode = ({ data, id }) => {
|
||||
const total_num_resps = Object.keys(counts).reduce((acc, llm_name) => {
|
||||
return acc + counts[llm_name];
|
||||
}, 0);
|
||||
setProgress(total_num_resps / max_responses * 100);
|
||||
setProgress(Math.max(5, total_num_resps / max_responses * 100));
|
||||
});
|
||||
|
||||
// The process has finished; close the connection:
|
||||
@ -386,7 +386,7 @@ const PromptNode = ({ data, id }) => {
|
||||
temperature: 0.5,
|
||||
n: numGenerations,
|
||||
},
|
||||
no_cache: true,
|
||||
no_cache: false,
|
||||
}),
|
||||
}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
|
@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Tuple, Iterator, Union
|
||||
import json, os, asyncio, random, string
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json, extract_responses
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json, extract_responses, merge_response_objs
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
|
||||
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
|
||||
@ -59,18 +59,18 @@ class PromptPipeline:
|
||||
prompt_str = str(prompt)
|
||||
|
||||
cached_resp = responses[prompt_str] if prompt_str in responses else None
|
||||
extracted_resps = extract_responses(cached_resp, llm) if cached_resp is not None else []
|
||||
extracted_resps = cached_resp["responses"] if cached_resp is not None else []
|
||||
|
||||
# First check if there is already a response for this item. If so, we can save an LLM call:
|
||||
if cached_resp and len(extract_responses) >= n:
|
||||
if cached_resp and len(extracted_resps) >= n:
|
||||
print(f" - Found cache'd response for prompt {prompt_str}. Using...")
|
||||
yield {
|
||||
"prompt": prompt_str,
|
||||
"query": responses[prompt_str]["query"],
|
||||
"query": cached_resp["query"],
|
||||
"responses": extracted_resps[:n],
|
||||
"raw_response": cached_resp["raw_response"],
|
||||
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.value,
|
||||
"info": responses[prompt_str]["info"],
|
||||
"llm": cached_resp["llm"] if "llm" in cached_resp else LLM.ChatGPT.value,
|
||||
"info": cached_resp["info"],
|
||||
}
|
||||
continue
|
||||
|
||||
@ -80,24 +80,14 @@ class PromptPipeline:
|
||||
await asyncio.sleep(wait_secs)
|
||||
|
||||
# Call the LLM asynchronously to generate a response
|
||||
tasks.append(self._prompt_llm(llm, prompt, n, temperature))
|
||||
tasks.append(self._prompt_llm(llm, prompt, n, temperature, past_resp_obj=cached_resp))
|
||||
else:
|
||||
# Blocking. Await + yield a single LLM call.
|
||||
_, query, response = await self._prompt_llm(llm, prompt, n, temperature)
|
||||
_, query, response, past_resp_obj = await self._prompt_llm(llm, prompt, n, temperature, past_resp_obj=cached_resp)
|
||||
info = prompt.fill_history
|
||||
|
||||
# Save the response to a JSON file
|
||||
responses[str(prompt)] = {
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
"raw_response": response,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
self._cache_responses(responses)
|
||||
|
||||
# Yield the response
|
||||
yield {
|
||||
# Create a response obj to represent the response
|
||||
resp_obj = {
|
||||
"prompt": str(prompt),
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
@ -105,29 +95,29 @@ class PromptPipeline:
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
|
||||
# Merge the response obj with the past one, if necessary
|
||||
if past_resp_obj is not None:
|
||||
resp_obj = merge_response_objs(resp_obj, past_resp_obj)
|
||||
|
||||
# Save the response to a JSON file
|
||||
responses[resp_obj["prompt"]] = {key: val for key, val in resp_obj.items()}
|
||||
self._cache_responses(responses)
|
||||
|
||||
# Yield the response
|
||||
yield resp_obj
|
||||
|
||||
# Yield responses as they come in
|
||||
for task in asyncio.as_completed(tasks):
|
||||
# Collect the response from the earliest completed task
|
||||
prompt, query, response = await task
|
||||
prompt, query, response, past_resp_obj = await task
|
||||
|
||||
# Each prompt has a history of what was filled in from its base template.
|
||||
# This data --like, "class", "language", "library" etc --can be useful when parsing responses.
|
||||
info = prompt.fill_history
|
||||
|
||||
# Save the response to a JSON file
|
||||
# NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
|
||||
responses[str(prompt)] = {
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
"raw_response": response,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
self._cache_responses(responses)
|
||||
|
||||
# Yield the response
|
||||
yield {
|
||||
# Create a response obj to represent the response
|
||||
resp_obj = {
|
||||
"prompt": str(prompt),
|
||||
"query": query,
|
||||
"responses": extract_responses(response, llm),
|
||||
@ -135,6 +125,18 @@ class PromptPipeline:
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
|
||||
# Merge the response obj with the past one, if necessary
|
||||
if past_resp_obj is not None:
|
||||
resp_obj = merge_response_objs(resp_obj, past_resp_obj)
|
||||
|
||||
# Save the response to a JSON file
|
||||
# NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
|
||||
responses[resp_obj["prompt"]] = {key: val for key, val in resp_obj.items()}
|
||||
self._cache_responses(responses)
|
||||
|
||||
# Yield the response
|
||||
yield resp_obj
|
||||
|
||||
def _load_cached_responses(self) -> Dict:
|
||||
"""
|
||||
@ -155,7 +157,13 @@ class PromptPipeline:
|
||||
def clear_cached_responses(self) -> None:
|
||||
self._cache_responses({})
|
||||
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Union[List, Dict]]:
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0, past_resp_obj: Union[Dict, None] = None) -> Tuple[str, Dict, Union[List, Dict], Union[Dict, None]]:
|
||||
# Detect how many responses we have already (from cache obj past_resp_obj)
|
||||
if past_resp_obj is not None:
|
||||
# How many *new* queries we need to send:
|
||||
# NOTE: The check n > len(past_resp_obj["responses"]) should occur prior to calling this function.
|
||||
n = n - len(past_resp_obj["responses"])
|
||||
|
||||
if llm is LLM.ChatGPT or llm is LLM.GPT4:
|
||||
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
|
||||
elif llm is LLM.Alpaca7B:
|
||||
@ -164,7 +172,7 @@ class PromptPipeline:
|
||||
query, response = await call_anthropic(prompt=str(prompt), model=llm, n=n, temperature=temperature)
|
||||
else:
|
||||
raise Exception(f"Language model {llm} is not supported.")
|
||||
return prompt, query, response
|
||||
return prompt, query, response, past_resp_obj
|
||||
|
||||
|
||||
"""
|
||||
|
@ -191,6 +191,29 @@ def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List
|
||||
else:
|
||||
raise ValueError(f"LLM {llm} is unsupported.")
|
||||
|
||||
def merge_response_objs(resp_obj_A: Union[dict, None], resp_obj_B: Union[dict, None]) -> dict:
|
||||
if resp_obj_B is None:
|
||||
return resp_obj_A
|
||||
elif resp_obj_A is None:
|
||||
return resp_obj_B
|
||||
raw_resp_A = resp_obj_A["raw_response"]
|
||||
raw_resp_B = resp_obj_B["raw_response"]
|
||||
if not isinstance(raw_resp_A, list):
|
||||
raw_resp_A = [ raw_resp_A ]
|
||||
if not isinstance(raw_resp_B, list):
|
||||
raw_resp_B = [ raw_resp_B ]
|
||||
C = {
|
||||
"responses": resp_obj_A["responses"] + resp_obj_B["responses"],
|
||||
"raw_response": raw_resp_A + raw_resp_B,
|
||||
}
|
||||
return {
|
||||
**C,
|
||||
"prompt": resp_obj_B['prompt'],
|
||||
"query": resp_obj_B['query'],
|
||||
"llm": resp_obj_B['llm'],
|
||||
"info": resp_obj_B['info'],
|
||||
}
|
||||
|
||||
def create_dir_if_not_exists(path: str) -> None:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user