mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
WIP check num responses in cache for given prompt
This commit is contained in:
parent
47be5ec96f
commit
846e1b9548
@ -232,11 +232,11 @@ const PromptNode = ({ data, id }) => {
|
||||
const all_same_num_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc && queries_per_llm[llm] === some_llm_num, true)
|
||||
if (num_llms_missing === num_llms && all_same_num_queries) { // Counts are the same
|
||||
const req = some_llm_num > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
|
||||
setRunTooltip(`Will send ${some_llm_num} new ${req}` + (num_llms > 1 ? ' per LLM' : ''));
|
||||
}
|
||||
else if (all_same_num_queries) {
|
||||
const req = some_llm_num > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : ''));
|
||||
setRunTooltip(`Will send ${some_llm_num} new ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : ''));
|
||||
}
|
||||
else { // Counts are different
|
||||
const sum_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc + queries_per_llm[llm], 0);
|
||||
|
@ -321,18 +321,21 @@ async def queryLLM():
|
||||
# Prompt the LLM with all permutations of the input prompt template:
|
||||
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
|
||||
resps = []
|
||||
num_resps = 0
|
||||
try:
|
||||
print(f'Querying {llm}...')
|
||||
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
resps.append(response)
|
||||
print(f"collected response from {llm.name}:", str(response))
|
||||
|
||||
num_resps += len(extract_responses(response, llm))
|
||||
|
||||
# Save the number of responses collected to a temp file on disk
|
||||
with open(tempfilepath, 'r') as f:
|
||||
txt = f.read().strip()
|
||||
|
||||
cur_data = json.loads(txt) if len(txt) > 0 else {}
|
||||
cur_data[llm_str] = len(resps)
|
||||
cur_data[llm_str] = num_resps
|
||||
|
||||
with open(tempfilepath, 'w') as f:
|
||||
json.dump(cur_data, f)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Tuple, Iterator, Union
|
||||
import json, os, asyncio, random, string
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json, cull_responses, extract_responses
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
|
||||
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
|
||||
@ -59,8 +59,9 @@ class PromptPipeline:
|
||||
prompt_str = str(prompt)
|
||||
|
||||
# First check if there is already a response for this item. If so, we can save an LLM call:
|
||||
if prompt_str in responses:
|
||||
if prompt_str in responses and len(extract_responses(responses[prompt_str], llm)) >= n:
|
||||
print(f" - Found cache'd response for prompt {prompt_str}. Using...")
|
||||
responses[prompt_str] = cull_responses(responses[prompt_str], llm, n)
|
||||
yield {
|
||||
"prompt": prompt_str,
|
||||
"query": responses[prompt_str]["query"],
|
||||
|
@ -174,7 +174,7 @@ def _extract_chatgpt_responses(response: dict) -> List[dict]:
|
||||
choices = response["response"]["choices"]
|
||||
return [
|
||||
c["message"]["content"]
|
||||
for i, c in enumerate(choices)
|
||||
for c in choices
|
||||
]
|
||||
|
||||
def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[dict]:
|
||||
@ -191,6 +191,22 @@ def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List
|
||||
else:
|
||||
raise ValueError(f"LLM {llm} is unsupported.")
|
||||
|
||||
def cull_responses(response: Union[list, dict], llm: Union[LLM, str], n: int) -> Union[list, dict]:
|
||||
"""
|
||||
Returns the same 'response' but with only 'n' responses.
|
||||
"""
|
||||
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value:
|
||||
response["response"]["choices"] = response["response"]["choices"][:n]
|
||||
return response
|
||||
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value:
|
||||
response["response"] = response["response"][:n]
|
||||
return response
|
||||
elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'):
|
||||
response["response"] = response["response"][:n]
|
||||
return response
|
||||
else:
|
||||
raise ValueError(f"LLM {llm} is unsupported.")
|
||||
|
||||
def create_dir_if_not_exists(path: str) -> None:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user