2023-04-19 15:08:31 -04:00
|
|
|
from abc import abstractmethod
|
2023-05-24 10:10:50 -04:00
|
|
|
from typing import List, Dict, Tuple, Iterator, Union, Optional
|
2023-05-04 11:34:12 -04:00
|
|
|
import json, os, asyncio, random, string
|
2023-06-13 18:01:24 -04:00
|
|
|
from chainforge.promptengine.utils import call_chatgpt, call_dalai, call_anthropic, call_google_palm, call_azure_openai, is_valid_filepath, is_valid_json, extract_responses, merge_response_objs
|
2023-06-07 20:36:55 -04:00
|
|
|
from chainforge.promptengine.models import LLM, RATE_LIMITS
|
2023-05-18 00:17:35 -04:00
|
|
|
from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator
|
2023-04-19 15:08:31 -04:00
|
|
|
|
2023-06-07 20:36:55 -04:00
|
|
|
class LLMResponseException(Exception):
|
|
|
|
""" Raised when there is an error generating a single response from an LLM """
|
|
|
|
pass
|
|
|
|
|
2023-04-19 15:08:31 -04:00
|
|
|
"""
|
|
|
|
Abstract class that captures a generic querying interface to prompt LLMs
|
|
|
|
"""
|
|
|
|
class PromptPipeline:
|
|
|
|
def __init__(self, storageFile: str):
|
|
|
|
if not is_valid_filepath(storageFile):
|
|
|
|
raise IOError(f"Filepath {storageFile} is invalid, or you do not have write access.")
|
|
|
|
|
|
|
|
self._filepath = storageFile
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def gen_prompts(self, properties) -> Iterator[PromptTemplate]:
|
|
|
|
raise NotImplementedError("Please Implement the gen_prompts method")
|
|
|
|
|
2023-06-07 20:36:55 -04:00
|
|
|
async def gen_responses(self, properties, llm: LLM, n: int = 1, temperature: float = 1.0, **llm_params) -> Iterator[Union[Dict, LLMResponseException]]:
|
2023-04-19 15:08:31 -04:00
|
|
|
"""
|
|
|
|
Calls LLM 'llm' with all prompts, and yields responses as dicts in format {prompt, query, response, llm, info}.
|
|
|
|
|
2023-05-03 15:12:52 -04:00
|
|
|
Queries are sent off asynchronously (if possible).
|
2023-06-07 20:36:55 -04:00
|
|
|
Yields responses as they come in. All LLM calls that yield errors (e.g., 'rate limit' error)
|
|
|
|
will yield an individual LLMResponseException, so downstream tasks must check for this exception type.
|
2023-05-03 15:12:52 -04:00
|
|
|
|
2023-06-07 20:36:55 -04:00
|
|
|
By default, for each response successfully collected, this also saves reponses to disk as JSON at the filepath given during init.
|
2023-04-19 15:08:31 -04:00
|
|
|
(Very useful for saving money in case something goes awry!)
|
|
|
|
To clear the cached responses, call clear_cached_responses().
|
|
|
|
|
2023-06-07 20:36:55 -04:00
|
|
|
NOTE: The reason we collect, rather than raise, LLMResponseExceptions is because some API calls
|
|
|
|
may still succeed, even if some fail. We don't want to stop listening to pending API calls,
|
|
|
|
because we may lose money. Instead, we fail selectively.
|
|
|
|
|
2023-04-19 15:08:31 -04:00
|
|
|
Do not override this function.
|
|
|
|
"""
|
|
|
|
# Double-check that properties is the correct type (JSON dict):
|
|
|
|
if not is_valid_json(properties):
|
2023-06-07 20:36:55 -04:00
|
|
|
raise ValueError("Properties argument is not valid JSON.")
|
2023-04-19 15:08:31 -04:00
|
|
|
|
|
|
|
# Load any cache'd responses
|
|
|
|
responses = self._load_cached_responses()
|
|
|
|
|
|
|
|
# Query LLM with each prompt, yield + cache the responses
|
2023-05-03 15:12:52 -04:00
|
|
|
tasks = []
|
2023-06-01 15:08:17 -04:00
|
|
|
max_req, wait_secs = RATE_LIMITS[llm] if llm in RATE_LIMITS else (1, 0)
|
2023-06-17 09:33:57 -04:00
|
|
|
num_queries_sent = -1
|
2023-05-19 15:32:01 -04:00
|
|
|
|
2023-06-17 09:33:57 -04:00
|
|
|
for prompt in self.gen_prompts(properties):
|
2023-04-19 15:08:31 -04:00
|
|
|
if isinstance(prompt, PromptTemplate) and not prompt.is_concrete():
|
|
|
|
raise Exception(f"Cannot send a prompt '{prompt}' to LLM: Prompt is a template.")
|
2023-05-04 11:34:12 -04:00
|
|
|
|
2023-04-19 15:08:31 -04:00
|
|
|
prompt_str = str(prompt)
|
2023-05-11 14:54:36 -04:00
|
|
|
info = prompt.fill_history
|
2023-06-11 11:48:27 -04:00
|
|
|
metavars = prompt.metavars
|
2023-05-10 13:55:35 -04:00
|
|
|
|
|
|
|
cached_resp = responses[prompt_str] if prompt_str in responses else None
|
2023-05-10 14:52:34 -04:00
|
|
|
extracted_resps = cached_resp["responses"] if cached_resp is not None else []
|
2023-04-19 15:08:31 -04:00
|
|
|
|
2023-06-01 15:08:17 -04:00
|
|
|
# First check if there is already a response for this item under these settings. If so, we can save an LLM call:
|
2023-05-10 14:52:34 -04:00
|
|
|
if cached_resp and len(extracted_resps) >= n:
|
2023-06-17 09:33:57 -04:00
|
|
|
print(f" - Found cache'd response for prompt {prompt_str}. Using...")
|
2023-04-19 15:08:31 -04:00
|
|
|
yield {
|
|
|
|
"prompt": prompt_str,
|
2023-05-10 14:52:34 -04:00
|
|
|
"query": cached_resp["query"],
|
2023-05-10 13:55:35 -04:00
|
|
|
"responses": extracted_resps[:n],
|
|
|
|
"raw_response": cached_resp["raw_response"],
|
2023-06-01 15:08:17 -04:00
|
|
|
"llm": cached_resp["llm"] if "llm" in cached_resp else LLM.OpenAI_ChatGPT.value,
|
2023-05-11 10:24:46 -04:00
|
|
|
# We want to use the new info, since 'vars' could have changed even though
|
|
|
|
# the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
|
2023-05-11 14:54:36 -04:00
|
|
|
"info": info,
|
2023-06-11 11:48:27 -04:00
|
|
|
"metavars": metavars
|
2023-04-19 15:08:31 -04:00
|
|
|
}
|
|
|
|
continue
|
|
|
|
|
2023-06-17 09:33:57 -04:00
|
|
|
num_queries_sent += 1
|
|
|
|
|
2023-05-24 10:10:50 -04:00
|
|
|
if max_req > 1:
|
|
|
|
# Call the LLM asynchronously to generate a response, sending off
|
|
|
|
# requests in batches of size 'max_req' separated by seconds 'wait_secs' to avoid hitting rate limit
|
2023-06-07 20:36:55 -04:00
|
|
|
tasks.append(self._prompt_llm(llm=llm,
|
|
|
|
prompt=prompt,
|
|
|
|
n=n,
|
|
|
|
temperature=temperature,
|
|
|
|
past_resp_obj=cached_resp,
|
2023-06-17 09:33:57 -04:00
|
|
|
query_number=num_queries_sent,
|
2023-06-07 20:36:55 -04:00
|
|
|
rate_limit_batch_size=max_req,
|
|
|
|
rate_limit_wait_secs=wait_secs,
|
|
|
|
**llm_params))
|
2023-05-03 15:12:52 -04:00
|
|
|
else:
|
2023-05-24 10:10:50 -04:00
|
|
|
# Block. Await + yield a single LLM call.
|
2023-06-07 20:36:55 -04:00
|
|
|
_, query, response, past_resp_obj = await self._prompt_llm(llm=llm,
|
|
|
|
prompt=prompt,
|
|
|
|
n=n,
|
|
|
|
temperature=temperature,
|
|
|
|
past_resp_obj=cached_resp,
|
|
|
|
**llm_params)
|
|
|
|
|
|
|
|
# Check for selective failure
|
|
|
|
if query is None and isinstance(response, LLMResponseException):
|
|
|
|
yield response # yield the LLMResponseException
|
|
|
|
continue
|
2023-05-04 11:34:12 -04:00
|
|
|
|
2023-05-10 14:52:34 -04:00
|
|
|
# Create a response obj to represent the response
|
|
|
|
resp_obj = {
|
|
|
|
"prompt": str(prompt),
|
|
|
|
"query": query,
|
2023-05-10 13:55:35 -04:00
|
|
|
"responses": extract_responses(response, llm),
|
|
|
|
"raw_response": response,
|
2023-05-07 12:13:25 -04:00
|
|
|
"llm": llm.value,
|
2023-05-04 11:34:12 -04:00
|
|
|
"info": info,
|
2023-06-11 11:48:27 -04:00
|
|
|
"metavars": metavars
|
2023-05-04 11:34:12 -04:00
|
|
|
}
|
2023-05-10 14:52:34 -04:00
|
|
|
|
|
|
|
# Merge the response obj with the past one, if necessary
|
|
|
|
if past_resp_obj is not None:
|
|
|
|
resp_obj = merge_response_objs(resp_obj, past_resp_obj)
|
|
|
|
|
2023-06-01 15:08:17 -04:00
|
|
|
# Save the current state of cache'd responses to a JSON file
|
|
|
|
responses[resp_obj["prompt"]] = resp_obj
|
2023-05-04 11:34:12 -04:00
|
|
|
self._cache_responses(responses)
|
|
|
|
|
2023-06-17 09:33:57 -04:00
|
|
|
print(f" - collected response from {llm.value} for prompt:", resp_obj['prompt'])
|
|
|
|
|
2023-05-04 11:34:12 -04:00
|
|
|
# Yield the response
|
2023-05-10 14:52:34 -04:00
|
|
|
yield resp_obj
|
2023-05-03 15:12:52 -04:00
|
|
|
|
|
|
|
# Yield responses as they come in
|
|
|
|
for task in asyncio.as_completed(tasks):
|
|
|
|
# Collect the response from the earliest completed task
|
2023-05-10 14:52:34 -04:00
|
|
|
prompt, query, response, past_resp_obj = await task
|
2023-05-04 11:34:12 -04:00
|
|
|
|
2023-06-07 20:36:55 -04:00
|
|
|
# Check for selective failure
|
|
|
|
if query is None and isinstance(response, LLMResponseException):
|
|
|
|
yield response # yield the LLMResponseException
|
|
|
|
continue
|
|
|
|
|
2023-05-04 11:34:12 -04:00
|
|
|
# Each prompt has a history of what was filled in from its base template.
|
|
|
|
# This data --like, "class", "language", "library" etc --can be useful when parsing responses.
|
|
|
|
info = prompt.fill_history
|
2023-06-11 11:48:27 -04:00
|
|
|
metavars = prompt.metavars
|
2023-05-10 14:52:34 -04:00
|
|
|
|
|
|
|
# Create a response obj to represent the response
|
|
|
|
resp_obj = {
|
|
|
|
"prompt": str(prompt),
|
2023-04-19 15:08:31 -04:00
|
|
|
"query": query,
|
2023-05-10 13:55:35 -04:00
|
|
|
"responses": extract_responses(response, llm),
|
|
|
|
"raw_response": response,
|
2023-05-07 12:13:25 -04:00
|
|
|
"llm": llm.value,
|
2023-04-19 15:08:31 -04:00
|
|
|
"info": info,
|
2023-06-11 11:48:27 -04:00
|
|
|
"metavars": metavars,
|
2023-04-19 15:08:31 -04:00
|
|
|
}
|
2023-05-10 14:52:34 -04:00
|
|
|
|
|
|
|
# Merge the response obj with the past one, if necessary
|
|
|
|
if past_resp_obj is not None:
|
|
|
|
resp_obj = merge_response_objs(resp_obj, past_resp_obj)
|
2023-06-01 15:08:17 -04:00
|
|
|
|
|
|
|
# Save the current state of cache'd responses to a JSON file
|
2023-05-10 14:52:34 -04:00
|
|
|
# NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
|
2023-06-01 15:08:17 -04:00
|
|
|
responses[resp_obj["prompt"]] = resp_obj
|
2023-04-19 15:08:31 -04:00
|
|
|
self._cache_responses(responses)
|
|
|
|
|
2023-06-17 09:33:57 -04:00
|
|
|
print(f" - collected response from {llm.value} for prompt:", resp_obj['prompt'])
|
|
|
|
|
2023-05-03 15:12:52 -04:00
|
|
|
# Yield the response
|
2023-05-10 14:52:34 -04:00
|
|
|
yield resp_obj
|
2023-04-19 15:08:31 -04:00
|
|
|
|
|
|
|
def _load_cached_responses(self) -> Dict:
|
|
|
|
"""
|
|
|
|
Loads saved responses of JSON at self._filepath.
|
|
|
|
Useful for continuing if computation was interrupted halfway through.
|
|
|
|
"""
|
|
|
|
if os.path.isfile(self._filepath):
|
|
|
|
with open(self._filepath, encoding="utf-8") as f:
|
|
|
|
responses = json.load(f)
|
|
|
|
return responses
|
|
|
|
else:
|
|
|
|
return {}
|
|
|
|
|
|
|
|
def _cache_responses(self, responses) -> None:
|
2023-05-23 22:35:39 -04:00
|
|
|
with open(self._filepath, "w", encoding='utf-8') as f:
|
2023-04-19 15:08:31 -04:00
|
|
|
json.dump(responses, f)
|
|
|
|
|
|
|
|
def clear_cached_responses(self) -> None:
|
|
|
|
self._cache_responses({})
|
|
|
|
|
2023-05-24 10:10:50 -04:00
|
|
|
async def _prompt_llm(self,
|
|
|
|
llm: LLM,
|
2023-06-07 20:36:55 -04:00
|
|
|
prompt: PromptTemplate,
|
|
|
|
n: int = 1,
|
2023-05-24 10:10:50 -04:00
|
|
|
temperature: float = 1.0,
|
|
|
|
past_resp_obj: Optional[Dict] = None,
|
|
|
|
query_number: Optional[int] = None,
|
|
|
|
rate_limit_batch_size: Optional[int] = None,
|
2023-06-01 15:08:17 -04:00
|
|
|
rate_limit_wait_secs: Optional[float] = None,
|
|
|
|
**llm_params) -> Tuple[str, Dict, Union[List, Dict], Union[Dict, None]]:
|
2023-05-10 14:52:34 -04:00
|
|
|
# Detect how many responses we have already (from cache obj past_resp_obj)
|
|
|
|
if past_resp_obj is not None:
|
|
|
|
# How many *new* queries we need to send:
|
|
|
|
# NOTE: The check n > len(past_resp_obj["responses"]) should occur prior to calling this function.
|
|
|
|
n = n - len(past_resp_obj["responses"])
|
2023-05-24 10:10:50 -04:00
|
|
|
|
|
|
|
# Block asynchronously when we exceed rate limits
|
|
|
|
if query_number is not None and rate_limit_batch_size is not None and rate_limit_wait_secs is not None and rate_limit_batch_size >= 1 and rate_limit_wait_secs > 0:
|
|
|
|
batch_num = int(query_number / rate_limit_batch_size)
|
|
|
|
if batch_num > 0:
|
|
|
|
# We've exceeded the estimated batch rate limit and need to wait the appropriate seconds before sending off new API calls:
|
|
|
|
wait_secs = rate_limit_wait_secs * batch_num
|
|
|
|
if query_number % rate_limit_batch_size == 0: # Print when we start blocking, for each batch
|
|
|
|
print(f"Batch rate limit of {rate_limit_batch_size} reached for LLM {llm}. Waiting {wait_secs} seconds until sending request batch #{batch_num}...")
|
|
|
|
await asyncio.sleep(wait_secs)
|
2023-05-10 14:52:34 -04:00
|
|
|
|
2023-06-07 20:36:55 -04:00
|
|
|
# Get the correct API call for the given LLM:
|
|
|
|
call_llm = None
|
2023-06-01 15:08:17 -04:00
|
|
|
if llm.name[:6] == 'OpenAI':
|
2023-06-07 20:36:55 -04:00
|
|
|
call_llm = call_chatgpt
|
2023-06-13 18:01:24 -04:00
|
|
|
elif llm.name[:5] == 'Azure':
|
|
|
|
call_llm = call_azure_openai
|
2023-06-01 15:08:17 -04:00
|
|
|
elif llm.name[:5] == 'PaLM2':
|
2023-06-07 20:36:55 -04:00
|
|
|
call_llm = call_google_palm
|
2023-06-01 15:08:17 -04:00
|
|
|
elif llm.name[:5] == 'Dalai':
|
2023-06-07 20:36:55 -04:00
|
|
|
call_llm = call_dalai
|
2023-05-07 12:13:25 -04:00
|
|
|
elif llm.value[:6] == 'claude':
|
2023-06-07 20:36:55 -04:00
|
|
|
call_llm = call_anthropic
|
2023-04-19 15:08:31 -04:00
|
|
|
else:
|
|
|
|
raise Exception(f"Language model {llm} is not supported.")
|
2023-05-19 15:32:01 -04:00
|
|
|
|
2023-06-07 20:36:55 -04:00
|
|
|
# Now try to call the API. If it fails for whatever reason, 'soft fail' by returning
|
|
|
|
# an LLMResponseException object as the 'response'.
|
|
|
|
try:
|
|
|
|
query, response = await call_llm(prompt=str(prompt), model=llm, n=n, temperature=temperature, **llm_params)
|
|
|
|
except Exception as e:
|
|
|
|
return prompt, None, LLMResponseException(str(e)), None
|
|
|
|
|
2023-05-10 14:52:34 -04:00
|
|
|
return prompt, query, response, past_resp_obj
|
2023-04-19 15:08:31 -04:00
|
|
|
|
|
|
|
"""
|
|
|
|
Most basic prompt pipeline: given a prompt (and any variables, if it's a template),
|
|
|
|
query the LLM with all prompt permutations, and cache responses.
|
|
|
|
"""
|
|
|
|
class PromptLLM(PromptPipeline):
|
|
|
|
def __init__(self, template: str, storageFile: str):
|
|
|
|
self._template = PromptTemplate(template)
|
|
|
|
super().__init__(storageFile)
|
|
|
|
def gen_prompts(self, properties: dict) -> Iterator[PromptTemplate]:
|
|
|
|
gen_prompts = PromptPermutationGenerator(self._template)
|
2023-05-04 11:34:12 -04:00
|
|
|
return gen_prompts(properties)
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
A dummy class that spoofs LLM responses. Used for testing.
|
|
|
|
"""
|
|
|
|
class PromptLLMDummy(PromptLLM):
|
2023-05-24 10:10:50 -04:00
|
|
|
def __init__(self, template: str, storageFile: str):
|
|
|
|
# Hijack the 'extract_responses' method so that for whichever 'llm' parameter,
|
|
|
|
# it will just return the response verbatim (since dummy responses will always be strings)
|
|
|
|
global extract_responses
|
|
|
|
extract_responses = lambda response, llm: response
|
|
|
|
super().__init__(template, storageFile)
|
|
|
|
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0, past_resp_obj: Optional[Dict] = None, **params) -> Tuple[Dict, Dict]:
|
2023-05-09 12:28:37 -04:00
|
|
|
# Wait a random amount of time, to simulate wait times from real queries
|
2023-05-04 11:34:12 -04:00
|
|
|
await asyncio.sleep(random.uniform(0.1, 3))
|
2023-06-07 20:36:55 -04:00
|
|
|
|
|
|
|
if random.random() > 0.2:
|
|
|
|
# Return a random string of characters of random length (within a predefined range)
|
|
|
|
return prompt, {'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(random.randint(25, 80))) for _ in range(n)], past_resp_obj
|
|
|
|
else:
|
|
|
|
# Return a mock 'error' making the API request
|
|
|
|
return prompt, None, LLMResponseException('Dummy error'), past_resp_obj
|