From e9a9663af53112207ee5d78d75ce53c81b98748c Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Tue, 9 May 2023 12:28:37 -0400 Subject: [PATCH] WIP fixed response cache's --- chain-forge/src/EvaluatorNode.js | 4 +- chain-forge/src/PromptNode.js | 48 ++++++++++++-- chain-forge/src/VisNode.js | 73 ++++++++++++++------- python-backend/flask_app.py | 97 ++++++++++++++++++++++------ python-backend/promptengine/query.py | 4 +- 5 files changed, 178 insertions(+), 48 deletions(-) diff --git a/chain-forge/src/EvaluatorNode.js b/chain-forge/src/EvaluatorNode.js index e7614f5..564d244 100644 --- a/chain-forge/src/EvaluatorNode.js +++ b/chain-forge/src/EvaluatorNode.js @@ -71,7 +71,7 @@ const EvaluatorNode = ({ data, id }) => { // Get all the script nodes, and get all the folder paths const script_nodes = nodes.filter(n => n.type === 'script'); const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat(); - console.log(script_paths); + // Run evaluator in backend const codeTextOnRun = codeText + ''; fetch(BASE_URL + 'app/execute', { @@ -95,6 +95,8 @@ const EvaluatorNode = ({ data, id }) => { alertModal.current.trigger(json ? json.error : 'Unknown error encountered when requesting evaluations: empty response returned.'); return; } + + console.log(json.responses); // Ping any vis nodes attached to this node to refresh their contents: const output_nodes = outputEdgesForNode(id).map(e => e.target); diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 0c8880c..15d5e40 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -182,6 +182,8 @@ const PromptNode = ({ data, id }) => { prompt: prompt, vars: vars, llms: llms, + id: id, + n: numGenerations, })}, rejected).then(function(response) { return response.json(); }, rejected).then(function(json) { @@ -209,9 +211,47 @@ const PromptNode = ({ data, id }) => { fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => { console.warn(err.message); // soft fail }).then((counts) => { - const n = counts[Object.keys(counts)[0]]; - const req = n > 1 ? 'requests' : 'request'; - setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : '')); + // Check for empty counts (means no requests will be sent!) + const num_llms_missing = Object.keys(counts).length; + if (num_llms_missing === 0) { + setRunTooltip('Will load responses from cache'); + return; + } + + // Tally how many queries per LLM: + let queries_per_llm = {}; + Object.keys(counts).forEach(llm => { + queries_per_llm[llm] = Object.keys(counts[llm]).reduce( + (acc, prompt) => acc + counts[llm][prompt] + , 0); + }); + + // Check if all counts are the same: + if (num_llms_missing > 1) { + const some_llm_num = queries_per_llm[Object.keys(queries_per_llm)[0]]; + const all_same_num_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc && queries_per_llm[llm] === some_llm_num, true) + if (num_llms_missing === num_llms && all_same_num_queries) { // Counts are the same + const req = some_llm_num > 1 ? 'requests' : 'request'; + setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ' per LLM' : '')); + } + else if (all_same_num_queries) { + const req = some_llm_num > 1 ? 'requests' : 'request'; + setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : '')); + } + else { // Counts are different + const sum_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc + queries_per_llm[llm], 0); + setRunTooltip(`Will send a variable # of queries to LLM(s) (total=${sum_queries})`); + } + } else { + const llm_name = Object.keys(queries_per_llm)[0]; + const llm_count = queries_per_llm[llm_name]; + const req = llm_count > 1 ? 'queries' : 'query'; + if (num_llms > num_llms_missing) + setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name} and load other responses from cache`); + else + setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name}`) + } + }); }; @@ -340,7 +380,7 @@ const PromptNode = ({ data, id }) => { temperature: 0.5, n: numGenerations, }, - no_cache: true, + no_cache: false, }), }, rejected).then(function(response) { return response.json(); diff --git a/chain-forge/src/VisNode.js b/chain-forge/src/VisNode.js index c7347e6..17a75c3 100644 --- a/chain-forge/src/VisNode.js +++ b/chain-forge/src/VisNode.js @@ -66,48 +66,75 @@ const VisNode = ({ data, id }) => { else responses_by_llm[item.llm] = [item]; }); + const llm_names = Object.keys(responses_by_llm); // Create Plotly spec here const varnames = Object.keys(json.responses[0].vars); - let spec = {}; + const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9']; + let spec = []; let layout = { width: 420, height: 300, title: '', margin: { - l: 40, r: 20, b: 20, t: 20, pad: 2 + l: 65, r: 20, b: 20, t: 20, pad: 0 } } - if (varnames.length === 1) { + const plot_grouped_boxplot = (resp_to_x) => { + llm_names.forEach((llm, idx) => { + // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks. + const rs = responses_by_llm[llm]; + const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat(); + spec.push({ + type: 'box', + name: llm, + marker: {color: colors[idx % colors.length]}, + x: rs.map(r => r.eval_res.items).flat(), + y: rs.map(r => Array(r.eval_res.items.length).fill(resp_to_x(r))).flat(), + boxpoints: 'all', + text: hover_texts, + hovertemplate: '%{text} (%{x})', + orientation: 'h', + }); + }); + layout.boxmode = 'group'; + }; + + if (varnames.length === 0) { + // No variables means they used a single prompt (no template) to generate responses + // (Users are likely evaluating differences in responses between LLMs) + plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12)); + // llm_names.forEach((llm, idx) => { + // // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks. + // const rs = responses_by_llm[llm]; + // const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat(); + // spec.push({ + // type: 'scatter', + // name: llm, + // marker: {color: colors[idx % colors.length]}, + // y: rs.map(r => r.eval_res.items).flat(), + // x: rs.map(r => Array(r.eval_res.items.length).fill(truncStr(r.prompt.trim(), 12))).flat(), // use the prompt str as var name + // // boxpoints: 'all', + // mode: 'markers', + // text: hover_texts, + // hovertemplate: '%{text} (%{y})', + // }); + // }); + // layout.scattermode = 'group'; + } + else if (varnames.length === 1) { // 1 var; numeric eval - if (Object.keys(responses_by_llm).length === 1) { + if (llm_names.length === 1) { // Simple box plot, as there is only a single LLM in the response spec = json.responses.map(r => { // Use the var value to 'name' this group of points: const s = truncStr(r.vars[varnames[0]].trim(), 12); - return {type: 'box', y: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}'}; + return {type: 'box', x: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}', orientation: 'h'}; }); layout.hovermode = 'closest'; } else { // There are multiple LLMs in the response; do a grouped box plot by LLM. // Note that 'name' is now the LLM, and 'x' stores the value of the var: - spec = []; - const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9']; - Object.keys(responses_by_llm).forEach((llm, idx) => { - // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks. - const rs = responses_by_llm[llm]; - const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat(); - spec.push({ - type: 'box', - name: llm, - marker: {color: colors[idx % colors.length]}, - y: rs.map(r => r.eval_res.items).flat(), - x: rs.map(r => Array(r.eval_res.items.length).fill(r.vars[varnames[0]].trim())).flat(), - boxpoints: 'all', - text: hover_texts, - hovertemplate: '%{text}', - }); - }); - layout.boxmode = 'group'; + plot_grouped_boxplot((r) => r.vars[varnames[0]].trim()); } } else if (varnames.length === 2) { diff --git a/python-backend/flask_app.py b/python-backend/flask_app.py index 619ffb3..cf3ac87 100644 --- a/python-backend/flask_app.py +++ b/python-backend/flask_app.py @@ -142,6 +142,30 @@ def reduce_responses(responses: list, vars: list) -> list: return ret +def load_all_cached_responses(cache_ids): + if not isinstance(cache_ids, list): + cache_ids = [cache_ids] + + # Load all responses with the given ID: + all_cache_files = get_files_at_dir('cache/') + responses = [] + for cache_id in cache_ids: + cache_files = [fname for fname in get_filenames_with_id(all_cache_files, cache_id) if fname != f"{cache_id}.json"] + if len(cache_files) == 0: + continue + + for filename in cache_files: + res = load_cache_json(os.path.join('cache', filename)) + if isinstance(res, dict): + # Convert to standard response format + res = [ + to_standard_format({'prompt': prompt, **res_obj}) + for prompt, res_obj in res.items() + ] + responses.extend(res) + + return responses + @app.route('/app/countQueriesRequired', methods=['POST']) def countQueries(): """ @@ -152,24 +176,51 @@ def countQueries(): 'prompt': str # the prompt template, with any {{}} vars 'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.) 'llms': list # the list of LLMs you will query + 'n': int # how many responses expected per prompt + 'id': str (optional) # a unique ID of the node with cache'd responses. If missing, assumes no cache will be used. } """ data = request.get_json() - if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}): + if not set(data.keys()).issuperset({'prompt', 'vars', 'llms', 'n'}): return jsonify({'error': 'POST data is improper format.'}) + n = int(data['n']) + try: gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt'])) all_prompt_permutations = list(gen_prompts(data['vars'])) except Exception as e: return jsonify({'error': str(e)}) + + if 'id' in data: + # Load all cache'd responses with the given id: + cached_resps = load_all_cached_responses(data['id']) + else: + cached_resps = [] + + missing_queries = {} + def add_to_missing_queries(llm, prompt, num): + if llm not in missing_queries: + missing_queries[llm] = {} + missing_queries[llm][prompt] = num + + # Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt + for prompt in all_prompt_permutations: + prompt = str(prompt) + matching_resps = [r for r in cached_resps if r['prompt'] == prompt] + for llm in data['llms']: + match_per_llm = [r for r in matching_resps if r['llm'] == llm] + if len(match_per_llm) == 0: + add_to_missing_queries(llm, prompt, n) + elif len(match_per_llm) == 1: + # Check how many were stored; if not enough, add how many missing queries: + num_resps = len(match_per_llm[0]['responses']) + if n > len(match_per_llm[0]['responses']): + add_to_missing_queries(llm, prompt, n - num_resps) + else: + raise Exception(f"More than one response found for the same prompt ({prompt}) and LLM ({llm})") - # TODO: Send more informative data back including how many queries per LLM based on cache'd data - num_queries = {} # len(all_prompt_permutations) * len(data['llms']) - for llm in data['llms']: - num_queries[llm] = len(all_prompt_permutations) - - ret = jsonify({'counts': num_queries}) + ret = jsonify({'counts': missing_queries}) ret.headers.add('Access-Control-Allow-Origin', '*') return ret @@ -239,6 +290,11 @@ async def queryLLM(): # Create a cache dir if it doesn't exist: create_dir_if_not_exists('cache') + # Check that the filepath used to cache eval'd responses is valid: + cache_filepath_last_run = os.path.join('cache', f"{data['id']}.json") + if not is_valid_filepath(cache_filepath_last_run): + return jsonify({'error': f'Invalid filepath: {cache_filepath_last_run}'}) + # For each LLM, generate and cache responses: responses = {} llms = data['llm'] @@ -303,6 +359,10 @@ async def queryLLM(): # Remove the temp file used to stream progress updates: if os.path.exists(tempfilepath): os.remove(tempfilepath) + + # Save the responses *of this run* to the disk, for further recall: + with open(cache_filepath_last_run, "w") as f: + json.dump(res, f) # Return all responses for all LLMs print('returning responses:', res) @@ -476,19 +536,18 @@ def grabResponses(): all_cache_files = get_files_at_dir('cache/') responses = [] for cache_id in data['responses']: - cache_files = get_filenames_with_id(all_cache_files, cache_id) - if len(cache_files) == 0: + fname = f"{cache_id}.json" + if fname not in all_cache_files: return jsonify({'error': f'Did not find cache file for id {cache_id}'}) - - for filename in cache_files: - res = load_cache_json(os.path.join('cache', filename)) - if isinstance(res, dict): - # Convert to standard response format - res = [ - to_standard_format({'prompt': prompt, **res_obj}) - for prompt, res_obj in res.items() - ] - responses.extend(res) + + res = load_cache_json(os.path.join('cache', fname)) + if isinstance(res, dict): + # Convert to standard response format + res = [ + to_standard_format({'prompt': prompt, **res_obj}) + for prompt, res_obj in res.items() + ] + responses.extend(res) ret = jsonify({'responses': responses}) ret.headers.add('Access-Control-Allow-Origin', '*') diff --git a/python-backend/promptengine/query.py b/python-backend/promptengine/query.py index dde875c..9f022ab 100644 --- a/python-backend/promptengine/query.py +++ b/python-backend/promptengine/query.py @@ -177,5 +177,7 @@ class PromptLLM(PromptPipeline): """ class PromptLLMDummy(PromptLLM): async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]: + # Wait a random amount of time, to simulate wait times from real queries await asyncio.sleep(random.uniform(0.1, 3)) - return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(40)) for _ in range(n)]) \ No newline at end of file + # Return a random string of characters of random length (within a predefined range) + return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(random.randint(25, 80))) for _ in range(n)]) \ No newline at end of file