From 47be5ec96f407933e0641891eef84544340f51c9 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Tue, 9 May 2023 12:52:37 -0400 Subject: [PATCH] WIP fixing count queries --- chain-forge/src/PromptNode.js | 8 ++++---- python-backend/flask_app.py | 8 +++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 15d5e40..fe30e41 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -190,7 +190,7 @@ const PromptNode = ({ data, id }) => { if (!json || !json.counts) { throw new Error('Request was sent and received by backend server, but there was no response.'); } - return json.counts; + return [json.counts, json.total_num_responses]; }, rejected); }; @@ -210,7 +210,7 @@ const PromptNode = ({ data, id }) => { // Fetch response counts from backend fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => { console.warn(err.message); // soft fail - }).then((counts) => { + }).then(([counts, total_num_responses]) => { // Check for empty counts (means no requests will be sent!) const num_llms_missing = Object.keys(counts).length; if (num_llms_missing === 0) { @@ -304,7 +304,7 @@ const PromptNode = ({ data, id }) => { py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected); // Open a socket to listen for progress - const open_progress_listener_socket = (response_counts) => { + const open_progress_listener_socket = ([response_counts, total_num_responses]) => { // With the counts information we can create progress bars. Now we load a socket connection to // the socketio server that will stream to us the current progress: const socket = io('http://localhost:8001/' + 'queryllm', { @@ -312,7 +312,7 @@ const PromptNode = ({ data, id }) => { cors: {origin: "http://localhost:8000/"}, }); - const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0); + const max_responses = Object.keys(total_num_responses).reduce((acc, llm) => acc + total_num_responses[llm], 0); // On connect to the server, ask it to give us the current progress // for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'. diff --git a/python-backend/flask_app.py b/python-backend/flask_app.py index cf3ac87..b929457 100644 --- a/python-backend/flask_app.py +++ b/python-backend/flask_app.py @@ -199,16 +199,22 @@ def countQueries(): cached_resps = [] missing_queries = {} + num_responses_req = {} def add_to_missing_queries(llm, prompt, num): if llm not in missing_queries: missing_queries[llm] = {} missing_queries[llm][prompt] = num + def add_to_num_responses_req(llm, num): + if llm not in num_responses_req: + num_responses_req[llm] = 0 + num_responses_req[llm] += num # Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt for prompt in all_prompt_permutations: prompt = str(prompt) matching_resps = [r for r in cached_resps if r['prompt'] == prompt] for llm in data['llms']: + add_to_num_responses_req(llm, n) match_per_llm = [r for r in matching_resps if r['llm'] == llm] if len(match_per_llm) == 0: add_to_missing_queries(llm, prompt, n) @@ -220,7 +226,7 @@ def countQueries(): else: raise Exception(f"More than one response found for the same prompt ({prompt}) and LLM ({llm})") - ret = jsonify({'counts': missing_queries}) + ret = jsonify({'counts': missing_queries, 'total_num_responses': num_responses_req}) ret.headers.add('Access-Control-Allow-Origin', '*') return ret