WIP fixed response cache's

This commit is contained in:
Ian Arawjo 2023-05-09 12:28:37 -04:00
parent 16135934f4
commit e9a9663af5
5 changed files with 178 additions and 48 deletions

View File

@ -71,7 +71,7 @@ const EvaluatorNode = ({ data, id }) => {
// Get all the script nodes, and get all the folder paths
const script_nodes = nodes.filter(n => n.type === 'script');
const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
console.log(script_paths);
// Run evaluator in backend
const codeTextOnRun = codeText + '';
fetch(BASE_URL + 'app/execute', {
@ -95,6 +95,8 @@ const EvaluatorNode = ({ data, id }) => {
alertModal.current.trigger(json ? json.error : 'Unknown error encountered when requesting evaluations: empty response returned.');
return;
}
console.log(json.responses);
// Ping any vis nodes attached to this node to refresh their contents:
const output_nodes = outputEdgesForNode(id).map(e => e.target);

View File

@ -182,6 +182,8 @@ const PromptNode = ({ data, id }) => {
prompt: prompt,
vars: vars,
llms: llms,
id: id,
n: numGenerations,
})}, rejected).then(function(response) {
return response.json();
}, rejected).then(function(json) {
@ -209,9 +211,47 @@ const PromptNode = ({ data, id }) => {
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
console.warn(err.message); // soft fail
}).then((counts) => {
const n = counts[Object.keys(counts)[0]];
const req = n > 1 ? 'requests' : 'request';
setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
// Check for empty counts (means no requests will be sent!)
const num_llms_missing = Object.keys(counts).length;
if (num_llms_missing === 0) {
setRunTooltip('Will load responses from cache');
return;
}
// Tally how many queries per LLM:
let queries_per_llm = {};
Object.keys(counts).forEach(llm => {
queries_per_llm[llm] = Object.keys(counts[llm]).reduce(
(acc, prompt) => acc + counts[llm][prompt]
, 0);
});
// Check if all counts are the same:
if (num_llms_missing > 1) {
const some_llm_num = queries_per_llm[Object.keys(queries_per_llm)[0]];
const all_same_num_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc && queries_per_llm[llm] === some_llm_num, true)
if (num_llms_missing === num_llms && all_same_num_queries) { // Counts are the same
const req = some_llm_num > 1 ? 'requests' : 'request';
setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
}
else if (all_same_num_queries) {
const req = some_llm_num > 1 ? 'requests' : 'request';
setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : ''));
}
else { // Counts are different
const sum_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc + queries_per_llm[llm], 0);
setRunTooltip(`Will send a variable # of queries to LLM(s) (total=${sum_queries})`);
}
} else {
const llm_name = Object.keys(queries_per_llm)[0];
const llm_count = queries_per_llm[llm_name];
const req = llm_count > 1 ? 'queries' : 'query';
if (num_llms > num_llms_missing)
setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name} and load other responses from cache`);
else
setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name}`)
}
});
};
@ -340,7 +380,7 @@ const PromptNode = ({ data, id }) => {
temperature: 0.5,
n: numGenerations,
},
no_cache: true,
no_cache: false,
}),
}, rejected).then(function(response) {
return response.json();

View File

@ -66,48 +66,75 @@ const VisNode = ({ data, id }) => {
else
responses_by_llm[item.llm] = [item];
});
const llm_names = Object.keys(responses_by_llm);
// Create Plotly spec here
const varnames = Object.keys(json.responses[0].vars);
let spec = {};
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
let spec = [];
let layout = {
width: 420, height: 300, title: '', margin: {
l: 40, r: 20, b: 20, t: 20, pad: 2
l: 65, r: 20, b: 20, t: 20, pad: 0
}
}
if (varnames.length === 1) {
const plot_grouped_boxplot = (resp_to_x) => {
llm_names.forEach((llm, idx) => {
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
const rs = responses_by_llm[llm];
const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
spec.push({
type: 'box',
name: llm,
marker: {color: colors[idx % colors.length]},
x: rs.map(r => r.eval_res.items).flat(),
y: rs.map(r => Array(r.eval_res.items.length).fill(resp_to_x(r))).flat(),
boxpoints: 'all',
text: hover_texts,
hovertemplate: '%{text} <b><i>(%{x})</i></b>',
orientation: 'h',
});
});
layout.boxmode = 'group';
};
if (varnames.length === 0) {
// No variables means they used a single prompt (no template) to generate responses
// (Users are likely evaluating differences in responses between LLMs)
plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12));
// llm_names.forEach((llm, idx) => {
// // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
// const rs = responses_by_llm[llm];
// const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
// spec.push({
// type: 'scatter',
// name: llm,
// marker: {color: colors[idx % colors.length]},
// y: rs.map(r => r.eval_res.items).flat(),
// x: rs.map(r => Array(r.eval_res.items.length).fill(truncStr(r.prompt.trim(), 12))).flat(), // use the prompt str as var name
// // boxpoints: 'all',
// mode: 'markers',
// text: hover_texts,
// hovertemplate: '%{text} <b><i>(%{y})</i></b>',
// });
// });
// layout.scattermode = 'group';
}
else if (varnames.length === 1) {
// 1 var; numeric eval
if (Object.keys(responses_by_llm).length === 1) {
if (llm_names.length === 1) {
// Simple box plot, as there is only a single LLM in the response
spec = json.responses.map(r => {
// Use the var value to 'name' this group of points:
const s = truncStr(r.vars[varnames[0]].trim(), 12);
return {type: 'box', y: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}'};
return {type: 'box', x: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}', orientation: 'h'};
});
layout.hovermode = 'closest';
} else {
// There are multiple LLMs in the response; do a grouped box plot by LLM.
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
spec = [];
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
Object.keys(responses_by_llm).forEach((llm, idx) => {
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
const rs = responses_by_llm[llm];
const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
spec.push({
type: 'box',
name: llm,
marker: {color: colors[idx % colors.length]},
y: rs.map(r => r.eval_res.items).flat(),
x: rs.map(r => Array(r.eval_res.items.length).fill(r.vars[varnames[0]].trim())).flat(),
boxpoints: 'all',
text: hover_texts,
hovertemplate: '%{text}',
});
});
layout.boxmode = 'group';
plot_grouped_boxplot((r) => r.vars[varnames[0]].trim());
}
}
else if (varnames.length === 2) {

View File

@ -142,6 +142,30 @@ def reduce_responses(responses: list, vars: list) -> list:
return ret
def load_all_cached_responses(cache_ids):
if not isinstance(cache_ids, list):
cache_ids = [cache_ids]
# Load all responses with the given ID:
all_cache_files = get_files_at_dir('cache/')
responses = []
for cache_id in cache_ids:
cache_files = [fname for fname in get_filenames_with_id(all_cache_files, cache_id) if fname != f"{cache_id}.json"]
if len(cache_files) == 0:
continue
for filename in cache_files:
res = load_cache_json(os.path.join('cache', filename))
if isinstance(res, dict):
# Convert to standard response format
res = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in res.items()
]
responses.extend(res)
return responses
@app.route('/app/countQueriesRequired', methods=['POST'])
def countQueries():
"""
@ -152,24 +176,51 @@ def countQueries():
'prompt': str # the prompt template, with any {{}} vars
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
'llms': list # the list of LLMs you will query
'n': int # how many responses expected per prompt
'id': str (optional) # a unique ID of the node with cache'd responses. If missing, assumes no cache will be used.
}
"""
data = request.get_json()
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms', 'n'}):
return jsonify({'error': 'POST data is improper format.'})
n = int(data['n'])
try:
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
all_prompt_permutations = list(gen_prompts(data['vars']))
except Exception as e:
return jsonify({'error': str(e)})
if 'id' in data:
# Load all cache'd responses with the given id:
cached_resps = load_all_cached_responses(data['id'])
else:
cached_resps = []
missing_queries = {}
def add_to_missing_queries(llm, prompt, num):
if llm not in missing_queries:
missing_queries[llm] = {}
missing_queries[llm][prompt] = num
# Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt
for prompt in all_prompt_permutations:
prompt = str(prompt)
matching_resps = [r for r in cached_resps if r['prompt'] == prompt]
for llm in data['llms']:
match_per_llm = [r for r in matching_resps if r['llm'] == llm]
if len(match_per_llm) == 0:
add_to_missing_queries(llm, prompt, n)
elif len(match_per_llm) == 1:
# Check how many were stored; if not enough, add how many missing queries:
num_resps = len(match_per_llm[0]['responses'])
if n > len(match_per_llm[0]['responses']):
add_to_missing_queries(llm, prompt, n - num_resps)
else:
raise Exception(f"More than one response found for the same prompt ({prompt}) and LLM ({llm})")
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
for llm in data['llms']:
num_queries[llm] = len(all_prompt_permutations)
ret = jsonify({'counts': num_queries})
ret = jsonify({'counts': missing_queries})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@ -239,6 +290,11 @@ async def queryLLM():
# Create a cache dir if it doesn't exist:
create_dir_if_not_exists('cache')
# Check that the filepath used to cache eval'd responses is valid:
cache_filepath_last_run = os.path.join('cache', f"{data['id']}.json")
if not is_valid_filepath(cache_filepath_last_run):
return jsonify({'error': f'Invalid filepath: {cache_filepath_last_run}'})
# For each LLM, generate and cache responses:
responses = {}
llms = data['llm']
@ -303,6 +359,10 @@ async def queryLLM():
# Remove the temp file used to stream progress updates:
if os.path.exists(tempfilepath):
os.remove(tempfilepath)
# Save the responses *of this run* to the disk, for further recall:
with open(cache_filepath_last_run, "w") as f:
json.dump(res, f)
# Return all responses for all LLMs
print('returning responses:', res)
@ -476,19 +536,18 @@ def grabResponses():
all_cache_files = get_files_at_dir('cache/')
responses = []
for cache_id in data['responses']:
cache_files = get_filenames_with_id(all_cache_files, cache_id)
if len(cache_files) == 0:
fname = f"{cache_id}.json"
if fname not in all_cache_files:
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
for filename in cache_files:
res = load_cache_json(os.path.join('cache', filename))
if isinstance(res, dict):
# Convert to standard response format
res = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in res.items()
]
responses.extend(res)
res = load_cache_json(os.path.join('cache', fname))
if isinstance(res, dict):
# Convert to standard response format
res = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in res.items()
]
responses.extend(res)
ret = jsonify({'responses': responses})
ret.headers.add('Access-Control-Allow-Origin', '*')

View File

@ -177,5 +177,7 @@ class PromptLLM(PromptPipeline):
"""
class PromptLLMDummy(PromptLLM):
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]:
# Wait a random amount of time, to simulate wait times from real queries
await asyncio.sleep(random.uniform(0.1, 3))
return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(40)) for _ in range(n)])
# Return a random string of characters of random length (within a predefined range)
return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(random.randint(25, 80))) for _ in range(n)])