Carry prompt vars across prompt node chains

This commit is contained in:
Ian Arawjo 2023-05-11 14:54:36 -04:00
parent 09871cdc1f
commit c3994329e8
4 changed files with 79 additions and 18 deletions

View File

@ -147,6 +147,12 @@ const PromptNode = ({ data, id }) => {
const pullInputData = () => {
// Pull data from each source recursively:
const pulled_data = {};
const store_data = (_texts, _varname, _data) => {
if (_varname in _data)
_data[_varname] = _data[_varname].concat(_texts);
else
_data[_varname] = _texts;
};
const get_outputs = (varnames, nodeId) => {
varnames.forEach(varname => {
// Find the relevant edge(s):
@ -154,14 +160,24 @@ const PromptNode = ({ data, id }) => {
if (e.target == nodeId && e.targetHandle == varname) {
// Get the immediate output:
let out = output(e.source, e.sourceHandle);
if (!out) return;
// Save the var data from the pulled output
if (varname in pulled_data)
pulled_data[varname] = pulled_data[varname].concat(out);
else
pulled_data[varname] = out;
if (!out || !Array.isArray(out) || out.length === 0) return;
// Check the format of the output. Can be str or dict with 'text' and 'vars' attrs:
if (typeof out[0] === 'object') {
out.forEach(obj => store_data([obj], varname, pulled_data));
// out.forEach((obj) => {
// store_data([obj.text], varname, pulled_data);
// // We need to carry through each individual var as well:
// Object.keys(obj.vars).forEach(_v =>
// store_data([obj.vars[_v]], _v, pulled_data)
// );
// });
}
else {
// Save the list of strings from the pulled output under the var 'varname'
store_data(out, varname, pulled_data);
}
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
@ -172,8 +188,20 @@ const PromptNode = ({ data, id }) => {
};
get_outputs(templateVars, id);
console.log(pulled_data);
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
const str_to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
const to_py_template_format = (str_or_obj) => {
if (typeof str_or_obj === 'object') {
let new_vars = {};
Object.keys(str_or_obj.fill_history).forEach(v => {
new_vars[v] = str_to_py_template_format(str_or_obj.fill_history[v]);
});
return {text: str_to_py_template_format(str_or_obj.text), fill_history: new_vars};
} else
return str_to_py_template_format(str_or_obj);
};
const py_prompt_template = to_py_template_format(promptText);
// Do the same for the vars, since vars can themselves be prompt templates:
@ -420,7 +448,10 @@ const PromptNode = ({ data, id }) => {
setPromptTextOnLastRun(promptText);
// Save response texts as 'fields' of data, for any prompt nodes pulling the outputs
setDataPropsForNode(id, {fields: json.responses.map(r => r['responses']).flat()});
setDataPropsForNode(id, {fields: json.responses.map(
resp_obj => resp_obj['responses'].map(
r => ({text: r, fill_history: resp_obj['vars']}))).flat()
});
// Save preview strings of responses, for quick glance
// Bucket responses by LLM:

View File

@ -1,9 +1,8 @@
import React, { useState, useEffect, useCallback } from 'react';
import { Handle } from 'react-flow-renderer';
import { MultiSelect } from '@mantine/core';
import useStore from './store';
import Plot from 'react-plotly.js';
import { hover } from '@testing-library/user-event/dist/hover';
import { create } from 'zustand';
import NodeLabel from './NodeLabelComponent';
import {BASE_URL} from './store';
@ -42,6 +41,9 @@ const VisNode = ({ data, id }) => {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const [plotlyObj, setPlotlyObj] = useState([]);
const [pastInputs, setPastInputs] = useState([]);
// The MultiSelect so people can dynamically set what vars they care about
const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []);
const handleOnConnect = useCallback(() => {
// Grab the input node ids
@ -78,6 +80,10 @@ const VisNode = ({ data, id }) => {
}
}
setMultiSelectVars(
varnames.map(name => ({value: name, label: name}))
);
const plot_grouped_boxplot = (resp_to_x) => {
llm_names.forEach((llm, idx) => {
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
@ -192,8 +198,13 @@ const VisNode = ({ data, id }) => {
return (
<div className="vis-node cfnode">
<NodeLabel title={data.title || 'Vis Node'}
nodeId={id}
icon={'📊'} />
nodeId={id}
icon={'📊'} />
<MultiSelect className='nodrag nowheel'
data={multiSelectVars}
placeholder="Pick all vars you wish to plot"
size="sm"
defaultValue={multiSelectVars} />
<div className="nodrag">{plotlyObj}</div>
<Handle
type="target"

View File

@ -57,6 +57,7 @@ class PromptPipeline:
raise Exception(f"Cannot send a prompt '{prompt}' to LLM: Prompt is a template.")
prompt_str = str(prompt)
info = prompt.fill_history
cached_resp = responses[prompt_str] if prompt_str in responses else None
extracted_resps = cached_resp["responses"] if cached_resp is not None else []
@ -72,7 +73,7 @@ class PromptPipeline:
"llm": cached_resp["llm"] if "llm" in cached_resp else LLM.ChatGPT.value,
# We want to use the new info, since 'vars' could have changed even though
# the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
"info": prompt.fill_history,
"info": info,
}
continue
@ -86,7 +87,6 @@ class PromptPipeline:
else:
# Blocking. Await + yield a single LLM call.
_, query, response, past_resp_obj = await self._prompt_llm(llm, prompt, n, temperature, past_resp_obj=cached_resp)
info = prompt.fill_history
# Create a response obj to represent the response
resp_obj = {

View File

@ -52,11 +52,14 @@ class PromptTemplate:
except KeyError as e:
return False
def fill(self, paramDict: Dict[str, str]) -> 'PromptTemplate':
def fill(self, paramDict: Dict[str, Union[str, Dict[str, str]]]) -> 'PromptTemplate':
"""
Formats the template string with the given parameters, returning a new PromptTemplate.
Can return a partial completion.
NOTE: paramDict values can be in a special form: {text: <str>, fill_history: {varname: <str>}}
in order to bundle in any past fill history that is lost in the current text.
Example usage:
prompt = prompt_template.fill({
"className": className,
@ -64,13 +67,26 @@ class PromptTemplate:
"PL": "Python"
});
"""
# Check for special 'past fill history' format:
past_fill_history = {}
if len(paramDict) > 0 and isinstance(next(iter(paramDict.values())), dict):
for obj in paramDict.values():
past_fill_history = {**obj['fill_history'], **past_fill_history}
paramDict = {param: obj['text'] for param, obj in paramDict.items()}
filled_pt = PromptTemplate(
Template(self.template).safe_substitute(paramDict)
)
# Deep copy prior fill history from this version over to new one
# Deep copy prior fill history of this PromptTemplate from this version over to new one
filled_pt.fill_history = { key: val for (key, val) in self.fill_history.items() }
# Append any past history passed as vars:
for key, val in past_fill_history.items():
if key in filled_pt.fill_history:
print(f"Warning: PromptTemplate already has fill history for key {key}.")
filled_pt.fill_history[key] = val
# Add the new fill history using the passed parameters that we just filled in
for key, val in paramDict.items():
if key in filled_pt.fill_history:
@ -85,6 +101,9 @@ class PromptPermutationGenerator:
Given a PromptTemplate and a parameter dict that includes arrays of items,
generate all the permutations of the prompt for all permutations of the items.
NOTE: Items can be in a special form: {text: <str>, fill_history: {varname: <str>}}
in order to bundle in any past fill history that is lost in the current text.
Example usage:
prompt_gen = PromptPermutationGenerator('Can you list all the cities in the country ${country} by the cheapest ${domain} prices?')
for prompt in prompt_gen({"country":["Canada", "South Africa", "China"],
@ -129,7 +148,7 @@ class PromptPermutationGenerator:
res.extend(self._gen_perm(p, params_left, paramDict))
return res
def __call__(self, paramDict: Dict[str, Union[str, List[str]]]):
def __call__(self, paramDict: Dict[str, Union[str, List[str], Dict[str, str]]]):
if len(paramDict) == 0:
yield self.template
return