Merge pull request #6 from ianarawjo/pv

Adding support for external python scripts within evaluators
This commit is contained in:
ianarawjo 2023-05-03 12:04:31 -04:00 committed by GitHub
commit 62c33ee1ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 155 additions and 11 deletions

View File

@ -14,6 +14,7 @@ import PromptNode from './PromptNode';
import EvaluatorNode from './EvaluatorNode';
import VisNode from './VisNode';
import InspectNode from './InspectorNode';
import ScriptNode from './ScriptNode';
import './text-fields-node.css';
// State management (from https://reactflow.dev/docs/guides/state-management/)
@ -39,6 +40,7 @@ const nodeTypes = {
evaluator: EvaluatorNode,
vis: VisNode,
inspect: InspectNode,
script: ScriptNode
};
const connectionLineStyle = { stroke: '#ddd' };
@ -84,6 +86,10 @@ const App = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} });
};
const addScriptNode = (event) => {
const { x, y } = getViewportCenter();
addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} });
};
/**
* SAVING / LOADING, IMPORT / EXPORT (from JSON)
@ -194,6 +200,7 @@ const App = () => {
<button onClick={addEvalNode}>Add evaluator node</button>
<button onClick={addVisNode}>Add vis node</button>
<button onClick={addInspectNode}>Add inspect node</button>
<button onClick={addScriptNode}>Add script node</button>
<button onClick={saveFlow} style={{marginLeft: '12px'}}>Save</button>
<button onClick={loadFlowFromCache}>Load</button>
<button onClick={exportFlow} style={{marginLeft: '12px'}}>Export</button>

View File

@ -4,6 +4,7 @@ import useStore from './store';
import StatusIndicator from './StatusIndicatorComponent'
import NodeLabel from './NodeLabelComponent'
import AlertModal from './AlertModal'
import {BASE_URL} from './store';
// Ace code editor
import AceEditor from "react-ace";
@ -18,6 +19,7 @@ const EvaluatorNode = ({ data, id }) => {
const getNode = useStore((state) => state.getNode);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const [status, setStatus] = useState('none');
const nodes = useStore((state) => state.nodes);
// For displaying error messages to user
const alertModal = useRef(null);
@ -49,6 +51,7 @@ const EvaluatorNode = ({ data, id }) => {
};
const handleRunClick = (event) => {
// Get the ids from the connected input nodes:
const input_node_ids = inputEdgesForNode(id).map(e => e.source);
if (input_node_ids.length === 0) {
@ -72,9 +75,13 @@ const EvaluatorNode = ({ data, id }) => {
alertModal.current.trigger(err_msg);
};
// Get all the script nodes, and get all the folder paths
const script_nodes = nodes.filter(n => n.type === 'script');
const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
console.log(script_paths);
// Run evaluator in backend
const codeTextOnRun = codeText + '';
fetch('http://localhost:5000/execute', {
fetch(BASE_URL + 'execute', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
@ -83,6 +90,7 @@ const EvaluatorNode = ({ data, id }) => {
scope: mapScope,
responses: input_node_ids,
reduce_vars: reduceMethod === 'avg' ? reduceVars : [],
script_paths: script_paths,
// write an extra part here that takes in reduce func
}),
}, rejected).then(function(response) {

View File

@ -1,6 +1,7 @@
import React, { useState } from 'react';
import { Handle } from 'react-flow-renderer';
import useStore from './store';
import {BASE_URL} from './store';
const bucketResponsesByLLM = (responses) => {
let responses_by_llm = {};
@ -30,7 +31,7 @@ const InspectorNode = ({ data, id }) => {
console.log(input_node_ids);
// Grab responses associated with those ids:
fetch('http://localhost:5000/grabResponses', {
fetch(BASE_URL + 'grabResponses', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({

View File

@ -8,6 +8,7 @@ import NodeLabel from './NodeLabelComponent'
import TemplateHooks from './TemplateHooksComponent'
import LLMList from './LLMListComponent'
import AlertModal from './AlertModal'
import {BASE_URL} from './store';
// Available LLMs
const allLLMs = [
@ -186,7 +187,7 @@ const PromptNode = ({ data, id }) => {
};
// Run all prompt permutations through the LLM to generate + cache responses:
fetch('http://localhost:5000/queryllm', {
fetch(BASE_URL + 'queryllm', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({

View File

@ -0,0 +1,77 @@
import React, { useState, useRef, useEffect, useCallback } from 'react';
import useStore from './store';
import NodeLabel from './NodeLabelComponent'
const ScriptNode = ({ data, id }) => {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const delButtonId = 'del-';
// Handle a change in a scripts' input.
const handleInputChange = useCallback((event) => {
// Update the data for this script node's id.
let new_data = { 'scriptFiles': { ...data.scriptFiles } };
new_data.scriptFiles[event.target.id] = event.target.value;
console.log(new_data);
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode]);
// Handle delete script file.
const handleDelete = useCallback((event) => {
// Update the data for this script node's id.
let new_data = { 'scriptFiles': { ...data.scriptFiles } };
var item_id = event.target.id.substring(delButtonId.length);
delete new_data.scriptFiles[item_id];
// if the new_data is empty, initialize it with one empty field
if (Object.keys(new_data.scriptFiles).length === 0) {
new_data.scriptFiles['f0'] = '';
}
console.log(new_data);
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode]);
// Initialize fields (run once at init)
const [scriptFiles, setScriptFiles] = useState([]);
useEffect(() => {
if (!data.scriptFiles)
setDataPropsForNode(id, { scriptFiles: { f0: '' } });
}, [data.scriptFiles, id, setDataPropsForNode]);
// Whenever 'data' changes, update the input fields to reflect the current state.
useEffect(() => {
const f = data.scriptFiles ? Object.keys(data.scriptFiles) : ['f0'];
setScriptFiles(f.map((i) => {
const val = data.scriptFiles ? data.scriptFiles[i] : '';
return (
<div className="input-field" key={i}>
<input className='script-node-input' type='text' id={i} onChange={handleInputChange} value={val}/><button id={delButtonId + i} onClick={handleDelete}>x</button><br/>
</div>
)
}));
}, [data.scriptFiles, handleInputChange, handleDelete]);
// Add a field
const handleAddField = useCallback(() => {
// Update the data for this script node's id.
const num_files = data.scriptFiles ? Object.keys(data.scriptFiles).length : 0;
let new_data = { 'scriptFiles': { ...data.scriptFiles } };
new_data.scriptFiles['f' + num_files.toString()] = "";
setDataPropsForNode(id, new_data);
}, [data, id, setDataPropsForNode]);
return (
<div className="script-node">
<div className="node-header">
<NodeLabel title={data.title || 'Global Scripts'} nodeId={id} />
</div>
<label htmlFor="num-generations" style={{fontSize: '10pt'}}>Enter folder paths for external modules you wish to import.</label> <br/><br/>
<div ref={ref}>
{scriptFiles}
</div>
<div className="add-text-field-btn">
<button onClick={handleAddField}>+</button>
</div>
</div>
);
};
export default ScriptNode;

View File

@ -4,7 +4,8 @@ import useStore from './store';
import Plot from 'react-plotly.js';
import { hover } from '@testing-library/user-event/dist/hover';
import { create } from 'zustand';
import NodeLabel from './NodeLabelComponent'
import NodeLabel from './NodeLabelComponent';
import {BASE_URL} from './store';
// Helper funcs
const truncStr = (s, maxLen) => {
@ -54,7 +55,7 @@ const VisNode = ({ data, id }) => {
// Grab the input node ids
const input_node_ids = [data.input];
fetch('http://localhost:5000/grabResponses', {
fetch(BASE_URL + 'grabResponses', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({

View File

@ -11,3 +11,14 @@ code {
font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New',
monospace;
}
.script-node {
background-color: #fff;
padding: 10px;
border: 1px solid #000;
border-radius: 5px;
}
.script-node-input {
min-width: 300px;
}

View File

@ -33,6 +33,8 @@ const initialEdges = [
{ id: 'e1-2', source: initprompt, target: initeval, interactionWidth: 100},
];
export const BASE_URL = 'http://localhost:8000/';
// TypeScript only
// type RFState = {
// nodes: Node[];
@ -81,7 +83,7 @@ const useStore = create((set, get) => ({
)(get().nodes)
});
},
getNode: (id) => get().nodes.find(n => n.id == id),
getNode: (id) => get().nodes.find(n => n.id === id),
addNode: (newnode) => {
set({
nodes: get().nodes.concat(newnode)
@ -122,4 +124,4 @@ const useStore = create((set, get) => ({
},
}));
export default useStore;
export default useStore;

View File

@ -5,7 +5,8 @@ from flask import Flask, request, jsonify
from flask_cors import CORS
from promptengine.query import PromptLLM
from promptengine.template import PromptTemplate, PromptPermutationGenerator
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
import sys
app = Flask(__name__)
CORS(app)
@ -138,6 +139,10 @@ def reduce_responses(responses: list, vars: list) -> list:
return ret
@app.route('/test', methods=['GET'])
def test():
return "Hello, world!"
@app.route('/queryllm', methods=['POST'])
def queryLLM():
"""
@ -175,9 +180,13 @@ def queryLLM():
if 'no_cache' in data and data['no_cache'] is True:
remove_cached_responses(data['id'])
# Create a cache dir if it doesn't exist:
create_dir_if_not_exists('cache')
# For each LLM, generate and cache responses:
responses = {}
params = data['params'] if 'params' in data else {}
for llm in llms:
# Check that storage path is valid:
@ -195,6 +204,7 @@ def queryLLM():
for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
responses[llm].append(response)
except Exception as e:
print('error generating responses:', e)
raise e
return jsonify({'error': str(e)})
@ -206,6 +216,7 @@ def queryLLM():
]
# Return all responses for all LLMs
print('returning responses:', res)
ret = jsonify({'responses': res})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@ -224,6 +235,7 @@ def execute():
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
}
NOTE: This should only be run on your server on code you trust.
@ -250,6 +262,24 @@ def execute():
elif isinstance(data['responses'], str):
data['responses'] = [ data['responses'] ]
# add the path to any scripts to the path:
try:
if 'script_paths' in data:
for script_path in data['script_paths']:
# get the folder of the script_path:
script_folder = os.path.dirname(script_path)
# check that the script_folder is valid, and it contains __init__.py
if not os.path.exists(script_folder):
print(script_folder, 'is not a valid script path.')
print(os.path.exists(script_folder))
continue
# add it to the path:
sys.path.append(script_folder)
print(f'added {script_folder} to sys.path')
except Exception as e:
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
# Create the evaluator function
# DANGER DANGER!
try:
@ -379,4 +409,4 @@ def grabResponses():
return ret
if __name__ == '__main__':
app.run(host="localhost", port=5000, debug=True)
app.run(host="localhost", port=8000, debug=True)

View File

@ -6,6 +6,8 @@ import json, os, time
DALAI_MODEL = None
DALAI_RESPONSE = None
openai.api_key = os.environ.get("OPENAI_API_KEY")
""" Supported LLM coding assistants """
class LLM(Enum):
ChatGPT = 0
@ -141,6 +143,10 @@ def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
else:
raise ValueError(f"LLM {llm} is unsupported.")
def create_dir_if_not_exists(path: str) -> None:
if not os.path.exists(path):
os.makedirs(path)
def is_valid_filepath(filepath: str) -> bool:
try:
with open(filepath, 'r'):

View File

@ -9,7 +9,7 @@
<script>
function test_exec() {
const response = fetch('http://localhost:5000/execute', {
const response = fetch(BASE_URL + 'execute', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
@ -25,7 +25,7 @@
}
function test_query() {
const response = fetch('http://localhost:5000/queryllm', {
const response = fetch(BASE_URL + 'queryllm', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({