mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Merge remote-tracking branch 'origin/main' into dev/ian
This commit is contained in:
commit
d11b23e23d
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,4 +5,5 @@ __pycache__
|
||||
python-backend/cache
|
||||
|
||||
# venv
|
||||
venv
|
||||
venv
|
||||
node_modules
|
||||
|
67
CONTRIBUTOR_GUIDE.md
Normal file
67
CONTRIBUTOR_GUIDE.md
Normal file
@ -0,0 +1,67 @@
|
||||
# Contributor Guide
|
||||
|
||||
This is a guide to running the current version of ChainForge, for people who want to develop or extend it.
|
||||
Note that this document will change in the future.
|
||||
|
||||
## Getting Started
|
||||
### Install requirements
|
||||
Before you can run ChainForge, you need to install dependencies. `cd` into `python-backend` and run
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
to install requirements. (Ideally, you will run this in a `virtualenv`.)
|
||||
|
||||
To install Node.hs requirements, first make sure you have Node.js installed. Then `cd` into `chain-forge` and run:
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
## Running ChainForge
|
||||
|
||||
To serve ChainForge, you have two options:
|
||||
1. Run everything from a single Python script, which requires building the React app to static files, or
|
||||
2. Serve the React front-end separately from the Flask back-end and take advantage of React hot reloading.
|
||||
|
||||
We recommend the former option for end-users, and the latter for developers.
|
||||
|
||||
### Option 1: Build React app as static files (end-users)
|
||||
|
||||
`cd` into `chain-forge` directory and run:
|
||||
|
||||
```
|
||||
npm run build
|
||||
```
|
||||
|
||||
Wait a moment while it builds the React app to static files.
|
||||
|
||||
### Option 2: Serve React front-end with hot reloading (developers)
|
||||
|
||||
`cd` into `chain-forge` directory and run the following to serve the React front-end:
|
||||
|
||||
```
|
||||
npm run start
|
||||
```
|
||||
|
||||
### Serving the backend
|
||||
|
||||
Regardless of which option you chose, `cd` into `python-backend` and run:
|
||||
|
||||
```bash
|
||||
python app.py
|
||||
```
|
||||
|
||||
> **Note**
|
||||
> You can add the `--dummy-responses` flag in case you're worried about making calls to OpenAI. This will spoof all LLM responses as random strings, and is great for testing the interface without accidentally spending $$.
|
||||
|
||||
This script spins up two servers, the main one on port 8000 and a SocketIO server on port 8001 (used for streaming progress updates).
|
||||
|
||||
If you built the React app statically, go to `localhost:8000` in a web browser to view the app.
|
||||
If you served the React app with hot reloading with `npm run start`, go to the server address you ran it on (usually `localhost:3000`).
|
||||
|
||||
|
||||
## Contributing to ChainForge
|
||||
|
||||
If you have access to the main repository, we request that you add a branch `dev/<your_first_name>` and develop changes from there. When you are ready to push changes, say to addres an open Issue, make a Pull Request on the `main` repository and assign the main developer (Ian Arawjo) to it.
|
33
README.md
33
README.md
@ -19,30 +19,23 @@ Taken together, these three features let you easily:
|
||||
|
||||
# Installation
|
||||
|
||||
To install, use `pip`. From the command line:
|
||||
To get started, currently see the `CONTRIBUTOR_GUIDE.md`. Below are the planned installation steps (which are not yet active):
|
||||
|
||||
```
|
||||
pip install chainforge
|
||||
```
|
||||
|
||||
[TODO: Upload CF to PyPI]
|
||||
[TODO: Create a command-line alias (?) so you can run `chainforge serve <react_port?> <py_port?>` and spin up both React and the Python backend automatically.]
|
||||
|
||||
To run simply, type:
|
||||
|
||||
```
|
||||
chainforge serve
|
||||
```
|
||||
|
||||
This spins up two local servers: a React server through npm, and a Python backend, powered by Flask. For more options, such as port numbers, type `chainforge --help`.
|
||||
|
||||
### Sharing prompt chains
|
||||
|
||||
All ChainForge node graphs are importable/exportable as JSON specs. You can freely share prompt chains you develop (alongside any custom analysis code), whether to the public or within your organization.
|
||||
>
|
||||
> To install, use `pip`. From the command line:
|
||||
>
|
||||
> ```
|
||||
> pip install chainforge
|
||||
> ```
|
||||
> To run, type:
|
||||
> ```
|
||||
> chainforge serve
|
||||
> ```
|
||||
> This spins up two local servers: a React server through npm, and a Python backend, powered by Flask. For more options, such as port numbers, type `chainforge --help`.
|
||||
|
||||
## Example: Test LLM robustness to prompt injection
|
||||
|
||||
...
|
||||
> ...
|
||||
|
||||
# Development
|
||||
|
||||
|
@ -10,6 +10,7 @@ const AlertModal = forwardRef((props, ref) => {
|
||||
|
||||
// This gives the parent access to triggering the modal alert
|
||||
const trigger = (msg) => {
|
||||
if (!msg) msg = "Unknown error.";
|
||||
console.error(msg);
|
||||
setAlertMsg(msg);
|
||||
open();
|
||||
|
@ -45,3 +45,8 @@ path.react-flow__edge-path:hover {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.rich-editor {
|
||||
min-width: 500px;
|
||||
min-height: 500px;
|
||||
}
|
@ -15,6 +15,7 @@ import EvaluatorNode from './EvaluatorNode';
|
||||
import VisNode from './VisNode';
|
||||
import InspectNode from './InspectorNode';
|
||||
import ScriptNode from './ScriptNode';
|
||||
import CsvNode from './CsvNode';
|
||||
import './text-fields-node.css';
|
||||
|
||||
// State management (from https://reactflow.dev/docs/guides/state-management/)
|
||||
@ -40,7 +41,8 @@ const nodeTypes = {
|
||||
evaluator: EvaluatorNode,
|
||||
vis: VisNode,
|
||||
inspect: InspectNode,
|
||||
script: ScriptNode
|
||||
script: ScriptNode,
|
||||
csv: CsvNode,
|
||||
};
|
||||
|
||||
const connectionLineStyle = { stroke: '#ddd' };
|
||||
@ -91,6 +93,11 @@ const App = () => {
|
||||
addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
|
||||
const addCsvNode = (event) => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'csvNode-'+Date.now(), type: 'csv', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
|
||||
/**
|
||||
* SAVING / LOADING, IMPORT / EXPORT (from JSON)
|
||||
*/
|
||||
@ -201,6 +208,7 @@ const App = () => {
|
||||
<button onClick={addVisNode}>Add vis node</button>
|
||||
<button onClick={addInspectNode}>Add inspect node</button>
|
||||
<button onClick={addScriptNode}>Add script node</button>
|
||||
<button onClick={addCsvNode}>Add csv node</button>
|
||||
<button onClick={saveFlow} style={{marginLeft: '12px'}}>Save</button>
|
||||
<button onClick={loadFlowFromCache}>Load</button>
|
||||
<button onClick={exportFlow} style={{marginLeft: '12px'}}>Export</button>
|
||||
|
129
chain-forge/src/CsvNode.js
Normal file
129
chain-forge/src/CsvNode.js
Normal file
@ -0,0 +1,129 @@
|
||||
import React, { useState, useRef, useEffect, useCallback } from 'react';
|
||||
import { Badge, Text } from '@mantine/core';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import { IconCsv } from '@tabler/icons-react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
|
||||
const CsvNode = ({ data, id }) => {
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const [contentDiv, setContentDiv] = useState(null);
|
||||
const [isEditing, setIsEditing] = useState(true);
|
||||
const [csvInput, setCsvInput] = useState(null);
|
||||
const [countText, setCountText] = useState(null);
|
||||
|
||||
// initializing
|
||||
useEffect(() => {
|
||||
if (!data.fields) {
|
||||
setDataPropsForNode(id, { text: '', fields: [] });
|
||||
}
|
||||
}, []);
|
||||
|
||||
const processCsv = (csv) => {
|
||||
var matches = csv.match(/(\s*"[^"]+"\s*|\s*[^,]+|,)(?=,|$)/g);
|
||||
for (var n = 0; n < matches.length; ++n) {
|
||||
matches[n] = matches[n].trim();
|
||||
if (matches[n] == ',') matches[n] = '';
|
||||
}
|
||||
return matches.map(e => e.trim()).filter(e => e.length > 0);
|
||||
}
|
||||
|
||||
// const processCsv = (csv) => {
|
||||
// if (!csv) return [];
|
||||
// // Split the input string by rows, and merge
|
||||
// var res = csv.split('\n').join(',');
|
||||
|
||||
// // remove all the empty or whitespace-only elements
|
||||
// return res.split(',').map(e => e.trim()).filter(e => e.length > 0);
|
||||
// }
|
||||
|
||||
// Handle a change in a text fields' input.
|
||||
const handleInputChange = useCallback((event) => {
|
||||
// Update the data for this text fields' id.
|
||||
let new_data = { 'text': event.target.value, 'fields': processCsv(event.target.value) };
|
||||
setDataPropsForNode(id, new_data);
|
||||
}, [id, setDataPropsForNode]);
|
||||
|
||||
const handKeyDown = useCallback((event) => {
|
||||
if (event.key === 'Enter') {
|
||||
setIsEditing(false);
|
||||
setCsvInput(null);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// handling Div Click
|
||||
const handleDivOnClick = useCallback((event) => {
|
||||
setIsEditing(true);
|
||||
}, []);
|
||||
|
||||
const handleOnBlur = useCallback((event) => {
|
||||
setIsEditing(false);
|
||||
setCsvInput(null);
|
||||
}, []);
|
||||
|
||||
// render csv div
|
||||
const renderCsvDiv = useCallback(() => {
|
||||
// Take the data.text as csv (only 1 row), and get individual elements
|
||||
const elements = data.fields;
|
||||
|
||||
// generate a HTML code that highlights the elements
|
||||
const html = [];
|
||||
elements.forEach((e, idx) => {
|
||||
// html.push(<Badge color="orange" size="lg" radius="sm">{e}</Badge>)
|
||||
html.push(<span key={idx} className="csv-element">{e}</span>);
|
||||
if (idx < elements.length - 1) {
|
||||
html.push(<span key={idx + 'comma'} className="csv-comma">,</span>);
|
||||
}
|
||||
});
|
||||
|
||||
setContentDiv(<div className='csv-div nowheel' onClick={handleDivOnClick}>
|
||||
{html}
|
||||
</div>);
|
||||
setCountText(<Text size="xs" style={{ marginTop: '5px' }} color='blue' align='right'>{elements.length} elements</Text>);
|
||||
}, [data.text, handleDivOnClick]);
|
||||
|
||||
// When isEditing changes, add input
|
||||
useEffect(() => {
|
||||
if (!isEditing) {
|
||||
setCsvInput(null);
|
||||
renderCsvDiv();
|
||||
return;
|
||||
}
|
||||
if (!csvInput) {
|
||||
var text_val = data.text || '';
|
||||
setCsvInput(
|
||||
<div className="input-field" key={id}>
|
||||
<textarea id={id} name={id} className="text-field-fixed nodrag csv-input" rows="2" cols="40" defaultValue={text_val} onChange={handleInputChange} placeholder='Paste your CSV text here' onKeyDown={handKeyDown} onBlur={handleOnBlur} autoFocus={true}/>
|
||||
</div>
|
||||
);
|
||||
setContentDiv(null);
|
||||
setCountText(null);
|
||||
}
|
||||
}, [isEditing]);
|
||||
|
||||
// when data.text changes, update the content div
|
||||
useEffect(() => {
|
||||
// When in editing mode, don't update the content div
|
||||
if (isEditing) return;
|
||||
if (!data.text) return;
|
||||
renderCsvDiv();
|
||||
|
||||
}, [id, data.text]);
|
||||
|
||||
return (
|
||||
<div className="text-fields-node cfnode">
|
||||
<NodeLabel title={data.title || 'CSV Node'} nodeId={id} icon={<IconCsv size="16px" />} />
|
||||
{csvInput}
|
||||
{contentDiv}
|
||||
{countText ? countText : <></>}
|
||||
<Handle
|
||||
type="source"
|
||||
position="right"
|
||||
id="output"
|
||||
style={{ top: "50%", background: '#555' }}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default CsvNode;
|
@ -71,7 +71,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
// Get all the script nodes, and get all the folder paths
|
||||
const script_nodes = nodes.filter(n => n.type === 'script');
|
||||
const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
|
||||
console.log(script_paths);
|
||||
|
||||
// Run evaluator in backend
|
||||
const codeTextOnRun = codeText + '';
|
||||
fetch(BASE_URL + 'app/execute', {
|
||||
@ -95,6 +95,8 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
alertModal.current.trigger(json ? json.error : 'Unknown error encountered when requesting evaluations: empty response returned.');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(json.responses);
|
||||
|
||||
// Ping any vis nodes attached to this node to refresh their contents:
|
||||
const output_nodes = outputEdgesForNode(id).map(e => e.target);
|
||||
@ -178,7 +180,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
:</div>
|
||||
|
||||
{/* <span className="code-style">response</span>: */}
|
||||
<div className="nodrag">
|
||||
<div className="ace-editor-container nodrag">
|
||||
<AceEditor
|
||||
mode="python"
|
||||
theme="xcode"
|
||||
@ -187,8 +189,14 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
name={"aceeditor_"+id}
|
||||
editorProps={{ $blockScrolling: true }}
|
||||
width='400px'
|
||||
height='200px'
|
||||
height='100px'
|
||||
tabSize={2}
|
||||
onLoad={editorInstance => { // Make Ace Editor div resizeable.
|
||||
editorInstance.container.style.resize = "both";
|
||||
document.addEventListener("mouseup", e => (
|
||||
editorInstance.resize()
|
||||
));
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{/* <CodeMirror
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { useState } from 'react';
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
@ -21,6 +21,7 @@ const InspectorNode = ({ data, id }) => {
|
||||
const [varSelects, setVarSelects] = useState([]);
|
||||
const [pastInputs, setPastInputs] = useState([]);
|
||||
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
|
||||
const handleVarValueSelect = () => {
|
||||
}
|
||||
@ -83,7 +84,7 @@ const InspectorNode = ({ data, id }) => {
|
||||
);
|
||||
});
|
||||
return (
|
||||
<div key={llm} className="llm-response-container nowheel">
|
||||
<div key={llm} className="llm-response-container">
|
||||
<h1>{llm}</h1>
|
||||
{res_divs}
|
||||
</div>
|
||||
@ -115,6 +116,14 @@ const InspectorNode = ({ data, id }) => {
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (data.refresh && data.refresh === true) {
|
||||
// Recreate the visualization:
|
||||
setDataPropsForNode(id, { refresh: false });
|
||||
handleOnConnect();
|
||||
}
|
||||
}, [data, id, handleOnConnect, setDataPropsForNode]);
|
||||
|
||||
return (
|
||||
<div className="inspector-node cfnode">
|
||||
<NodeLabel title={data.title || 'Inspect Node'}
|
||||
@ -123,7 +132,9 @@ const InspectorNode = ({ data, id }) => {
|
||||
{/* <div className="var-select-toolbar">
|
||||
{varSelects}
|
||||
</div> */}
|
||||
<div className="inspect-response-container nowheel nodrag">
|
||||
{responses}
|
||||
</div>
|
||||
<Handle
|
||||
type="target"
|
||||
position="left"
|
||||
|
@ -11,10 +11,10 @@ import io from 'socket.io-client';
|
||||
|
||||
// Available LLMs
|
||||
const allLLMs = [
|
||||
{ name: "GPT3.5", emoji: "🙂", model: "gpt3.5", temp: 1.0 },
|
||||
{ name: "GPT4", emoji: "🥵", model: "gpt4", temp: 1.0 },
|
||||
{ name: "GPT3.5", emoji: "🙂", model: "gpt-3.5-turbo", temp: 1.0 },
|
||||
{ name: "GPT4", emoji: "🥵", model: "gpt-4", temp: 1.0 },
|
||||
{ name: "Alpaca 7B", emoji: "🦙", model: "alpaca.7B", temp: 0.5 },
|
||||
{ name: "Claude v1", emoji: "📚", model: "claude.v1", temp: 0.5 },
|
||||
{ name: "Claude v1", emoji: "📚", model: "claude-v1", temp: 0.5 },
|
||||
{ name: "Ian Chatbot", emoji: "💩", model: "test", temp: 0.5 }
|
||||
];
|
||||
const initLLMs = [allLLMs[0]];
|
||||
@ -50,6 +50,7 @@ const PromptNode = ({ data, id }) => {
|
||||
const edges = useStore((state) => state.edges);
|
||||
const output = useStore((state) => state.output);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const outputEdgesForNode = useStore((state) => state.outputEdgesForNode);
|
||||
const getNode = useStore((state) => state.getNode);
|
||||
|
||||
const [templateVars, setTemplateVars] = useState(data.vars || []);
|
||||
@ -181,13 +182,15 @@ const PromptNode = ({ data, id }) => {
|
||||
prompt: prompt,
|
||||
vars: vars,
|
||||
llms: llms,
|
||||
id: id,
|
||||
n: numGenerations,
|
||||
})}, rejected).then(function(response) {
|
||||
return response.json();
|
||||
}, rejected).then(function(json) {
|
||||
if (!json || !json.counts) {
|
||||
throw new Error('Request was sent and received by backend server, but there was no response.');
|
||||
}
|
||||
return json.counts;
|
||||
return [json.counts, json.total_num_responses];
|
||||
}, rejected);
|
||||
};
|
||||
|
||||
@ -199,6 +202,12 @@ const PromptNode = ({ data, id }) => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if the PromptNode is not already waiting for a response...
|
||||
if (status === 'loading') {
|
||||
setRunTooltip('Fetching responses...');
|
||||
return;
|
||||
}
|
||||
|
||||
// Get input data and prompt
|
||||
const [py_prompt, pulled_vars] = pullInputData();
|
||||
const llms = llmItemsCurrState.map(item => item.model);
|
||||
@ -207,10 +216,48 @@ const PromptNode = ({ data, id }) => {
|
||||
// Fetch response counts from backend
|
||||
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
|
||||
console.warn(err.message); // soft fail
|
||||
}).then((counts) => {
|
||||
const n = counts[Object.keys(counts)[0]];
|
||||
const req = n > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
|
||||
}).then(([counts, total_num_responses]) => {
|
||||
// Check for empty counts (means no requests will be sent!)
|
||||
const num_llms_missing = Object.keys(counts).length;
|
||||
if (num_llms_missing === 0) {
|
||||
setRunTooltip('Will load responses from cache');
|
||||
return;
|
||||
}
|
||||
|
||||
// Tally how many queries per LLM:
|
||||
let queries_per_llm = {};
|
||||
Object.keys(counts).forEach(llm => {
|
||||
queries_per_llm[llm] = Object.keys(counts[llm]).reduce(
|
||||
(acc, prompt) => acc + counts[llm][prompt]
|
||||
, 0);
|
||||
});
|
||||
|
||||
// Check if all counts are the same:
|
||||
if (num_llms_missing > 1) {
|
||||
const some_llm_num = queries_per_llm[Object.keys(queries_per_llm)[0]];
|
||||
const all_same_num_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc && queries_per_llm[llm] === some_llm_num, true)
|
||||
if (num_llms_missing === num_llms && all_same_num_queries) { // Counts are the same
|
||||
const req = some_llm_num > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${some_llm_num} new ${req}` + (num_llms > 1 ? ' per LLM' : ''));
|
||||
}
|
||||
else if (all_same_num_queries) {
|
||||
const req = some_llm_num > 1 ? 'requests' : 'request';
|
||||
setRunTooltip(`Will send ${some_llm_num} new ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : ''));
|
||||
}
|
||||
else { // Counts are different
|
||||
const sum_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc + queries_per_llm[llm], 0);
|
||||
setRunTooltip(`Will send a variable # of queries to LLM(s) (total=${sum_queries})`);
|
||||
}
|
||||
} else {
|
||||
const llm_name = Object.keys(queries_per_llm)[0];
|
||||
const llm_count = queries_per_llm[llm_name];
|
||||
const req = llm_count > 1 ? 'queries' : 'query';
|
||||
if (num_llms > num_llms_missing)
|
||||
setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name} and load others`);
|
||||
else
|
||||
setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name}`)
|
||||
}
|
||||
|
||||
});
|
||||
};
|
||||
|
||||
@ -263,7 +310,7 @@ const PromptNode = ({ data, id }) => {
|
||||
py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected);
|
||||
|
||||
// Open a socket to listen for progress
|
||||
const open_progress_listener_socket = (response_counts) => {
|
||||
const open_progress_listener_socket = ([response_counts, total_num_responses]) => {
|
||||
// With the counts information we can create progress bars. Now we load a socket connection to
|
||||
// the socketio server that will stream to us the current progress:
|
||||
const socket = io('http://localhost:8001/' + 'queryllm', {
|
||||
@ -271,7 +318,7 @@ const PromptNode = ({ data, id }) => {
|
||||
cors: {origin: "http://localhost:8000/"},
|
||||
});
|
||||
|
||||
const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0);
|
||||
const max_responses = Object.keys(total_num_responses).reduce((acc, llm) => acc + total_num_responses[llm], 0);
|
||||
|
||||
// On connect to the server, ask it to give us the current progress
|
||||
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.
|
||||
@ -392,6 +439,15 @@ const PromptNode = ({ data, id }) => {
|
||||
);
|
||||
}));
|
||||
|
||||
// Ping any inspect nodes attached to this node to refresh their contents:
|
||||
const output_nodes = outputEdgesForNode(id).map(e => e.target);
|
||||
output_nodes.forEach(n => {
|
||||
const node = getNode(n);
|
||||
if (node && node.type === 'inspect') {
|
||||
setDataPropsForNode(node.id, { refresh: true });
|
||||
}
|
||||
});
|
||||
|
||||
// Log responses for debugging:
|
||||
console.log(json.responses);
|
||||
} else {
|
||||
|
@ -66,48 +66,75 @@ const VisNode = ({ data, id }) => {
|
||||
else
|
||||
responses_by_llm[item.llm] = [item];
|
||||
});
|
||||
const llm_names = Object.keys(responses_by_llm);
|
||||
|
||||
// Create Plotly spec here
|
||||
const varnames = Object.keys(json.responses[0].vars);
|
||||
let spec = {};
|
||||
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
|
||||
let spec = [];
|
||||
let layout = {
|
||||
width: 420, height: 300, title: '', margin: {
|
||||
l: 40, r: 20, b: 20, t: 20, pad: 2
|
||||
l: 100, r: 20, b: 20, t: 20, pad: 0
|
||||
}
|
||||
}
|
||||
|
||||
if (varnames.length === 1) {
|
||||
const plot_grouped_boxplot = (resp_to_x) => {
|
||||
llm_names.forEach((llm, idx) => {
|
||||
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
|
||||
const rs = responses_by_llm[llm];
|
||||
const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
|
||||
spec.push({
|
||||
type: 'box',
|
||||
name: llm,
|
||||
marker: {color: colors[idx % colors.length]},
|
||||
x: rs.map(r => r.eval_res.items).flat(),
|
||||
y: rs.map(r => Array(r.eval_res.items.length).fill(resp_to_x(r))).flat(),
|
||||
boxpoints: 'all',
|
||||
text: hover_texts,
|
||||
hovertemplate: '%{text} <b><i>(%{x})</i></b>',
|
||||
orientation: 'h',
|
||||
});
|
||||
});
|
||||
layout.boxmode = 'group';
|
||||
};
|
||||
|
||||
if (varnames.length === 0) {
|
||||
// No variables means they used a single prompt (no template) to generate responses
|
||||
// (Users are likely evaluating differences in responses between LLMs)
|
||||
plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12));
|
||||
// llm_names.forEach((llm, idx) => {
|
||||
// // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
|
||||
// const rs = responses_by_llm[llm];
|
||||
// const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
|
||||
// spec.push({
|
||||
// type: 'scatter',
|
||||
// name: llm,
|
||||
// marker: {color: colors[idx % colors.length]},
|
||||
// y: rs.map(r => r.eval_res.items).flat(),
|
||||
// x: rs.map(r => Array(r.eval_res.items.length).fill(truncStr(r.prompt.trim(), 12))).flat(), // use the prompt str as var name
|
||||
// // boxpoints: 'all',
|
||||
// mode: 'markers',
|
||||
// text: hover_texts,
|
||||
// hovertemplate: '%{text} <b><i>(%{y})</i></b>',
|
||||
// });
|
||||
// });
|
||||
// layout.scattermode = 'group';
|
||||
}
|
||||
else if (varnames.length === 1) {
|
||||
// 1 var; numeric eval
|
||||
if (Object.keys(responses_by_llm).length === 1) {
|
||||
if (llm_names.length === 1) {
|
||||
// Simple box plot, as there is only a single LLM in the response
|
||||
spec = json.responses.map(r => {
|
||||
// Use the var value to 'name' this group of points:
|
||||
const s = truncStr(r.vars[varnames[0]].trim(), 12);
|
||||
|
||||
return {type: 'box', y: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}'};
|
||||
return {type: 'box', x: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}', orientation: 'h'};
|
||||
});
|
||||
layout.hovermode = 'closest';
|
||||
} else {
|
||||
// There are multiple LLMs in the response; do a grouped box plot by LLM.
|
||||
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
|
||||
spec = [];
|
||||
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
|
||||
Object.keys(responses_by_llm).forEach((llm, idx) => {
|
||||
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
|
||||
const rs = responses_by_llm[llm];
|
||||
const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
|
||||
spec.push({
|
||||
type: 'box',
|
||||
name: llm,
|
||||
marker: {color: colors[idx % colors.length]},
|
||||
y: rs.map(r => r.eval_res.items).flat(),
|
||||
x: rs.map(r => Array(r.eval_res.items.length).fill(r.vars[varnames[0]].trim())).flat(),
|
||||
boxpoints: 'all',
|
||||
text: hover_texts,
|
||||
hovertemplate: '%{text}',
|
||||
});
|
||||
});
|
||||
layout.boxmode = 'group';
|
||||
plot_grouped_boxplot((r) => r.vars[varnames[0]].trim());
|
||||
}
|
||||
}
|
||||
else if (varnames.length === 2) {
|
||||
|
@ -21,4 +21,36 @@ code {
|
||||
|
||||
.script-node-input {
|
||||
min-width: 300px;
|
||||
}
|
||||
|
||||
.csv-element {
|
||||
position: relative;
|
||||
color: #8a3e07;
|
||||
background-color: #FFE8CC;
|
||||
font-size: inherit;
|
||||
font-family: monospace;
|
||||
padding: .2em .4em;
|
||||
border: none;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
display: inline-block;
|
||||
margin: 4px 2px;
|
||||
cursor: pointer;
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
/* set a muted text */
|
||||
.csv-comma {
|
||||
color: #FFC107;
|
||||
}
|
||||
|
||||
.csv-div {
|
||||
width: 350px;
|
||||
max-height: 250px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.csv-input {
|
||||
width: 350px;
|
||||
height: 150px;
|
||||
}
|
@ -144,6 +144,10 @@
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.ace-editor-container {
|
||||
resize:vertical;
|
||||
}
|
||||
|
||||
.vis-node {
|
||||
background-color: #fff;
|
||||
padding: 10px;
|
||||
@ -156,8 +160,12 @@
|
||||
padding: 10px;
|
||||
border: 1px solid #999;
|
||||
border-radius: 5px;
|
||||
}
|
||||
.inspect-response-container {
|
||||
overflow-y: auto;
|
||||
max-height: 400px;
|
||||
width: 450px;
|
||||
max-height: 350px;
|
||||
resize: both;
|
||||
}
|
||||
|
||||
.small-response {
|
||||
@ -172,7 +180,7 @@
|
||||
}
|
||||
|
||||
.llm-response-container {
|
||||
max-width: 450px;
|
||||
max-width: 100%;
|
||||
}
|
||||
.llm-response-container h1 {
|
||||
font-weight: 400;
|
||||
|
@ -1,4 +1,4 @@
|
||||
import json, os, asyncio, sys, argparse, threading
|
||||
import json, os, asyncio, sys, argparse, threading, traceback
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify, render_template, send_from_directory
|
||||
@ -21,12 +21,9 @@ cors = CORS(app, resources={r"/*": {"origins": "*"}})
|
||||
def index():
|
||||
return render_template("index.html")
|
||||
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
'alpaca.7B': LLM.Alpaca7B,
|
||||
'gpt4': LLM.GPT4,
|
||||
}
|
||||
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
|
||||
LLM_NAME_MAP = {}
|
||||
for model in LLM:
|
||||
LLM_NAME_MAP[model.value] = model
|
||||
|
||||
@dataclass
|
||||
class ResponseInfo:
|
||||
@ -40,7 +37,7 @@ class ResponseInfo:
|
||||
return self.text
|
||||
|
||||
def to_standard_format(r: dict) -> list:
|
||||
llm = LLM_NAME_MAP_INVERSE[r['llm']]
|
||||
llm = r['llm']
|
||||
resp_obj = {
|
||||
'vars': r['info'],
|
||||
'llm': llm,
|
||||
@ -52,9 +49,6 @@ def to_standard_format(r: dict) -> list:
|
||||
resp_obj['eval_res'] = r['eval_res']
|
||||
return resp_obj
|
||||
|
||||
def get_llm_of_response(response: dict) -> LLM:
|
||||
return LLM_NAME_MAP[response['llm']]
|
||||
|
||||
def get_filenames_with_id(filenames: list, id: str) -> list:
|
||||
return [
|
||||
c for c in filenames
|
||||
@ -72,17 +66,17 @@ def load_cache_json(filepath: str) -> dict:
|
||||
responses = json.load(f)
|
||||
return responses
|
||||
|
||||
def run_over_responses(eval_func, responses: dict, scope: str) -> list:
|
||||
for prompt, resp_obj in responses.items():
|
||||
res = extract_responses(resp_obj, resp_obj['llm'])
|
||||
def run_over_responses(eval_func, responses: list, scope: str) -> list:
|
||||
for resp_obj in responses:
|
||||
res = resp_obj['responses']
|
||||
if scope == 'response':
|
||||
evals = [ # Run evaluator func over every individual response text
|
||||
eval_func(
|
||||
ResponseInfo(
|
||||
text=r,
|
||||
prompt=prompt,
|
||||
var=resp_obj['info'],
|
||||
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
|
||||
prompt=resp_obj['prompt'],
|
||||
var=resp_obj['vars'],
|
||||
llm=resp_obj['llm'])
|
||||
) for r in res
|
||||
]
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
@ -148,6 +142,30 @@ def reduce_responses(responses: list, vars: list) -> list:
|
||||
|
||||
return ret
|
||||
|
||||
def load_all_cached_responses(cache_ids):
|
||||
if not isinstance(cache_ids, list):
|
||||
cache_ids = [cache_ids]
|
||||
|
||||
# Load all responses with the given ID:
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
responses = []
|
||||
for cache_id in cache_ids:
|
||||
cache_files = [fname for fname in get_filenames_with_id(all_cache_files, cache_id) if fname != f"{cache_id}.json"]
|
||||
if len(cache_files) == 0:
|
||||
continue
|
||||
|
||||
for filename in cache_files:
|
||||
res = load_cache_json(os.path.join('cache', filename))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
return responses
|
||||
|
||||
@app.route('/app/countQueriesRequired', methods=['POST'])
|
||||
def countQueries():
|
||||
"""
|
||||
@ -158,24 +176,57 @@ def countQueries():
|
||||
'prompt': str # the prompt template, with any {{}} vars
|
||||
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
|
||||
'llms': list # the list of LLMs you will query
|
||||
'n': int # how many responses expected per prompt
|
||||
'id': str (optional) # a unique ID of the node with cache'd responses. If missing, assumes no cache will be used.
|
||||
}
|
||||
"""
|
||||
data = request.get_json()
|
||||
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
|
||||
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms', 'n'}):
|
||||
return jsonify({'error': 'POST data is improper format.'})
|
||||
|
||||
n = int(data['n'])
|
||||
|
||||
try:
|
||||
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
|
||||
all_prompt_permutations = list(gen_prompts(data['vars']))
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
if 'id' in data:
|
||||
# Load all cache'd responses with the given id:
|
||||
cached_resps = load_all_cached_responses(data['id'])
|
||||
else:
|
||||
cached_resps = []
|
||||
|
||||
missing_queries = {}
|
||||
num_responses_req = {}
|
||||
def add_to_missing_queries(llm, prompt, num):
|
||||
if llm not in missing_queries:
|
||||
missing_queries[llm] = {}
|
||||
missing_queries[llm][prompt] = num
|
||||
def add_to_num_responses_req(llm, num):
|
||||
if llm not in num_responses_req:
|
||||
num_responses_req[llm] = 0
|
||||
num_responses_req[llm] += num
|
||||
|
||||
# Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt
|
||||
for prompt in all_prompt_permutations:
|
||||
prompt = str(prompt)
|
||||
matching_resps = [r for r in cached_resps if r['prompt'] == prompt]
|
||||
for llm in data['llms']:
|
||||
add_to_num_responses_req(llm, n)
|
||||
match_per_llm = [r for r in matching_resps if r['llm'] == llm]
|
||||
if len(match_per_llm) == 0:
|
||||
add_to_missing_queries(llm, prompt, n)
|
||||
elif len(match_per_llm) == 1:
|
||||
# Check how many were stored; if not enough, add how many missing queries:
|
||||
num_resps = len(match_per_llm[0]['responses'])
|
||||
if n > len(match_per_llm[0]['responses']):
|
||||
add_to_missing_queries(llm, prompt, n - num_resps)
|
||||
else:
|
||||
raise Exception(f"More than one response found for the same prompt ({prompt}) and LLM ({llm})")
|
||||
|
||||
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
|
||||
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
|
||||
for llm in data['llms']:
|
||||
num_queries[llm] = len(all_prompt_permutations)
|
||||
|
||||
ret = jsonify({'counts': num_queries})
|
||||
ret = jsonify({'counts': missing_queries, 'total_num_responses': num_responses_req})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
|
||||
@ -245,6 +296,11 @@ async def queryLLM():
|
||||
# Create a cache dir if it doesn't exist:
|
||||
create_dir_if_not_exists('cache')
|
||||
|
||||
# Check that the filepath used to cache eval'd responses is valid:
|
||||
cache_filepath_last_run = os.path.join('cache', f"{data['id']}.json")
|
||||
if not is_valid_filepath(cache_filepath_last_run):
|
||||
return jsonify({'error': f'Invalid filepath: {cache_filepath_last_run}'})
|
||||
|
||||
# For each LLM, generate and cache responses:
|
||||
responses = {}
|
||||
llms = data['llm']
|
||||
@ -265,23 +321,27 @@ async def queryLLM():
|
||||
# Prompt the LLM with all permutations of the input prompt template:
|
||||
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
|
||||
resps = []
|
||||
num_resps = 0
|
||||
try:
|
||||
print(f'Querying {llm}...')
|
||||
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
resps.append(response)
|
||||
print(f"collected response from {llm.name}:", str(response))
|
||||
|
||||
num_resps += len(extract_responses(response, llm))
|
||||
|
||||
# Save the number of responses collected to a temp file on disk
|
||||
with open(tempfilepath, 'r') as f:
|
||||
txt = f.read().strip()
|
||||
|
||||
cur_data = json.loads(txt) if len(txt) > 0 else {}
|
||||
cur_data[llm_str] = len(resps)
|
||||
cur_data[llm_str] = num_resps
|
||||
|
||||
with open(tempfilepath, 'w') as f:
|
||||
json.dump(cur_data, f)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
print(f'error generating responses for {llm}:', e)
|
||||
print(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
return {'llm': llm, 'responses': resps}
|
||||
@ -308,6 +368,10 @@ async def queryLLM():
|
||||
# Remove the temp file used to stream progress updates:
|
||||
if os.path.exists(tempfilepath):
|
||||
os.remove(tempfilepath)
|
||||
|
||||
# Save the responses *of this run* to the disk, for further recall:
|
||||
with open(cache_filepath_last_run, "w") as f:
|
||||
json.dump(res, f)
|
||||
|
||||
# Return all responses for all LLMs
|
||||
print('returning responses:', res)
|
||||
@ -388,38 +452,29 @@ def execute():
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
all_evald_responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
fname = f"{cache_id}.json"
|
||||
if fname not in all_cache_files:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
|
||||
for filename in cache_files:
|
||||
# Load the raw responses from the cache
|
||||
responses = load_cache_json(os.path.join('cache', fname))
|
||||
if len(responses) == 0: continue
|
||||
|
||||
# Load the raw responses from the cache
|
||||
responses = load_cache_json(os.path.join('cache', filename))
|
||||
if len(responses) == 0: continue
|
||||
# Run the evaluator over them:
|
||||
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
|
||||
try:
|
||||
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
|
||||
|
||||
# Run the evaluator over them:
|
||||
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
|
||||
try:
|
||||
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
|
||||
# Perform any reduction operations:
|
||||
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
|
||||
evald_responses = reduce_responses(
|
||||
evald_responses,
|
||||
vars=data['reduce_vars']
|
||||
)
|
||||
|
||||
# Convert to standard format:
|
||||
std_evald_responses = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in evald_responses.items()
|
||||
]
|
||||
|
||||
# Perform any reduction operations:
|
||||
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
|
||||
std_evald_responses = reduce_responses(
|
||||
std_evald_responses,
|
||||
vars=data['reduce_vars']
|
||||
)
|
||||
|
||||
all_evald_responses.extend(std_evald_responses)
|
||||
all_evald_responses.extend(evald_responses)
|
||||
|
||||
# Store the evaluated responses in a new cache json:
|
||||
with open(cache_filepath, "w") as f:
|
||||
@ -481,19 +536,18 @@ def grabResponses():
|
||||
all_cache_files = get_files_at_dir('cache/')
|
||||
responses = []
|
||||
for cache_id in data['responses']:
|
||||
cache_files = get_filenames_with_id(all_cache_files, cache_id)
|
||||
if len(cache_files) == 0:
|
||||
fname = f"{cache_id}.json"
|
||||
if fname not in all_cache_files:
|
||||
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
|
||||
|
||||
for filename in cache_files:
|
||||
res = load_cache_json(os.path.join('cache', filename))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
res = load_cache_json(os.path.join('cache', fname))
|
||||
if isinstance(res, dict):
|
||||
# Convert to standard response format
|
||||
res = [
|
||||
to_standard_format({'prompt': prompt, **res_obj})
|
||||
for prompt, res_obj in res.items()
|
||||
]
|
||||
responses.extend(res)
|
||||
|
||||
ret = jsonify({'responses': responses})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
|
39
python-backend/promptengine/models.py
Normal file
39
python-backend/promptengine/models.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
A list of all model APIs natively supported by ChainForge.
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
class LLM(str, Enum):
|
||||
""" OpenAI Chat """
|
||||
ChatGPT = "gpt-3.5-turbo"
|
||||
GPT4 = "gpt-4"
|
||||
|
||||
""" Dalai-served models """
|
||||
Alpaca7B = "alpaca.7B"
|
||||
|
||||
""" Anthropic """
|
||||
# Our largest model, ideal for a wide range of more complex tasks. Using this model name
|
||||
# will automatically switch you to newer versions of claude-v1 as they are released.
|
||||
Claude_v1 = "claude-v1"
|
||||
|
||||
# An earlier version of claude-v1
|
||||
Claude_v1_0 = "claude-v1.0"
|
||||
|
||||
# An improved version of claude-v1. It is slightly improved at general helpfulness,
|
||||
# instruction following, coding, and other tasks. It is also considerably better with
|
||||
# non-English languages. This model also has the ability to role play (in harmless ways)
|
||||
# more consistently, and it defaults to writing somewhat longer and more thorough responses.
|
||||
Claude_v1_2 = "claude-v1.2"
|
||||
|
||||
# A significantly improved version of claude-v1. Compared to claude-v1.2, it's more robust
|
||||
# against red-team inputs, better at precise instruction-following, better at code, and better
|
||||
# and non-English dialogue and writing.
|
||||
Claude_v1_3 = "claude-v1.3"
|
||||
|
||||
# A smaller model with far lower latency, sampling at roughly 40 words/sec! Its output quality
|
||||
# is somewhat lower than claude-v1 models, particularly for complex tasks. However, it is much
|
||||
# less expensive and blazing fast. We believe that this model provides more than adequate performance
|
||||
# on a range of tasks including text classification, summarization, and lightweight chat applications,
|
||||
# as well as search result summarization. Using this model name will automatically switch you to newer
|
||||
# versions of claude-instant-v1 as they are released.
|
||||
Claude_v1_instant = "claude-instant-v1"
|
@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Tuple, Iterator
|
||||
from typing import List, Dict, Tuple, Iterator, Union
|
||||
import json, os, asyncio, random, string
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, is_valid_filepath, is_valid_json
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json, cull_responses, extract_responses
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
|
||||
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
|
||||
@ -59,13 +59,14 @@ class PromptPipeline:
|
||||
prompt_str = str(prompt)
|
||||
|
||||
# First check if there is already a response for this item. If so, we can save an LLM call:
|
||||
if prompt_str in responses:
|
||||
if prompt_str in responses and len(extract_responses(responses[prompt_str], llm)) >= n:
|
||||
print(f" - Found cache'd response for prompt {prompt_str}. Using...")
|
||||
responses[prompt_str] = cull_responses(responses[prompt_str], llm, n)
|
||||
yield {
|
||||
"prompt": prompt_str,
|
||||
"query": responses[prompt_str]["query"],
|
||||
"response": responses[prompt_str]["response"],
|
||||
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.name,
|
||||
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.value,
|
||||
"info": responses[prompt_str]["info"],
|
||||
}
|
||||
continue
|
||||
@ -86,7 +87,7 @@ class PromptPipeline:
|
||||
responses[str(prompt)] = {
|
||||
"query": query,
|
||||
"response": response,
|
||||
"llm": llm.name,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
self._cache_responses(responses)
|
||||
@ -96,7 +97,7 @@ class PromptPipeline:
|
||||
"prompt":str(prompt),
|
||||
"query":query,
|
||||
"response":response,
|
||||
"llm": llm.name,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
|
||||
@ -114,7 +115,7 @@ class PromptPipeline:
|
||||
responses[str(prompt)] = {
|
||||
"query": query,
|
||||
"response": response,
|
||||
"llm": llm.name,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
self._cache_responses(responses)
|
||||
@ -124,7 +125,7 @@ class PromptPipeline:
|
||||
"prompt":str(prompt),
|
||||
"query":query,
|
||||
"response":response,
|
||||
"llm": llm.name,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
|
||||
@ -147,11 +148,13 @@ class PromptPipeline:
|
||||
def clear_cached_responses(self) -> None:
|
||||
self._cache_responses({})
|
||||
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Union[List, Dict]]:
|
||||
if llm is LLM.ChatGPT or llm is LLM.GPT4:
|
||||
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
|
||||
elif llm is LLM.Alpaca7B:
|
||||
query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
|
||||
query, response = await call_dalai(model=llm, port=4000, prompt=str(prompt), n=n, temperature=temperature)
|
||||
elif llm.value[:6] == 'claude':
|
||||
query, response = await call_anthropic(prompt=str(prompt), model=llm, n=n, temperature=temperature)
|
||||
else:
|
||||
raise Exception(f"Language model {llm} is not supported.")
|
||||
return prompt, query, response
|
||||
@ -175,5 +178,7 @@ class PromptLLM(PromptPipeline):
|
||||
"""
|
||||
class PromptLLMDummy(PromptLLM):
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]:
|
||||
# Wait a random amount of time, to simulate wait times from real queries
|
||||
await asyncio.sleep(random.uniform(0.1, 3))
|
||||
return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(40)) for _ in range(n)])
|
||||
# Return a random string of characters of random length (within a predefined range)
|
||||
return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(random.randint(25, 80))) for _ in range(n)])
|
@ -1,28 +1,22 @@
|
||||
from typing import Dict, Tuple, List, Union
|
||||
from enum import Enum
|
||||
import openai
|
||||
from typing import Dict, Tuple, List, Union, Callable
|
||||
import json, os, time, asyncio
|
||||
|
||||
from promptengine.models import LLM
|
||||
|
||||
DALAI_MODEL = None
|
||||
DALAI_RESPONSE = None
|
||||
|
||||
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
""" Supported LLM coding assistants """
|
||||
class LLM(Enum):
|
||||
ChatGPT = 0
|
||||
Alpaca7B = 1
|
||||
GPT4 = 2
|
||||
|
||||
async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls GPT3.5 via OpenAI's API.
|
||||
Returns raw query and response JSON dicts.
|
||||
|
||||
NOTE: It is recommended to set an environment variable OPENAI_API_KEY with your OpenAI API key
|
||||
"""
|
||||
model_map = { LLM.ChatGPT: 'gpt-3.5-turbo', LLM.GPT4: 'gpt-4' }
|
||||
if model not in model_map:
|
||||
raise Exception(f"Could not find OpenAI chat model {model}")
|
||||
model = model_map[model]
|
||||
import openai
|
||||
if not openai.api_key:
|
||||
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
model = model.value
|
||||
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
|
||||
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
|
||||
query = {
|
||||
@ -37,13 +31,64 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
|
||||
response = openai.ChatCompletion.create(**query)
|
||||
return query, response
|
||||
|
||||
async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
|
||||
custom_prompt_wrapper: Union[Callable[[str], str], None]=None,
|
||||
max_tokens_to_sample=1024,
|
||||
stop_sequences: Union[List[str], str]=["\n\nHuman:"],
|
||||
async_mode=False,
|
||||
**params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls Anthropic API with the given model, passing in params.
|
||||
Returns raw query and response JSON dicts.
|
||||
|
||||
Unique parameters:
|
||||
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
|
||||
explore custom prompt wrappers that deviate, write a function that maps from 'prompt' to custom wrapper.
|
||||
If set to None, defaults to Anthropic's suggested prompt wrapper.
|
||||
- max_tokens_to_sample: A maximum number of tokens to generate before stopping.
|
||||
- stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent.
|
||||
- async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async.
|
||||
If you want to send all 'n' requests at once, you can set async_mode to True.
|
||||
|
||||
NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
|
||||
"""
|
||||
import anthropic
|
||||
client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
|
||||
|
||||
# Format query
|
||||
query = {
|
||||
'model': model.value,
|
||||
'prompt': f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" if not custom_prompt_wrapper else custom_prompt_wrapper(prompt),
|
||||
'max_tokens_to_sample': max_tokens_to_sample,
|
||||
'stop_sequences': stop_sequences,
|
||||
'temperature': temperature,
|
||||
**params
|
||||
}
|
||||
|
||||
print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...")
|
||||
|
||||
# Request responses using the passed async_mode
|
||||
responses = []
|
||||
if async_mode:
|
||||
# Gather n responses by firing off all API requests at once
|
||||
tasks = [client.acompletion(**query) for _ in range(n)]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
else:
|
||||
# Repeat call n times, waiting for each response to come in:
|
||||
while len(responses) < n:
|
||||
resp = await client.acompletion(**query)
|
||||
responses.append(resp)
|
||||
print(f'{model.value} response {len(responses)} of {n}:\n', resp)
|
||||
|
||||
return query, responses
|
||||
|
||||
async def call_dalai(model: LLM, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
|
||||
Returns the raw query and response JSON dicts.
|
||||
|
||||
Parameters:
|
||||
- llm_name: The LLM's name as known by Dalai; e.g., 'alpaca.7b'
|
||||
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
|
||||
- port: The port of the local server where Dalai is running. Usually 3000.
|
||||
- prompt: The prompt to pass to the LLM.
|
||||
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
|
||||
@ -75,7 +120,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
|
||||
# Create full query to Dalai
|
||||
query = {
|
||||
'prompt': prompt,
|
||||
'model': llm_name,
|
||||
'model': model.value,
|
||||
'id': str(round(time.time()*1000)),
|
||||
'temp': temperature,
|
||||
**def_params
|
||||
@ -129,18 +174,36 @@ def _extract_chatgpt_responses(response: dict) -> List[dict]:
|
||||
choices = response["response"]["choices"]
|
||||
return [
|
||||
c["message"]["content"]
|
||||
for i, c in enumerate(choices)
|
||||
for c in choices
|
||||
]
|
||||
|
||||
def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
|
||||
def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[dict]:
|
||||
"""
|
||||
Given a LLM and a response object from its API, extract the
|
||||
text response(s) part of the response object.
|
||||
"""
|
||||
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name or llm is LLM.GPT4 or llm == LLM.GPT4.name:
|
||||
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value:
|
||||
return _extract_chatgpt_responses(response)
|
||||
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.name:
|
||||
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value:
|
||||
return response["response"]
|
||||
elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'):
|
||||
return [r["completion"] for r in response["response"]]
|
||||
else:
|
||||
raise ValueError(f"LLM {llm} is unsupported.")
|
||||
|
||||
def cull_responses(response: Union[list, dict], llm: Union[LLM, str], n: int) -> Union[list, dict]:
|
||||
"""
|
||||
Returns the same 'response' but with only 'n' responses.
|
||||
"""
|
||||
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value:
|
||||
response["response"]["choices"] = response["response"]["choices"][:n]
|
||||
return response
|
||||
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value:
|
||||
response["response"] = response["response"][:n]
|
||||
return response
|
||||
elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'):
|
||||
response["response"] = response["response"][:n]
|
||||
return response
|
||||
else:
|
||||
raise ValueError(f"LLM {llm} is unsupported.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user