mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
merge
This commit is contained in:
commit
a5ed261e36
@ -14,6 +14,7 @@ import PromptNode from './PromptNode';
|
||||
import EvaluatorNode from './EvaluatorNode';
|
||||
import VisNode from './VisNode';
|
||||
import InspectNode from './InspectorNode';
|
||||
import ScriptNode from './ScriptNode';
|
||||
import './text-fields-node.css';
|
||||
|
||||
// State management (from https://reactflow.dev/docs/guides/state-management/)
|
||||
@ -39,6 +40,7 @@ const nodeTypes = {
|
||||
evaluator: EvaluatorNode,
|
||||
vis: VisNode,
|
||||
inspect: InspectNode,
|
||||
script: ScriptNode
|
||||
};
|
||||
|
||||
const connectionLineStyle = { stroke: '#ddd' };
|
||||
@ -84,6 +86,10 @@ const App = () => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
const addScriptNode = (event) => {
|
||||
const { x, y } = getViewportCenter();
|
||||
addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} });
|
||||
};
|
||||
|
||||
/**
|
||||
* SAVING / LOADING, IMPORT / EXPORT (from JSON)
|
||||
@ -194,6 +200,7 @@ const App = () => {
|
||||
<button onClick={addEvalNode}>Add evaluator node</button>
|
||||
<button onClick={addVisNode}>Add vis node</button>
|
||||
<button onClick={addInspectNode}>Add inspect node</button>
|
||||
<button onClick={addScriptNode}>Add script node</button>
|
||||
<button onClick={saveFlow} style={{marginLeft: '12px'}}>Save</button>
|
||||
<button onClick={loadFlowFromCache}>Load</button>
|
||||
<button onClick={exportFlow} style={{marginLeft: '12px'}}>Export</button>
|
||||
|
@ -5,6 +5,7 @@ import StatusIndicator from './StatusIndicatorComponent'
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import AlertModal from './AlertModal'
|
||||
import { IconTerminal } from '@tabler/icons-react'
|
||||
import {BASE_URL} from './store';
|
||||
|
||||
// Ace code editor
|
||||
import AceEditor from "react-ace";
|
||||
@ -19,6 +20,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
const getNode = useStore((state) => state.getNode);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const [status, setStatus] = useState('none');
|
||||
const nodes = useStore((state) => state.nodes);
|
||||
|
||||
// For displaying error messages to user
|
||||
const alertModal = useRef(null);
|
||||
@ -50,6 +52,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
};
|
||||
|
||||
const handleRunClick = (event) => {
|
||||
|
||||
// Get the ids from the connected input nodes:
|
||||
const input_node_ids = inputEdgesForNode(id).map(e => e.source);
|
||||
if (input_node_ids.length === 0) {
|
||||
@ -73,9 +76,13 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
alertModal.current.trigger(err_msg);
|
||||
};
|
||||
|
||||
// Get all the script nodes, and get all the folder paths
|
||||
const script_nodes = nodes.filter(n => n.type === 'script');
|
||||
const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
|
||||
console.log(script_paths);
|
||||
// Run evaluator in backend
|
||||
const codeTextOnRun = codeText + '';
|
||||
fetch('http://localhost:5000/execute', {
|
||||
fetch(BASE_URL + 'execute', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
@ -84,6 +91,7 @@ const EvaluatorNode = ({ data, id }) => {
|
||||
scope: mapScope,
|
||||
responses: input_node_ids,
|
||||
reduce_vars: reduceMethod === 'avg' ? reduceVars : [],
|
||||
script_paths: script_paths,
|
||||
// write an extra part here that takes in reduce func
|
||||
}),
|
||||
}, rejected).then(function(response) {
|
||||
|
@ -2,6 +2,7 @@ import React, { useState } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import {BASE_URL} from './store';
|
||||
|
||||
const bucketResponsesByLLM = (responses) => {
|
||||
let responses_by_llm = {};
|
||||
@ -31,7 +32,7 @@ const InspectorNode = ({ data, id }) => {
|
||||
console.log(input_node_ids);
|
||||
|
||||
// Grab responses associated with those ids:
|
||||
fetch('http://localhost:5000/grabResponses', {
|
||||
fetch(BASE_URL + 'grabResponses', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -8,6 +8,7 @@ import NodeLabel from './NodeLabelComponent'
|
||||
import TemplateHooks from './TemplateHooksComponent'
|
||||
import LLMList from './LLMListComponent'
|
||||
import AlertModal from './AlertModal'
|
||||
import {BASE_URL} from './store';
|
||||
|
||||
// Available LLMs
|
||||
const allLLMs = [
|
||||
@ -186,7 +187,7 @@ const PromptNode = ({ data, id }) => {
|
||||
};
|
||||
|
||||
// Run all prompt permutations through the LLM to generate + cache responses:
|
||||
fetch('http://localhost:5000/queryllm', {
|
||||
fetch(BASE_URL + 'queryllm', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
77
chain-forge/src/ScriptNode.js
Normal file
77
chain-forge/src/ScriptNode.js
Normal file
@ -0,0 +1,77 @@
|
||||
import React, { useState, useRef, useEffect, useCallback } from 'react';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
|
||||
|
||||
const ScriptNode = ({ data, id }) => {
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const delButtonId = 'del-';
|
||||
// Handle a change in a scripts' input.
|
||||
const handleInputChange = useCallback((event) => {
|
||||
// Update the data for this script node's id.
|
||||
let new_data = { 'scriptFiles': { ...data.scriptFiles } };
|
||||
new_data.scriptFiles[event.target.id] = event.target.value;
|
||||
console.log(new_data);
|
||||
setDataPropsForNode(id, new_data);
|
||||
}, [data, id, setDataPropsForNode]);
|
||||
|
||||
// Handle delete script file.
|
||||
const handleDelete = useCallback((event) => {
|
||||
// Update the data for this script node's id.
|
||||
let new_data = { 'scriptFiles': { ...data.scriptFiles } };
|
||||
var item_id = event.target.id.substring(delButtonId.length);
|
||||
delete new_data.scriptFiles[item_id];
|
||||
// if the new_data is empty, initialize it with one empty field
|
||||
if (Object.keys(new_data.scriptFiles).length === 0) {
|
||||
new_data.scriptFiles['f0'] = '';
|
||||
}
|
||||
console.log(new_data);
|
||||
setDataPropsForNode(id, new_data);
|
||||
}, [data, id, setDataPropsForNode]);
|
||||
|
||||
// Initialize fields (run once at init)
|
||||
const [scriptFiles, setScriptFiles] = useState([]);
|
||||
useEffect(() => {
|
||||
if (!data.scriptFiles)
|
||||
setDataPropsForNode(id, { scriptFiles: { f0: '' } });
|
||||
}, [data.scriptFiles, id, setDataPropsForNode]);
|
||||
|
||||
// Whenever 'data' changes, update the input fields to reflect the current state.
|
||||
useEffect(() => {
|
||||
const f = data.scriptFiles ? Object.keys(data.scriptFiles) : ['f0'];
|
||||
setScriptFiles(f.map((i) => {
|
||||
const val = data.scriptFiles ? data.scriptFiles[i] : '';
|
||||
return (
|
||||
<div className="input-field" key={i}>
|
||||
<input className='script-node-input' type='text' id={i} onChange={handleInputChange} value={val}/><button id={delButtonId + i} onClick={handleDelete}>x</button><br/>
|
||||
</div>
|
||||
)
|
||||
}));
|
||||
}, [data.scriptFiles, handleInputChange, handleDelete]);
|
||||
|
||||
// Add a field
|
||||
const handleAddField = useCallback(() => {
|
||||
// Update the data for this script node's id.
|
||||
const num_files = data.scriptFiles ? Object.keys(data.scriptFiles).length : 0;
|
||||
let new_data = { 'scriptFiles': { ...data.scriptFiles } };
|
||||
new_data.scriptFiles['f' + num_files.toString()] = "";
|
||||
setDataPropsForNode(id, new_data);
|
||||
}, [data, id, setDataPropsForNode]);
|
||||
|
||||
return (
|
||||
<div className="script-node">
|
||||
<div className="node-header">
|
||||
<NodeLabel title={data.title || 'Global Scripts'} nodeId={id} />
|
||||
</div>
|
||||
<label htmlFor="num-generations" style={{fontSize: '10pt'}}>Enter folder paths for external modules you wish to import.</label> <br/><br/>
|
||||
<div ref={ref}>
|
||||
{scriptFiles}
|
||||
</div>
|
||||
<div className="add-text-field-btn">
|
||||
<button onClick={handleAddField}>+</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ScriptNode;
|
@ -4,7 +4,8 @@ import useStore from './store';
|
||||
import Plot from 'react-plotly.js';
|
||||
import { hover } from '@testing-library/user-event/dist/hover';
|
||||
import { create } from 'zustand';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
import NodeLabel from './NodeLabelComponent';
|
||||
import {BASE_URL} from './store';
|
||||
|
||||
// Helper funcs
|
||||
const truncStr = (s, maxLen) => {
|
||||
@ -54,7 +55,7 @@ const VisNode = ({ data, id }) => {
|
||||
// Grab the input node ids
|
||||
const input_node_ids = [data.input];
|
||||
|
||||
fetch('http://localhost:5000/grabResponses', {
|
||||
fetch(BASE_URL + 'grabResponses', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
@ -11,3 +11,14 @@ code {
|
||||
font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New',
|
||||
monospace;
|
||||
}
|
||||
|
||||
.script-node {
|
||||
background-color: #fff;
|
||||
padding: 10px;
|
||||
border: 1px solid #000;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.script-node-input {
|
||||
min-width: 300px;
|
||||
}
|
@ -33,6 +33,8 @@ const initialEdges = [
|
||||
{ id: 'e1-2', source: initprompt, target: initeval, interactionWidth: 100},
|
||||
];
|
||||
|
||||
export const BASE_URL = 'http://localhost:8000/';
|
||||
|
||||
// TypeScript only
|
||||
// type RFState = {
|
||||
// nodes: Node[];
|
||||
@ -81,7 +83,7 @@ const useStore = create((set, get) => ({
|
||||
)(get().nodes)
|
||||
});
|
||||
},
|
||||
getNode: (id) => get().nodes.find(n => n.id == id),
|
||||
getNode: (id) => get().nodes.find(n => n.id === id),
|
||||
addNode: (newnode) => {
|
||||
set({
|
||||
nodes: get().nodes.concat(newnode)
|
||||
@ -122,4 +124,4 @@ const useStore = create((set, get) => ({
|
||||
},
|
||||
}));
|
||||
|
||||
export default useStore;
|
||||
export default useStore;
|
||||
|
@ -5,7 +5,8 @@ from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from promptengine.query import PromptLLM
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir
|
||||
from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
|
||||
import sys
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
@ -138,6 +139,10 @@ def reduce_responses(responses: list, vars: list) -> list:
|
||||
|
||||
return ret
|
||||
|
||||
@app.route('/test', methods=['GET'])
|
||||
def test():
|
||||
return "Hello, world!"
|
||||
|
||||
@app.route('/queryllm', methods=['POST'])
|
||||
def queryLLM():
|
||||
"""
|
||||
@ -175,9 +180,13 @@ def queryLLM():
|
||||
if 'no_cache' in data and data['no_cache'] is True:
|
||||
remove_cached_responses(data['id'])
|
||||
|
||||
# Create a cache dir if it doesn't exist:
|
||||
create_dir_if_not_exists('cache')
|
||||
|
||||
# For each LLM, generate and cache responses:
|
||||
responses = {}
|
||||
params = data['params'] if 'params' in data else {}
|
||||
|
||||
for llm in llms:
|
||||
|
||||
# Check that storage path is valid:
|
||||
@ -195,6 +204,7 @@ def queryLLM():
|
||||
for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
|
||||
responses[llm].append(response)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
raise e
|
||||
return jsonify({'error': str(e)})
|
||||
|
||||
@ -206,6 +216,7 @@ def queryLLM():
|
||||
]
|
||||
|
||||
# Return all responses for all LLMs
|
||||
print('returning responses:', res)
|
||||
ret = jsonify({'responses': res})
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
@ -224,6 +235,7 @@ def execute():
|
||||
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
|
||||
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
|
||||
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
|
||||
'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
|
||||
}
|
||||
|
||||
NOTE: This should only be run on your server on code you trust.
|
||||
@ -250,6 +262,24 @@ def execute():
|
||||
elif isinstance(data['responses'], str):
|
||||
data['responses'] = [ data['responses'] ]
|
||||
|
||||
# add the path to any scripts to the path:
|
||||
try:
|
||||
if 'script_paths' in data:
|
||||
for script_path in data['script_paths']:
|
||||
# get the folder of the script_path:
|
||||
script_folder = os.path.dirname(script_path)
|
||||
# check that the script_folder is valid, and it contains __init__.py
|
||||
if not os.path.exists(script_folder):
|
||||
print(script_folder, 'is not a valid script path.')
|
||||
print(os.path.exists(script_folder))
|
||||
continue
|
||||
|
||||
# add it to the path:
|
||||
sys.path.append(script_folder)
|
||||
print(f'added {script_folder} to sys.path')
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
|
||||
|
||||
# Create the evaluator function
|
||||
# DANGER DANGER!
|
||||
try:
|
||||
@ -379,4 +409,4 @@ def grabResponses():
|
||||
return ret
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="localhost", port=5000, debug=True)
|
||||
app.run(host="localhost", port=8000, debug=True)
|
@ -6,6 +6,8 @@ import json, os, time
|
||||
DALAI_MODEL = None
|
||||
DALAI_RESPONSE = None
|
||||
|
||||
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
""" Supported LLM coding assistants """
|
||||
class LLM(Enum):
|
||||
ChatGPT = 0
|
||||
@ -141,6 +143,10 @@ def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
|
||||
else:
|
||||
raise ValueError(f"LLM {llm} is unsupported.")
|
||||
|
||||
def create_dir_if_not_exists(path: str) -> None:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
def is_valid_filepath(filepath: str) -> bool:
|
||||
try:
|
||||
with open(filepath, 'r'):
|
||||
|
@ -9,7 +9,7 @@
|
||||
<script>
|
||||
|
||||
function test_exec() {
|
||||
const response = fetch('http://localhost:5000/execute', {
|
||||
const response = fetch(BASE_URL + 'execute', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
@ -25,7 +25,7 @@
|
||||
}
|
||||
|
||||
function test_query() {
|
||||
const response = fetch('http://localhost:5000/queryllm', {
|
||||
const response = fetch(BASE_URL + 'queryllm', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
|
||||
body: JSON.stringify({
|
||||
|
Loading…
x
Reference in New Issue
Block a user