mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
WIP fixed response cache's
This commit is contained in:
parent
16135934f4
commit
e9a9663af5
@ -71,7 +71,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
// Get all the script nodes, and get all the folder paths
|
||||
const script_nodes = nodes.filter(n => n.type === 'script');
|
||||
const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
|
||||
console.log(script_paths);
|
||||
|
||||
// Run evaluator in backend
|
||||
const codeTextOnRun = codeText + '';
|
||||
fetch(BASE_URL + 'app/execute', {
|
||||
@ -95,6 +95,8 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
alertModal.current.trigger(json ? json.error : 'Unknown error encountered when requesting evaluations: empty response returned.');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(json.responses);
|
||||
|
||||
// Ping any vis nodes attached to this node to refresh their contents:
|
||||
const output_nodes = outputEdgesForNode(id).map(e => e.target);
|
||||
|
@ -182,6 +182,8 @@ const PromptNode = ({ data, id }) => {
|
||||
prompt: prompt,
|
||||
vars: vars,
|
||||
llms: llms,
|
||||
id: id,
|
||||
n: numGenerations,
|
||||
})}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
}, rejected).then(function(json) {
|
||||
@ -209,9 +211,47 @@ const PromptNode = ({ data, id }) => {
|
||||
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
|
||||
console.warn(err.message); // soft fail
|
||||
}).then((counts) => {
|
||||
const n = counts[Object.keys(counts)[0]];
|
||||
const req = n > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
|
||||
// Check for empty counts (means no requests will be sent!)
|
||||
const num_llms_missing = Object.keys(counts).length;
|
||||
if (num_llms_missing === 0) {
|
||||
setRunTooltip('Will load responses from cache');
|
||||
return;
|
||||
}
|
||||
|
||||
// Tally how many queries per LLM:
|
||||
let queries_per_llm = {};
|
||||
Object.keys(counts).forEach(llm => {
|
||||
queries_per_llm[llm] = Object.keys(counts[llm]).reduce(
|
||||
(acc, prompt) => acc + counts[llm][prompt]
|
||||
, 0);
|
||||
});
|
||||
|
||||
// Check if all counts are the same:
|
||||
if (num_llms_missing > 1) {
|
||||
const some_llm_num = queries_per_llm[Object.keys(queries_per_llm)[0]];
|
||||
const all_same_num_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc && queries_per_llm[llm] === some_llm_num, true)
|
||||
if (num_llms_missing === num_llms && all_same_num_queries) { // Counts are the same
|
||||
const req = some_llm_num > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
|
||||
}
|
||||
else if (all_same_num_queries) {
|
||||
const req = some_llm_num > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${some_llm_num} ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : ''));
|
||||
}
|
||||
else { // Counts are different
|
||||
const sum_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc + queries_per_llm[llm], 0);
|
||||
setRunTooltip(`Will send a variable # of queries to LLM(s) (total=${sum_queries})`);
|
||||
}
|
||||
} else {
|
||||
const llm_name = Object.keys(queries_per_llm)[0];
|
||||
const llm_count = queries_per_llm[llm_name];
|
||||
const req = llm_count > 1 ? 'queries' : 'query';
|
||||
if (num_llms > num_llms_missing)
|
||||
setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name} and load other responses from cache`);
|
||||
else
|
||||
setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name}`)
|
||||
}
|
||||
|
||||
});
|
||||
};
|
||||
|
||||
@ -340,7 +380,7 @@ const PromptNode = ({ data, id }) => {
|
||||
temperature: 0.5,
|
||||
n: numGenerations,
|
||||
},
|
||||
no_cache: true,
|
||||
no_cache: false,
|
||||
}),
|
||||
}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
|
@ -66,48 +66,75 @@ const VisNode = ({ data, id }) => {
|
||||
else
|
||||
responses_by_llm[item.llm] = [item];
|
||||
});
|
||||
const llm_names = Object.keys(responses_by_llm);
|
||||
|
||||
// Create Plotly spec here
|
||||
const varnames = Object.keys(json.responses[0].vars);
|
||||
let spec = {};
|
||||
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
|
||||
let spec = [];
|
||||
let layout = {
|
||||
width: 420, height: 300, title: '', margin: {
|
||||
l: 40, r: 20, b: 20, t: 20, pad: 2
|
||||
l: 65, r: 20, b: 20, t: 20, pad: 0
|
||||
}
|
||||
}
|
||||
|
||||
if (varnames.length === 1) {
|
||||
const plot_grouped_boxplot = (resp_to_x) => {
|
||||
llm_names.forEach((llm, idx) => {
|
||||
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
|
||||
const rs = responses_by_llm[llm];
|
||||
const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
|
||||
spec.push({
|
||||
type: 'box',
|
||||
name: llm,
|
||||
marker: {color: colors[idx % colors.length]},
|
||||
x: rs.map(r => r.eval_res.items).flat(),
|
||||
y: rs.map(r => Array(r.eval_res.items.length).fill(resp_to_x(r))).flat(),
|
||||
boxpoints: 'all',
|
||||
text: hover_texts,
|
||||
hovertemplate: '%{text} <b><i>(%{x})</i></b>',
|
||||
orientation: 'h',
|
||||
});
|
||||
});
|
||||
layout.boxmode = 'group';
|
||||
};
|
||||
|
||||
if (varnames.length === 0) {
|
||||
// No variables means they used a single prompt (no template) to generate responses
|
||||
// (Users are likely evaluating differences in responses between LLMs)
|
||||
plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12));
|
||||
// llm_names.forEach((llm, idx) => {
|
||||
// // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
|
||||
// const rs = responses_by_llm[llm];
|
||||
// const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
|
||||
// spec.push({
|
||||
// type: 'scatter',
|
||||
// name: llm,
|
||||
// marker: {color: colors[idx % colors.length]},
|
||||
// y: rs.map(r => r.eval_res.items).flat(),
|
||||
// x: rs.map(r => Array(r.eval_res.items.length).fill(truncStr(r.prompt.trim(), 12))).flat(), // use the prompt str as var name
|
||||
// // boxpoints: 'all',
|
||||
// mode: 'markers',
|
||||
// text: hover_texts,
|
||||
// hovertemplate: '%{text} <b><i>(%{y})</i></b>',
|
||||
// });
|
||||
// });
|
||||
// layout.scattermode = 'group';
|
||||
}
|
||||
else if (varnames.length === 1) {
|
||||
// 1 var; numeric eval
|
||||
if (Object.keys(responses_by_llm).length === 1) {
|
||||
if (llm_names.length === 1) {
|
||||
// Simple box plot, as there is only a single LLM in the response
|
||||
spec = json.responses.map(r => {
|
||||
// Use the var value to 'name' this group of points:
|
||||
const s = truncStr(r.vars[varnames[0]].trim(), 12);
|
||||
|
||||
return {type: 'box', y: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}'};
|
||||
return {type: 'box', x: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}', orientation: 'h'};
|
||||
});
|
||||
layout.hovermode = 'closest';
|
||||
} else {
|
||||
// There are multiple LLMs in the response; do a grouped box plot by LLM.
|
||||
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
|
||||
spec = [];
|
||||
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
|
||||
Object.keys(responses_by_llm).forEach((llm, idx) => {
|
||||
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
|
||||
const rs = responses_by_llm[llm];
|
||||
const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
|
||||
spec.push({
|
||||
type: 'box',
|
||||
name: llm,
|
||||
marker: {color: colors[idx % colors.length]},
|
||||
y: rs.map(r => r.eval_res.items).flat(),
|
||||
x: rs.map(r => Array(r.eval_res.items.length).fill(r.vars[varnames[0]].trim())).flat(),
|
||||
boxpoints: 'all',
|
||||
text: hover_texts,
|
||||
hovertemplate: '%{text}',
|
||||
});
|
||||
});
|
||||
layout.boxmode = 'group';
|
||||
plot_grouped_boxplot((r) => r.vars[varnames[0]].trim());
|
||||
}
|
||||
}
|
||||
else if (varnames.length === 2) {
|
||||
|
@ -142,6 +142,30 @@ def reduce_responses(responses: list, vars: list) -> list:
|
||||
|
||||
return ret
|
||||
|
||||
def load_all_cached_responses(cache_ids):
|
||||
if not isinstance(cache_ids, list):
|
||||
cache_ids = [cache_ids]
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
responses = []
|
||||
for cache_id in cache_ids:
|
||||
cache_files = [fname for fname in get_filenames_with_id(all_cache_files, cache_id) if fname != f"{cache_id}.json"]
|
||||
if len(cache_files) == 0:
|
||||
continue
|
||||
|
||||
for filename in cache_files:
|
||||
res = load_cache_json(os.path.join('cache', filename))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
return responses
|
||||
|
||||
@app.route('/app/countQueriesRequired', methods=['POST'])
|
||||
def countQueries():
|
||||
"""
|
||||
@ -152,24 +176,51 @@ def countQueries():
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'llms': list # the list of LLMs you will query
|
||||
'n': int # how many responses expected per prompt
|
||||
'id': str (optional) # a unique ID of the node with cache'd responses. If missing, assumes no cache will be used.
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
|
||||
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms', 'n'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
n = int(data['n'])
|
||||
|
||||
try:
|
||||
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
|
||||
all_prompt_permutations = list(gen_prompts(data['vars']))
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
if 'id' in data:
|
||||
# Load all cache'd responses with the given id:
|
||||
cached_resps = load_all_cached_responses(data['id'])
|
||||
else:
|
||||
cached_resps = []
|
||||
|
||||
missing_queries = {}
|
||||
def add_to_missing_queries(llm, prompt, num):
|
||||
if llm not in missing_queries:
|
||||
missing_queries[llm] = {}
|
||||
missing_queries[llm][prompt] = num
|
||||
|
||||
# Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt
|
||||
for prompt in all_prompt_permutations:
|
||||
prompt = str(prompt)
|
||||
matching_resps = [r for r in cached_resps if r['prompt'] == prompt]
|
||||
for llm in data['llms']:
|
||||
match_per_llm = [r for r in matching_resps if r['llm'] == llm]
|
||||
if len(match_per_llm) == 0:
|
||||
add_to_missing_queries(llm, prompt, n)
|
||||
elif len(match_per_llm) == 1:
|
||||
# Check how many were stored; if not enough, add how many missing queries:
|
||||
num_resps = len(match_per_llm[0]['responses'])
|
||||
if n > len(match_per_llm[0]['responses']):
|
||||
add_to_missing_queries(llm, prompt, n - num_resps)
|
||||
else:
|
||||
raise Exception(f"More than one response found for the same prompt ({prompt}) and LLM ({llm})")
|
||||
|
||||
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
|
||||
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
|
||||
for llm in data['llms']:
|
||||
num_queries[llm] = len(all_prompt_permutations)
|
||||
|
||||
ret = jsonify({'counts': num_queries})
|
||||
ret = jsonify({'counts': missing_queries})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@ -239,6 +290,11 @@ async def queryLLM():
|
||||
# Create a cache dir if it doesn't exist:
|
||||
create_dir_if_not_exists('cache')
|
||||
|
||||
# Check that the filepath used to cache eval'd responses is valid:
|
||||
cache_filepath_last_run = os.path.join('cache', f"{data['id']}.json")
|
||||
if not is_valid_filepath(cache_filepath_last_run):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath_last_run}'})
|
||||
|
||||
# For each LLM, generate and cache responses:
|
||||
responses = {}
|
||||
llms = data['llm']
|
||||
@ -303,6 +359,10 @@ async def queryLLM():
|
||||
# Remove the temp file used to stream progress updates:
|
||||
if os.path.exists(tempfilepath):
|
||||
os.remove(tempfilepath)
|
||||
|
||||
# Save the responses *of this run* to the disk, for further recall:
|
||||
with open(cache_filepath_last_run, "w") as f:
|
||||
json.dump(res, f)
|
||||
|
||||
# Return all responses for all LLMs
|
||||
print('returning responses:', res)
|
||||
@ -476,19 +536,18 @@ def grabResponses():
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
fname = f"{cache_id}.json"
|
||||
if fname not in all_cache_files:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
for filename in cache_files:
|
||||
res = load_cache_json(os.path.join('cache', filename))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
res = load_cache_json(os.path.join('cache', fname))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
ret = jsonify({'responses': responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
|
@ -177,5 +177,7 @@ class PromptLLM(PromptPipeline):
|
||||
"""
|
||||
class PromptLLMDummy(PromptLLM):
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]:
|
||||
# Wait a random amount of time, to simulate wait times from real queries
|
||||
await asyncio.sleep(random.uniform(0.1, 3))
|
||||
return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(40)) for _ in range(n)])
|
||||
# Return a random string of characters of random length (within a predefined range)
|
||||
return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(random.randint(25, 80))) for _ in range(n)])
|
Loading…
x
Reference in New Issue
Block a user