'Dirty' downstream nodes whenever upstream changes are made. Minor styling improvements. (#105)

* Invalidate eval node upon upstream changes.

* Chain update pings across nodes. Autoresize textfields when typing.

* Wide output handles when entire node is output

* update package version
This commit is contained in:
ianarawjo 2023-07-21 12:39:08 -04:00 committed by GitHub
parent de48255a8b
commit 666d5900b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 143 additions and 83 deletions

View File

@ -1,15 +1,15 @@
{
"files": {
"main.css": "/static/css/main.d97bf957.css",
"main.js": "/static/js/main.690672fa.js",
"main.css": "/static/css/main.a8d99f88.css",
"main.js": "/static/js/main.d97188e6.js",
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
"index.html": "/index.html",
"main.d97bf957.css.map": "/static/css/main.d97bf957.css.map",
"main.690672fa.js.map": "/static/js/main.690672fa.js.map",
"main.a8d99f88.css.map": "/static/css/main.a8d99f88.css.map",
"main.d97188e6.js.map": "/static/js/main.d97188e6.js.map",
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
},
"entrypoints": [
"static/css/main.d97bf957.css",
"static/js/main.690672fa.js"
"static/css/main.a8d99f88.css",
"static/js/main.d97188e6.js"
]
}

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.690672fa.js"></script><link href="/static/css/main.d97bf957.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.d97188e6.js"></script><link href="/static/css/main.a8d99f88.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,5 @@
import React, { useState, useRef, useEffect, useCallback } from 'react';
import { Badge, Text } from '@mantine/core';
import React, { useState, useEffect, useCallback } from 'react';
import { Text } from '@mantine/core';
import useStore from './store';
import NodeLabel from './NodeLabelComponent'
import { IconCsv } from '@tabler/icons-react';
@ -7,6 +7,7 @@ import { Handle } from 'react-flow-renderer';
const CsvNode = ({ data, id }) => {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
const [contentDiv, setContentDiv] = useState(null);
const [isEditing, setIsEditing] = useState(true);
const [csvInput, setCsvInput] = useState(null);
@ -29,21 +30,13 @@ const CsvNode = ({ data, id }) => {
return matches.map(e => e.trim()).filter(e => e.length > 0);
}
// const processCsv = (csv) => {
// if (!csv) return [];
// // Split the input string by rows, and merge
// var res = csv.split('\n').join(',');
// // remove all the empty or whitespace-only elements
// return res.split(',').map(e => e.trim()).filter(e => e.length > 0);
// }
// Handle a change in a text fields' input.
const handleInputChange = useCallback((event) => {
// Update the data for this text fields' id.
let new_data = { 'text': event.target.value, 'fields': processCsv(event.target.value) };
setDataPropsForNode(id, new_data);
}, [id]);
pingOutputNodes(id);
}, [id, pingOutputNodes, setDataPropsForNode]);
const handKeyDown = useCallback((event) => {
if (event.key === 'Enter' && data.text && data.text.trim().length > 0) {
@ -124,7 +117,8 @@ const CsvNode = ({ data, id }) => {
type="source"
position="right"
id="output"
style={{ top: "50%", background: '#555' }}
className="grouped-handle"
style={{ top: "50%" }}
/>
</div>
);

View File

@ -118,7 +118,7 @@ const EditableTable = ({ rows, columns, handleSaveCell, handleInsertColumn, hand
suppressContentEditableWarning={true}>{c.header}</p>
<Menu closeOnClickOutside styles={{dropdown: {boxShadow: '1px 1px 4px #ccc'}}}>
<Menu.Target>
<IconDots size='12pt' style={{padding: '0px', marginTop: '8pt', marginLeft: '2pt'}} className='table-col-edit-btn' />
<IconDots size='12pt' style={{padding: '0px', marginTop: '3pt', marginLeft: '2pt'}} className='table-col-edit-btn' />
</Menu.Target>
<Menu.Dropdown>
<Menu.Item key='rename_col' onClick={() => handleRenameColumn(c)}><IconPencil size='10pt' />&nbsp;Rename column</Menu.Item>

View File

@ -79,8 +79,7 @@ function evaluate(response) {
const EvaluatorNode = ({ data, id }) => {
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const outputEdgesForNode = useStore((state) => state.outputEdgesForNode);
const getNode = useStore((state) => state.getNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const [status, setStatus] = useState('none');
const nodes = useStore((state) => state.nodes);
@ -127,13 +126,21 @@ const EvaluatorNode = ({ data, id }) => {
responses: [id],
}).then(function(json) {
if (json.responses && json.responses.length > 0) {
// Store responses and set status to green checkmark
setLastResponses(json.responses);
setStatus('ready');
// Store responses and set status to green checkmark
setLastResponses(json.responses);
setStatus('ready');
}
});
}, []);
// On upstream changes
useEffect(() => {
if (data.refresh && data.refresh === true) {
setDataPropsForNode(id, { refresh: false });
setStatus('warning');
}
}, [data]);
const handleCodeChange = (code) => {
if (codeTextOnLastRun !== false) {
const code_changed = code !== codeTextOnLastRun;
@ -216,13 +223,7 @@ const EvaluatorNode = ({ data, id }) => {
}
// Ping any vis + inspect nodes attached to this node to refresh their contents:
const output_nodes = outputEdgesForNode(id).map(e => e.target);
output_nodes.forEach(n => {
const node = getNode(n);
if (node && (node.type === 'vis' || node.type === 'inspect')) {
setDataPropsForNode(node.id, { refresh: true });
}
});
pingOutputNodes(id);
console.log(json.responses);
setLastResponses(json.responses);
@ -291,13 +292,15 @@ const EvaluatorNode = ({ data, id }) => {
type="target"
position="left"
id="responseBatch"
style={{ top: '50%', background: '#555' }}
className="grouped-handle"
style={{ top: '50%' }}
/>
<Handle
type="source"
position="right"
id="output"
style={{ top: '50%', background: '#555' }}
className="grouped-handle"
style={{ top: '50%' }}
/>
<div className="core-mirror-field">
<div className="code-mirror-field-header">Define an <Code>evaluate</Code> func to map over each response:

View File

@ -69,7 +69,8 @@ const InspectorNode = ({ data, id }) => {
type="target"
position="left"
id="input"
style={{ top: "50%", background: '#555' }}
className="grouped-handle"
style={{ top: "50%" }}
onConnect={handleOnConnect}
/>
</div>

View File

@ -79,8 +79,6 @@ export default function LLMList({llms, onItemsChange}) {
}
));
// Replace the item in the list and re-save:
}, [items, updateItems]);
const onDragEnd = (result) => {

View File

@ -57,7 +57,7 @@ const TemperatureStatus = styled.span`
`;
export const DragItem = styled.div`
padding: 8px;
padding: 6px;
border-radius: 6px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.12), 0 1px 2px rgba(0, 0, 0, 0.24);
background: white;

View File

@ -93,7 +93,7 @@ const PromptNode = ({ data, id }) => {
const edges = useStore((state) => state.edges);
const output = useStore((state) => state.output);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const outputEdgesForNode = useStore((state) => state.outputEdgesForNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
const getNode = useStore((state) => state.getNode);
// API Keys (set by user in popup GlobalSettingsModal)
@ -105,6 +105,7 @@ const PromptNode = ({ data, id }) => {
const [promptTextOnLastRun, setPromptTextOnLastRun] = useState(null);
const [status, setStatus] = useState('none');
const [numGenerations, setNumGenerations] = useState(data.n || 1);
const [numGenerationsLastRun, setNumGenerationsLastRun] = useState(data.n || 1);
// For displaying error messages to user
const alertModal = useRef(null);
@ -163,6 +164,12 @@ const PromptNode = ({ data, id }) => {
inspectModal.current.trigger();
}, [inspectModal, jsonResponses]);
// Signal that prompt node state is dirty; user should re-run:
const signalDirty = useCallback(() => {
if (promptTextOnLastRun !== null && status === 'ready')
setStatus('warning');
}, [promptTextOnLastRun, status])
const addModel = useCallback((model) => {
// Get the item for that model
let item = AvailableLLMs.find(llm => llm.base_model === model);
@ -185,12 +192,23 @@ const PromptNode = ({ data, id }) => {
// Add model to LLM list (regardless of it's present already or not).
setLLMItems(llmItemsCurrState.concat([item]))
}, [llmItemsCurrState]);
signalDirty();
}, [llmItemsCurrState, signalDirty]);
const onLLMListItemsChange = useCallback((new_items) => {
setLLMItemsCurrState(new_items);
setDataPropsForNode(id, { llms: new_items });
}, [setLLMItemsCurrState]);
// If there's been any change to the item list, signal dirty:
if (new_items.length !== llmItemsCurrState.length || !new_items.every(i => llmItemsCurrState.some(s => s.key === i.key))) {
signalDirty();
} else if (!new_items.every(itemA => {
const itemB = llmItemsCurrState.find(b => b.key === itemA.key);
return JSON.stringify(itemA.settings) === JSON.stringify(itemB.settings);
})) {
signalDirty();
}
}, [setLLMItemsCurrState, signalDirty]);
const refreshTemplateHooks = (text) => {
// Update template var fields + handles
@ -207,12 +225,8 @@ const PromptNode = ({ data, id }) => {
data['prompt'] = value;
// Update status icon, if need be:
if (promptTextOnLastRun !== null) {
if (status !== 'warning' && value !== promptTextOnLastRun) {
setStatus('warning');
} else if (status === 'warning' && value === promptTextOnLastRun) {
setStatus('ready');
}
if (promptTextOnLastRun !== null && status !== 'warning' && value !== promptTextOnLastRun) {
setStatus('warning');
}
refreshTemplateHooks(value);
@ -234,6 +248,14 @@ const PromptNode = ({ data, id }) => {
});
}, []);
// On upstream changes
useEffect(() => {
if (data.refresh && data.refresh === true) {
setDataPropsForNode(id, { refresh: false });
setStatus('warning');
}
}, [data]);
// Pull all inputs needed to request responses.
// Returns [prompt, vars dict]
const pullInputData = () => {
@ -522,6 +544,7 @@ const PromptNode = ({ data, id }) => {
// Save prompt text so we remember what prompt we have responses cache'd for:
setPromptTextOnLastRun(promptText);
setNumGenerationsLastRun(numGenerations);
// Save response texts as 'fields' of data, for any prompt nodes pulling the outputs
// First we need to get a unique key for a unique metavar for the LLM set that produced these responses,
@ -544,13 +567,7 @@ const PromptNode = ({ data, id }) => {
});
// Ping any inspect nodes attached to this node to refresh their contents:
const output_nodes = outputEdgesForNode(id).map(e => e.target);
output_nodes.forEach(n => {
const node = getNode(n);
if (node && node.type === 'inspect') {
setDataPropsForNode(node.id, { refresh: true });
}
});
pingOutputNodes(id);
} else {
setStatus('error');
triggerAlert(json.error || 'Unknown error when querying LLM');
@ -566,15 +583,17 @@ const PromptNode = ({ data, id }) => {
.catch(rejected);
};
const handleNumGenChange = (event) => {
const handleNumGenChange = useCallback((event) => {
let n = event.target.value;
if (!isNaN(n) && n.length > 0 && /^\d+$/.test(n)) {
// n is an integer; save it
n = parseInt(n);
if (n !== numGenerationsLastRun && status === 'ready')
setStatus('warning');
setNumGenerations(n);
setDataPropsForNode(id, {n: n});
}
};
}, [numGenerationsLastRun, setDataPropsForNode, status]);
const hideStatusIndicator = () => {
if (status !== 'none') { setStatus('none'); }
@ -624,15 +643,18 @@ const PromptNode = ({ data, id }) => {
</Box>
</Modal>
<Textarea ref={setRef}
autosize
className="prompt-field-fixed nodrag nowheel"
minRows="4"
maxRows="12"
defaultValue={data.prompt}
onChange={handleInputChange} />
<Handle
type="source"
position="right"
id="prompt"
style={{ top: '50%', background: '#555' }}
className="grouped-handle"
style={{ top: '50%' }}
/>
<TemplateHooks vars={templateVars} nodeId={id} startY={hooksY} />
<hr />

View File

@ -2,7 +2,7 @@ export default function StatusIndicator({ status }) {
switch (status) {
case 'warning': // Display mustard 'warning' icon
return (
<div className="status-icon warning-status">&#9888;<span className='status-tooltip'>Contents changed. Downstream results might be invalidated. Press Play to rerun.</span></div>
<div className="status-icon warning-status">&#9888;<span className='status-tooltip'>Something changed. Downstream results might be invalidated. Press Play to rerun.</span></div>
);
case 'ready': // Display green checkmark 'ready' icon
return (

View File

@ -42,6 +42,7 @@ const TabularDataNode = ({ data, id }) => {
return {...col};
}));
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
const [contextMenuPos, setContextMenuPos] = useState({left: -100, top:0});
const [contextMenuOpened, setContextMenuOpened] = useState(false);
@ -61,6 +62,7 @@ const TabularDataNode = ({ data, id }) => {
const [renameColumnInitialVal, setRenameColumnInitialVal] = useState("");
const handleSaveCell = useCallback((rowIdx, columnKey, value) => {
pingOutputNodes(id);
if (rowIdx === -1) {
// Saving the column header
setTableColumns(tableColumns.map(col => {
@ -73,7 +75,7 @@ const TabularDataNode = ({ data, id }) => {
console.log('handleSaveCell', rowIdx, columnKey, value);
tableData[rowIdx][columnKey] = value;
setTableData([...tableData]);
}, [tableData, tableColumns]);
}, [tableData, tableColumns, pingOutputNodes]);
// Adds a new row to the table
const handleAddRow = useCallback(() => {
@ -138,7 +140,8 @@ const TabularDataNode = ({ data, id }) => {
setTableColumns([...tableColumns]);
setTableData([...tableData]);
}, [tableColumns, tableData]);
pingOutputNodes(id);
}, [tableColumns, tableData, pingOutputNodes]);
// Opens a modal popup to let user rename a column
const openRenameColumnModal = useCallback((col) => {
@ -146,6 +149,7 @@ const TabularDataNode = ({ data, id }) => {
if (renameColumnModal && renameColumnModal.current)
renameColumnModal.current.trigger();
}, [renameColumnModal]);
const handleRenameColumn = useCallback((new_header) => {
if (typeof renameColumnInitialVal !== 'object') {
console.error('Initial column value was not set.');
@ -157,7 +161,8 @@ const TabularDataNode = ({ data, id }) => {
return c;
});
setTableColumns([...new_cols]);
}, [tableColumns, renameColumnInitialVal]);
pingOutputNodes(id);
}, [tableColumns, renameColumnInitialVal, pingOutputNodes]);
// Removes a row of the table, at <table> index 'selectedRow'
const handleRemoveRow = useCallback(() => {
@ -257,6 +262,7 @@ const TabularDataNode = ({ data, id }) => {
// Save the new columns and rows
setTableColumns(cols);
setTableData(rows);
pingOutputNodes(id);
};
// Import tabular data from a file

View File

@ -33,6 +33,7 @@ const TextFieldsNode = ({ data, id }) => {
const [templateVars, setTemplateVars] = useState(data.vars || []);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
const [textfieldsValues, setTextfieldsValues] = useState(data.fields || {});
const [fieldVisibility, setFieldVisibility] = useState(data.fields_visibility || {});
@ -62,7 +63,8 @@ const TextFieldsNode = ({ data, id }) => {
setTextfieldsValues(new_fields);
setFieldVisibility(new_vis);
setDataPropsForNode(id, {fields: new_fields, fields_visibility: new_vis});
}, [textfieldsValues, fieldVisibility, id, delButtonId, setDataPropsForNode]);
pingOutputNodes(id);
}, [textfieldsValues, fieldVisibility, id, delButtonId, setDataPropsForNode, pingOutputNodes]);
// Initialize fields (run once at init)
useEffect(() => {
@ -80,7 +82,8 @@ const TextFieldsNode = ({ data, id }) => {
new_fields[getUID()] = "";
setTextfieldsValues(new_fields);
setDataPropsForNode(id, { fields: new_fields });
}, [textfieldsValues, id, setDataPropsForNode]);
pingOutputNodes(id);
}, [textfieldsValues, id, setDataPropsForNode, pingOutputNodes]);
// Disable/hide a text field temporarily
const handleDisableField = useCallback((field_id) => {
@ -88,7 +91,8 @@ const TextFieldsNode = ({ data, id }) => {
vis[field_id] = fieldVisibility[field_id] === false; // toggles it
setFieldVisibility(vis);
setDataPropsForNode(id, { fields_visibility: vis });
}, [fieldVisibility, setDataPropsForNode]);
pingOutputNodes(id);
}, [fieldVisibility, setDataPropsForNode, pingOutputNodes]);
// Save the state of a textfield when it changes and update hooks
const handleTextFieldChange = useCallback((field_id, val) => {
@ -114,13 +118,13 @@ const TextFieldsNode = ({ data, id }) => {
// Update template var fields + handles, if there's a change in sets
const past_vars = new Set(templateVars);
if (!setsAreEqual(all_found_vars, past_vars)) {
console.log('set vars');
const new_vars_arr = Array.from(all_found_vars);
new_data.vars = new_vars_arr;
setTemplateVars(new_vars_arr);
}
setDataPropsForNode(id, new_data);
pingOutputNodes(id);
}, [textfieldsValues, templateVars, id]);
@ -152,6 +156,13 @@ const TextFieldsNode = ({ data, id }) => {
}
}, [ref]);
// Pass upstream changes down to later nodes in the chain
useEffect(() => {
if (data.refresh && data.refresh === true) {
pingOutputNodes(id);
}
}, [data, id, pingOutputNodes]);
return (
<div className="text-fields-node cfnode">
<NodeLabel title={data.title || 'TextFields Node'} nodeId={id} icon={<IconTextPlus size="16px" />} />
@ -160,7 +171,9 @@ const TextFieldsNode = ({ data, id }) => {
<div className="input-field" key={i}>
<Textarea id={i} name={i}
className="text-field-fixed nodrag nowheel"
autosize
minRows="2"
maxRows="8"
value={textfieldsValues[i]}
disabled={fieldVisibility[i] === false}
onChange={(event) => handleTextFieldChange(i, event.currentTarget.value)} />
@ -185,7 +198,8 @@ const TextFieldsNode = ({ data, id }) => {
type="source"
position="right"
id="output"
style={{ top: "50%", background: '#555' }}
className="grouped-handle"
style={{ top: "50%" }}
/>
<TemplateHooks vars={templateVars} nodeId={id} startY={hooksY} />
<div className="add-text-field-btn">

View File

@ -692,13 +692,7 @@ const VisNode = ({ data, id }) => {
// From here a React effect will detect the changes to these values and display a new plot
}
});
// Analyze its structure --how many 'vars'?
// Based on its structure, construct a Plotly data visualization
// :: For 1 var and 1 eval_res that's a number, plot {x: var, y: eval_res}
// :: For 2 vars and 1 eval_res that's a number, plot {x: var1, y: var2, z: eval_res}
// :: For all else, don't plot anything (at the moment)
}, [data]);
if (data.input) {
@ -788,7 +782,8 @@ const VisNode = ({ data, id }) => {
type="target"
position="left"
id="input"
style={{ top: '50%', background: '#555' }}
className="grouped-handle"
style={{ top: '50%' }}
onConnect={handleOnConnect}
/>
</div>

View File

@ -25,6 +25,8 @@ export const colorPalettes = {
var: varColorPalette,
}
const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'textfields']);
// A global store of variables, used for maintaining state
// across ChainForge and ReactFlow components.
const useStore = create((set, get) => ({
@ -102,6 +104,15 @@ const useStore = create((set, get) => ({
outputEdgesForNode: (sourceNodeId) => {
return get().edges.filter(e => e.source == sourceNodeId);
},
pingOutputNodes: (sourceNodeId) => {
const out_nodes = get().outputEdgesForNode(sourceNodeId).map(e => e.target);
out_nodes.forEach(n => {
const node = get().getNode(n);
if (node && refreshableOutputNodeTypes.has(node.type)) {
get().setDataPropsForNode(node.id, { refresh: true });
}
});
},
output: (sourceNodeId, sourceHandleKey) => {
// Get the source node
const src_node = get().getNode(sourceNodeId);
@ -227,6 +238,11 @@ const useStore = create((set, get) => ({
get().setDataPropsForNode(target.id, { input: connection.source });
}
// Ping target node to fresh if necessary
if (target && refreshableOutputNodeTypes.has(target.type)) {
get().setDataPropsForNode(target.id, { refresh: true });
}
connection.interactionWidth = 100;
connection.markerEnd = {type: 'arrow', width: '22px', height: '22px'};

View File

@ -6,12 +6,22 @@
min-width: 200px;
}
path.react-flow__edge-path:hover {
stroke: #222;
stroke-width: 2px;
}
hr {
border: none;
background-color: #bbb;
height: 1px;
}
.grouped-handle {
background: #555;
height: 15px;
}
.small-standard-font {
font-size: 10pt;
font-family: -apple-system, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
@ -520,9 +530,10 @@
white-space: pre-wrap;
overflow-y: scroll;
max-height: 150px;
margin: 6px 0px;
font-family: monospace;
font-size: 10pt;
line-height: 1.2;
line-height: 1.0;
cursor: text;
}
th .content-editable-div {

View File

@ -6,7 +6,7 @@ def readme():
setup(
name='chainforge',
version='0.2.1.4',
version='0.2.1.5',
packages=find_packages(),
author="Ian Arawjo",
description="A Visual Programming Environment for Prompt Engineering",