mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Make async requests to OpenAI, with batch rate limits
This commit is contained in:
parent
302021dae0
commit
789645d23a
@ -144,7 +144,7 @@ def test():
|
||||
return "Hello, world!"
|
||||
|
||||
@app.route('/queryllm', methods=['POST'])
|
||||
def queryLLM():
|
||||
async def queryLLM():
|
||||
"""
|
||||
Queries LLM(s) given a JSON spec.
|
||||
|
||||
@ -201,7 +201,7 @@ def queryLLM():
|
||||
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
|
||||
responses[llm] = []
|
||||
try:
|
||||
for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
responses[llm].append(response)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
|
@ -1,10 +1,20 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Tuple, Iterator
|
||||
import json
|
||||
import os
|
||||
import json, os, asyncio
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, is_valid_filepath, is_valid_json
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
|
||||
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
|
||||
# For a basic organization in OpenAI, GPT3.5 is currently 3500 and GPT4 is 200 RPM (requests per minute).
|
||||
# For Anthropic evaluaton preview of Claude, can only send 1 request at a time (synchronously).
|
||||
# A 'cheap' version of controlling for rate limits is to wait a few seconds between batches of requests being sent off.
|
||||
# The following is only a guideline, and a bit on the conservative side.
|
||||
MAX_SIMULTANEOUS_REQUESTS = {
|
||||
LLM.ChatGPT: (50, 10), # max 50 requests a batch; wait 10 seconds between
|
||||
LLM.GPT4: (20, 10), # max 10 requests a batch; wait 10 seconds between
|
||||
LLM.Alpaca7B: (1, 0), # 1 indicates synchronous
|
||||
}
|
||||
|
||||
"""
|
||||
Abstract class that captures a generic querying interface to prompt LLMs
|
||||
"""
|
||||
@ -19,17 +29,13 @@ class PromptPipeline:
|
||||
def gen_prompts(self, properties) -> Iterator[PromptTemplate]:
|
||||
raise NotImplementedError("Please Implement the gen_prompts method")
|
||||
|
||||
@abstractmethod
|
||||
def analyze_response(self, response) -> bool:
|
||||
"""
|
||||
Analyze the response and return True if the response is valid.
|
||||
"""
|
||||
raise NotImplementedError("Please Implement the analyze_response method")
|
||||
|
||||
def gen_responses(self, properties, llm: LLM, n: int = 1, temperature: float = 1.0) -> Iterator[Dict]:
|
||||
async def gen_responses(self, properties, llm: LLM, n: int = 1, temperature: float = 1.0) -> Iterator[Dict]:
|
||||
"""
|
||||
Calls LLM 'llm' with all prompts, and yields responses as dicts in format {prompt, query, response, llm, info}.
|
||||
|
||||
Queries are sent off asynchronously (if possible).
|
||||
Yields responses as they come in.
|
||||
|
||||
By default, for each response, this also saves reponses to disk as JSON at the filepath given during init.
|
||||
(Very useful for saving money in case something goes awry!)
|
||||
To clear the cached responses, call clear_cached_responses().
|
||||
@ -44,7 +50,9 @@ class PromptPipeline:
|
||||
responses = self._load_cached_responses()
|
||||
|
||||
# Query LLM with each prompt, yield + cache the responses
|
||||
for prompt in self.gen_prompts(properties):
|
||||
tasks = []
|
||||
max_req, wait_secs = MAX_SIMULTANEOUS_REQUESTS[llm] if llm in MAX_SIMULTANEOUS_REQUESTS else (1, 0)
|
||||
for num_queries, prompt in enumerate(self.gen_prompts(properties)):
|
||||
if isinstance(prompt, PromptTemplate) and not prompt.is_concrete():
|
||||
raise Exception(f"Cannot send a prompt '{prompt}' to LLM: Prompt is a template.")
|
||||
|
||||
@ -65,9 +73,22 @@ class PromptPipeline:
|
||||
}
|
||||
continue
|
||||
|
||||
# Call the LLM to generate a response
|
||||
query, response = self._prompt_llm(llm, prompt_str, n, temperature)
|
||||
|
||||
if max_req > 1:
|
||||
if (num_queries+1) % max_req == 0:
|
||||
print(f"Batch rate limit of {max_req} reached. Waiting {wait_secs} seconds until sending further requests...")
|
||||
await asyncio.sleep(wait_secs)
|
||||
|
||||
# Call the LLM asynchronously to generate a response
|
||||
tasks.append(self._prompt_llm(llm, prompt_str, n, temperature))
|
||||
else:
|
||||
# Blocking. Await + yield a single LLM call.
|
||||
yield await self._prompt_llm(llm, prompt_str, n, temperature)
|
||||
|
||||
# Yield responses as they come in
|
||||
for task in asyncio.as_completed(tasks):
|
||||
# Collect the response from the earliest completed task
|
||||
query, response = await task
|
||||
|
||||
# Save the response to a JSON file
|
||||
# NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
|
||||
responses[prompt_str] = {
|
||||
@ -78,6 +99,7 @@ class PromptPipeline:
|
||||
}
|
||||
self._cache_responses(responses)
|
||||
|
||||
# Yield the response
|
||||
yield {
|
||||
"prompt":prompt_str,
|
||||
"query":query,
|
||||
@ -105,11 +127,11 @@ class PromptPipeline:
|
||||
def clear_cached_responses(self) -> None:
|
||||
self._cache_responses({})
|
||||
|
||||
def _prompt_llm(self, llm: LLM, prompt: str, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]:
|
||||
async def _prompt_llm(self, llm: LLM, prompt: str, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]:
|
||||
if llm is LLM.ChatGPT or llm is LLM.GPT4:
|
||||
return call_chatgpt(prompt, model=llm, n=n, temperature=temperature)
|
||||
return await call_chatgpt(prompt, model=llm, n=n, temperature=temperature)
|
||||
elif llm is LLM.Alpaca7B:
|
||||
return call_dalai(llm_name='alpaca.7B', port=4000, prompt=prompt, n=n, temperature=temperature)
|
||||
return await call_dalai(llm_name='alpaca.7B', port=4000, prompt=prompt, n=n, temperature=temperature)
|
||||
else:
|
||||
raise Exception(f"Language model {llm} is not supported.")
|
||||
|
||||
|
@ -14,7 +14,7 @@ class LLM(Enum):
|
||||
Alpaca7B = 1
|
||||
GPT4 = 2
|
||||
|
||||
def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
|
||||
async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls GPT3.5 via OpenAI's API.
|
||||
Returns raw query and response JSON dicts.
|
||||
@ -36,7 +36,7 @@ def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0,
|
||||
response = openai.ChatCompletion.create(**query)
|
||||
return query, response
|
||||
|
||||
def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
|
||||
Returns the raw query and response JSON dicts.
|
||||
|
@ -1,5 +1,5 @@
|
||||
dalaipy==2.0.2
|
||||
flask
|
||||
flask[async]
|
||||
flask_cors
|
||||
openai
|
||||
python-socketio
|
Loading…
x
Reference in New Issue
Block a user