Merge remote-tracking branch 'origin/main' into dev/ian

This commit is contained in:
Ian Arawjo 2023-05-10 13:22:32 -04:00
commit d11b23e23d
17 changed files with 667 additions and 160 deletions

3
.gitignore vendored
View File

@ -5,4 +5,5 @@ __pycache__
python-backend/cache
# venv
venv
venv
node_modules

67
CONTRIBUTOR_GUIDE.md Normal file
View File

@ -0,0 +1,67 @@
# Contributor Guide
This is a guide to running the current version of ChainForge, for people who want to develop or extend it.
Note that this document will change in the future.
## Getting Started
### Install requirements
Before you can run ChainForge, you need to install dependencies. `cd` into `python-backend` and run
```bash
pip install -r requirements.txt
```
to install requirements. (Ideally, you will run this in a `virtualenv`.)
To install Node.hs requirements, first make sure you have Node.js installed. Then `cd` into `chain-forge` and run:
```bash
npm install
```
## Running ChainForge
To serve ChainForge, you have two options:
1. Run everything from a single Python script, which requires building the React app to static files, or
2. Serve the React front-end separately from the Flask back-end and take advantage of React hot reloading.
We recommend the former option for end-users, and the latter for developers.
### Option 1: Build React app as static files (end-users)
`cd` into `chain-forge` directory and run:
```
npm run build
```
Wait a moment while it builds the React app to static files.
### Option 2: Serve React front-end with hot reloading (developers)
`cd` into `chain-forge` directory and run the following to serve the React front-end:
```
npm run start
```
### Serving the backend
Regardless of which option you chose, `cd` into `python-backend` and run:
```bash
python app.py
```
> **Note**
> You can add the `--dummy-responses` flag in case you're worried about making calls to OpenAI. This will spoof all LLM responses as random strings, and is great for testing the interface without accidentally spending $$.
This script spins up two servers, the main one on port 8000 and a SocketIO server on port 8001 (used for streaming progress updates).
If you built the React app statically, go to `localhost:8000` in a web browser to view the app.
If you served the React app with hot reloading with `npm run start`, go to the server address you ran it on (usually `localhost:3000`).
## Contributing to ChainForge
If you have access to the main repository, we request that you add a branch `dev/<your_first_name>` and develop changes from there. When you are ready to push changes, say to addres an open Issue, make a Pull Request on the `main` repository and assign the main developer (Ian Arawjo) to it.

View File

@ -19,30 +19,23 @@ Taken together, these three features let you easily:
# Installation
To install, use `pip`. From the command line:
To get started, currently see the `CONTRIBUTOR_GUIDE.md`. Below are the planned installation steps (which are not yet active):
```
pip install chainforge
```
[TODO: Upload CF to PyPI]
[TODO: Create a command-line alias (?) so you can run `chainforge serve <react_port?> <py_port?>` and spin up both React and the Python backend automatically.]
To run simply, type:
```
chainforge serve
```
This spins up two local servers: a React server through npm, and a Python backend, powered by Flask. For more options, such as port numbers, type `chainforge --help`.
### Sharing prompt chains
All ChainForge node graphs are importable/exportable as JSON specs. You can freely share prompt chains you develop (alongside any custom analysis code), whether to the public or within your organization.
>
> To install, use `pip`. From the command line:
>
> ```
> pip install chainforge
> ```
> To run, type:
> ```
> chainforge serve
> ```
> This spins up two local servers: a React server through npm, and a Python backend, powered by Flask. For more options, such as port numbers, type `chainforge --help`.
## Example: Test LLM robustness to prompt injection
...
> ...
# Development

View File

@ -10,6 +10,7 @@ const AlertModal = forwardRef((props, ref) => {
// This gives the parent access to triggering the modal alert
const trigger = (msg) => {
if (!msg) msg = "Unknown error.";
console.error(msg);
setAlertMsg(msg);
open();

View File

@ -45,3 +45,8 @@ path.react-flow__edge-path:hover {
transform: rotate(360deg);
}
}
.rich-editor {
min-width: 500px;
min-height: 500px;
}

View File

@ -15,6 +15,7 @@ import EvaluatorNode from './EvaluatorNode';
import VisNode from './VisNode';
import InspectNode from './InspectorNode';
import ScriptNode from './ScriptNode';
import CsvNode from './CsvNode';
import './text-fields-node.css';
// State management (from https://reactflow.dev/docs/guides/state-management/)
@ -40,7 +41,8 @@ const nodeTypes = {
evaluator: EvaluatorNode,
vis: VisNode,
inspect: InspectNode,
script: ScriptNode
script: ScriptNode,
csv: CsvNode,
};
const connectionLineStyle = { stroke: '#ddd' };
@ -91,6 +93,11 @@ const App = () => {
addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} });
};
const addCsvNode = (event) => {
const { x, y } = getViewportCenter();
addNode({ id: 'csvNode-'+Date.now(), type: 'csv', data: {}, position: {x: x-200, y:y-100} });
};
/**
* SAVING / LOADING, IMPORT / EXPORT (from JSON)
*/
@ -201,6 +208,7 @@ const App = () => {
<button onClick={addVisNode}>Add vis node</button>
<button onClick={addInspectNode}>Add inspect node</button>
<button onClick={addScriptNode}>Add script node</button>
<button onClick={addCsvNode}>Add csv node</button>
<button onClick={saveFlow} style={{marginLeft: '12px'}}>Save</button>
<button onClick={loadFlowFromCache}>Load</button>
<button onClick={exportFlow} style={{marginLeft: '12px'}}>Export</button>

129
chain-forge/src/CsvNode.js Normal file
View File

@ -0,0 +1,129 @@
import React, { useState, useRef, useEffect, useCallback } from 'react';
import { Badge, Text } from '@mantine/core';
import useStore from './store';
import NodeLabel from './NodeLabelComponent'
import { IconCsv } from '@tabler/icons-react';
import { Handle } from 'react-flow-renderer';
const CsvNode = ({ data, id }) => {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const [contentDiv, setContentDiv] = useState(null);
const [isEditing, setIsEditing] = useState(true);
const [csvInput, setCsvInput] = useState(null);
const [countText, setCountText] = useState(null);
// initializing
useEffect(() => {
if (!data.fields) {
setDataPropsForNode(id, { text: '', fields: [] });
}
}, []);
const processCsv = (csv) => {
var matches = csv.match(/(\s*"[^"]+"\s*|\s*[^,]+|,)(?=,|$)/g);
for (var n = 0; n < matches.length; ++n) {
matches[n] = matches[n].trim();
if (matches[n] == ',') matches[n] = '';
}
return matches.map(e => e.trim()).filter(e => e.length > 0);
}
// const processCsv = (csv) => {
// if (!csv) return [];
// // Split the input string by rows, and merge
// var res = csv.split('\n').join(',');
// // remove all the empty or whitespace-only elements
// return res.split(',').map(e => e.trim()).filter(e => e.length > 0);
// }
// Handle a change in a text fields' input.
const handleInputChange = useCallback((event) => {
// Update the data for this text fields' id.
let new_data = { 'text': event.target.value, 'fields': processCsv(event.target.value) };
setDataPropsForNode(id, new_data);
}, [id, setDataPropsForNode]);
const handKeyDown = useCallback((event) => {
if (event.key === 'Enter') {
setIsEditing(false);
setCsvInput(null);
}
}, []);
// handling Div Click
const handleDivOnClick = useCallback((event) => {
setIsEditing(true);
}, []);
const handleOnBlur = useCallback((event) => {
setIsEditing(false);
setCsvInput(null);
}, []);
// render csv div
const renderCsvDiv = useCallback(() => {
// Take the data.text as csv (only 1 row), and get individual elements
const elements = data.fields;
// generate a HTML code that highlights the elements
const html = [];
elements.forEach((e, idx) => {
// html.push(<Badge color="orange" size="lg" radius="sm">{e}</Badge>)
html.push(<span key={idx} className="csv-element">{e}</span>);
if (idx < elements.length - 1) {
html.push(<span key={idx + 'comma'} className="csv-comma">,</span>);
}
});
setContentDiv(<div className='csv-div nowheel' onClick={handleDivOnClick}>
{html}
</div>);
setCountText(<Text size="xs" style={{ marginTop: '5px' }} color='blue' align='right'>{elements.length} elements</Text>);
}, [data.text, handleDivOnClick]);
// When isEditing changes, add input
useEffect(() => {
if (!isEditing) {
setCsvInput(null);
renderCsvDiv();
return;
}
if (!csvInput) {
var text_val = data.text || '';
setCsvInput(
<div className="input-field" key={id}>
<textarea id={id} name={id} className="text-field-fixed nodrag csv-input" rows="2" cols="40" defaultValue={text_val} onChange={handleInputChange} placeholder='Paste your CSV text here' onKeyDown={handKeyDown} onBlur={handleOnBlur} autoFocus={true}/>
</div>
);
setContentDiv(null);
setCountText(null);
}
}, [isEditing]);
// when data.text changes, update the content div
useEffect(() => {
// When in editing mode, don't update the content div
if (isEditing) return;
if (!data.text) return;
renderCsvDiv();
}, [id, data.text]);
return (
<div className="text-fields-node cfnode">
<NodeLabel title={data.title || 'CSV Node'} nodeId={id} icon={<IconCsv size="16px" />} />
{csvInput}
{contentDiv}
{countText ? countText : <></>}
<Handle
type="source"
position="right"
id="output"
style={{ top: "50%", background: '#555' }}
/>
</div>
);
};
export default CsvNode;

View File

@ -71,7 +71,7 @@ const EvaluatorNode = ({ data, id }) => {
// Get all the script nodes, and get all the folder paths
const script_nodes = nodes.filter(n => n.type === 'script');
const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
console.log(script_paths);
// Run evaluator in backend
const codeTextOnRun = codeText + '';
fetch(BASE_URL + 'app/execute', {
@ -95,6 +95,8 @@ const EvaluatorNode = ({ data, id }) => {
alertModal.current.trigger(json ? json.error : 'Unknown error encountered when requesting evaluations: empty response returned.');
return;
}
console.log(json.responses);
// Ping any vis nodes attached to this node to refresh their contents:
const output_nodes = outputEdgesForNode(id).map(e => e.target);
@ -178,7 +180,7 @@ const EvaluatorNode = ({ data, id }) => {
:</div>
{/* <span className="code-style">response</span>: */}
<div className="nodrag">
<div className="ace-editor-container nodrag">
<AceEditor
mode="python"
theme="xcode"
@ -187,8 +189,14 @@ const EvaluatorNode = ({ data, id }) => {
name={"aceeditor_"+id}
editorProps={{ $blockScrolling: true }}
width='400px'
height='200px'
height='100px'
tabSize={2}
onLoad={editorInstance => { // Make Ace Editor div resizeable.
editorInstance.container.style.resize = "both";
document.addEventListener("mouseup", e => (
editorInstance.resize()
));
}}
/>
</div>
{/* <CodeMirror

View File

@ -1,4 +1,4 @@
import React, { useState } from 'react';
import React, { useState, useEffect } from 'react';
import { Handle } from 'react-flow-renderer';
import useStore from './store';
import NodeLabel from './NodeLabelComponent'
@ -21,6 +21,7 @@ const InspectorNode = ({ data, id }) => {
const [varSelects, setVarSelects] = useState([]);
const [pastInputs, setPastInputs] = useState([]);
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const handleVarValueSelect = () => {
}
@ -83,7 +84,7 @@ const InspectorNode = ({ data, id }) => {
);
});
return (
<div key={llm} className="llm-response-container nowheel">
<div key={llm} className="llm-response-container">
<h1>{llm}</h1>
{res_divs}
</div>
@ -115,6 +116,14 @@ const InspectorNode = ({ data, id }) => {
}
}
useEffect(() => {
if (data.refresh && data.refresh === true) {
// Recreate the visualization:
setDataPropsForNode(id, { refresh: false });
handleOnConnect();
}
}, [data, id, handleOnConnect, setDataPropsForNode]);
return (
<div className="inspector-node cfnode">
<NodeLabel title={data.title || 'Inspect Node'}
@ -123,7 +132,9 @@ const InspectorNode = ({ data, id }) => {
{/* <div className="var-select-toolbar">
{varSelects}
</div> */}
<div className="inspect-response-container nowheel nodrag">
{responses}
</div>
<Handle
type="target"
position="left"

View File

@ -11,10 +11,10 @@ import io from 'socket.io-client';
// Available LLMs
const allLLMs = [
{ name: "GPT3.5", emoji: "🙂", model: "gpt3.5", temp: 1.0 },
{ name: "GPT4", emoji: "🥵", model: "gpt4", temp: 1.0 },
{ name: "GPT3.5", emoji: "🙂", model: "gpt-3.5-turbo", temp: 1.0 },
{ name: "GPT4", emoji: "🥵", model: "gpt-4", temp: 1.0 },
{ name: "Alpaca 7B", emoji: "🦙", model: "alpaca.7B", temp: 0.5 },
{ name: "Claude v1", emoji: "📚", model: "claude.v1", temp: 0.5 },
{ name: "Claude v1", emoji: "📚", model: "claude-v1", temp: 0.5 },
{ name: "Ian Chatbot", emoji: "💩", model: "test", temp: 0.5 }
];
const initLLMs = [allLLMs[0]];
@ -50,6 +50,7 @@ const PromptNode = ({ data, id }) => {
const edges = useStore((state) => state.edges);
const output = useStore((state) => state.output);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const outputEdgesForNode = useStore((state) => state.outputEdgesForNode);
const getNode = useStore((state) => state.getNode);
const [templateVars, setTemplateVars] = useState(data.vars || []);
@ -181,13 +182,15 @@ const PromptNode = ({ data, id }) => {
prompt: prompt,
vars: vars,
llms: llms,
id: id,
n: numGenerations,
})}, rejected).then(function(response) {
return response.json();
}, rejected).then(function(json) {
if (!json || !json.counts) {
throw new Error('Request was sent and received by backend server, but there was no response.');
}
return json.counts;
return [json.counts, json.total_num_responses];
}, rejected);
};
@ -199,6 +202,12 @@ const PromptNode = ({ data, id }) => {
return;
}
// Check if the PromptNode is not already waiting for a response...
if (status === 'loading') {
setRunTooltip('Fetching responses...');
return;
}
// Get input data and prompt
const [py_prompt, pulled_vars] = pullInputData();
const llms = llmItemsCurrState.map(item => item.model);
@ -207,10 +216,48 @@ const PromptNode = ({ data, id }) => {
// Fetch response counts from backend
fetchResponseCounts(py_prompt, pulled_vars, llms, (err) => {
console.warn(err.message); // soft fail
}).then((counts) => {
const n = counts[Object.keys(counts)[0]];
const req = n > 1 ? 'requests' : 'request';
setRunTooltip(`Will send ${n} ${req}` + (num_llms > 1 ? ' per LLM' : ''));
}).then(([counts, total_num_responses]) => {
// Check for empty counts (means no requests will be sent!)
const num_llms_missing = Object.keys(counts).length;
if (num_llms_missing === 0) {
setRunTooltip('Will load responses from cache');
return;
}
// Tally how many queries per LLM:
let queries_per_llm = {};
Object.keys(counts).forEach(llm => {
queries_per_llm[llm] = Object.keys(counts[llm]).reduce(
(acc, prompt) => acc + counts[llm][prompt]
, 0);
});
// Check if all counts are the same:
if (num_llms_missing > 1) {
const some_llm_num = queries_per_llm[Object.keys(queries_per_llm)[0]];
const all_same_num_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc && queries_per_llm[llm] === some_llm_num, true)
if (num_llms_missing === num_llms && all_same_num_queries) { // Counts are the same
const req = some_llm_num > 1 ? 'requests' : 'request';
setRunTooltip(`Will send ${some_llm_num} new ${req}` + (num_llms > 1 ? ' per LLM' : ''));
}
else if (all_same_num_queries) {
const req = some_llm_num > 1 ? 'requests' : 'request';
setRunTooltip(`Will send ${some_llm_num} new ${req}` + (num_llms > 1 ? ` to ${num_llms_missing} LLMs` : ''));
}
else { // Counts are different
const sum_queries = Object.keys(queries_per_llm).reduce((acc, llm) => acc + queries_per_llm[llm], 0);
setRunTooltip(`Will send a variable # of queries to LLM(s) (total=${sum_queries})`);
}
} else {
const llm_name = Object.keys(queries_per_llm)[0];
const llm_count = queries_per_llm[llm_name];
const req = llm_count > 1 ? 'queries' : 'query';
if (num_llms > num_llms_missing)
setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name} and load others`);
else
setRunTooltip(`Will send ${llm_count} ${req} to ${llm_name}`)
}
});
};
@ -263,7 +310,7 @@ const PromptNode = ({ data, id }) => {
py_prompt_template, pulled_data, llmItemsCurrState.map(item => item.model), rejected);
// Open a socket to listen for progress
const open_progress_listener_socket = (response_counts) => {
const open_progress_listener_socket = ([response_counts, total_num_responses]) => {
// With the counts information we can create progress bars. Now we load a socket connection to
// the socketio server that will stream to us the current progress:
const socket = io('http://localhost:8001/' + 'queryllm', {
@ -271,7 +318,7 @@ const PromptNode = ({ data, id }) => {
cors: {origin: "http://localhost:8000/"},
});
const max_responses = Object.keys(response_counts).reduce((acc, llm) => acc + response_counts[llm], 0);
const max_responses = Object.keys(total_num_responses).reduce((acc, llm) => acc + total_num_responses[llm], 0);
// On connect to the server, ask it to give us the current progress
// for task 'queryllm' with id 'id', and stop when it reads progress >= 'max'.
@ -392,6 +439,15 @@ const PromptNode = ({ data, id }) => {
);
}));
// Ping any inspect nodes attached to this node to refresh their contents:
const output_nodes = outputEdgesForNode(id).map(e => e.target);
output_nodes.forEach(n => {
const node = getNode(n);
if (node && node.type === 'inspect') {
setDataPropsForNode(node.id, { refresh: true });
}
});
// Log responses for debugging:
console.log(json.responses);
} else {

View File

@ -66,48 +66,75 @@ const VisNode = ({ data, id }) => {
else
responses_by_llm[item.llm] = [item];
});
const llm_names = Object.keys(responses_by_llm);
// Create Plotly spec here
const varnames = Object.keys(json.responses[0].vars);
let spec = {};
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
let spec = [];
let layout = {
width: 420, height: 300, title: '', margin: {
l: 40, r: 20, b: 20, t: 20, pad: 2
l: 100, r: 20, b: 20, t: 20, pad: 0
}
}
if (varnames.length === 1) {
const plot_grouped_boxplot = (resp_to_x) => {
llm_names.forEach((llm, idx) => {
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
const rs = responses_by_llm[llm];
const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
spec.push({
type: 'box',
name: llm,
marker: {color: colors[idx % colors.length]},
x: rs.map(r => r.eval_res.items).flat(),
y: rs.map(r => Array(r.eval_res.items.length).fill(resp_to_x(r))).flat(),
boxpoints: 'all',
text: hover_texts,
hovertemplate: '%{text} <b><i>(%{x})</i></b>',
orientation: 'h',
});
});
layout.boxmode = 'group';
};
if (varnames.length === 0) {
// No variables means they used a single prompt (no template) to generate responses
// (Users are likely evaluating differences in responses between LLMs)
plot_grouped_boxplot((r) => truncStr(r.prompt.trim(), 12));
// llm_names.forEach((llm, idx) => {
// // Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
// const rs = responses_by_llm[llm];
// const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
// spec.push({
// type: 'scatter',
// name: llm,
// marker: {color: colors[idx % colors.length]},
// y: rs.map(r => r.eval_res.items).flat(),
// x: rs.map(r => Array(r.eval_res.items.length).fill(truncStr(r.prompt.trim(), 12))).flat(), // use the prompt str as var name
// // boxpoints: 'all',
// mode: 'markers',
// text: hover_texts,
// hovertemplate: '%{text} <b><i>(%{y})</i></b>',
// });
// });
// layout.scattermode = 'group';
}
else if (varnames.length === 1) {
// 1 var; numeric eval
if (Object.keys(responses_by_llm).length === 1) {
if (llm_names.length === 1) {
// Simple box plot, as there is only a single LLM in the response
spec = json.responses.map(r => {
// Use the var value to 'name' this group of points:
const s = truncStr(r.vars[varnames[0]].trim(), 12);
return {type: 'box', y: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}'};
return {type: 'box', x: r.eval_res.items, name: s, boxpoints: 'all', text: createHoverTexts(r.responses), hovertemplate: '%{text}', orientation: 'h'};
});
layout.hovermode = 'closest';
} else {
// There are multiple LLMs in the response; do a grouped box plot by LLM.
// Note that 'name' is now the LLM, and 'x' stores the value of the var:
spec = [];
const colors = ['#cbf078', '#f1b963', '#e46161', '#f8f398', '#defcf9', '#cadefc', '#c3bef0', '#cca8e9'];
Object.keys(responses_by_llm).forEach((llm, idx) => {
// Create HTML for hovering over a single datapoint. We must use 'br' to specify line breaks.
const rs = responses_by_llm[llm];
const hover_texts = rs.map(r => createHoverTexts(r.responses)).flat();
spec.push({
type: 'box',
name: llm,
marker: {color: colors[idx % colors.length]},
y: rs.map(r => r.eval_res.items).flat(),
x: rs.map(r => Array(r.eval_res.items.length).fill(r.vars[varnames[0]].trim())).flat(),
boxpoints: 'all',
text: hover_texts,
hovertemplate: '%{text}',
});
});
layout.boxmode = 'group';
plot_grouped_boxplot((r) => r.vars[varnames[0]].trim());
}
}
else if (varnames.length === 2) {

View File

@ -21,4 +21,36 @@ code {
.script-node-input {
min-width: 300px;
}
.csv-element {
position: relative;
color: #8a3e07;
background-color: #FFE8CC;
font-size: inherit;
font-family: monospace;
padding: .2em .4em;
border: none;
text-align: center;
text-decoration: none;
display: inline-block;
margin: 4px 2px;
cursor: pointer;
border-radius: 10px;
}
/* set a muted text */
.csv-comma {
color: #FFC107;
}
.csv-div {
width: 350px;
max-height: 250px;
overflow-y: auto;
}
.csv-input {
width: 350px;
height: 150px;
}

View File

@ -144,6 +144,10 @@
border-radius: 5px;
}
.ace-editor-container {
resize:vertical;
}
.vis-node {
background-color: #fff;
padding: 10px;
@ -156,8 +160,12 @@
padding: 10px;
border: 1px solid #999;
border-radius: 5px;
}
.inspect-response-container {
overflow-y: auto;
max-height: 400px;
width: 450px;
max-height: 350px;
resize: both;
}
.small-response {
@ -172,7 +180,7 @@
}
.llm-response-container {
max-width: 450px;
max-width: 100%;
}
.llm-response-container h1 {
font-weight: 400;

View File

@ -1,4 +1,4 @@
import json, os, asyncio, sys, argparse, threading
import json, os, asyncio, sys, argparse, threading, traceback
from dataclasses import dataclass
from statistics import mean, median, stdev
from flask import Flask, request, jsonify, render_template, send_from_directory
@ -21,12 +21,9 @@ cors = CORS(app, resources={r"/*": {"origins": "*"}})
def index():
return render_template("index.html")
LLM_NAME_MAP = {
'gpt3.5': LLM.ChatGPT,
'alpaca.7B': LLM.Alpaca7B,
'gpt4': LLM.GPT4,
}
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
LLM_NAME_MAP = {}
for model in LLM:
LLM_NAME_MAP[model.value] = model
@dataclass
class ResponseInfo:
@ -40,7 +37,7 @@ class ResponseInfo:
return self.text
def to_standard_format(r: dict) -> list:
llm = LLM_NAME_MAP_INVERSE[r['llm']]
llm = r['llm']
resp_obj = {
'vars': r['info'],
'llm': llm,
@ -52,9 +49,6 @@ def to_standard_format(r: dict) -> list:
resp_obj['eval_res'] = r['eval_res']
return resp_obj
def get_llm_of_response(response: dict) -> LLM:
return LLM_NAME_MAP[response['llm']]
def get_filenames_with_id(filenames: list, id: str) -> list:
return [
c for c in filenames
@ -72,17 +66,17 @@ def load_cache_json(filepath: str) -> dict:
responses = json.load(f)
return responses
def run_over_responses(eval_func, responses: dict, scope: str) -> list:
for prompt, resp_obj in responses.items():
res = extract_responses(resp_obj, resp_obj['llm'])
def run_over_responses(eval_func, responses: list, scope: str) -> list:
for resp_obj in responses:
res = resp_obj['responses']
if scope == 'response':
evals = [ # Run evaluator func over every individual response text
eval_func(
ResponseInfo(
text=r,
prompt=prompt,
var=resp_obj['info'],
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
prompt=resp_obj['prompt'],
var=resp_obj['vars'],
llm=resp_obj['llm'])
) for r in res
]
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
@ -148,6 +142,30 @@ def reduce_responses(responses: list, vars: list) -> list:
return ret
def load_all_cached_responses(cache_ids):
if not isinstance(cache_ids, list):
cache_ids = [cache_ids]
# Load all responses with the given ID:
all_cache_files = get_files_at_dir('cache/')
responses = []
for cache_id in cache_ids:
cache_files = [fname for fname in get_filenames_with_id(all_cache_files, cache_id) if fname != f"{cache_id}.json"]
if len(cache_files) == 0:
continue
for filename in cache_files:
res = load_cache_json(os.path.join('cache', filename))
if isinstance(res, dict):
# Convert to standard response format
res = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in res.items()
]
responses.extend(res)
return responses
@app.route('/app/countQueriesRequired', methods=['POST'])
def countQueries():
"""
@ -158,24 +176,57 @@ def countQueries():
'prompt': str # the prompt template, with any {{}} vars
'vars': dict # a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.)
'llms': list # the list of LLMs you will query
'n': int # how many responses expected per prompt
'id': str (optional) # a unique ID of the node with cache'd responses. If missing, assumes no cache will be used.
}
"""
data = request.get_json()
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms'}):
if not set(data.keys()).issuperset({'prompt', 'vars', 'llms', 'n'}):
return jsonify({'error': 'POST data is improper format.'})
n = int(data['n'])
try:
gen_prompts = PromptPermutationGenerator(PromptTemplate(data['prompt']))
all_prompt_permutations = list(gen_prompts(data['vars']))
except Exception as e:
return jsonify({'error': str(e)})
if 'id' in data:
# Load all cache'd responses with the given id:
cached_resps = load_all_cached_responses(data['id'])
else:
cached_resps = []
missing_queries = {}
num_responses_req = {}
def add_to_missing_queries(llm, prompt, num):
if llm not in missing_queries:
missing_queries[llm] = {}
missing_queries[llm][prompt] = num
def add_to_num_responses_req(llm, num):
if llm not in num_responses_req:
num_responses_req[llm] = 0
num_responses_req[llm] += num
# Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt
for prompt in all_prompt_permutations:
prompt = str(prompt)
matching_resps = [r for r in cached_resps if r['prompt'] == prompt]
for llm in data['llms']:
add_to_num_responses_req(llm, n)
match_per_llm = [r for r in matching_resps if r['llm'] == llm]
if len(match_per_llm) == 0:
add_to_missing_queries(llm, prompt, n)
elif len(match_per_llm) == 1:
# Check how many were stored; if not enough, add how many missing queries:
num_resps = len(match_per_llm[0]['responses'])
if n > len(match_per_llm[0]['responses']):
add_to_missing_queries(llm, prompt, n - num_resps)
else:
raise Exception(f"More than one response found for the same prompt ({prompt}) and LLM ({llm})")
# TODO: Send more informative data back including how many queries per LLM based on cache'd data
num_queries = {} # len(all_prompt_permutations) * len(data['llms'])
for llm in data['llms']:
num_queries[llm] = len(all_prompt_permutations)
ret = jsonify({'counts': num_queries})
ret = jsonify({'counts': missing_queries, 'total_num_responses': num_responses_req})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@ -245,6 +296,11 @@ async def queryLLM():
# Create a cache dir if it doesn't exist:
create_dir_if_not_exists('cache')
# Check that the filepath used to cache eval'd responses is valid:
cache_filepath_last_run = os.path.join('cache', f"{data['id']}.json")
if not is_valid_filepath(cache_filepath_last_run):
return jsonify({'error': f'Invalid filepath: {cache_filepath_last_run}'})
# For each LLM, generate and cache responses:
responses = {}
llms = data['llm']
@ -265,23 +321,27 @@ async def queryLLM():
# Prompt the LLM with all permutations of the input prompt template:
# NOTE: If the responses are already cache'd, this just loads them (no LLM is queried, saving $$$)
resps = []
num_resps = 0
try:
print(f'Querying {llm}...')
async for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
resps.append(response)
print(f"collected response from {llm.name}:", str(response))
num_resps += len(extract_responses(response, llm))
# Save the number of responses collected to a temp file on disk
with open(tempfilepath, 'r') as f:
txt = f.read().strip()
cur_data = json.loads(txt) if len(txt) > 0 else {}
cur_data[llm_str] = len(resps)
cur_data[llm_str] = num_resps
with open(tempfilepath, 'w') as f:
json.dump(cur_data, f)
except Exception as e:
print('error generating responses:', e)
print(f'error generating responses for {llm}:', e)
print(traceback.format_exc())
raise e
return {'llm': llm, 'responses': resps}
@ -308,6 +368,10 @@ async def queryLLM():
# Remove the temp file used to stream progress updates:
if os.path.exists(tempfilepath):
os.remove(tempfilepath)
# Save the responses *of this run* to the disk, for further recall:
with open(cache_filepath_last_run, "w") as f:
json.dump(res, f)
# Return all responses for all LLMs
print('returning responses:', res)
@ -388,38 +452,29 @@ def execute():
all_cache_files = get_files_at_dir('cache/')
all_evald_responses = []
for cache_id in data['responses']:
cache_files = get_filenames_with_id(all_cache_files, cache_id)
if len(cache_files) == 0:
fname = f"{cache_id}.json"
if fname not in all_cache_files:
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
# To avoid loading all response files into memory at once, we'll run the evaluator on each file:
for filename in cache_files:
# Load the raw responses from the cache
responses = load_cache_json(os.path.join('cache', fname))
if len(responses) == 0: continue
# Load the raw responses from the cache
responses = load_cache_json(os.path.join('cache', filename))
if len(responses) == 0: continue
# Run the evaluator over them:
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
try:
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
except Exception as e:
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
# Run the evaluator over them:
# NOTE: 'evaluate' here was defined dynamically from 'exec' above.
try:
evald_responses = run_over_responses(evaluate, responses, scope=data['scope'])
except Exception as e:
return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}'})
# Perform any reduction operations:
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
evald_responses = reduce_responses(
evald_responses,
vars=data['reduce_vars']
)
# Convert to standard format:
std_evald_responses = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in evald_responses.items()
]
# Perform any reduction operations:
if 'reduce_vars' in data and len(data['reduce_vars']) > 0:
std_evald_responses = reduce_responses(
std_evald_responses,
vars=data['reduce_vars']
)
all_evald_responses.extend(std_evald_responses)
all_evald_responses.extend(evald_responses)
# Store the evaluated responses in a new cache json:
with open(cache_filepath, "w") as f:
@ -481,19 +536,18 @@ def grabResponses():
all_cache_files = get_files_at_dir('cache/')
responses = []
for cache_id in data['responses']:
cache_files = get_filenames_with_id(all_cache_files, cache_id)
if len(cache_files) == 0:
fname = f"{cache_id}.json"
if fname not in all_cache_files:
return jsonify({'error': f'Did not find cache file for id {cache_id}'})
for filename in cache_files:
res = load_cache_json(os.path.join('cache', filename))
if isinstance(res, dict):
# Convert to standard response format
res = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in res.items()
]
responses.extend(res)
res = load_cache_json(os.path.join('cache', fname))
if isinstance(res, dict):
# Convert to standard response format
res = [
to_standard_format({'prompt': prompt, **res_obj})
for prompt, res_obj in res.items()
]
responses.extend(res)
ret = jsonify({'responses': responses})
ret.headers.add('Access-Control-Allow-Origin', '*')

View File

@ -0,0 +1,39 @@
"""
A list of all model APIs natively supported by ChainForge.
"""
from enum import Enum
class LLM(str, Enum):
""" OpenAI Chat """
ChatGPT = "gpt-3.5-turbo"
GPT4 = "gpt-4"
""" Dalai-served models """
Alpaca7B = "alpaca.7B"
""" Anthropic """
# Our largest model, ideal for a wide range of more complex tasks. Using this model name
# will automatically switch you to newer versions of claude-v1 as they are released.
Claude_v1 = "claude-v1"
# An earlier version of claude-v1
Claude_v1_0 = "claude-v1.0"
# An improved version of claude-v1. It is slightly improved at general helpfulness,
# instruction following, coding, and other tasks. It is also considerably better with
# non-English languages. This model also has the ability to role play (in harmless ways)
# more consistently, and it defaults to writing somewhat longer and more thorough responses.
Claude_v1_2 = "claude-v1.2"
# A significantly improved version of claude-v1. Compared to claude-v1.2, it's more robust
# against red-team inputs, better at precise instruction-following, better at code, and better
# and non-English dialogue and writing.
Claude_v1_3 = "claude-v1.3"
# A smaller model with far lower latency, sampling at roughly 40 words/sec! Its output quality
# is somewhat lower than claude-v1 models, particularly for complex tasks. However, it is much
# less expensive and blazing fast. We believe that this model provides more than adequate performance
# on a range of tasks including text classification, summarization, and lightweight chat applications,
# as well as search result summarization. Using this model name will automatically switch you to newer
# versions of claude-instant-v1 as they are released.
Claude_v1_instant = "claude-instant-v1"

View File

@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import List, Dict, Tuple, Iterator
from typing import List, Dict, Tuple, Iterator, Union
import json, os, asyncio, random, string
from promptengine.utils import LLM, call_chatgpt, call_dalai, is_valid_filepath, is_valid_json
from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json, cull_responses, extract_responses
from promptengine.template import PromptTemplate, PromptPermutationGenerator
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
@ -59,13 +59,14 @@ class PromptPipeline:
prompt_str = str(prompt)
# First check if there is already a response for this item. If so, we can save an LLM call:
if prompt_str in responses:
if prompt_str in responses and len(extract_responses(responses[prompt_str], llm)) >= n:
print(f" - Found cache'd response for prompt {prompt_str}. Using...")
responses[prompt_str] = cull_responses(responses[prompt_str], llm, n)
yield {
"prompt": prompt_str,
"query": responses[prompt_str]["query"],
"response": responses[prompt_str]["response"],
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.name,
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.value,
"info": responses[prompt_str]["info"],
}
continue
@ -86,7 +87,7 @@ class PromptPipeline:
responses[str(prompt)] = {
"query": query,
"response": response,
"llm": llm.name,
"llm": llm.value,
"info": info,
}
self._cache_responses(responses)
@ -96,7 +97,7 @@ class PromptPipeline:
"prompt":str(prompt),
"query":query,
"response":response,
"llm": llm.name,
"llm": llm.value,
"info": info,
}
@ -114,7 +115,7 @@ class PromptPipeline:
responses[str(prompt)] = {
"query": query,
"response": response,
"llm": llm.name,
"llm": llm.value,
"info": info,
}
self._cache_responses(responses)
@ -124,7 +125,7 @@ class PromptPipeline:
"prompt":str(prompt),
"query":query,
"response":response,
"llm": llm.name,
"llm": llm.value,
"info": info,
}
@ -147,11 +148,13 @@ class PromptPipeline:
def clear_cached_responses(self) -> None:
self._cache_responses({})
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Union[List, Dict]]:
if llm is LLM.ChatGPT or llm is LLM.GPT4:
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
elif llm is LLM.Alpaca7B:
query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
query, response = await call_dalai(model=llm, port=4000, prompt=str(prompt), n=n, temperature=temperature)
elif llm.value[:6] == 'claude':
query, response = await call_anthropic(prompt=str(prompt), model=llm, n=n, temperature=temperature)
else:
raise Exception(f"Language model {llm} is not supported.")
return prompt, query, response
@ -175,5 +178,7 @@ class PromptLLM(PromptPipeline):
"""
class PromptLLMDummy(PromptLLM):
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]:
# Wait a random amount of time, to simulate wait times from real queries
await asyncio.sleep(random.uniform(0.1, 3))
return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(40)) for _ in range(n)])
# Return a random string of characters of random length (within a predefined range)
return prompt, *({'prompt': str(prompt)}, [''.join(random.choice(string.ascii_letters) for i in range(random.randint(25, 80))) for _ in range(n)])

View File

@ -1,28 +1,22 @@
from typing import Dict, Tuple, List, Union
from enum import Enum
import openai
from typing import Dict, Tuple, List, Union, Callable
import json, os, time, asyncio
from promptengine.models import LLM
DALAI_MODEL = None
DALAI_RESPONSE = None
openai.api_key = os.environ.get("OPENAI_API_KEY")
""" Supported LLM coding assistants """
class LLM(Enum):
ChatGPT = 0
Alpaca7B = 1
GPT4 = 2
async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
"""
Calls GPT3.5 via OpenAI's API.
Returns raw query and response JSON dicts.
NOTE: It is recommended to set an environment variable OPENAI_API_KEY with your OpenAI API key
"""
model_map = { LLM.ChatGPT: 'gpt-3.5-turbo', LLM.GPT4: 'gpt-4' }
if model not in model_map:
raise Exception(f"Could not find OpenAI chat model {model}")
model = model_map[model]
import openai
if not openai.api_key:
openai.api_key = os.environ.get("OPENAI_API_KEY")
model = model.value
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
query = {
@ -37,13 +31,64 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
response = openai.ChatCompletion.create(**query)
return query, response
async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
custom_prompt_wrapper: Union[Callable[[str], str], None]=None,
max_tokens_to_sample=1024,
stop_sequences: Union[List[str], str]=["\n\nHuman:"],
async_mode=False,
**params) -> Tuple[Dict, Dict]:
"""
Calls Anthropic API with the given model, passing in params.
Returns raw query and response JSON dicts.
Unique parameters:
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
explore custom prompt wrappers that deviate, write a function that maps from 'prompt' to custom wrapper.
If set to None, defaults to Anthropic's suggested prompt wrapper.
- max_tokens_to_sample: A maximum number of tokens to generate before stopping.
- stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent.
- async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async.
If you want to send all 'n' requests at once, you can set async_mode to True.
NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
"""
import anthropic
client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
# Format query
query = {
'model': model.value,
'prompt': f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" if not custom_prompt_wrapper else custom_prompt_wrapper(prompt),
'max_tokens_to_sample': max_tokens_to_sample,
'stop_sequences': stop_sequences,
'temperature': temperature,
**params
}
print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...")
# Request responses using the passed async_mode
responses = []
if async_mode:
# Gather n responses by firing off all API requests at once
tasks = [client.acompletion(**query) for _ in range(n)]
responses = await asyncio.gather(*tasks)
else:
# Repeat call n times, waiting for each response to come in:
while len(responses) < n:
resp = await client.acompletion(**query)
responses.append(resp)
print(f'{model.value} response {len(responses)} of {n}:\n', resp)
return query, responses
async def call_dalai(model: LLM, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
"""
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
Returns the raw query and response JSON dicts.
Parameters:
- llm_name: The LLM's name as known by Dalai; e.g., 'alpaca.7b'
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
- port: The port of the local server where Dalai is running. Usually 3000.
- prompt: The prompt to pass to the LLM.
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
@ -75,7 +120,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
# Create full query to Dalai
query = {
'prompt': prompt,
'model': llm_name,
'model': model.value,
'id': str(round(time.time()*1000)),
'temp': temperature,
**def_params
@ -129,18 +174,36 @@ def _extract_chatgpt_responses(response: dict) -> List[dict]:
choices = response["response"]["choices"]
return [
c["message"]["content"]
for i, c in enumerate(choices)
for c in choices
]
def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[dict]:
"""
Given a LLM and a response object from its API, extract the
text response(s) part of the response object.
"""
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name or llm is LLM.GPT4 or llm == LLM.GPT4.name:
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value:
return _extract_chatgpt_responses(response)
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.name:
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value:
return response["response"]
elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'):
return [r["completion"] for r in response["response"]]
else:
raise ValueError(f"LLM {llm} is unsupported.")
def cull_responses(response: Union[list, dict], llm: Union[LLM, str], n: int) -> Union[list, dict]:
"""
Returns the same 'response' but with only 'n' responses.
"""
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value:
response["response"]["choices"] = response["response"]["choices"][:n]
return response
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value:
response["response"] = response["response"][:n]
return response
elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'):
response["response"] = response["response"][:n]
return response
else:
raise ValueError(f"LLM {llm} is unsupported.")