Chaining prompts together

This commit is contained in:
Ian Arawjo 2023-05-11 10:24:46 -04:00
parent 8de28cdac0
commit 09871cdc1f
6 changed files with 64 additions and 36 deletions

View File

@ -1,9 +1,24 @@
import React, { useState, useEffect } from 'react';
import { Handle } from 'react-flow-renderer';
import { Badge } from '@mantine/core';
import useStore from './store';
import NodeLabel from './NodeLabelComponent'
import {BASE_URL} from './store';
// Helper funcs
const truncStr = (s, maxLen) => {
if (s.length > maxLen) // Cut the name short if it's long
return s.substring(0, maxLen) + '...'
else
return s;
}
const vars_to_str = (vars) => {
const pairs = Object.keys(vars).map(varname => {
const s = truncStr(vars[varname].trim(), 12);
return `${varname} = '${s}'`;
});
return pairs;
};
const bucketResponsesByLLM = (responses) => {
let responses_by_llm = {};
responses.forEach(item => {
@ -58,27 +73,19 @@ const InspectorNode = ({ data, id }) => {
// // Create a dict version
// let tempvars = {};
const vars_to_str = (vars) => {
const pairs = Object.keys(vars).map(varname => {
let s = vars[varname].trim();
if (s.length > 12)
s = s.substring(0, 12) + '...'
return `${varname} = '${s}'`;
});
return pairs.join('; ');
};
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1'];
setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => {
const res_divs = responses_by_llm[llm].map((res_obj, res_idx) => {
const ps = res_obj.responses.map((r, idx) =>
(<pre className="small-response" key={idx}>{r}</pre>)
);
// Object.keys(res_obj.vars).forEach(v => {tempvars[v].add(res_obj.vars[v])});
const vars = vars_to_str(res_obj.vars);
const var_tags = vars.map((v) => (
<Badge key={v} color="blue" size="xs">{v}</Badge>
));
return (
<div key={"r"+res_idx} className="response-box" style={{ backgroundColor: colors[llm_idx % colors.length] }}>
<p className="response-tag">{vars}</p>
{var_tags}
{ps}
</div>
);

View File

@ -36,7 +36,7 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
const run_btn = (<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick} onPointerEnter={handleRunHover}>&#9654;</button>);
if (runButtonTooltip)
setRunButton(
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2}>
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2} zIndex={1001}>
{run_btn}
</Tooltip>
);

View File

@ -54,17 +54,17 @@ const PromptNode = ({ data, id }) => {
const getNode = useStore((state) => state.getNode);
const [templateVars, setTemplateVars] = useState(data.vars || []);
const [promptText, setPromptText] = useState(data.prompt);
const [promptText, setPromptText] = useState(data.prompt || "");
const [promptTextOnLastRun, setPromptTextOnLastRun] = useState(null);
const [status, setStatus] = useState('none');
const [responsePreviews, setReponsePreviews] = useState([]);
const [responsePreviews, setResponsePreviews] = useState([]);
const [numGenerations, setNumGenerations] = useState(data.n || 1);
// For displaying error messages to user
const alertModal = useRef(null);
// Selecting LLM models to prompt
const [llmItems, setLLMItems] = useState(initLLMs.map((i, idx) => ({key: uuid(), ...i})));
const [llmItems, setLLMItems] = useState(data.llms || initLLMs.map((i) => ({key: uuid(), ...i})));
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
const resetLLMItemsProgress = useCallback(() => {
setLLMItems(llmItemsCurrState.map(item => {
@ -101,8 +101,23 @@ const PromptNode = ({ data, id }) => {
const onLLMListItemsChange = useCallback((new_items) => {
setLLMItemsCurrState(new_items);
setDataPropsForNode(id, { llms: new_items });
}, [setLLMItemsCurrState]);
const refreshTemplateHooks = (text) => {
// Update template var fields + handles
const braces_regex = /(?<!\\){(.*?)(?<!\\)}/g; // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
const found_template_vars = text.match(braces_regex);
if (found_template_vars && found_template_vars.length > 0) {
const temp_var_names = found_template_vars.map(
name => name.substring(1, name.length-1) // remove brackets {}
)
setTemplateVars(temp_var_names);
} else {
setTemplateVars([]);
}
};
const handleInputChange = (event) => {
const value = event.target.value;
@ -119,19 +134,14 @@ const PromptNode = ({ data, id }) => {
}
}
// Update template var fields + handles
const braces_regex = /(?<!\\){(.*?)(?<!\\)}/g; // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
const found_template_vars = value.match(braces_regex);
if (found_template_vars && found_template_vars.length > 0) {
const temp_var_names = found_template_vars.map(
name => name.substring(1, name.length-1) // remove brackets {}
)
setTemplateVars(temp_var_names);
} else {
setTemplateVars([]);
}
refreshTemplateHooks(value);
};
// On initialization
useEffect(() => {
refreshTemplateHooks(promptText);
}, []);
// Pull all inputs needed to request responses.
// Returns [prompt, vars dict]
const pullInputData = () => {
@ -144,6 +154,7 @@ const PromptNode = ({ data, id }) => {
if (e.target == nodeId && e.targetHandle == varname) {
// Get the immediate output:
let out = output(e.source, e.sourceHandle);
if (!out) return;
// Save the var data from the pulled output
if (varname in pulled_data)
@ -284,7 +295,7 @@ const PromptNode = ({ data, id }) => {
// Set status indicator
setStatus('loading');
setReponsePreviews([]);
setResponsePreviews([]);
const [py_prompt_template, pulled_data] = pullInputData();
@ -408,12 +419,15 @@ const PromptNode = ({ data, id }) => {
// Save prompt text so we remember what prompt we have responses cache'd for:
setPromptTextOnLastRun(promptText);
// Save response texts as 'fields' of data, for any prompt nodes pulling the outputs
setDataPropsForNode(id, {fields: json.responses.map(r => r['responses']).flat()});
// Save preview strings of responses, for quick glance
// Bucket responses by LLM:
const responses_by_llm = bucketResponsesByLLM(json.responses);
// const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
// const colors = ['green', 'yellow', 'orange', 'red', 'pink', 'grape', 'violet', 'indigo', 'blue', 'gray', 'cyan', 'lime'];
setReponsePreviews(Object.keys(responses_by_llm).map((llm, llm_idx) => {
setResponsePreviews(Object.keys(responses_by_llm).map((llm, llm_idx) => {
const resp_boxes = responses_by_llm[llm].map((res_obj, idx) => {
const num_resp = res_obj['responses'].length;
const resp_prevs = res_obj['responses'].map((r, i) =>
@ -421,7 +435,7 @@ const PromptNode = ({ data, id }) => {
);
const vars = vars_to_str(res_obj.vars);
const var_tags = vars.map((v, i) => (
<Badge key={v} color="green" size="xs">{v}</Badge>
<Badge key={v} color="blue" size="xs">{v}</Badge>
));
return (
<div key={idx} className="response-box">

View File

@ -60,7 +60,10 @@ const useStore = create((set, get) => ({
if (src_node) {
// Get the data related to that handle:
if ("fields" in src_node.data) {
return Object.values(src_node.data["fields"]);
if (Array.isArray(src_node.data["fields"]))
return src_node.data["fields"];
else
return Object.values(src_node.data["fields"]);
}
// NOTE: This assumes it's on the 'data' prop, with the same id as the handle:
else return src_node.data[sourceHandleKey];

View File

@ -163,7 +163,8 @@
}
.inspect-response-container {
overflow-y: auto;
width: 450px;
min-width: 150px;
max-width: 450px;
max-height: 350px;
resize: both;
}
@ -172,9 +173,9 @@
font-size: 8pt;
font-family: monospace;
border-style: dotted;
border-color: #aaa;
border-color: #fff;
padding: 2px;
margin: 0px;
margin: 2px 1px;
background-color: rgba(255, 255, 255, 0.4);
white-space: pre-wrap;
}
@ -199,6 +200,7 @@
.response-box {
padding: 2px;
margin: 0px 2px 4px 2px;
border-radius: 5px;
}
.response-tag {
font-size: 9pt;

View File

@ -70,7 +70,9 @@ class PromptPipeline:
"responses": extracted_resps[:n],
"raw_response": cached_resp["raw_response"],
"llm": cached_resp["llm"] if "llm" in cached_resp else LLM.ChatGPT.value,
"info": cached_resp["info"],
# We want to use the new info, since 'vars' could have changed even though
# the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
"info": prompt.fill_history,
}
continue