From 8de28cdac0a9bd92ddc22ab71f2cb0b5ea924cf4 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Wed, 10 May 2023 14:52:34 -0400 Subject: [PATCH] Only send the queries required when num_generations increases, per prompt --- chain-forge/src/PromptNode.js | 4 +- python-backend/promptengine/query.py | 80 +++++++++++++++------------- python-backend/promptengine/utils.py | 23 ++++++++ 3 files changed, 69 insertions(+), 38 deletions(-) diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 0842169..0aa588c 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -362,7 +362,7 @@ const PromptNode = ({ data, id }) => { const total_num_resps = Object.keys(counts).reduce((acc, llm_name) => { return acc + counts[llm_name]; }, 0); - setProgress(total_num_resps / max_responses * 100); + setProgress(Math.max(5, total_num_resps / max_responses * 100)); }); // The process has finished; close the connection: @@ -386,7 +386,7 @@ const PromptNode = ({ data, id }) => { temperature: 0.5, n: numGenerations, }, - no_cache: true, + no_cache: false, }), }, rejected).then(function(response) { return response.json(); diff --git a/python-backend/promptengine/query.py b/python-backend/promptengine/query.py index 409980f..56f336e 100644 --- a/python-backend/promptengine/query.py +++ b/python-backend/promptengine/query.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import List, Dict, Tuple, Iterator, Union import json, os, asyncio, random, string -from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json, extract_responses +from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json, extract_responses, merge_response_objs from promptengine.template import PromptTemplate, PromptPermutationGenerator # LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits @@ -59,18 +59,18 @@ class PromptPipeline: prompt_str = str(prompt) cached_resp = responses[prompt_str] if prompt_str in responses else None - extracted_resps = extract_responses(cached_resp, llm) if cached_resp is not None else [] + extracted_resps = cached_resp["responses"] if cached_resp is not None else [] # First check if there is already a response for this item. If so, we can save an LLM call: - if cached_resp and len(extract_responses) >= n: + if cached_resp and len(extracted_resps) >= n: print(f" - Found cache'd response for prompt {prompt_str}. Using...") yield { "prompt": prompt_str, - "query": responses[prompt_str]["query"], + "query": cached_resp["query"], "responses": extracted_resps[:n], "raw_response": cached_resp["raw_response"], - "llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.value, - "info": responses[prompt_str]["info"], + "llm": cached_resp["llm"] if "llm" in cached_resp else LLM.ChatGPT.value, + "info": cached_resp["info"], } continue @@ -80,24 +80,14 @@ class PromptPipeline: await asyncio.sleep(wait_secs) # Call the LLM asynchronously to generate a response - tasks.append(self._prompt_llm(llm, prompt, n, temperature)) + tasks.append(self._prompt_llm(llm, prompt, n, temperature, past_resp_obj=cached_resp)) else: # Blocking. Await + yield a single LLM call. - _, query, response = await self._prompt_llm(llm, prompt, n, temperature) + _, query, response, past_resp_obj = await self._prompt_llm(llm, prompt, n, temperature, past_resp_obj=cached_resp) info = prompt.fill_history - # Save the response to a JSON file - responses[str(prompt)] = { - "query": query, - "responses": extract_responses(response, llm), - "raw_response": response, - "llm": llm.value, - "info": info, - } - self._cache_responses(responses) - - # Yield the response - yield { + # Create a response obj to represent the response + resp_obj = { "prompt": str(prompt), "query": query, "responses": extract_responses(response, llm), @@ -105,29 +95,29 @@ class PromptPipeline: "llm": llm.value, "info": info, } + + # Merge the response obj with the past one, if necessary + if past_resp_obj is not None: + resp_obj = merge_response_objs(resp_obj, past_resp_obj) + + # Save the response to a JSON file + responses[resp_obj["prompt"]] = {key: val for key, val in resp_obj.items()} + self._cache_responses(responses) + + # Yield the response + yield resp_obj # Yield responses as they come in for task in asyncio.as_completed(tasks): # Collect the response from the earliest completed task - prompt, query, response = await task + prompt, query, response, past_resp_obj = await task # Each prompt has a history of what was filled in from its base template. # This data --like, "class", "language", "library" etc --can be useful when parsing responses. info = prompt.fill_history - - # Save the response to a JSON file - # NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data! - responses[str(prompt)] = { - "query": query, - "responses": extract_responses(response, llm), - "raw_response": response, - "llm": llm.value, - "info": info, - } - self._cache_responses(responses) - # Yield the response - yield { + # Create a response obj to represent the response + resp_obj = { "prompt": str(prompt), "query": query, "responses": extract_responses(response, llm), @@ -135,6 +125,18 @@ class PromptPipeline: "llm": llm.value, "info": info, } + + # Merge the response obj with the past one, if necessary + if past_resp_obj is not None: + resp_obj = merge_response_objs(resp_obj, past_resp_obj) + + # Save the response to a JSON file + # NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data! + responses[resp_obj["prompt"]] = {key: val for key, val in resp_obj.items()} + self._cache_responses(responses) + + # Yield the response + yield resp_obj def _load_cached_responses(self) -> Dict: """ @@ -155,7 +157,13 @@ class PromptPipeline: def clear_cached_responses(self) -> None: self._cache_responses({}) - async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Union[List, Dict]]: + async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0, past_resp_obj: Union[Dict, None] = None) -> Tuple[str, Dict, Union[List, Dict], Union[Dict, None]]: + # Detect how many responses we have already (from cache obj past_resp_obj) + if past_resp_obj is not None: + # How many *new* queries we need to send: + # NOTE: The check n > len(past_resp_obj["responses"]) should occur prior to calling this function. + n = n - len(past_resp_obj["responses"]) + if llm is LLM.ChatGPT or llm is LLM.GPT4: query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature) elif llm is LLM.Alpaca7B: @@ -164,7 +172,7 @@ class PromptPipeline: query, response = await call_anthropic(prompt=str(prompt), model=llm, n=n, temperature=temperature) else: raise Exception(f"Language model {llm} is not supported.") - return prompt, query, response + return prompt, query, response, past_resp_obj """ diff --git a/python-backend/promptengine/utils.py b/python-backend/promptengine/utils.py index 6b24b9b..52506fd 100644 --- a/python-backend/promptengine/utils.py +++ b/python-backend/promptengine/utils.py @@ -191,6 +191,29 @@ def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List else: raise ValueError(f"LLM {llm} is unsupported.") +def merge_response_objs(resp_obj_A: Union[dict, None], resp_obj_B: Union[dict, None]) -> dict: + if resp_obj_B is None: + return resp_obj_A + elif resp_obj_A is None: + return resp_obj_B + raw_resp_A = resp_obj_A["raw_response"] + raw_resp_B = resp_obj_B["raw_response"] + if not isinstance(raw_resp_A, list): + raw_resp_A = [ raw_resp_A ] + if not isinstance(raw_resp_B, list): + raw_resp_B = [ raw_resp_B ] + C = { + "responses": resp_obj_A["responses"] + resp_obj_B["responses"], + "raw_response": raw_resp_A + raw_resp_B, + } + return { + **C, + "prompt": resp_obj_B['prompt'], + "query": resp_obj_B['query'], + "llm": resp_obj_B['llm'], + "info": resp_obj_B['info'], + } + def create_dir_if_not_exists(path: str) -> None: if not os.path.exists(path): os.makedirs(path)