From 03f2847f53c332f9bb20e8da3c0e373e8b23a1f9 Mon Sep 17 00:00:00 2001 From: Priyan Vaithilingam Date: Tue, 2 May 2023 18:32:12 -0400 Subject: [PATCH 1/5] basic script node implementation --- chain-forge/src/App.js | 7 +++ chain-forge/src/EvaluatorNode.js | 9 +++- chain-forge/src/InspectorNode.js | 2 +- chain-forge/src/PromptNode.js | 2 +- chain-forge/src/ScriptNode.js | 79 ++++++++++++++++++++++++++++ chain-forge/src/VisNode.js | 2 +- chain-forge/src/index.css | 11 ++++ python-backend/app.py | 30 ++++++++++- python-backend/promptengine/utils.py | 8 +++ python-backend/test.html | 4 +- 10 files changed, 146 insertions(+), 8 deletions(-) create mode 100644 chain-forge/src/ScriptNode.js diff --git a/chain-forge/src/App.js b/chain-forge/src/App.js index 5a06c60..50698a2 100644 --- a/chain-forge/src/App.js +++ b/chain-forge/src/App.js @@ -14,6 +14,7 @@ import PromptNode from './PromptNode'; import EvaluatorNode from './EvaluatorNode'; import VisNode from './VisNode'; import InspectNode from './InspectorNode'; +import ScriptNode from './ScriptNode'; import './text-fields-node.css'; // State management (from https://reactflow.dev/docs/guides/state-management/) @@ -39,6 +40,7 @@ const nodeTypes = { evaluator: EvaluatorNode, vis: VisNode, inspect: InspectNode, + script: ScriptNode }; const connectionLineStyle = { stroke: '#ddd' }; @@ -84,6 +86,10 @@ const App = () => { const { x, y } = getViewportCenter(); addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} }); }; + const addScriptNode = (event) => { + const { x, y } = getViewportCenter(); + addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} }); + }; /** * SAVING / LOADING, IMPORT / EXPORT (from JSON) @@ -194,6 +200,7 @@ const App = () => { + diff --git a/chain-forge/src/EvaluatorNode.js b/chain-forge/src/EvaluatorNode.js index 7265b6f..c779c81 100644 --- a/chain-forge/src/EvaluatorNode.js +++ b/chain-forge/src/EvaluatorNode.js @@ -43,6 +43,7 @@ const EvaluatorNode = ({ data, id }) => { const getNode = useStore((state) => state.getNode); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const [status, setStatus] = useState('none'); + const nodes = useStore((state) => state.nodes); // Mantine modal popover for alerts const [opened, { open, close }] = useDisclosure(false); @@ -82,6 +83,7 @@ const EvaluatorNode = ({ data, id }) => { }; const handleRunClick = (event) => { + // Get the ids from the connected input nodes: const input_node_ids = inputEdgesForNode(id).map(e => e.source); if (input_node_ids.length === 0) { @@ -105,9 +107,13 @@ const EvaluatorNode = ({ data, id }) => { triggerErrorAlert(err_msg); }; + // Get all the script nodes, and get all the folder paths + const script_nodes = nodes.filter(n => n.type === 'script'); + const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat(); + console.log(script_paths); // Run evaluator in backend const codeTextOnRun = codeText + ''; - fetch('http://localhost:5000/execute', { + fetch('http://localhost:8000/execute', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ @@ -116,6 +122,7 @@ const EvaluatorNode = ({ data, id }) => { scope: mapScope, responses: input_node_ids, reduce_vars: reduceMethod === 'avg' ? reduceVars : [], + script_paths: script_paths, // write an extra part here that takes in reduce func }), }, rejected).then(function(response) { diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js index d1d38da..3d71a41 100644 --- a/chain-forge/src/InspectorNode.js +++ b/chain-forge/src/InspectorNode.js @@ -30,7 +30,7 @@ const InspectorNode = ({ data, id }) => { console.log(input_node_ids); // Grab responses associated with those ids: - fetch('http://localhost:5000/grabResponses', { + fetch('http://localhost:8000/grabResponses', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js index 98263ad..ff5b3e2 100644 --- a/chain-forge/src/PromptNode.js +++ b/chain-forge/src/PromptNode.js @@ -125,7 +125,7 @@ const PromptNode = ({ data, id }) => { const py_prompt_template = promptText.replace(/(? { + const _union = new Set(setA); + for (const elem of setB) { + _union.add(elem); + } + return _union; +} + +const ScriptNode = ({ data, id }) => { + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + + // Handle a change in a text fields' input. + const handleInputChange = useCallback((event) => { + // Update the data for this script node's id. + let new_data = { 'scriptFiles': { ...data.scriptFiles } }; + new_data.scriptFiles[event.target.id] = event.target.value; + console.log(new_data); + setDataPropsForNode(id, new_data); + }, [data, id, setDataPropsForNode]); + + // Initialize fields (run once at init) + const [scriptFiles, setScriptFiles] = useState([]); + useEffect(() => { + if (!data.scriptFiles) + setDataPropsForNode(id, { scriptFiles: { f0: '' } }); + }, []); + + // Whenever 'data' changes, update the input fields to reflect the current state. + useEffect(() => { + const f = data.scriptFiles ? Object.keys(data.scriptFiles) : ['f0']; + setScriptFiles(f.map((i) => { + const val = data.scriptFiles ? data.scriptFiles[i] : ''; + return ( +
+
+
+ ) + })); + }, [data.scriptFiles]); + + // Add a field + const handleAddField = useCallback(() => { + // Update the data for this script node's id. + const num_files = data.scriptFiles ? Object.keys(data.scriptFiles).length : 0; + let new_data = { 'scriptFiles': { ...data.scriptFiles } }; + new_data.scriptFiles['f' + num_files.toString()] = ""; + setDataPropsForNode(id, new_data); + }, [data, id, setDataPropsForNode]); + + // Dynamically update the y-position of the template hook s + const ref = useRef(null); + const [hooksY, setHooksY] = useState(120); + useEffect(() => { + const node_height = ref.current.clientHeight; + setHooksY(node_height + 70); + }, [scriptFiles]); + + return ( +
+
+ +
+
+ {scriptFiles} +
+
+ +
+
+ ); +}; + +export default ScriptNode; \ No newline at end of file diff --git a/chain-forge/src/VisNode.js b/chain-forge/src/VisNode.js index 80e14c2..4a334a7 100644 --- a/chain-forge/src/VisNode.js +++ b/chain-forge/src/VisNode.js @@ -54,7 +54,7 @@ const VisNode = ({ data, id }) => { // Grab the input node ids const input_node_ids = [data.input]; - fetch('http://localhost:5000/grabResponses', { + fetch('http://localhost:8000/grabResponses', { method: 'POST', headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, body: JSON.stringify({ diff --git a/chain-forge/src/index.css b/chain-forge/src/index.css index ec2585e..346977a 100644 --- a/chain-forge/src/index.css +++ b/chain-forge/src/index.css @@ -11,3 +11,14 @@ code { font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace; } + +.script-node { + background-color: #fff; + padding: 10px; + border: 1px solid #000; + border-radius: 5px; +} + +.script-node-input { + min-width: 300px; +} \ No newline at end of file diff --git a/python-backend/app.py b/python-backend/app.py index 484034e..d83f4c4 100644 --- a/python-backend/app.py +++ b/python-backend/app.py @@ -5,7 +5,8 @@ from flask import Flask, request, jsonify from flask_cors import CORS from promptengine.query import PromptLLM from promptengine.template import PromptTemplate, PromptPermutationGenerator -from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir +from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists +import sys app = Flask(__name__) CORS(app) @@ -137,6 +138,10 @@ def reduce_responses(responses: list, vars: list) -> list: return ret +@app.route('/test', methods=['GET']) +def test(): + return "Hello, world!" + @app.route('/queryllm', methods=['POST']) def queryLLM(): """ @@ -154,6 +159,8 @@ def queryLLM(): """ data = request.get_json() + print('got a request!', data) + # Check that all required info is here: if not set(data.keys()).issuperset({'llm', 'prompt', 'vars', 'id'}): return jsonify({'error': 'POST data is improper format.'}) @@ -174,9 +181,13 @@ def queryLLM(): if 'no_cache' in data and data['no_cache'] is True: remove_cached_responses(data['id']) + # Create a cache dir if it doesn't exist: + create_dir_if_not_exists('cache') + # For each LLM, generate and cache responses: responses = {} params = data['params'] if 'params' in data else {} + for llm in llms: # Check that storage path is valid: @@ -194,6 +205,7 @@ def queryLLM(): for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params): responses[llm].append(response) except Exception as e: + print('error generating responses:', e) raise e return jsonify({'error': str(e)}) @@ -205,6 +217,7 @@ def queryLLM(): ] # Return all responses for all LLMs + print('returning responses:', res) ret = jsonify({'responses': res}) ret.headers.add('Access-Control-Allow-Origin', '*') return ret @@ -223,6 +236,7 @@ def execute(): 'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch. # If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt. 'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range) + 'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated } NOTE: This should only be run on your server on code you trust. @@ -249,6 +263,18 @@ def execute(): elif isinstance(data['responses'], str): data['responses'] = [ data['responses'] ] + # add the path to any scripts to the path: + try: + if 'script_paths' in data: + for script_path in data['script_paths']: + # get the folder of the script_path: + script_folder = os.path.dirname(script_path) + # add it to the path: + sys.path.append(script_folder) + print(f'added {script_folder} to sys.path') + except Exception as e: + return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'}) + # Create the evaluator function # DANGER DANGER! try: @@ -378,4 +404,4 @@ def grabResponses(): return ret if __name__ == '__main__': - app.run(host="localhost", port=5000, debug=True) \ No newline at end of file + app.run(host="localhost", port=8000, debug=True) \ No newline at end of file diff --git a/python-backend/promptengine/utils.py b/python-backend/promptengine/utils.py index 0637530..73a7cf5 100644 --- a/python-backend/promptengine/utils.py +++ b/python-backend/promptengine/utils.py @@ -6,6 +6,10 @@ import json, os, time DALAI_MODEL = None DALAI_RESPONSE = None +openai.api_key = os.environ.get("OPENAI_API_KEY") + + + """ Supported LLM coding assistants """ class LLM(Enum): ChatGPT = 0 @@ -136,6 +140,10 @@ def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]: else: raise ValueError(f"LLM {llm} is unsupported.") +def create_dir_if_not_exists(path: str) -> None: + if not os.path.exists(path): + os.makedirs(path) + def is_valid_filepath(filepath: str) -> bool: try: with open(filepath, 'r'): diff --git a/python-backend/test.html b/python-backend/test.html index 0d74bc3..b340f9b 100644 --- a/python-backend/test.html +++ b/python-backend/test.html @@ -9,7 +9,7 @@