diff --git a/chain-forge/src/App.js b/chain-forge/src/App.js
index 5a06c60..50698a2 100644
--- a/chain-forge/src/App.js
+++ b/chain-forge/src/App.js
@@ -14,6 +14,7 @@ import PromptNode from './PromptNode';
import EvaluatorNode from './EvaluatorNode';
import VisNode from './VisNode';
import InspectNode from './InspectorNode';
+import ScriptNode from './ScriptNode';
import './text-fields-node.css';
// State management (from https://reactflow.dev/docs/guides/state-management/)
@@ -39,6 +40,7 @@ const nodeTypes = {
evaluator: EvaluatorNode,
vis: VisNode,
inspect: InspectNode,
+ script: ScriptNode
};
const connectionLineStyle = { stroke: '#ddd' };
@@ -84,6 +86,10 @@ const App = () => {
const { x, y } = getViewportCenter();
addNode({ id: 'inspectNode-'+Date.now(), type: 'inspect', data: {}, position: {x: x-200, y:y-100} });
};
+ const addScriptNode = (event) => {
+ const { x, y } = getViewportCenter();
+ addNode({ id: 'scriptNode-'+Date.now(), type: 'script', data: {}, position: {x: x-200, y:y-100} });
+ };
/**
* SAVING / LOADING, IMPORT / EXPORT (from JSON)
@@ -194,6 +200,7 @@ const App = () => {
+
diff --git a/chain-forge/src/EvaluatorNode.js b/chain-forge/src/EvaluatorNode.js
index a7841e8..34edb91 100644
--- a/chain-forge/src/EvaluatorNode.js
+++ b/chain-forge/src/EvaluatorNode.js
@@ -4,6 +4,7 @@ import useStore from './store';
import StatusIndicator from './StatusIndicatorComponent'
import NodeLabel from './NodeLabelComponent'
import AlertModal from './AlertModal'
+import {BASE_URL} from './store';
// Ace code editor
import AceEditor from "react-ace";
@@ -18,6 +19,7 @@ const EvaluatorNode = ({ data, id }) => {
const getNode = useStore((state) => state.getNode);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const [status, setStatus] = useState('none');
+ const nodes = useStore((state) => state.nodes);
// For displaying error messages to user
const alertModal = useRef(null);
@@ -49,6 +51,7 @@ const EvaluatorNode = ({ data, id }) => {
};
const handleRunClick = (event) => {
+
// Get the ids from the connected input nodes:
const input_node_ids = inputEdgesForNode(id).map(e => e.source);
if (input_node_ids.length === 0) {
@@ -72,9 +75,13 @@ const EvaluatorNode = ({ data, id }) => {
alertModal.current.trigger(err_msg);
};
+ // Get all the script nodes, and get all the folder paths
+ const script_nodes = nodes.filter(n => n.type === 'script');
+ const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
+ console.log(script_paths);
// Run evaluator in backend
const codeTextOnRun = codeText + '';
- fetch('http://localhost:5000/execute', {
+ fetch(BASE_URL + 'execute', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
@@ -83,6 +90,7 @@ const EvaluatorNode = ({ data, id }) => {
scope: mapScope,
responses: input_node_ids,
reduce_vars: reduceMethod === 'avg' ? reduceVars : [],
+ script_paths: script_paths,
// write an extra part here that takes in reduce func
}),
}, rejected).then(function(response) {
diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js
index e00ebc3..9788d85 100644
--- a/chain-forge/src/InspectorNode.js
+++ b/chain-forge/src/InspectorNode.js
@@ -1,6 +1,7 @@
import React, { useState } from 'react';
import { Handle } from 'react-flow-renderer';
import useStore from './store';
+import {BASE_URL} from './store';
const bucketResponsesByLLM = (responses) => {
let responses_by_llm = {};
@@ -30,7 +31,7 @@ const InspectorNode = ({ data, id }) => {
console.log(input_node_ids);
// Grab responses associated with those ids:
- fetch('http://localhost:5000/grabResponses', {
+ fetch(BASE_URL + 'grabResponses', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
diff --git a/chain-forge/src/PromptNode.js b/chain-forge/src/PromptNode.js
index e7290f5..be89611 100644
--- a/chain-forge/src/PromptNode.js
+++ b/chain-forge/src/PromptNode.js
@@ -8,6 +8,7 @@ import NodeLabel from './NodeLabelComponent'
import TemplateHooks from './TemplateHooksComponent'
import LLMList from './LLMListComponent'
import AlertModal from './AlertModal'
+import {BASE_URL} from './store';
// Available LLMs
const allLLMs = [
@@ -186,7 +187,7 @@ const PromptNode = ({ data, id }) => {
};
// Run all prompt permutations through the LLM to generate + cache responses:
- fetch('http://localhost:5000/queryllm', {
+ fetch(BASE_URL + 'queryllm', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
diff --git a/chain-forge/src/ScriptNode.js b/chain-forge/src/ScriptNode.js
new file mode 100644
index 0000000..adb662c
--- /dev/null
+++ b/chain-forge/src/ScriptNode.js
@@ -0,0 +1,77 @@
+import React, { useState, useRef, useEffect, useCallback } from 'react';
+import useStore from './store';
+import NodeLabel from './NodeLabelComponent'
+
+
+const ScriptNode = ({ data, id }) => {
+ const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
+ const delButtonId = 'del-';
+ // Handle a change in a scripts' input.
+ const handleInputChange = useCallback((event) => {
+ // Update the data for this script node's id.
+ let new_data = { 'scriptFiles': { ...data.scriptFiles } };
+ new_data.scriptFiles[event.target.id] = event.target.value;
+ console.log(new_data);
+ setDataPropsForNode(id, new_data);
+ }, [data, id, setDataPropsForNode]);
+
+ // Handle delete script file.
+ const handleDelete = useCallback((event) => {
+ // Update the data for this script node's id.
+ let new_data = { 'scriptFiles': { ...data.scriptFiles } };
+ var item_id = event.target.id.substring(delButtonId.length);
+ delete new_data.scriptFiles[item_id];
+ // if the new_data is empty, initialize it with one empty field
+ if (Object.keys(new_data.scriptFiles).length === 0) {
+ new_data.scriptFiles['f0'] = '';
+ }
+ console.log(new_data);
+ setDataPropsForNode(id, new_data);
+ }, [data, id, setDataPropsForNode]);
+
+ // Initialize fields (run once at init)
+ const [scriptFiles, setScriptFiles] = useState([]);
+ useEffect(() => {
+ if (!data.scriptFiles)
+ setDataPropsForNode(id, { scriptFiles: { f0: '' } });
+ }, [data.scriptFiles, id, setDataPropsForNode]);
+
+ // Whenever 'data' changes, update the input fields to reflect the current state.
+ useEffect(() => {
+ const f = data.scriptFiles ? Object.keys(data.scriptFiles) : ['f0'];
+ setScriptFiles(f.map((i) => {
+ const val = data.scriptFiles ? data.scriptFiles[i] : '';
+ return (
+
+
+
+ )
+ }));
+ }, [data.scriptFiles, handleInputChange, handleDelete]);
+
+ // Add a field
+ const handleAddField = useCallback(() => {
+ // Update the data for this script node's id.
+ const num_files = data.scriptFiles ? Object.keys(data.scriptFiles).length : 0;
+ let new_data = { 'scriptFiles': { ...data.scriptFiles } };
+ new_data.scriptFiles['f' + num_files.toString()] = "";
+ setDataPropsForNode(id, new_data);
+ }, [data, id, setDataPropsForNode]);
+
+ return (
+
+
+
+
+
+
+ {scriptFiles}
+
+
+
+
+
+ );
+};
+
+export default ScriptNode;
\ No newline at end of file
diff --git a/chain-forge/src/VisNode.js b/chain-forge/src/VisNode.js
index 80e14c2..615f55f 100644
--- a/chain-forge/src/VisNode.js
+++ b/chain-forge/src/VisNode.js
@@ -4,7 +4,8 @@ import useStore from './store';
import Plot from 'react-plotly.js';
import { hover } from '@testing-library/user-event/dist/hover';
import { create } from 'zustand';
-import NodeLabel from './NodeLabelComponent'
+import NodeLabel from './NodeLabelComponent';
+import {BASE_URL} from './store';
// Helper funcs
const truncStr = (s, maxLen) => {
@@ -54,7 +55,7 @@ const VisNode = ({ data, id }) => {
// Grab the input node ids
const input_node_ids = [data.input];
- fetch('http://localhost:5000/grabResponses', {
+ fetch(BASE_URL + 'grabResponses', {
method: 'POST',
headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
body: JSON.stringify({
diff --git a/chain-forge/src/index.css b/chain-forge/src/index.css
index ec2585e..346977a 100644
--- a/chain-forge/src/index.css
+++ b/chain-forge/src/index.css
@@ -11,3 +11,14 @@ code {
font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New',
monospace;
}
+
+.script-node {
+ background-color: #fff;
+ padding: 10px;
+ border: 1px solid #000;
+ border-radius: 5px;
+}
+
+.script-node-input {
+ min-width: 300px;
+}
\ No newline at end of file
diff --git a/chain-forge/src/store.js b/chain-forge/src/store.js
index 0354426..f058c3f 100644
--- a/chain-forge/src/store.js
+++ b/chain-forge/src/store.js
@@ -33,6 +33,8 @@ const initialEdges = [
{ id: 'e1-2', source: initprompt, target: initeval, interactionWidth: 100},
];
+export const BASE_URL = 'http://localhost:8000/';
+
// TypeScript only
// type RFState = {
// nodes: Node[];
@@ -81,7 +83,7 @@ const useStore = create((set, get) => ({
)(get().nodes)
});
},
- getNode: (id) => get().nodes.find(n => n.id == id),
+ getNode: (id) => get().nodes.find(n => n.id === id),
addNode: (newnode) => {
set({
nodes: get().nodes.concat(newnode)
@@ -122,4 +124,4 @@ const useStore = create((set, get) => ({
},
}));
-export default useStore;
\ No newline at end of file
+export default useStore;
diff --git a/python-backend/app.py b/python-backend/app.py
index fbafb5a..0ef553a 100644
--- a/python-backend/app.py
+++ b/python-backend/app.py
@@ -5,7 +5,8 @@ from flask import Flask, request, jsonify
from flask_cors import CORS
from promptengine.query import PromptLLM
from promptengine.template import PromptTemplate, PromptPermutationGenerator
-from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir
+from promptengine.utils import LLM, extract_responses, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists
+import sys
app = Flask(__name__)
CORS(app)
@@ -138,6 +139,10 @@ def reduce_responses(responses: list, vars: list) -> list:
return ret
+@app.route('/test', methods=['GET'])
+def test():
+ return "Hello, world!"
+
@app.route('/queryllm', methods=['POST'])
def queryLLM():
"""
@@ -175,9 +180,13 @@ def queryLLM():
if 'no_cache' in data and data['no_cache'] is True:
remove_cached_responses(data['id'])
+ # Create a cache dir if it doesn't exist:
+ create_dir_if_not_exists('cache')
+
# For each LLM, generate and cache responses:
responses = {}
params = data['params'] if 'params' in data else {}
+
for llm in llms:
# Check that storage path is valid:
@@ -195,6 +204,7 @@ def queryLLM():
for response in prompter.gen_responses(properties=data['vars'], llm=llm, **params):
responses[llm].append(response)
except Exception as e:
+ print('error generating responses:', e)
raise e
return jsonify({'error': str(e)})
@@ -206,6 +216,7 @@ def queryLLM():
]
# Return all responses for all LLMs
+ print('returning responses:', res)
ret = jsonify({'responses': res})
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
@@ -224,6 +235,7 @@ def execute():
'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
# If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
'reduce_vars': unspecified | List[str] # the 'vars' to average over (mean, median, stdev, range)
+ 'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
}
NOTE: This should only be run on your server on code you trust.
@@ -250,6 +262,24 @@ def execute():
elif isinstance(data['responses'], str):
data['responses'] = [ data['responses'] ]
+ # add the path to any scripts to the path:
+ try:
+ if 'script_paths' in data:
+ for script_path in data['script_paths']:
+ # get the folder of the script_path:
+ script_folder = os.path.dirname(script_path)
+ # check that the script_folder is valid, and it contains __init__.py
+ if not os.path.exists(script_folder):
+ print(script_folder, 'is not a valid script path.')
+ print(os.path.exists(script_folder))
+ continue
+
+ # add it to the path:
+ sys.path.append(script_folder)
+ print(f'added {script_folder} to sys.path')
+ except Exception as e:
+ return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
+
# Create the evaluator function
# DANGER DANGER!
try:
@@ -379,4 +409,4 @@ def grabResponses():
return ret
if __name__ == '__main__':
- app.run(host="localhost", port=5000, debug=True)
\ No newline at end of file
+ app.run(host="localhost", port=8000, debug=True)
\ No newline at end of file
diff --git a/python-backend/promptengine/utils.py b/python-backend/promptengine/utils.py
index 2d5b937..456a7e6 100644
--- a/python-backend/promptengine/utils.py
+++ b/python-backend/promptengine/utils.py
@@ -6,6 +6,8 @@ import json, os, time
DALAI_MODEL = None
DALAI_RESPONSE = None
+openai.api_key = os.environ.get("OPENAI_API_KEY")
+
""" Supported LLM coding assistants """
class LLM(Enum):
ChatGPT = 0
@@ -141,6 +143,10 @@ def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
else:
raise ValueError(f"LLM {llm} is unsupported.")
+def create_dir_if_not_exists(path: str) -> None:
+ if not os.path.exists(path):
+ os.makedirs(path)
+
def is_valid_filepath(filepath: str) -> bool:
try:
with open(filepath, 'r'):
diff --git a/python-backend/test.html b/python-backend/test.html
index 0d74bc3..110ca67 100644
--- a/python-backend/test.html
+++ b/python-backend/test.html
@@ -9,7 +9,7 @@