diff --git a/chain-forge/src/App.js b/chain-forge/src/App.js index 5a06c60..50698a2 100644 --- a/chain-forge/src/App.js +++ b/chain-forge/src/App.js @@ -14,6 +14,7 @@ import PromptNode from './PromptNode'; import EvaluatorNode from './EvaluatorNode'; import VisNode from './VisNode'; import InspectNode from './InspectorNode'; +import ScriptNode from './ScriptNode'; import './text-fields-node.css'; // State management (from https://reactflow.dev/docs/guides/state-management/) @@ -39,6 +40,7 @@ const nodeTypes = { evaluator: EvaluatorNode, vis: VisNode, inspect: InspectNode, + script: ScriptNode }; const connectionLineStyle = { stroke: '#ddd' }; @@ -84,6 +86,10 @@ const App = () => { const { x, y } = getViewportCenter(); addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} }); }; + const addScriptNode = (event) => { + const { x, y } = getViewportCenter(); + addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} }); + }; /** * SAVING / LOADING, IMPORT / EXPORT (from JSON) @@ -194,6 +200,7 @@ const App = () => { + diff --git a/chain-forge/src/EvaluatorNode.js b/chain-forge/src/EvaluatorNode.js index a7841e8..34edb91 100644 --- a/chain-forge/src/EvaluatorNode.js +++ b/chain-forge/src/EvaluatorNode.js @@ -4,6 +4,7 @@ import useStore from './store'; import StatusIndicator from './StatusIndicatorComponent' import NodeLabel from './NodeLabelComponent' import AlertModal from './AlertModal' +import {BASE_URL} from './store'; // Ace code editor import AceEditor from "react-ace"; @@ -18,6 +19,7 @@ const EvaluatorNode = ({ data, id }) => { const getNode = useStore((state) => state.getNode); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const [status, setStatus] = useState('none'); + const nodes = useStore((state) => state.nodes); // For displaying error messages to user const alertModal = useRef(null); @@ -49,6 +51,7 @@ const EvaluatorNode = ({ data, id }) => { }; const handleRunClick = (event) => { + // Get the ids from the connected input nodes: const input_node_ids = inputEdgesForNode(id).map(e => e.source); if (input_node_ids.length === 0) { @@ -72,9 +75,13 @@ const EvaluatorNode = ({ data, id }) => { alertModal.current.trigger(err_msg); }; + // Get all the script nodes, and get all the folder paths + const script_nodes = nodes.filter(n => n.type === 'script'); + const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat(); + console.log(script_paths); // Run evaluator in backend const codeTextOnRun = codeText + ''; - fetch('http://localhost:5000/execute', { + fetch(BASE_URL + 'execute', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ @@ -83,6 +90,7 @@ const EvaluatorNode = ({ data, id }) => { scope: mapScope, responses: input_node_ids, reduce_vars: reduceMethod === 'avg' ? reduceVars : [], + script_paths: script_paths, // write an extra part here that takes in reduce func }), }, rejected).then(function(response) { diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js index e00ebc3..9788d85 100644 --- a/chain-forge/src/InspectorNode.js +++ b/chain-forge/src/InspectorNode.js @@ -1,6 +1,7 @@ import React, { useState } from 'react'; import { Handle } from 'react-flow-renderer'; import useStore from './store'; +import {BASE_URL} from './store'; const bucketResponsesByLLM = (responses) => { let responses_by_llm = {}; @@ -30,7 +31,7 @@ const InspectorNode = ({ data, id }) => { console.log(input_node_ids); // Grab responses associated with those ids: - fetch('http://localhost:5000/grabResponses', { + fetch(BASE_URL + 'grabResponses', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index e7290f5..be89611 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -8,6 +8,7 @@ import NodeLabel from './NodeLabelComponent' import TemplateHooks from './TemplateHooksComponent' import LLMList from './LLMListComponent' import AlertModal from './AlertModal' +import {BASE_URL} from './store'; // Available LLMs const allLLMs = [ @@ -186,7 +187,7 @@ const PromptNode = ({ data, id }) => { }; // Run all prompt permutations through the LLM to generate + cache responses: - fetch('http://localhost:5000/queryllm', { + fetch(BASE_URL + 'queryllm', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ diff --git a/chain-forge/src/ScriptNode.js b/chain-forge/src/ScriptNode.js new file mode 100644 index 0000000..adb662c --- /dev/null +++ b/chain-forge/src/ScriptNode.js @@ -0,0 +1,77 @@ +import React, { useState, useRef, useEffect, useCallback } from 'react'; +import useStore from './store'; +import NodeLabel from './NodeLabelComponent' + + +const ScriptNode = ({ data, id }) => { + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + const delButtonId = 'del-'; + // Handle a change in a scripts' input. + const handleInputChange = useCallback((event) => { + // Update the data for this script node's id. + let new_data = { 'scriptFiles': { ...data.scriptFiles } }; + new_data.scriptFiles[event.target.id] = event.target.value; + console.log(new_data); + setDataPropsForNode(id, new_data); + }, [data, id, setDataPropsForNode]); + + // Handle delete script file. + const handleDelete = useCallback((event) => { + // Update the data for this script node's id. + let new_data = { 'scriptFiles': { ...data.scriptFiles } }; + var item_id = event.target.id.substring(delButtonId.length); + delete new_data.scriptFiles[item_id]; + // if the new_data is empty, initialize it with one empty field + if (Object.keys(new_data.scriptFiles).length === 0) { + new_data.scriptFiles['f0'] = ''; + } + console.log(new_data); + setDataPropsForNode(id, new_data); + }, [data, id, setDataPropsForNode]); + + // Initialize fields (run once at init) + const [scriptFiles, setScriptFiles] = useState([]); + useEffect(() => { + if (!data.scriptFiles) + setDataPropsForNode(id, { scriptFiles: { f0: '' } }); + }, [data.scriptFiles, id, setDataPropsForNode]); + + // Whenever 'data' changes, update the input fields to reflect the current state. + useEffect(() => { + const f = data.scriptFiles ? Object.keys(data.scriptFiles) : ['f0']; + setScriptFiles(f.map((i) => { + const val = data.scriptFiles ? data.scriptFiles[i] : ''; + return ( +
+
+
+ ) + })); + }, [data.scriptFiles, handleInputChange, handleDelete]); + + // Add a field + const handleAddField = useCallback(() => { + // Update the data for this script node's id. + const num_files = data.scriptFiles ? Object.keys(data.scriptFiles).length : 0; + let new_data = { 'scriptFiles': { ...data.scriptFiles } }; + new_data.scriptFiles['f' + num_files.toString()] = ""; + setDataPropsForNode(id, new_data); + }, [data, id, setDataPropsForNode]); + + return ( +
+
+ +
+

+
+ {scriptFiles} +
+
+ +
+
+ ); +}; + +export default ScriptNode; \ No newline at end of file diff --git a/chain-forge/src/VisNode.js b/chain-forge/src/VisNode.js index 80e14c2..615f55f 100644 --- a/chain-forge/src/VisNode.js +++ b/chain-forge/src/VisNode.js @@ -4,7 +4,8 @@ import useStore from './store'; import Plot from 'react-plotly.js'; import { hover } from '@testing-library/user-event/dist/hover'; import { create } from 'zustand'; -import NodeLabel from './NodeLabelComponent' +import NodeLabel from './NodeLabelComponent'; +import {BASE_URL} from './store'; // Helper funcs const truncStr = (s, maxLen) => { @@ -54,7 +55,7 @@ const VisNode = ({ data, id }) => { // Grab the input node ids const input_node_ids = [data.input]; - fetch('http://localhost:5000/grabResponses', { + fetch(BASE_URL + 'grabResponses', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ diff --git a/chain-forge/src/index.css b/chain-forge/src/index.css index ec2585e..346977a 100644 --- a/chain-forge/src/index.css +++ b/chain-forge/src/index.css @@ -11,3 +11,14 @@ code { font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace; } + +.script-node { + background-color: #fff; + padding: 10px; + border: 1px solid #000; + border-radius: 5px; +} + +.script-node-input { + min-width: 300px; +} \ No newline at end of file diff --git a/chain-forge/src/store.js b/chain-forge/src/store.js index 0354426..f058c3f 100644 --- a/chain-forge/src/store.js +++ b/chain-forge/src/store.js @@ -33,6 +33,8 @@ const initialEdges = [ { id: 'e1-2', source: initprompt, target: initeval, interactionWidth: 100}, ]; +export const BASE_URL = 'http://localhost:8000/'; + // TypeScript only // type RFState = { // nodes: Node[]; @@ -81,7 +83,7 @@ const useStore = create((set, get) => ({ )(get().nodes) }); }, - getNode: (id) => get().nodes.find(n => n.id == id), + getNode: (id) => get().nodes.find(n => n.id === id), addNode: (newnode) => { set({ nodes: get().nodes.concat(newnode) @@ -122,4 +124,4 @@ const useStore = create((set, get) => ({ }, })); -export default useStore; \ No newline at end of file +export default useStore; diff --git a/python-backend/app.py b/python-backend/app.py index fbafb5a..0ef553a 100644 --- a/python-backend/app.py +++ b/python-backend/app.py @@ -5,7 +5,8 @@ from flask import Flask, request, jsonify from flask_cors import CORS from promptengine.query import PromptLLM from promptengine.template import PromptTemplate, PromptPermutationGenerator -from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir +from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists +import sys app = Flask(__name__) CORS(app) @@ -138,6 +139,10 @@ def reduce_responses(responses: list, vars: list) -> list: return ret +@app.route('/test', methods=['GET']) +def test(): + return "Hello, world!" + @app.route('/queryllm', methods=['POST']) def queryLLM(): """ @@ -175,9 +180,13 @@ def queryLLM(): if 'no_cache' in data and data['no_cache'] is True: remove_cached_responses(data['id']) + # Create a cache dir if it doesn't exist: + create_dir_if_not_exists('cache') + # For each LLM, generate and cache responses: responses = {} params = data['params'] if 'params' in data else {} + for llm in llms: # Check that storage path is valid: @@ -195,6 +204,7 @@ def queryLLM(): for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params): responses[llm].append(response) except Exception as e: + print('error generating responses:', e) raise e return jsonify({'error': str(e)}) @@ -206,6 +216,7 @@ def queryLLM(): ] # Return all responses for all LLMs + print('returning responses:', res) ret = jsonify({'responses': res}) ret.headers.add('Access-Control-Allow-Origin', '*') return ret @@ -224,6 +235,7 @@ def execute(): 'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch. # If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt. 'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range) + 'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated } NOTE: This should only be run on your server on code you trust. @@ -250,6 +262,24 @@ def execute(): elif isinstance(data['responses'], str): data['responses'] = [ data['responses'] ] + # add the path to any scripts to the path: + try: + if 'script_paths' in data: + for script_path in data['script_paths']: + # get the folder of the script_path: + script_folder = os.path.dirname(script_path) + # check that the script_folder is valid, and it contains __init__.py + if not os.path.exists(script_folder): + print(script_folder, 'is not a valid script path.') + print(os.path.exists(script_folder)) + continue + + # add it to the path: + sys.path.append(script_folder) + print(f'added {script_folder} to sys.path') + except Exception as e: + return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'}) + # Create the evaluator function # DANGER DANGER! try: @@ -379,4 +409,4 @@ def grabResponses(): return ret if __name__ == '__main__': - app.run(host="localhost", port=5000, debug=True) \ No newline at end of file + app.run(host="localhost", port=8000, debug=True) \ No newline at end of file diff --git a/python-backend/promptengine/utils.py b/python-backend/promptengine/utils.py index 2d5b937..456a7e6 100644 --- a/python-backend/promptengine/utils.py +++ b/python-backend/promptengine/utils.py @@ -6,6 +6,8 @@ import json, os, time DALAI_MODEL = None DALAI_RESPONSE = None +openai.api_key = os.environ.get("OPENAI_API_KEY") + """ Supported LLM coding assistants """ class LLM(Enum): ChatGPT = 0 @@ -141,6 +143,10 @@ def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]: else: raise ValueError(f"LLM {llm} is unsupported.") +def create_dir_if_not_exists(path: str) -> None: + if not os.path.exists(path): + os.makedirs(path) + def is_valid_filepath(filepath: str) -> bool: try: with open(filepath, 'r'): diff --git a/python-backend/test.html b/python-backend/test.html index 0d74bc3..110ca67 100644 --- a/python-backend/test.html +++ b/python-backend/test.html @@ -9,7 +9,7 @@