mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Carry prompt vars across prompt node chains
This commit is contained in:
parent
09871cdc1f
commit
c3994329e8
@ -147,6 +147,12 @@ const PromptNode = ({ data, id }) => {
|
||||
const pullInputData = () => {
|
||||
// Pull data from each source recursively:
|
||||
const pulled_data = {};
|
||||
const store_data = (_texts, _varname, _data) => {
|
||||
if (_varname in _data)
|
||||
_data[_varname] = _data[_varname].concat(_texts);
|
||||
else
|
||||
_data[_varname] = _texts;
|
||||
};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
varnames.forEach(varname => {
|
||||
// Find the relevant edge(s):
|
||||
@ -154,14 +160,24 @@ const PromptNode = ({ data, id }) => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
// Get the immediate output:
|
||||
let out = output(e.source, e.sourceHandle);
|
||||
if (!out) return;
|
||||
|
||||
// Save the var data from the pulled output
|
||||
if (varname in pulled_data)
|
||||
pulled_data[varname] = pulled_data[varname].concat(out);
|
||||
else
|
||||
pulled_data[varname] = out;
|
||||
if (!out || !Array.isArray(out) || out.length === 0) return;
|
||||
|
||||
// Check the format of the output. Can be str or dict with 'text' and 'vars' attrs:
|
||||
if (typeof out[0] === 'object') {
|
||||
out.forEach(obj => store_data([obj], varname, pulled_data));
|
||||
// out.forEach((obj) => {
|
||||
// store_data([obj.text], varname, pulled_data);
|
||||
// // We need to carry through each individual var as well:
|
||||
// Object.keys(obj.vars).forEach(_v =>
|
||||
// store_data([obj.vars[_v]], _v, pulled_data)
|
||||
// );
|
||||
// });
|
||||
}
|
||||
else {
|
||||
// Save the list of strings from the pulled output under the var 'varname'
|
||||
store_data(out, varname, pulled_data);
|
||||
}
|
||||
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
@ -172,8 +188,20 @@ const PromptNode = ({ data, id }) => {
|
||||
};
|
||||
get_outputs(templateVars, id);
|
||||
|
||||
console.log(pulled_data);
|
||||
|
||||
// Get Pythonic version of the prompt, by adding a $ before any template variables in braces:
|
||||
const to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
|
||||
const str_to_py_template_format = (str) => str.replace(/(?<!\\){(.*?)(?<!\\)}/g, "${$1}")
|
||||
const to_py_template_format = (str_or_obj) => {
|
||||
if (typeof str_or_obj === 'object') {
|
||||
let new_vars = {};
|
||||
Object.keys(str_or_obj.fill_history).forEach(v => {
|
||||
new_vars[v] = str_to_py_template_format(str_or_obj.fill_history[v]);
|
||||
});
|
||||
return {text: str_to_py_template_format(str_or_obj.text), fill_history: new_vars};
|
||||
} else
|
||||
return str_to_py_template_format(str_or_obj);
|
||||
};
|
||||
const py_prompt_template = to_py_template_format(promptText);
|
||||
|
||||
// Do the same for the vars, since vars can themselves be prompt templates:
|
||||
@ -420,7 +448,10 @@ const PromptNode = ({ data, id }) => {
|
||||
setPromptTextOnLastRun(promptText);
|
||||
|
||||
// Save response texts as 'fields' of data, for any prompt nodes pulling the outputs
|
||||
setDataPropsForNode(id, {fields: json.responses.map(r => r['responses']).flat()});
|
||||
setDataPropsForNode(id, {fields: json.responses.map(
|
||||
resp_obj => resp_obj['responses'].map(
|
||||
r => ({text: r, fill_history: resp_obj['vars']}))).flat()
|
||||
});
|
||||
|
||||
// Save preview strings of responses, for quick glance
|
||||
// Bucket responses by LLM:
|
||||
|
@ -1,9 +1,8 @@
|
||||
import React, { useState, useEffect, useCallback } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import { MultiSelect } from '@mantine/core';
|
||||
import useStore from './store';
|
||||
import Plot from 'react-plotly.js';
|
||||
import { hover } from '@testing-library/user-event/dist/hover';
|
||||
import { create } from 'zustand';
|
||||
import NodeLabel from './NodeLabelComponent';
|
||||
import {BASE_URL} from './store';
|
||||
|
||||
@ -42,6 +41,9 @@ const VisNode = ({ data, id }) => {
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const [plotlyObj, setPlotlyObj] = useState([]);
|
||||
const [pastInputs, setPastInputs] = useState([]);
|
||||
|
||||
// The MultiSelect so people can dynamically set what vars they care about
|
||||
const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []);
|
||||
|
||||
const handleOnConnect = useCallback(() => {
|
||||
// Grab the input node ids
|
||||
@ -78,6 +80,10 @@ const VisNode = ({ data, id }) => {
|
||||
}
|
||||
}
|
||||
|
||||
setMultiSelectVars(
|
||||
varnames.map(name => ({value: name, label: name}))
|
||||
);
|
||||
|
||||
const plot_grouped_boxplot = (resp_to_x) => {
|
||||
llm_names.forEach((llm, idx) => {
|
||||
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
|
||||
@ -192,8 +198,13 @@ const VisNode = ({ data, id }) => {
|
||||
return (
|
||||
<div className="vis-node cfnode">
|
||||
<NodeLabel title={data.title || 'Vis Node'}
|
||||
nodeId={id}
|
||||
icon={'📊'} />
|
||||
nodeId={id}
|
||||
icon={'📊'} />
|
||||
<MultiSelect className='nodrag nowheel'
|
||||
data={multiSelectVars}
|
||||
placeholder="Pick all vars you wish to plot"
|
||||
size="sm"
|
||||
defaultValue={multiSelectVars} />
|
||||
<div className="nodrag">{plotlyObj}</div>
|
||||
<Handle
|
||||
type="target"
|
||||
|
@ -57,6 +57,7 @@ class PromptPipeline:
|
||||
raise Exception(f"Cannot send a prompt '{prompt}' to LLM: Prompt is a template.")
|
||||
|
||||
prompt_str = str(prompt)
|
||||
info = prompt.fill_history
|
||||
|
||||
cached_resp = responses[prompt_str] if prompt_str in responses else None
|
||||
extracted_resps = cached_resp["responses"] if cached_resp is not None else []
|
||||
@ -72,7 +73,7 @@ class PromptPipeline:
|
||||
"llm": cached_resp["llm"] if "llm" in cached_resp else LLM.ChatGPT.value,
|
||||
# We want to use the new info, since 'vars' could have changed even though
|
||||
# the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
|
||||
"info": prompt.fill_history,
|
||||
"info": info,
|
||||
}
|
||||
continue
|
||||
|
||||
@ -86,7 +87,6 @@ class PromptPipeline:
|
||||
else:
|
||||
# Blocking. Await + yield a single LLM call.
|
||||
_, query, response, past_resp_obj = await self._prompt_llm(llm, prompt, n, temperature, past_resp_obj=cached_resp)
|
||||
info = prompt.fill_history
|
||||
|
||||
# Create a response obj to represent the response
|
||||
resp_obj = {
|
||||
|
@ -52,11 +52,14 @@ class PromptTemplate:
|
||||
except KeyError as e:
|
||||
return False
|
||||
|
||||
def fill(self, paramDict: Dict[str, str]) -> 'PromptTemplate':
|
||||
def fill(self, paramDict: Dict[str, Union[str, Dict[str, str]]]) -> 'PromptTemplate':
|
||||
"""
|
||||
Formats the template string with the given parameters, returning a new PromptTemplate.
|
||||
Can return a partial completion.
|
||||
|
||||
NOTE: paramDict values can be in a special form: {text: <str>, fill_history: {varname: <str>}}
|
||||
in order to bundle in any past fill history that is lost in the current text.
|
||||
|
||||
Example usage:
|
||||
prompt = prompt_template.fill({
|
||||
"className": className,
|
||||
@ -64,13 +67,26 @@ class PromptTemplate:
|
||||
"PL": "Python"
|
||||
});
|
||||
"""
|
||||
# Check for special 'past fill history' format:
|
||||
past_fill_history = {}
|
||||
if len(paramDict) > 0 and isinstance(next(iter(paramDict.values())), dict):
|
||||
for obj in paramDict.values():
|
||||
past_fill_history = {**obj['fill_history'], **past_fill_history}
|
||||
paramDict = {param: obj['text'] for param, obj in paramDict.items()}
|
||||
|
||||
filled_pt = PromptTemplate(
|
||||
Template(self.template).safe_substitute(paramDict)
|
||||
)
|
||||
|
||||
# Deep copy prior fill history from this version over to new one
|
||||
# Deep copy prior fill history of this PromptTemplate from this version over to new one
|
||||
filled_pt.fill_history = { key: val for (key, val) in self.fill_history.items() }
|
||||
|
||||
# Append any past history passed as vars:
|
||||
for key, val in past_fill_history.items():
|
||||
if key in filled_pt.fill_history:
|
||||
print(f"Warning: PromptTemplate already has fill history for key {key}.")
|
||||
filled_pt.fill_history[key] = val
|
||||
|
||||
# Add the new fill history using the passed parameters that we just filled in
|
||||
for key, val in paramDict.items():
|
||||
if key in filled_pt.fill_history:
|
||||
@ -85,6 +101,9 @@ class PromptPermutationGenerator:
|
||||
Given a PromptTemplate and a parameter dict that includes arrays of items,
|
||||
generate all the permutations of the prompt for all permutations of the items.
|
||||
|
||||
NOTE: Items can be in a special form: {text: <str>, fill_history: {varname: <str>}}
|
||||
in order to bundle in any past fill history that is lost in the current text.
|
||||
|
||||
Example usage:
|
||||
prompt_gen = PromptPermutationGenerator('Can you list all the cities in the country ${country} by the cheapest ${domain} prices?')
|
||||
for prompt in prompt_gen({"country":["Canada", "South Africa", "China"],
|
||||
@ -129,7 +148,7 @@ class PromptPermutationGenerator:
|
||||
res.extend(self._gen_perm(p, params_left, paramDict))
|
||||
return res
|
||||
|
||||
def __call__(self, paramDict: Dict[str, Union[str, List[str]]]):
|
||||
def __call__(self, paramDict: Dict[str, Union[str, List[str], Dict[str, str]]]):
|
||||
if len(paramDict) == 0:
|
||||
yield self.template
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user