mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Chaining prompts together
This commit is contained in:
parent
8de28cdac0
commit
09871cdc1f
@ -1,9 +1,24 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import { Badge } from '@mantine/core';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import {BASE_URL} from './store';
|
||||
|
||||
// Helper funcs
|
||||
const truncStr = (s, maxLen) => {
|
||||
if (s.length > maxLen) // Cut the name short if it's long
|
||||
return s.substring(0, maxLen) + '...'
|
||||
else
|
||||
return s;
|
||||
}
|
||||
const vars_to_str = (vars) => {
|
||||
const pairs = Object.keys(vars).map(varname => {
|
||||
const s = truncStr(vars[varname].trim(), 12);
|
||||
return `${varname} = '${s}'`;
|
||||
});
|
||||
return pairs;
|
||||
};
|
||||
const bucketResponsesByLLM = (responses) => {
|
||||
let responses_by_llm = {};
|
||||
responses.forEach(item => {
|
||||
@ -58,27 +73,19 @@ const InspectorNode = ({ data, id }) => {
|
||||
// // Create a dict version
|
||||
// let tempvars = {};
|
||||
|
||||
const vars_to_str = (vars) => {
|
||||
const pairs = Object.keys(vars).map(varname => {
|
||||
let s = vars[varname].trim();
|
||||
if (s.length > 12)
|
||||
s = s.substring(0, 12) + '...'
|
||||
return `${varname} = '${s}'`;
|
||||
});
|
||||
return pairs.join('; ');
|
||||
};
|
||||
|
||||
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
|
||||
const colors = ['#ace1aeb1', '#f1b963b1', '#e46161b1', '#f8f398b1', '#defcf9b1', '#cadefcb1', '#c3bef0b1', '#cca8e9b1'];
|
||||
setResponses(Object.keys(responses_by_llm).map((llm, llm_idx) => {
|
||||
const res_divs = responses_by_llm[llm].map((res_obj, res_idx) => {
|
||||
const ps = res_obj.responses.map((r, idx) =>
|
||||
(<pre className="small-response" key={idx}>{r}</pre>)
|
||||
);
|
||||
// Object.keys(res_obj.vars).forEach(v => {tempvars[v].add(res_obj.vars[v])});
|
||||
const vars = vars_to_str(res_obj.vars);
|
||||
const var_tags = vars.map((v) => (
|
||||
<Badge key={v} color="blue" size="xs">{v}</Badge>
|
||||
));
|
||||
return (
|
||||
<div key={"r"+res_idx} className="response-box" style={{ backgroundColor: colors[llm_idx % colors.length] }}>
|
||||
<p className="response-tag">{vars}</p>
|
||||
{var_tags}
|
||||
{ps}
|
||||
</div>
|
||||
);
|
||||
|
@ -36,7 +36,7 @@ export default function NodeLabel({ title, nodeId, icon, onEdit, onSave, editabl
|
||||
const run_btn = (<button className="AmitSahoo45-button-3 nodrag" onClick={handleRunClick} onPointerEnter={handleRunHover}>▶</button>);
|
||||
if (runButtonTooltip)
|
||||
setRunButton(
|
||||
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2}>
|
||||
<Tooltip label={runButtonTooltip} withArrow arrowSize={6} arrowRadius={2} zIndex={1001}>
|
||||
{run_btn}
|
||||
</Tooltip>
|
||||
);
|
||||
|
@ -54,17 +54,17 @@ const PromptNode = ({ data, id }) => {
|
||||
const getNode = useStore((state) => state.getNode);
|
||||
|
||||
const [templateVars, setTemplateVars] = useState(data.vars || []);
|
||||
const [promptText, setPromptText] = useState(data.prompt);
|
||||
const [promptText, setPromptText] = useState(data.prompt || "");
|
||||
const [promptTextOnLastRun, setPromptTextOnLastRun] = useState(null);
|
||||
const [status, setStatus] = useState('none');
|
||||
const [responsePreviews, setReponsePreviews] = useState([]);
|
||||
const [responsePreviews, setResponsePreviews] = useState([]);
|
||||
const [numGenerations, setNumGenerations] = useState(data.n || 1);
|
||||
|
||||
// For displaying error messages to user
|
||||
const alertModal = useRef(null);
|
||||
|
||||
// Selecting LLM models to prompt
|
||||
const [llmItems, setLLMItems] = useState(initLLMs.map((i, idx) => ({key: uuid(), ...i})));
|
||||
const [llmItems, setLLMItems] = useState(data.llms || initLLMs.map((i) => ({key: uuid(), ...i})));
|
||||
const [llmItemsCurrState, setLLMItemsCurrState] = useState([]);
|
||||
const resetLLMItemsProgress = useCallback(() => {
|
||||
setLLMItems(llmItemsCurrState.map(item => {
|
||||
@ -101,8 +101,23 @@ const PromptNode = ({ data, id }) => {
|
||||
|
||||
const onLLMListItemsChange = useCallback((new_items) => {
|
||||
setLLMItemsCurrState(new_items);
|
||||
setDataPropsForNode(id, { llms: new_items });
|
||||
}, [setLLMItemsCurrState]);
|
||||
|
||||
const refreshTemplateHooks = (text) => {
|
||||
// Update template var fields + handles
|
||||
const braces_regex = /(?<!\\){(.*?)(?<!\\)}/g; // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
|
||||
const found_template_vars = text.match(braces_regex);
|
||||
if (found_template_vars && found_template_vars.length > 0) {
|
||||
const temp_var_names = found_template_vars.map(
|
||||
name => name.substring(1, name.length-1) // remove brackets {}
|
||||
)
|
||||
setTemplateVars(temp_var_names);
|
||||
} else {
|
||||
setTemplateVars([]);
|
||||
}
|
||||
};
|
||||
|
||||
const handleInputChange = (event) => {
|
||||
const value = event.target.value;
|
||||
|
||||
@ -119,19 +134,14 @@ const PromptNode = ({ data, id }) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Update template var fields + handles
|
||||
const braces_regex = /(?<!\\){(.*?)(?<!\\)}/g; // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this}
|
||||
const found_template_vars = value.match(braces_regex);
|
||||
if (found_template_vars && found_template_vars.length > 0) {
|
||||
const temp_var_names = found_template_vars.map(
|
||||
name => name.substring(1, name.length-1) // remove brackets {}
|
||||
)
|
||||
setTemplateVars(temp_var_names);
|
||||
} else {
|
||||
setTemplateVars([]);
|
||||
}
|
||||
refreshTemplateHooks(value);
|
||||
};
|
||||
|
||||
// On initialization
|
||||
useEffect(() => {
|
||||
refreshTemplateHooks(promptText);
|
||||
}, []);
|
||||
|
||||
// Pull all inputs needed to request responses.
|
||||
// Returns [prompt, vars dict]
|
||||
const pullInputData = () => {
|
||||
@ -144,6 +154,7 @@ const PromptNode = ({ data, id }) => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
// Get the immediate output:
|
||||
let out = output(e.source, e.sourceHandle);
|
||||
if (!out) return;
|
||||
|
||||
// Save the var data from the pulled output
|
||||
if (varname in pulled_data)
|
||||
@ -284,7 +295,7 @@ const PromptNode = ({ data, id }) => {
|
||||
|
||||
// Set status indicator
|
||||
setStatus('loading');
|
||||
setReponsePreviews([]);
|
||||
setResponsePreviews([]);
|
||||
|
||||
const [py_prompt_template, pulled_data] = pullInputData();
|
||||
|
||||
@ -408,12 +419,15 @@ const PromptNode = ({ data, id }) => {
|
||||
// Save prompt text so we remember what prompt we have responses cache'd for:
|
||||
setPromptTextOnLastRun(promptText);
|
||||
|
||||
// Save response texts as 'fields' of data, for any prompt nodes pulling the outputs
|
||||
setDataPropsForNode(id, {fields: json.responses.map(r => r['responses']).flat()});
|
||||
|
||||
// Save preview strings of responses, for quick glance
|
||||
// Bucket responses by LLM:
|
||||
const responses_by_llm = bucketResponsesByLLM(json.responses);
|
||||
// const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
|
||||
// const colors = ['green', 'yellow', 'orange', 'red', 'pink', 'grape', 'violet', 'indigo', 'blue', 'gray', 'cyan', 'lime'];
|
||||
setReponsePreviews(Object.keys(responses_by_llm).map((llm, llm_idx) => {
|
||||
setResponsePreviews(Object.keys(responses_by_llm).map((llm, llm_idx) => {
|
||||
const resp_boxes = responses_by_llm[llm].map((res_obj, idx) => {
|
||||
const num_resp = res_obj['responses'].length;
|
||||
const resp_prevs = res_obj['responses'].map((r, i) =>
|
||||
@ -421,7 +435,7 @@ const PromptNode = ({ data, id }) => {
|
||||
);
|
||||
const vars = vars_to_str(res_obj.vars);
|
||||
const var_tags = vars.map((v, i) => (
|
||||
<Badge key={v} color="green" size="xs">{v}</Badge>
|
||||
<Badge key={v} color="blue" size="xs">{v}</Badge>
|
||||
));
|
||||
return (
|
||||
<div key={idx} className="response-box">
|
||||
|
@ -60,7 +60,10 @@ const useStore = create((set, get) => ({
|
||||
if (src_node) {
|
||||
// Get the data related to that handle:
|
||||
if ("fields" in src_node.data) {
|
||||
return Object.values(src_node.data["fields"]);
|
||||
if (Array.isArray(src_node.data["fields"]))
|
||||
return src_node.data["fields"];
|
||||
else
|
||||
return Object.values(src_node.data["fields"]);
|
||||
}
|
||||
// NOTE: This assumes it's on the 'data' prop, with the same id as the handle:
|
||||
else return src_node.data[sourceHandleKey];
|
||||
|
@ -163,7 +163,8 @@
|
||||
}
|
||||
.inspect-response-container {
|
||||
overflow-y: auto;
|
||||
width: 450px;
|
||||
min-width: 150px;
|
||||
max-width: 450px;
|
||||
max-height: 350px;
|
||||
resize: both;
|
||||
}
|
||||
@ -172,9 +173,9 @@
|
||||
font-size: 8pt;
|
||||
font-family: monospace;
|
||||
border-style: dotted;
|
||||
border-color: #aaa;
|
||||
border-color: #fff;
|
||||
padding: 2px;
|
||||
margin: 0px;
|
||||
margin: 2px 1px;
|
||||
background-color: rgba(255, 255, 255, 0.4);
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
@ -199,6 +200,7 @@
|
||||
.response-box {
|
||||
padding: 2px;
|
||||
margin: 0px 2px 4px 2px;
|
||||
border-radius: 5px;
|
||||
}
|
||||
.response-tag {
|
||||
font-size: 9pt;
|
||||
|
@ -70,7 +70,9 @@ class PromptPipeline:
|
||||
"responses": extracted_resps[:n],
|
||||
"raw_response": cached_resp["raw_response"],
|
||||
"llm": cached_resp["llm"] if "llm" in cached_resp else LLM.ChatGPT.value,
|
||||
"info": cached_resp["info"],
|
||||
# We want to use the new info, since 'vars' could have changed even though
|
||||
# the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
|
||||
"info": prompt.fill_history,
|
||||
}
|
||||
continue
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user