From 846e1b95484af7091bf12432c904245eac9eb73a Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Tue, 9 May 2023 13:36:53 -0400 Subject: [PATCH] WIP check num responses in cache for given prompt --- chain-forge/src/PromptNode.js | 4 ++-- python-backend/flask_app.py | 5 ++++- python-backend/promptengine/query.py | 5 +++-- python-backend/promptengine/utils.py | 18 +++++++++++++++++- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index fe30e41..e4774f0 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -232,11 +232,11 @@ const PromptNode = ({ data, id }) => { const all_same_num_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc && queries_per_llm[llm] === some_llm_num, true) if (num_llms_missing === num_llms && all_same_num_queries) { // Counts are the same const req = some_llm_num > 1 ? 'requests' : 'request'; - setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ' per LLM' : '')); + setRunTooltip(`Will send ${some_llm_num} new ${req}` + (num_llms > 1 ? ' per LLM' : '')); } else if (all_same_num_queries) { const req = some_llm_num > 1 ? 'requests' : 'request'; - setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : '')); + setRunTooltip(`Will send ${some_llm_num} new ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : '')); } else { // Counts are different const sum_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc + queries_per_llm[llm], 0); diff --git a/python-backend/flask_app.py b/python-backend/flask_app.py index b929457..5c21bcc 100644 --- a/python-backend/flask_app.py +++ b/python-backend/flask_app.py @@ -321,18 +321,21 @@ async def queryLLM(): # Prompt the LLM with all permutations of the input prompt template: # NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$) resps = [] + num_resps = 0 try: print(f'Querying {llm}...') async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params): resps.append(response) print(f"collected response from {llm.name}:", str(response)) + num_resps += len(extract_responses(response, llm)) + # Save the number of responses collected to a temp file on disk with open(tempfilepath, 'r') as f: txt = f.read().strip() cur_data = json.loads(txt) if len(txt) > 0 else {} - cur_data[llm_str] = len(resps) + cur_data[llm_str] = num_resps with open(tempfilepath, 'w') as f: json.dump(cur_data, f) diff --git a/python-backend/promptengine/query.py b/python-backend/promptengine/query.py index 9f022ab..1833fb5 100644 --- a/python-backend/promptengine/query.py +++ b/python-backend/promptengine/query.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import List, Dict, Tuple, Iterator, Union import json, os, asyncio, random, string -from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json +from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json, cull_responses, extract_responses from promptengine.template import PromptTemplate, PromptPermutationGenerator # LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits @@ -59,8 +59,9 @@ class PromptPipeline: prompt_str = str(prompt) # First check if there is already a response for this item. If so, we can save an LLM call: - if prompt_str in responses: + if prompt_str in responses and len(extract_responses(responses[prompt_str], llm)) >= n: print(f" - Found cache'd response for prompt {prompt_str}. Using...") + responses[prompt_str] = cull_responses(responses[prompt_str], llm, n) yield { "prompt": prompt_str, "query": responses[prompt_str]["query"], diff --git a/python-backend/promptengine/utils.py b/python-backend/promptengine/utils.py index 31d2c45..6054225 100644 --- a/python-backend/promptengine/utils.py +++ b/python-backend/promptengine/utils.py @@ -174,7 +174,7 @@ def _extract_chatgpt_responses(response: dict) -> List[dict]: choices = response["response"]["choices"] return [ c["message"]["content"] - for i, c in enumerate(choices) + for c in choices ] def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[dict]: @@ -191,6 +191,22 @@ def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List else: raise ValueError(f"LLM {llm} is unsupported.") +def cull_responses(response: Union[list, dict], llm: Union[LLM, str], n: int) -> Union[list, dict]: + """ + Returns the same 'response' but with only 'n' responses. + """ + if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value: + response["response"]["choices"] = response["response"]["choices"][:n] + return response + elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value: + response["response"] = response["response"][:n] + return response + elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'): + response["response"] = response["response"][:n] + return response + else: + raise ValueError(f"LLM {llm} is unsupported.") + def create_dir_if_not_exists(path: str) -> None: if not os.path.exists(path): os.makedirs(path)