mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Add Join node (#144)
* Add Join node * Bug fix chat histories with undefined content in messages * Slightly decrease TF width
This commit is contained in:
parent
beeffd0ebb
commit
b448e300c5
25
chainforge/react-server/src/App.js
vendored
25
chainforge/react-server/src/App.js
vendored
@ -8,7 +8,7 @@ import ReactFlow, {
|
||||
} from 'reactflow';
|
||||
import { Button, Menu, LoadingOverlay, Text, Box, List, Loader, Tooltip } from '@mantine/core';
|
||||
import { useClipboard } from '@mantine/hooks';
|
||||
import { IconSettings, IconTextPlus, IconTerminal, IconCsv, IconSettingsAutomation, IconFileSymlink, IconRobot, IconRuler2 } from '@tabler/icons-react';
|
||||
import { IconSettings, IconTextPlus, IconTerminal, IconCsv, IconSettingsAutomation, IconFileSymlink, IconRobot, IconRuler2, IconArrowMerge } from '@tabler/icons-react';
|
||||
import RemoveEdge from './RemoveEdge';
|
||||
import TextFieldsNode from './TextFieldsNode'; // Import a custom node
|
||||
import PromptNode from './PromptNode';
|
||||
@ -19,6 +19,7 @@ import ScriptNode from './ScriptNode';
|
||||
import AlertModal from './AlertModal';
|
||||
import CsvNode from './CsvNode';
|
||||
import TabularDataNode from './TabularDataNode';
|
||||
import JoinNode from './JoinNode';
|
||||
import CommentNode from './CommentNode';
|
||||
import GlobalSettingsModal from './GlobalSettingsModal';
|
||||
import ExampleFlowsModal from './ExampleFlowsModal';
|
||||
@ -87,6 +88,7 @@ const nodeTypes = {
|
||||
csv: CsvNode,
|
||||
table: TabularDataNode,
|
||||
comment: CommentNode,
|
||||
join: JoinNode,
|
||||
};
|
||||
|
||||
const edgeTypes = {
|
||||
@ -197,27 +199,27 @@ const App = () => {
|
||||
code = "function evaluate(response) {\n return response.text.length;\n}";
|
||||
addNode({ id: 'evalNode-'+Date.now(), type: 'evaluator', data: { language: progLang, code: code }, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
const addVisNode = (event) => {
|
||||
const addVisNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'visNode-'+Date.now(), type: 'vis', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
const addInspectNode = (event) => {
|
||||
const addInspectNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
const addScriptNode = (event) => {
|
||||
const addScriptNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
const addCsvNode = (event) => {
|
||||
const addCsvNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'csvNode-'+Date.now(), type: 'csv', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
const addTabularDataNode = (event) => {
|
||||
const addTabularDataNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'table-'+Date.now(), type: 'table', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
const addCommentNode = (event) => {
|
||||
const addCommentNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'comment-'+Date.now(), type: 'comment', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
@ -225,6 +227,10 @@ const App = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'llmeval-'+Date.now(), type: 'llmeval', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
const addJoinNode = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'join-'+Date.now(), type: 'join', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
|
||||
const onClickExamples = () => {
|
||||
if (examplesModal && examplesModal.current)
|
||||
@ -768,6 +774,11 @@ const App = () => {
|
||||
<Menu.Item onClick={addInspectNode} icon={'🔍'}> Inspect Node </Menu.Item>
|
||||
</MenuTooltip>
|
||||
<Menu.Divider />
|
||||
<Menu.Label>Processors</Menu.Label>
|
||||
<MenuTooltip label="Concatenate responses or input data together before passing into later nodes, within or across variables and LLMs.">
|
||||
<Menu.Item onClick={addJoinNode} icon={<IconArrowMerge size='14pt' />}> Join Node </Menu.Item>
|
||||
</MenuTooltip>
|
||||
<Menu.Divider />
|
||||
<Menu.Label>Misc</Menu.Label>
|
||||
<MenuTooltip label="Make a comment about your flow.">
|
||||
<Menu.Item onClick={addCommentNode} icon={'✏️'}> Comment Node </Menu.Item>
|
||||
|
392
chainforge/react-server/src/JoinNode.js
vendored
Normal file
392
chainforge/react-server/src/JoinNode.js
vendored
Normal file
@ -0,0 +1,392 @@
|
||||
import React, { useState, useEffect, useCallback } from 'react';
|
||||
import { Handle } from 'reactflow';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent';
|
||||
import fetch_from_backend from './fetch_from_backend';
|
||||
import { IconArrowMerge, IconList } from '@tabler/icons-react';
|
||||
import { Divider, NativeSelect, Text, Popover, Tooltip, Center, Modal, Box } from '@mantine/core';
|
||||
import { useDisclosure } from '@mantine/hooks';
|
||||
|
||||
const formattingOptions = [
|
||||
{value: "\n\n", label:"double newline \\n\\n"},
|
||||
{value: "\n", label:"newline \\n"},
|
||||
{value: "-", label:"- dashed list"},
|
||||
{value: "1.", label:"1. numbered list"},
|
||||
{value: "[]", label:'["list", "of", "strings"]'}
|
||||
];
|
||||
|
||||
const joinTexts = (texts, formatting) => {
|
||||
if (formatting === "\n\n" || formatting === "\n")
|
||||
return texts.join(formatting);
|
||||
else if (formatting === "-")
|
||||
return texts.map((t) => ('- ' + t)).join("\n");
|
||||
else if (formatting === "1.")
|
||||
return texts.map((t, i) => (`${i+1}. ${t}`)).join("\n");
|
||||
else if (formatting === '[]')
|
||||
return JSON.stringify(texts);
|
||||
|
||||
console.error(`Could not join: Unknown formatting option: ${formatting}`);
|
||||
return texts;
|
||||
};
|
||||
|
||||
const getVarsAndMetavars = (input_data) => {
|
||||
// Find all vars and metavars in the input data (if any):
|
||||
let varnames = new Set();
|
||||
let metavars = new Set();
|
||||
Object.entries(input_data).forEach(([key, obj]) => {
|
||||
if (key !== '__input') varnames.add(key); // A "var" can also be other properties on input_data
|
||||
obj.forEach(resp_obj => {
|
||||
if (typeof resp_obj === "string") return;
|
||||
Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v));
|
||||
if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v));
|
||||
});
|
||||
});
|
||||
varnames = Array.from(varnames);
|
||||
metavars = Array.from(metavars);
|
||||
return {
|
||||
vars: varnames,
|
||||
metavars: metavars,
|
||||
};
|
||||
};
|
||||
|
||||
const countNumLLMs = (resp_objs_or_dict) => {
|
||||
const resp_objs = Array.isArray(resp_objs_or_dict) ? resp_objs_or_dict : Object.values(resp_objs_or_dict).flat();
|
||||
return (new Set(resp_objs.filter(r => typeof r !== "string" && r.llm !== undefined).map(r => r.llm?.key || r.llm))).size;
|
||||
};
|
||||
|
||||
const tagMetadataWithLLM = (input_data) => {
|
||||
let new_data = {};
|
||||
Object.entries(input_data).forEach(([varname, resp_objs]) => {
|
||||
new_data[varname] = resp_objs.map(r => {
|
||||
if (!r || typeof r === 'string' || !r?.llm?.key) return r;
|
||||
let r_copy = JSON.parse(JSON.stringify(r));
|
||||
r_copy.metavars["__LLM_key"] = r.llm.key;
|
||||
return r_copy;
|
||||
});
|
||||
});
|
||||
return new_data;
|
||||
};
|
||||
const extractLLMLookup = (input_data) => {
|
||||
let llm_lookup = {};
|
||||
Object.entries(input_data).forEach(([varname, resp_objs]) => {
|
||||
resp_objs.forEach(r => {
|
||||
if (typeof r === 'string' || !r?.llm?.key || r.llm.key in llm_lookup) return;
|
||||
llm_lookup[r.llm.key] = r.llm;
|
||||
});
|
||||
});
|
||||
return llm_lookup;
|
||||
};
|
||||
const removeLLMTagFromMetadata = (metavars) => {
|
||||
if (!('__LLM_key' in metavars))
|
||||
return metavars;
|
||||
let mcopy = JSON.parse(JSON.stringify(metavars));
|
||||
delete metavars['__LLM_key'];
|
||||
return mcopy;
|
||||
};
|
||||
|
||||
const truncStr = (s, maxLen) => {
|
||||
if (s.length > maxLen) // Cut the name short if it's long
|
||||
return s.substring(0, maxLen) + '...'
|
||||
else
|
||||
return s;
|
||||
};
|
||||
const groupResponsesBy = (responses, keyFunc) => {
|
||||
let responses_by_key = {};
|
||||
let unspecified_group = [];
|
||||
responses.forEach(item => {
|
||||
const key = keyFunc(item);
|
||||
const d = key !== null ? responses_by_key : unspecified_group;
|
||||
if (key in d)
|
||||
d[key].push(item);
|
||||
else
|
||||
d[key] = [item];
|
||||
});
|
||||
return [responses_by_key, unspecified_group];
|
||||
};
|
||||
|
||||
const DEFAULT_GROUPBY_VAR_ALL = { label: "all text", value: "A" };
|
||||
|
||||
const displayJoinedTexts = (textInfos, getColorForLLM) => {
|
||||
const color_for_llm = (llm) => (getColorForLLM(llm) + '99');
|
||||
return textInfos.map((info, idx) => {
|
||||
|
||||
const vars = info.fill_history;
|
||||
let var_tags = vars === undefined ? [] : Object.keys(vars).map((varname) => {
|
||||
const v = truncStr(vars[varname].trim(), 72);
|
||||
return (<div key={varname} className="response-var-inline">
|
||||
<span className="response-var-name">{varname} = </span><span className="response-var-value">{v}</span>
|
||||
</div>);
|
||||
});
|
||||
|
||||
const ps = (<pre className='small-response'>{info.text || info}</pre>);
|
||||
|
||||
return (
|
||||
<div key={"r"+idx} className="response-box" style={{ backgroundColor: (info.llm ? color_for_llm(info.llm?.name) : '#ddd'), width: `100%`}}>
|
||||
<div className="response-var-inline-container">
|
||||
{var_tags}
|
||||
</div>
|
||||
{info.llm === undefined ?
|
||||
ps
|
||||
: (<div className="response-item-llm-name-wrapper">
|
||||
<h1>{info.llm?.name}</h1>
|
||||
{ps}
|
||||
</div>)
|
||||
}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
const JoinedTextsPopover = ({ textInfos, onHover, onClick, getColorForLLM }) => {
|
||||
const [opened, { close, open }] = useDisclosure(false);
|
||||
|
||||
const _onHover = useCallback(() => {
|
||||
onHover();
|
||||
open();
|
||||
}, [onHover, open]);
|
||||
|
||||
return (
|
||||
<Popover position="right-start" withArrow withinPortal shadow="rgb(38, 57, 77) 0px 10px 30px -14px" key="query-info" opened={opened} styles={{dropdown: {maxHeight: '500px', maxWidth: '400px', overflowY: 'auto', backgroundColor: '#fff'}}}>
|
||||
<Popover.Target>
|
||||
<Tooltip label='Click to view all joined inputs' withArrow>
|
||||
<button className='custom-button' onMouseEnter={_onHover} onMouseLeave={close} onClick={onClick} style={{border:'none'}}>
|
||||
<IconList size='12pt' color='gray' style={{marginBottom: '-4px'}} />
|
||||
</button>
|
||||
</Tooltip>
|
||||
</Popover.Target>
|
||||
<Popover.Dropdown sx={{ pointerEvents: 'none' }}>
|
||||
<Center><Text size='xs' fw={500} color='#666'>Preview of joined inputs ({textInfos?.length} total)</Text></Center>
|
||||
{displayJoinedTexts(textInfos, getColorForLLM)}
|
||||
</Popover.Dropdown>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
const JoinNode = ({ data, id }) => {
|
||||
|
||||
const [joinedTexts, setJoinedTexts] = useState([]);
|
||||
|
||||
// For an info pop-up that previews all the joined inputs
|
||||
const [infoModalOpened, { open: openInfoModal, close: closeInfoModal }] = useDisclosure(false);
|
||||
|
||||
const [pastInputs, setPastInputs] = useState([]);
|
||||
const pullInputData = useStore((state) => state.pullInputData);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
|
||||
// Global lookup for what color to use per LLM
|
||||
const getColorForLLMAndSetIfNotFound = useStore((state) => state.getColorForLLMAndSetIfNotFound);
|
||||
|
||||
const [inputHasLLMs, setInputHasLLMs] = useState(false);
|
||||
|
||||
const [groupByVars, setGroupByVars] = useState([DEFAULT_GROUPBY_VAR_ALL]);
|
||||
const [groupByVar, setGroupByVar] = useState("A");
|
||||
|
||||
const [groupByLLM, setGroupByLLM] = useState("within");
|
||||
const [formatting, setFormatting] = useState(formattingOptions[0].value);
|
||||
|
||||
const handleOnConnect = useCallback(() => {
|
||||
let input_data = pullInputData(["__input"], id);
|
||||
if (!input_data?.__input) {
|
||||
console.warn('Join Node: No input data detected.');
|
||||
return;
|
||||
}
|
||||
|
||||
// Find all vars and metavars in the input data (if any):
|
||||
let {vars, metavars} = getVarsAndMetavars(input_data);
|
||||
|
||||
// Create lookup table for LLMs in input, indexed by llm key
|
||||
const llm_lookup = extractLLMLookup(input_data);
|
||||
|
||||
// Refresh the dropdown list with available vars/metavars:
|
||||
setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat(
|
||||
vars.map(varname => ({label: `by ${varname}`, value: `V${varname}`})))
|
||||
.concat(
|
||||
metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `by ${varname} (meta)`, value: `M${varname}`})))
|
||||
);
|
||||
|
||||
// Check whether more than one LLM is present in the inputs:
|
||||
const numLLMs = countNumLLMs(input_data);
|
||||
setInputHasLLMs(numLLMs > 1);
|
||||
|
||||
// Tag all response objects in the input data with a metavar for their LLM (using the llm key as a uid)
|
||||
input_data = tagMetadataWithLLM(input_data);
|
||||
|
||||
// A function to group the input (an array of texts/resp_objs) by the selected var
|
||||
// and then join the texts within the groups
|
||||
const joinByVar = (input) => {
|
||||
const varname = groupByVar.substring(1);
|
||||
const isMetavar = groupByVar[0] === 'M';
|
||||
const [groupedResps, unspecGroup] = groupResponsesBy(input,
|
||||
isMetavar ?
|
||||
(r) => (r.metavars ? r.metavars[varname] : undefined) :
|
||||
(r) => (r.fill_history ? r.fill_history[varname] : undefined)
|
||||
);
|
||||
|
||||
// Now join texts within each group:
|
||||
// (NOTE: We can do this directly here as response texts can't be templates themselves)
|
||||
let joined_texts = Object.entries(groupedResps).map(([var_val, resp_objs]) => {
|
||||
if (resp_objs.length === 0) return "";
|
||||
const llm = (countNumLLMs(resp_objs) > 1) ? undefined : resp_objs[0].llm;
|
||||
let vars = {};
|
||||
if (groupByVar !== 'A')
|
||||
vars[varname] = var_val;
|
||||
return {
|
||||
text: joinTexts(resp_objs.map(r => r.text !== undefined ? r.text : r), formatting),
|
||||
fill_history: isMetavar ? {} : vars,
|
||||
metavars: isMetavar ? vars : {},
|
||||
llm: llm,
|
||||
// NOTE: We lose all other metadata here, because we could've joined across other vars or metavars values.
|
||||
};
|
||||
});
|
||||
|
||||
// Add any data from unspecified group
|
||||
if (unspecGroup.length > 0) {
|
||||
const llm = (countNumLLMs(unspecGroup) > 1) ? undefined : unspecGroup[0].llm;
|
||||
joined_texts.push({
|
||||
text: joinTexts(unspecGroup.map(u => u.text !== undefined ? u.text : u), formatting),
|
||||
fill_history: {},
|
||||
metavars: {},
|
||||
llm: llm,
|
||||
});
|
||||
}
|
||||
|
||||
return joined_texts;
|
||||
};
|
||||
|
||||
// Generate (flatten) the inputs, which could be recursively chained templates
|
||||
// and a mix of LLM resp objects, templates, and strings.
|
||||
// (We tagged each object with its LLM key so that we can use built-in features to keep track of the LLM associated with each response object)
|
||||
fetch_from_backend('generatePrompts', {
|
||||
prompt: "{__input}",
|
||||
vars: input_data,
|
||||
}).then(promptTemplates => {
|
||||
|
||||
// Convert the templates into response objects
|
||||
let resp_objs = promptTemplates.map(p => ({
|
||||
text: p.toString(),
|
||||
fill_history: p.fill_history,
|
||||
llm: "__LLM_key" in p.metavars ? llm_lookup[p.metavars['__LLM_key']] : undefined,
|
||||
metavars: removeLLMTagFromMetadata(p.metavars),
|
||||
}));
|
||||
|
||||
// If there's multiple LLMs and groupByLLM is 'within', we need to
|
||||
// first group by the LLMs (and a possible 'undefined' group):
|
||||
if (numLLMs > 1 && groupByLLM === 'within') {
|
||||
let joined_texts = [];
|
||||
const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(resp_objs, r => r.llm?.key || r.llm);
|
||||
Object.entries(groupedRespsByLLM).map(([llm_key, resp_objs]) => {
|
||||
// Group only within the LLM
|
||||
joined_texts = joined_texts.concat(joinByVar(resp_objs));
|
||||
});
|
||||
|
||||
if (nonLLMRespGroup.length > 0)
|
||||
joined_texts.push(joinTexts(nonLLMRespGroup, formatting));
|
||||
|
||||
setJoinedTexts(joined_texts);
|
||||
setDataPropsForNode(id, { fields: joined_texts });
|
||||
} else {
|
||||
// Join across LLMs (join irrespective of LLM):
|
||||
if (groupByVar !== 'A') {
|
||||
// If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
|
||||
const joined_texts = joinByVar(resp_objs);
|
||||
setJoinedTexts(joined_texts);
|
||||
setDataPropsForNode(id, { fields: joined_texts });
|
||||
} else {
|
||||
let joined_texts = joinTexts(resp_objs.map(r => ((typeof r === 'string') ? r : r.text)), formatting);
|
||||
|
||||
// If there is exactly 1 LLM and it's present across all inputs, keep track of it:
|
||||
if (numLLMs === 1 && resp_objs.every((r) => r.llm !== undefined))
|
||||
joined_texts = {text: joined_texts, fill_history: {}, llm: resp_objs[0].llm};
|
||||
|
||||
setJoinedTexts([joined_texts]);
|
||||
setDataPropsForNode(id, { fields: [joined_texts] });
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
}, [formatting, pullInputData, groupByVar, groupByLLM]);
|
||||
|
||||
if (data.input) {
|
||||
// If there's a change in inputs...
|
||||
if (data.input != pastInputs) {
|
||||
setPastInputs(data.input);
|
||||
handleOnConnect();
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh join output anytime the dropdowns change
|
||||
useEffect(() => {
|
||||
handleOnConnect();
|
||||
}, [groupByVar, groupByLLM, formatting])
|
||||
|
||||
useEffect(() => {
|
||||
if (data.refresh && data.refresh === true) {
|
||||
// Recreate the visualization:
|
||||
setDataPropsForNode(id, { refresh: false });
|
||||
handleOnConnect();
|
||||
}
|
||||
}, [data, id, handleOnConnect, setDataPropsForNode]);
|
||||
|
||||
return (
|
||||
<div className="join-node cfnode">
|
||||
<NodeLabel title={data.title || 'Join Node'}
|
||||
nodeId={id}
|
||||
icon={<IconArrowMerge size='14pt'/>}
|
||||
customButtons={[
|
||||
<JoinedTextsPopover key='joined-text-previews' textInfos={joinedTexts} onHover={handleOnConnect} onClick={openInfoModal} getColorForLLM={getColorForLLMAndSetIfNotFound} />
|
||||
]} />
|
||||
<Modal title={'List of joined inputs (' + joinedTexts.length + ' total)'} size='xl' opened={infoModalOpened} onClose={closeInfoModal} styles={{header: {backgroundColor: '#FFD700'}, root: {position: 'relative', left: '-5%'}}}>
|
||||
<Box size={600} m='lg' mt='xl'>
|
||||
{displayJoinedTexts(joinedTexts, getColorForLLMAndSetIfNotFound)}
|
||||
</Box>
|
||||
</Modal>
|
||||
<div style={{display: 'flex', justifyContent: 'left', maxWidth: '100%', marginBottom: '10px'}}>
|
||||
<Text mt='3px' mr='xs'>Join</Text>
|
||||
<NativeSelect onChange={(e) => setGroupByVar(e.target.value)}
|
||||
className='nodrag nowheel'
|
||||
data={groupByVars}
|
||||
size="xs"
|
||||
value={groupByVar}
|
||||
miw='80px'
|
||||
mr='xs' />
|
||||
</div>
|
||||
{inputHasLLMs ?
|
||||
<div style={{display: 'flex', justifyContent: 'left', maxWidth: '100%', marginBottom: '10px'}}>
|
||||
<NativeSelect onChange={(e) => setGroupByLLM(e.target.value)}
|
||||
className='nodrag nowheel'
|
||||
data={["within", "across"]}
|
||||
size="xs"
|
||||
value={groupByLLM}
|
||||
maw='80px'
|
||||
mr='xs'
|
||||
ml='40px' />
|
||||
<Text mt='3px'>LLMs</Text>
|
||||
</div>
|
||||
: <></>}
|
||||
<Divider my="xs" label="formatting" labelPosition="center" />
|
||||
<NativeSelect onChange={(e) => setFormatting(e.target.value)}
|
||||
className='nodrag nowheel'
|
||||
data={formattingOptions}
|
||||
size="xs"
|
||||
value={formatting}
|
||||
miw='80px' />
|
||||
<Handle
|
||||
type="target"
|
||||
position="left"
|
||||
id="__input"
|
||||
className="grouped-handle"
|
||||
style={{ top: "50%" }}
|
||||
onConnect={handleOnConnect}
|
||||
/>
|
||||
<Handle
|
||||
type="source"
|
||||
position="right"
|
||||
id="output"
|
||||
className="grouped-handle"
|
||||
style={{ top: "50%" }}
|
||||
/>
|
||||
</div>);
|
||||
};
|
||||
|
||||
export default JoinNode;
|
@ -475,7 +475,7 @@ const LLMResponseInspector = ({ jsonResponses, wideFormat }) => {
|
||||
value={multiSelectValue}
|
||||
clearSearchOnChange={true}
|
||||
clearSearchOnBlur={true}
|
||||
w='80%' />
|
||||
w={wideFormat ? '80%' : '100%'} />
|
||||
<Checkbox checked={onlyShowScores}
|
||||
label="Only show scores"
|
||||
onChange={(e) => setOnlyShowScores(e.currentTarget.checked)}
|
||||
|
71
chainforge/react-server/src/PromptNode.js
vendored
71
chainforge/react-server/src/PromptNode.js
vendored
@ -78,10 +78,9 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
|
||||
// Get state from the Zustand store:
|
||||
const edges = useStore((state) => state.edges);
|
||||
const output = useStore((state) => state.output);
|
||||
const pullInputData = useStore((state) => state.pullInputData);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
|
||||
const getNode = useStore((state) => state.getNode);
|
||||
|
||||
// API Keys (set by user in popup GlobalSettingsModal)
|
||||
const apiKeys = useStore((state) => state.apiKeys);
|
||||
@ -208,52 +207,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
}
|
||||
}, [data]);
|
||||
|
||||
// Pull all inputs needed to request responses.
|
||||
// Returns [prompt, vars dict]
|
||||
const pullInputData = (_targetHandles) => {
|
||||
// Pull data from each source recursively:
|
||||
const pulled_data = {};
|
||||
const store_data = (_texts, _varname, _data) => {
|
||||
if (_varname in _data)
|
||||
_data[_varname] = _data[_varname].concat(_texts);
|
||||
else
|
||||
_data[_varname] = _texts;
|
||||
};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
varnames.forEach(varname => {
|
||||
// Find the relevant edge(s):
|
||||
edges.forEach(e => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
// Get the immediate output:
|
||||
let out = output(e.source, e.sourceHandle);
|
||||
if (!out || !Array.isArray(out) || out.length === 0) return;
|
||||
|
||||
// Check the format of the output. Can be str or dict with 'text' and more attrs:
|
||||
if (typeof out[0] === 'object') {
|
||||
out.forEach(obj => store_data([obj], varname, pulled_data));
|
||||
}
|
||||
else {
|
||||
// Save the list of strings from the pulled output under the var 'varname'
|
||||
store_data(out, varname, pulled_data);
|
||||
}
|
||||
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
get_outputs(n_vars, e.source);
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
get_outputs(_targetHandles, id);
|
||||
|
||||
return pulled_data;
|
||||
};
|
||||
|
||||
// Chat nodes only. Pulls input data attached to the 'past conversations' handle.
|
||||
// Returns a tuple (past_chat_llms, __past_chats), where both are undefined if nothing is connected.
|
||||
const pullInputChats = () => {
|
||||
const pulled_data = pullInputData(['__past_chats']);
|
||||
const pulled_data = pullInputData(['__past_chats'], id);
|
||||
if (!('__past_chats' in pulled_data)) return [undefined, undefined];
|
||||
|
||||
// For storing the unique LLMs in past_chats:
|
||||
@ -313,12 +270,12 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
const [promptPreviews, setPromptPreviews] = useState([]);
|
||||
const handlePreviewHover = () => {
|
||||
// Pull input data and prompt
|
||||
const pulled_vars = pullInputData(templateVars);
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
fetch_from_backend('generatePrompts', {
|
||||
prompt: promptText,
|
||||
vars: pulled_vars,
|
||||
}).then(prompts => {
|
||||
setPromptPreviews(prompts.map(p => (new PromptInfo(p))));
|
||||
setPromptPreviews(prompts.map(p => (new PromptInfo(p.toString()))));
|
||||
});
|
||||
|
||||
pullInputChats();
|
||||
@ -352,7 +309,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
}
|
||||
|
||||
// Pull the input data
|
||||
const pulled_vars = pullInputData(templateVars);
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
|
||||
const llms = _llmItemsCurrState.map(item => item.model);
|
||||
const num_llms = llms.length;
|
||||
@ -442,6 +399,20 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if pulled chats includes undefined content.
|
||||
// This could happen with Join nodes, where there is no longer a single "prompt" (user prompt)
|
||||
// of the chat provenance. Instead of blocking this behavior, we replace undefined with a blank string,
|
||||
// and output a warning to the console.
|
||||
if (!pulled_chats.every(c => c.messages.every(m => m.content !== undefined))) {
|
||||
console.warn("Chat history contains undefined content. This can happen if a Join Node was used, \
|
||||
as there is no longer a single prompt as the provenance of the conversation. \
|
||||
Soft failing by replacing undefined with empty strings.");
|
||||
pulled_chats.forEach(c => {c.messages = c.messages.map(m => {
|
||||
if (m.content !== undefined) return m;
|
||||
else return {...m, content: " "}; // the string contains a single space since PaLM2 refuses to answer with empty strings
|
||||
})});
|
||||
}
|
||||
|
||||
// Override LLM list with the past llm info (unique LLMs in prior responses)
|
||||
_llmItemsCurrState = past_chat_llms;
|
||||
|
||||
@ -462,13 +433,13 @@ const PromptNode = ({ data, id, type: node_type }) => {
|
||||
setProgressAnimated(true);
|
||||
|
||||
// Pull the data to fill in template input variables, if any
|
||||
const pulled_data = pullInputData(templateVars);
|
||||
const pulled_data = pullInputData(templateVars, id);
|
||||
const prompt_template = promptText;
|
||||
|
||||
const rejected = (err) => {
|
||||
setStatus('error');
|
||||
setContChatToggleDisabled(false);
|
||||
triggerAlert(err.message);
|
||||
triggerAlert(err.message || err);
|
||||
};
|
||||
|
||||
// Fetch info about the number of queries we'll need to make
|
||||
|
2
chainforge/react-server/src/RemoveEdge.js
vendored
2
chainforge/react-server/src/RemoveEdge.js
vendored
@ -45,7 +45,7 @@ export default function CustomEdge({
|
||||
|
||||
// Thanks in part to oshanley https://github.com/wbkd/react-flow/issues/1211#issuecomment-1585032930
|
||||
return (
|
||||
<EdgePathContainer onPointerEnter={()=>setHovering(true)} onPointerLeave={()=>setHovering(false)} onClick={()=>console.log('click')}>
|
||||
<EdgePathContainer onPointerEnter={()=>setHovering(true)} onPointerLeave={()=>setHovering(false)}>
|
||||
<BaseEdge path={edgePath} markerEnd={markerEnd} style={{...style, stroke: (hovering ? '#000' : '#999')}} />
|
||||
<EdgeLabelRenderer>
|
||||
<div
|
||||
|
@ -371,9 +371,9 @@ function run_over_responses(eval_func: (resp: ResponseInfo) => any, responses: A
|
||||
* @param vars a dict of the template variables to fill the prompt template with, by name. (See countQueries docstring for more info).
|
||||
* @returns An array of strings representing the prompts that will be sent out. Note that this could include unfilled template vars.
|
||||
*/
|
||||
export async function generatePrompts(root_prompt: string, vars: Dict): Promise<string[]> {
|
||||
export async function generatePrompts(root_prompt: string, vars: Dict): Promise<PromptTemplate[]> {
|
||||
const gen_prompts = new PromptPermutationGenerator(root_prompt);
|
||||
const all_prompt_permutations = Array.from(gen_prompts.generate(vars)).map(p => p.toString());
|
||||
const all_prompt_permutations = Array.from(gen_prompts.generate(vars));
|
||||
return all_prompt_permutations;
|
||||
}
|
||||
|
||||
|
52
chainforge/react-server/src/store.js
vendored
52
chainforge/react-server/src/store.js
vendored
@ -26,7 +26,7 @@ export const colorPalettes = {
|
||||
var: varColorPalette,
|
||||
}
|
||||
|
||||
const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval']);
|
||||
const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval', 'join']);
|
||||
|
||||
export let initLLMProviders = [
|
||||
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
|
||||
@ -204,6 +204,56 @@ const useStore = create((set, get) => ({
|
||||
return null;
|
||||
}
|
||||
},
|
||||
|
||||
// Pull all inputs needed to request responses.
|
||||
// Returns [prompt, vars dict]
|
||||
pullInputData: (_targetHandles, node_id) => {
|
||||
// Functions/data from the store:
|
||||
const getNode = get().getNode;
|
||||
const output = get().output;
|
||||
const edges = get().edges;
|
||||
|
||||
// Helper function to store collected data in dict:
|
||||
const store_data = (_texts, _varname, _data) => {
|
||||
if (_varname in _data)
|
||||
_data[_varname] = _data[_varname].concat(_texts);
|
||||
else
|
||||
_data[_varname] = _texts;
|
||||
};
|
||||
|
||||
// Pull data from each source recursively:
|
||||
const pulled_data = {};
|
||||
const get_outputs = (varnames, nodeId) => {
|
||||
varnames.forEach(varname => {
|
||||
// Find the relevant edge(s):
|
||||
edges.forEach(e => {
|
||||
if (e.target == nodeId && e.targetHandle == varname) {
|
||||
// Get the immediate output:
|
||||
let out = output(e.source, e.sourceHandle);
|
||||
if (!out || !Array.isArray(out) || out.length === 0) return;
|
||||
|
||||
// Check the format of the output. Can be str or dict with 'text' and more attrs:
|
||||
if (typeof out[0] === 'object') {
|
||||
out.forEach(obj => store_data([obj], varname, pulled_data));
|
||||
}
|
||||
else {
|
||||
// Save the list of strings from the pulled output under the var 'varname'
|
||||
store_data(out, varname, pulled_data);
|
||||
}
|
||||
|
||||
// Get any vars that the output depends on, and recursively collect those outputs as well:
|
||||
const n_vars = getNode(e.source).data.vars;
|
||||
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
|
||||
get_outputs(n_vars, e.source);
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
get_outputs(_targetHandles, node_id);
|
||||
|
||||
return pulled_data;
|
||||
},
|
||||
|
||||
setDataPropsForNode: (id, data_props) => {
|
||||
set({
|
||||
nodes: (nds =>
|
||||
|
@ -409,6 +409,9 @@
|
||||
color: #444;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
.join-text-preview {
|
||||
margin: 0px 0px 10px 0px;
|
||||
}
|
||||
|
||||
.small-response {
|
||||
font-size: 8pt;
|
||||
@ -531,6 +534,10 @@
|
||||
border-color: #222;
|
||||
}
|
||||
|
||||
.join-node {
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.tabular-data-node {
|
||||
min-width: 280px;
|
||||
}
|
||||
@ -652,7 +659,7 @@
|
||||
.text-field-fixed .mantine-Textarea-wrapper textarea {
|
||||
resize: vertical;
|
||||
overflow-y: auto;
|
||||
width: 280px;
|
||||
width: 260px;
|
||||
padding: calc(0.5rem / 3);
|
||||
font-size: 10pt;
|
||||
font-family: monospace;
|
||||
|
Loading…
x
Reference in New Issue
Block a user