From c3994329e82df097efe08b54ed22714594cbb8a7 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Thu, 11 May 2023 14:54:36 -0400 Subject: [PATCH] Carry prompt vars across prompt node chains --- chain-forge/src/PromptNode.js | 49 ++++++++++++++++++++----- chain-forge/src/VisNode.js | 19 ++++++++-- python-backend/promptengine/query.py | 4 +- python-backend/promptengine/template.py | 25 +++++++++++-- 4 files changed, 79 insertions(+), 18 deletions(-) diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 134d987..0b16a5c 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -147,6 +147,12 @@ const PromptNode = ({ data, id }) => { const pullInputData = () => { // Pull data from each source recursively: const pulled_data = {}; + const store_data = (_texts, _varname, _data) => { + if (_varname in _data) + _data[_varname] = _data[_varname].concat(_texts); + else + _data[_varname] = _texts; + }; const get_outputs = (varnames, nodeId) => { varnames.forEach(varname => { // Find the relevant edge(s): @@ -154,14 +160,24 @@ const PromptNode = ({ data, id }) => { if (e.target == nodeId && e.targetHandle == varname) { // Get the immediate output: let out = output(e.source, e.sourceHandle); - if (!out) return; - - // Save the var data from the pulled output - if (varname in pulled_data) - pulled_data[varname] = pulled_data[varname].concat(out); - else - pulled_data[varname] = out; + if (!out || !Array.isArray(out) || out.length === 0) return; + // Check the format of the output. Can be str or dict with 'text' and 'vars' attrs: + if (typeof out[0] === 'object') { + out.forEach(obj => store_data([obj], varname, pulled_data)); + // out.forEach((obj) => { + // store_data([obj.text], varname, pulled_data); + // // We need to carry through each individual var as well: + // Object.keys(obj.vars).forEach(_v => + // store_data([obj.vars[_v]], _v, pulled_data) + // ); + // }); + } + else { + // Save the list of strings from the pulled output under the var 'varname' + store_data(out, varname, pulled_data); + } + // Get any vars that the output depends on, and recursively collect those outputs as well: const n_vars = getNode(e.source).data.vars; if (n_vars && Array.isArray(n_vars) && n_vars.length > 0) @@ -172,8 +188,20 @@ const PromptNode = ({ data, id }) => { }; get_outputs(templateVars, id); + console.log(pulled_data); + // Get Pythonic version of the prompt, by adding a $ before any template variables in braces: - const to_py_template_format = (str) => str.replace(/(? str.replace(/(? { + if (typeof str_or_obj === 'object') { + let new_vars = {}; + Object.keys(str_or_obj.fill_history).forEach(v => { + new_vars[v] = str_to_py_template_format(str_or_obj.fill_history[v]); + }); + return {text: str_to_py_template_format(str_or_obj.text), fill_history: new_vars}; + } else + return str_to_py_template_format(str_or_obj); + }; const py_prompt_template = to_py_template_format(promptText); // Do the same for the vars, since vars can themselves be prompt templates: @@ -420,7 +448,10 @@ const PromptNode = ({ data, id }) => { setPromptTextOnLastRun(promptText); // Save response texts as 'fields' of data, for any prompt nodes pulling the outputs - setDataPropsForNode(id, {fields: json.responses.map(r => r['responses']).flat()}); + setDataPropsForNode(id, {fields: json.responses.map( + resp_obj => resp_obj['responses'].map( + r => ({text: r, fill_history: resp_obj['vars']}))).flat() + }); // Save preview strings of responses, for quick glance // Bucket responses by LLM: diff --git a/chain-forge/src/VisNode.js b/chain-forge/src/VisNode.js index a4165c2..4dca2b0 100644 --- a/chain-forge/src/VisNode.js +++ b/chain-forge/src/VisNode.js @@ -1,9 +1,8 @@ import React, { useState, useEffect, useCallback } from 'react'; import { Handle } from 'react-flow-renderer'; +import { MultiSelect } from '@mantine/core'; import useStore from './store'; import Plot from 'react-plotly.js'; -import { hover } from '@testing-library/user-event/dist/hover'; -import { create } from 'zustand'; import NodeLabel from './NodeLabelComponent'; import {BASE_URL} from './store'; @@ -42,6 +41,9 @@ const VisNode = ({ data, id }) => { const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const [plotlyObj, setPlotlyObj] = useState([]); const [pastInputs, setPastInputs] = useState([]); + + // The MultiSelect so people can dynamically set what vars they care about + const [multiSelectVars, setMultiSelectVars] = useState(data.vars || []); const handleOnConnect = useCallback(() => { // Grab the input node ids @@ -78,6 +80,10 @@ const VisNode = ({ data, id }) => { } } + setMultiSelectVars( + varnames.map(name => ({value: name, label: name})) + ); + const plot_grouped_boxplot = (resp_to_x) => { llm_names.forEach((llm, idx) => { // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks. @@ -192,8 +198,13 @@ const VisNode = ({ data, id }) => { return (
+ nodeId={id} + icon={'📊'} /> +
{plotlyObj}
this is a {x} where x='tool'") - "info": prompt.fill_history, + "info": info, } continue @@ -86,7 +87,6 @@ class PromptPipeline: else: # Blocking. Await + yield a single LLM call. _, query, response, past_resp_obj = await self._prompt_llm(llm, prompt, n, temperature, past_resp_obj=cached_resp) - info = prompt.fill_history # Create a response obj to represent the response resp_obj = { diff --git a/python-backend/promptengine/template.py b/python-backend/promptengine/template.py index 100d896..27a5025 100644 --- a/python-backend/promptengine/template.py +++ b/python-backend/promptengine/template.py @@ -52,11 +52,14 @@ class PromptTemplate: except KeyError as e: return False - def fill(self, paramDict: Dict[str, str]) -> 'PromptTemplate': + def fill(self, paramDict: Dict[str, Union[str, Dict[str, str]]]) -> 'PromptTemplate': """ Formats the template string with the given parameters, returning a new PromptTemplate. Can return a partial completion. + NOTE: paramDict values can be in a special form: {text: , fill_history: {varname: }} + in order to bundle in any past fill history that is lost in the current text. + Example usage: prompt = prompt_template.fill({ "className": className, @@ -64,13 +67,26 @@ class PromptTemplate: "PL": "Python" }); """ + # Check for special 'past fill history' format: + past_fill_history = {} + if len(paramDict) > 0 and isinstance(next(iter(paramDict.values())), dict): + for obj in paramDict.values(): + past_fill_history = {**obj['fill_history'], **past_fill_history} + paramDict = {param: obj['text'] for param, obj in paramDict.items()} + filled_pt = PromptTemplate( Template(self.template).safe_substitute(paramDict) ) - # Deep copy prior fill history from this version over to new one + # Deep copy prior fill history of this PromptTemplate from this version over to new one filled_pt.fill_history = { key: val for (key, val) in self.fill_history.items() } + # Append any past history passed as vars: + for key, val in past_fill_history.items(): + if key in filled_pt.fill_history: + print(f"Warning: PromptTemplate already has fill history for key {key}.") + filled_pt.fill_history[key] = val + # Add the new fill history using the passed parameters that we just filled in for key, val in paramDict.items(): if key in filled_pt.fill_history: @@ -85,6 +101,9 @@ class PromptPermutationGenerator: Given a PromptTemplate and a parameter dict that includes arrays of items, generate all the permutations of the prompt for all permutations of the items. + NOTE: Items can be in a special form: {text: , fill_history: {varname: }} + in order to bundle in any past fill history that is lost in the current text. + Example usage: prompt_gen = PromptPermutationGenerator('Can you list all the cities in the country ${country} by the cheapest ${domain} prices?') for prompt in prompt_gen({"country":["Canada", "South Africa", "China"], @@ -129,7 +148,7 @@ class PromptPermutationGenerator: res.extend(self._gen_perm(p, params_left, paramDict)) return res - def __call__(self, paramDict: Dict[str, Union[str, List[str]]]): + def __call__(self, paramDict: Dict[str, Union[str, List[str], Dict[str, str]]]): if len(paramDict) == 0: yield self.template return