join node wip

This commit is contained in:
Ian Arawjo 2023-10-16 11:54:51 -04:00
parent d8e734d778
commit 095d77a71b
6 changed files with 289 additions and 119 deletions

View File

@ -1,58 +1,206 @@
import React, { useState, useEffect } from 'react';
import React, { useState, useEffect, useCallback } from 'react';
import { Handle } from 'reactflow';
import useStore from './store';
import NodeLabel from './NodeLabelComponent';
import fetch_from_backend from './fetch_from_backend';
import { IconArrowMerge } from '@tabler/icons-react';
import { Divider, NativeSelect, Text } from '@mantine/core';
import { IconArrowMerge, IconList } from '@tabler/icons-react';
import { Divider, NativeSelect, Text, Popover, Tooltip, Center, Modal, Box } from '@mantine/core';
import { useDisclosure } from '@mantine/hooks';
const formattingOptions = [
{value: "\n\n", label:"double newline \\n\\n"},
{value: "\n", label:"newline \\n"},
{value: "-", label:"- dashed list"},
{value: "[]", label:'["list", "of", "strings"]'}
];
const joinTexts = (texts, formatting) => {
if (formatting === "\n\n" || formatting === "\n")
return texts.join(formatting);
else if (formatting === "-")
return texts.map((t) => (' - ' + t)).join("\n");
else if (formatting === '[]')
return JSON.stringify(texts);
console.error(`Could not join: Unknown formatting option: ${formatting}`);
return texts;
};
const getVarsAndMetavars = (input_data) => {
// Find all vars and metavars in the input data (if any):
let varnames = new Set();
let metavars = new Set();
input_data.forEach(resp_obj => {
if (typeof resp_obj === "string") return;
Object.keys(resp_obj.fill_history).forEach(v => varnames.add(v));
if (resp_obj.metavars) Object.keys(resp_obj.metavars).forEach(v => metavars.add(v));
});
varnames = Array.from(varnames);
metavars = Array.from(metavars);
return {
vars: varnames,
metavars: metavars,
};
}
const containsMultipleLLMs = (resp_objs) => {
return (new Set(resp_objs.map(r => r.llm.key || r.llm))).length > 1;
}
const truncStr = (s, maxLen) => {
if (s.length > maxLen) // Cut the name short if it's long
return s.substring(0, maxLen) + '...'
else
return s;
};
const groupResponsesBy = (responses, keyFunc) => {
let responses_by_key = {};
let unspecified_group = [];
responses.forEach(item => {
const key = keyFunc(item);
const d = key !== null ? responses_by_key : unspecified_group;
if (key in d)
d[key].push(item);
else
d[key] = [item];
});
return [responses_by_key, unspecified_group];
};
const DEFAULT_GROUPBY_VAR_ALL = { label: "all text", value: "A" };
const displayJoinedTexts = (textInfos) =>
textInfos.map((info, idx) => {
const vars = info.fill_history;
const var_tags = vars === undefined ? [] : Object.keys(vars).map((varname) => {
const v = truncStr(vars[varname].trim(), 72);
return (<div key={varname} className="response-var-inline">
<span className="response-var-name">{varname}&nbsp;=&nbsp;</span><span className="response-var-value">{v}</span>
</div>);
});
return (
<div key={idx}>
{var_tags}
<pre className='prompt-preview join-text-preview'>
{info.text || info}
</pre>
</div>
);
});
const JoinedTextsPopover = ({ textInfos, onHover, onClick }) => {
const [opened, { close, open }] = useDisclosure(false);
const _onHover = useCallback(() => {
onHover();
open();
}, [onHover, open]);
return (
<Popover position="right-start" withArrow withinPortal shadow="rgb(38, 57, 77) 0px 10px 30px -14px" key="query-info" opened={opened} styles={{dropdown: {maxHeight: '500px', maxWidth: '400px', overflowY: 'auto', backgroundColor: '#fff'}}}>
<Popover.Target>
<Tooltip label='Click to view all joined inputs' withArrow>
<button className='custom-button' onMouseEnter={_onHover} onMouseLeave={close} onClick={onClick} style={{border:'none'}}>
<IconList size='12pt' color='gray' style={{marginBottom: '-4px'}} />
</button>
</Tooltip>
</Popover.Target>
<Popover.Dropdown sx={{ pointerEvents: 'none' }}>
<Center><Text size='xs' fw={500} color='#666'>Preview of joined inputs ({textInfos?.length} total)</Text></Center>
{displayJoinedTexts(textInfos)}
</Popover.Dropdown>
</Popover>
);
};
const JoinNode = ({ data, id }) => {
let is_fetching = false;
const [jsonResponses, setJSONResponses] = useState(null);
const [joinedTexts, setJoinedTexts] = useState([]);
// For an info pop-up that previews all the joined inputs
const [infoModalOpened, { open: openInfoModal, close: closeInfoModal }] = useDisclosure(false);
const [pastInputs, setPastInputs] = useState([]);
const pullInputData = useStore((state) => state.pullInputData);
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const handleOnConnect = () => {
// For some reason, 'on connect' is called twice upon connection.
// We detect when an inspector node is already fetching, and disable the second call:
if (is_fetching) return;
const [inputHasLLMs, setInputHasLLMs] = useState(false);
const [inputHasMultiRespPerLLM, setInputHasMultiRespPerLLM] = useState(false);
// Get the ids from the connected input nodes:
const input_node_ids = inputEdgesForNode(id).map(e => e.source);
is_fetching = true;
// Grab responses associated with those ids:
fetch_from_backend('grabResponses', {
'responses': input_node_ids
}).then(function(json) {
if (json.responses && json.responses.length > 0) {
setJSONResponses(json.responses);
}
is_fetching = false;
}).catch(() => {
is_fetching = false;
});
}
const [groupByVar, setGroupByVar] = useState("all text");
const handleChangeGroupByVar = (new_val) => {
setGroupByVar(new_val.target.value);
};
const [groupByVars, setGroupByVars] = useState([DEFAULT_GROUPBY_VAR_ALL]);
const [groupByVar, setGroupByVar] = useState("A");
const [groupByLLM, setGroupByLLM] = useState("within");
const handleChangeGroupByLLM = (new_val) => {
setGroupByLLM(new_val.target.value);
};
const [responsesPerPrompt, setResponsesPerPrompt] = useState("all");
const handleChangeResponsesPerPrompt = (new_val) => {
setResponsesPerPrompt(new_val.target.value);
};
const [formatting, setFormatting] = useState(formattingOptions[0].value);
const handleOnConnect = useCallback(() => {
const input_data = pullInputData(["__input"], id);
if (!input_data?.__input) {
console.warn('Join Node: No input data detected.');
return;
}
console.log(input_data);
// Find all vars and metavars in the input data (if any):
const {vars, metavars} = getVarsAndMetavars(input_data.__input);
// Refresh the dropdown list with available vars/metavars:
setGroupByVars([DEFAULT_GROUPBY_VAR_ALL].concat(
vars.map(varname => ({label: `within ${varname}`, value: `V${varname}`})))
.concat(
metavars.filter(varname => !varname.startsWith('LLM_')).map(varname => ({label: `within ${varname} (meta)`, value: `M${varname}`})))
);
// If groupByVar is set to non-ALL (not "A"), then we need to group responses by that variable first:
if (groupByVar !== 'A') {
const varname = groupByVar.substring(1);
const [groupedResps, unspecGroup] = groupResponsesBy(input_data.__input,
(groupByVar[0] === 'V') ?
(r) => r.fill_history[varname] :
(r) => r.metavars[varname]
);
console.log(groupedResps);
// Now join texts within each group:
// (NOTE: We can do this directly here as response texts can't be templates themselves)
let joined_texts = Object.entries(groupedResps).map(([var_val, resp_objs]) => {
if (resp_objs.length === 0) return "";
let llm = containsMultipleLLMs(resp_objs) ? undefined : resp_objs[0].llm;
let vars = {};
vars[varname] = var_val;
return {
text: joinTexts(resp_objs.map(r => r.text), formatting),
fill_history: vars,
llm: llm,
// NOTE: We lose all other metadata here, because we could've joined across other vars or metavars values.
};
});
setJoinedTexts(joined_texts);
console.log(joined_texts);
}
else {
// Since templates could be chained, we need to run this
// through the prompt generator:
fetch_from_backend('generatePrompts', {
prompt: "{__input}",
vars: input_data,
}).then(promptTemplates => {
const texts = promptTemplates.map(p => p.toString());
console.log(texts);
const joined_texts = joinTexts(texts, formatting);
setJoinedTexts([joined_texts]);
console.log(joined_texts);
});
}
}, [formatting, pullInputData, groupByVar]);
if (data.input) {
// If there's a change in inputs...
@ -75,50 +223,62 @@ const JoinNode = ({ data, id }) => {
<NodeLabel title={data.title || 'Join Node'}
nodeId={id}
icon={<IconArrowMerge size='14pt'/>}
/>
customButtons={[
<JoinedTextsPopover key='joined-text-previews' textInfos={joinedTexts} onHover={handleOnConnect} onClick={openInfoModal} />
]} />
<Modal title={'List of joined inputs (' + joinedTexts.length + ' total)'} size='xl' opened={infoModalOpened} onClose={closeInfoModal} styles={{header: {backgroundColor: '#FFD700'}, root: {position: 'relative', left: '-5%'}}}>
<Box size={600} m='lg' mt='xl'>
{displayJoinedTexts(joinedTexts)}
</Box>
</Modal>
<div style={{display: 'flex', justifyContent: 'left', maxWidth: '100%', marginBottom: '10px'}}>
<Text mt='3px' mr='xs'>Join</Text>
<NativeSelect onChange={handleChangeGroupByVar}
<NativeSelect onChange={(e) => setGroupByVar(e.target.value)}
className='nodrag nowheel'
data={["all text", "by country", "by city"]}
data={groupByVars}
size="xs"
value={groupByVar}
miw='80px'
mr='xs' />
</div>
<div style={{display: 'flex', justifyContent: 'left', maxWidth: '100%', marginBottom: '10px'}}>
<NativeSelect onChange={handleChangeGroupByLLM}
className='nodrag nowheel'
data={["within", "across"]}
size="xs"
value={groupByLLM}
maw='80px'
mr='xs' />
<Text mt='3px'>LLM(s)</Text>
</div>
<div style={{display: 'flex', justifyContent: 'left', maxWidth: '100%'}}>
<Text size='sm' mt='3px' mr='xs' color='gray' fs='italic'> take</Text>
<NativeSelect onChange={handleChangeResponsesPerPrompt}
className='nodrag nowheel'
data={["all", "1", "2", "3"]}
size="xs"
value={"1"}
maw='80px'
mr='xs'
color='gray' />
<Text size='sm' mt='3px' color='gray' fs='italic'>resp / prompt</Text>
</div>
{inputHasLLMs ?
<div style={{display: 'flex', justifyContent: 'left', maxWidth: '100%', marginBottom: '10px'}}>
<NativeSelect onChange={(e) => setGroupByLLM(e.target.value)}
className='nodrag nowheel'
data={["within", "across"]}
size="xs"
value={groupByLLM}
maw='80px'
mr='xs'
ml='40px' />
<Text mt='3px'>LLM(s)</Text>
</div>
: <></>}
{inputHasMultiRespPerLLM ?
<div style={{display: 'flex', justifyContent: 'left', maxWidth: '100%'}}>
<NativeSelect onChange={(e) => setResponsesPerPrompt(e.target.value)}
className='nodrag nowheel'
data={["all", "1", "2", "3"]}
size="xs"
value={"1"}
maw='80px'
mr='xs'
ml='40px'
color='gray' />
<Text size='sm' mt='3px' color='gray' fs='italic'>resp / prompt</Text>
</div>
: <></>}
<Divider my="xs" label="formatting" labelPosition="center" />
<NativeSelect onChange={handleChangeResponsesPerPrompt}
<NativeSelect onChange={(e) => setFormatting(e.target.value)}
className='nodrag nowheel'
data={["double newline \\n\\n", "newline \\n", "- dashed list", '["list", "of", "strings"]']}
data={formattingOptions}
size="xs"
value={"double newline"}
value={formatting}
miw='80px' />
<Handle
type="target"
position="left"
id="input"
id="__input"
className="grouped-handle"
style={{ top: "50%" }}
onConnect={handleOnConnect}

View File

@ -78,10 +78,9 @@ const PromptNode = ({ data, id, type: node_type }) => {
// Get state from the Zustand store:
const edges = useStore((state) => state.edges);
const output = useStore((state) => state.output);
const pullInputData = useStore((state) => state.pullInputData);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
const getNode = useStore((state) => state.getNode);
// API Keys (set by user in popup GlobalSettingsModal)
const apiKeys = useStore((state) => state.apiKeys);
@ -208,52 +207,10 @@ const PromptNode = ({ data, id, type: node_type }) => {
}
}, [data]);
// Pull all inputs needed to request responses.
// Returns [prompt, vars dict]
const pullInputData = (_targetHandles) => {
// Pull data from each source recursively:
const pulled_data = {};
const store_data = (_texts, _varname, _data) => {
if (_varname in _data)
_data[_varname] = _data[_varname].concat(_texts);
else
_data[_varname] = _texts;
};
const get_outputs = (varnames, nodeId) => {
varnames.forEach(varname => {
// Find the relevant edge(s):
edges.forEach(e => {
if (e.target == nodeId && e.targetHandle == varname) {
// Get the immediate output:
let out = output(e.source, e.sourceHandle);
if (!out || !Array.isArray(out) || out.length === 0) return;
// Check the format of the output. Can be str or dict with 'text' and more attrs:
if (typeof out[0] === 'object') {
out.forEach(obj => store_data([obj], varname, pulled_data));
}
else {
// Save the list of strings from the pulled output under the var 'varname'
store_data(out, varname, pulled_data);
}
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
get_outputs(n_vars, e.source);
}
});
});
};
get_outputs(_targetHandles, id);
return pulled_data;
};
// Chat nodes only. Pulls input data attached to the 'past conversations' handle.
// Returns a tuple (past_chat_llms, __past_chats), where both are undefined if nothing is connected.
const pullInputChats = () => {
const pulled_data = pullInputData(['__past_chats']);
const pulled_data = pullInputData(['__past_chats'], id);
if (!('__past_chats' in pulled_data)) return [undefined, undefined];
// For storing the unique LLMs in past_chats:
@ -313,12 +270,12 @@ const PromptNode = ({ data, id, type: node_type }) => {
const [promptPreviews, setPromptPreviews] = useState([]);
const handlePreviewHover = () => {
// Pull input data and prompt
const pulled_vars = pullInputData(templateVars);
const pulled_vars = pullInputData(templateVars, id);
fetch_from_backend('generatePrompts', {
prompt: promptText,
vars: pulled_vars,
}).then(prompts => {
setPromptPreviews(prompts.map(p => (new PromptInfo(p))));
setPromptPreviews(prompts.map(p => (new PromptInfo(p.toString()))));
});
pullInputChats();
@ -352,7 +309,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
}
// Pull the input data
const pulled_vars = pullInputData(templateVars);
const pulled_vars = pullInputData(templateVars, id);
const llms = _llmItemsCurrState.map(item => item.model);
const num_llms = llms.length;
@ -462,7 +419,7 @@ const PromptNode = ({ data, id, type: node_type }) => {
setProgressAnimated(true);
// Pull the data to fill in template input variables, if any
const pulled_data = pullInputData(templateVars);
const pulled_data = pullInputData(templateVars, id);
const prompt_template = promptText;
const rejected = (err) => {

View File

@ -45,7 +45,7 @@ export default function CustomEdge({
// Thanks in part to oshanley https://github.com/wbkd/react-flow/issues/1211#issuecomment-1585032930
return (
<EdgePathContainer onPointerEnter={()=>setHovering(true)} onPointerLeave={()=>setHovering(false)} onClick={()=>console.log('click')}>
<EdgePathContainer onPointerEnter={()=>setHovering(true)} onPointerLeave={()=>setHovering(false)}>
<BaseEdge path={edgePath} markerEnd={markerEnd} style={{...style, stroke: (hovering ? '#000' : '#999')}} />
<EdgeLabelRenderer>
<div

View File

@ -371,9 +371,9 @@ function run_over_responses(eval_func: (resp: ResponseInfo) => any, responses: A
* @param vars a dict of the template variables to fill the prompt template with, by name. (See countQueries docstring for more info).
* @returns An array of strings representing the prompts that will be sent out. Note that this could include unfilled template vars.
*/
export async function generatePrompts(root_prompt: string, vars: Dict): Promise<string[]> {
export async function generatePrompts(root_prompt: string, vars: Dict): Promise<PromptTemplate[]> {
const gen_prompts = new PromptPermutationGenerator(root_prompt);
const all_prompt_permutations = Array.from(gen_prompts.generate(vars)).map(p => p.toString());
const all_prompt_permutations = Array.from(gen_prompts.generate(vars));
return all_prompt_permutations;
}

View File

@ -26,7 +26,7 @@ export const colorPalettes = {
var: varColorPalette,
}
const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval']);
const refreshableOutputNodeTypes = new Set(['evaluator', 'prompt', 'inspect', 'vis', 'llmeval', 'textfields', 'chat', 'simpleval', 'join']);
export let initLLMProviders = [
{ name: "GPT3.5", emoji: "🤖", model: "gpt-3.5-turbo", base_model: "gpt-3.5-turbo", temp: 1.0 }, // The base_model designates what settings form will be used, and must be unique.
@ -204,6 +204,56 @@ const useStore = create((set, get) => ({
return null;
}
},
// Pull all inputs needed to request responses.
// Returns [prompt, vars dict]
pullInputData: (_targetHandles, node_id) => {
// Functions/data from the store:
const getNode = get().getNode;
const output = get().output;
const edges = get().edges;
// Helper function to store collected data in dict:
const store_data = (_texts, _varname, _data) => {
if (_varname in _data)
_data[_varname] = _data[_varname].concat(_texts);
else
_data[_varname] = _texts;
};
// Pull data from each source recursively:
const pulled_data = {};
const get_outputs = (varnames, nodeId) => {
varnames.forEach(varname => {
// Find the relevant edge(s):
edges.forEach(e => {
if (e.target == nodeId && e.targetHandle == varname) {
// Get the immediate output:
let out = output(e.source, e.sourceHandle);
if (!out || !Array.isArray(out) || out.length === 0) return;
// Check the format of the output. Can be str or dict with 'text' and more attrs:
if (typeof out[0] === 'object') {
out.forEach(obj => store_data([obj], varname, pulled_data));
}
else {
// Save the list of strings from the pulled output under the var 'varname'
store_data(out, varname, pulled_data);
}
// Get any vars that the output depends on, and recursively collect those outputs as well:
const n_vars = getNode(e.source).data.vars;
if (n_vars && Array.isArray(n_vars) && n_vars.length > 0)
get_outputs(n_vars, e.source);
}
});
});
};
get_outputs(_targetHandles, node_id);
return pulled_data;
},
setDataPropsForNode: (id, data_props) => {
set({
nodes: (nds =>

View File

@ -409,6 +409,9 @@
color: #444;
white-space: pre-wrap;
}
.join-text-preview {
margin: 0px 0px 10px 0px;
}
.small-response {
font-size: 8pt;