WIP fixing count queries

This commit is contained in:
Ian Arawjo 2023-05-09 12:52:37 -04:00
parent e9a9663af5
commit 47be5ec96f
2 changed files with 11 additions and 5 deletions

View File

@ -190,7 +190,7 @@ const PromptNode = ({ data, id }) => {
if (!json || !json.counts) {
throw new Error('Request was sent and received by backend server, but there was no response.');
}
return json.counts;
return [json.counts, json.total_num_responses];
}, rejected);
};
@ -210,7 +210,7 @@ const PromptNode = ({ data, id }) => {
// Fetch response counts from backend
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
console.warn(err.message); // soft fail
}).then((counts) => {
}).then(([counts, total_num_responses]) => {
// Check for empty counts (means no requests will be sent!)
const num_llms_missing = Object.keys(counts).length;
if (num_llms_missing === 0) {
@ -304,7 +304,7 @@ const PromptNode = ({ data, id }) => {
py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected);
// Open a socket to listen for progress
const open_progress_listener_socket = (response_counts) => {
const open_progress_listener_socket = ([response_counts, total_num_responses]) => {
// With the counts information we can create progress bars. Now we load a socket connection to
// the socketio server that will stream to us the current progress:
const socket = io('http://localhost:8001/' + 'queryllm', {
@ -312,7 +312,7 @@ const PromptNode = ({ data, id }) => {
cors: {origin: "http://localhost:8000/"},
});
const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0);
const max_responses = Object.keys(total_num_responses).reduce((acc, llm) => acc + total_num_responses[llm], 0);
// On connect to the server, ask it to give us the current progress
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.

View File

@ -199,16 +199,22 @@ def countQueries():
cached_resps = []
missing_queries = {}
num_responses_req = {}
def add_to_missing_queries(llm, prompt, num):
if llm not in missing_queries:
missing_queries[llm] = {}
missing_queries[llm][prompt] = num
def add_to_num_responses_req(llm, num):
if llm not in num_responses_req:
num_responses_req[llm] = 0
num_responses_req[llm] += num
# Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt
for prompt in all_prompt_permutations:
prompt = str(prompt)
matching_resps = [r for r in cached_resps if r['prompt'] == prompt]
for llm in data['llms']:
add_to_num_responses_req(llm, n)
match_per_llm = [r for r in matching_resps if r['llm'] == llm]
if len(match_per_llm) == 0:
add_to_missing_queries(llm, prompt, n)
@ -220,7 +226,7 @@ def countQueries():
else:
raise Exception(f"More than one response found for the same prompt ({prompt}) and LLM ({llm})")
ret = jsonify({'counts': missing_queries})
ret = jsonify({'counts': missing_queries, 'total_num_responses': num_responses_req})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret