diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py
index ec337c7..1bddb50 100644
--- a/chainforge/flask_app.py
+++ b/chainforge/flask_app.py
@@ -8,7 +8,7 @@ from flask_cors import CORS
from chainforge.promptengine.query import PromptLLM, PromptLLMDummy, LLMResponseException
from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator
from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys
-
+import requests as py_requests
""" =================
SETUP AND GLOBALS
@@ -679,6 +679,82 @@ async def queryLLM():
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
+@app.route('/app/executepy', methods=['POST'])
+def executepy():
+ """
+ Executes a Python function sent from JavaScript,
+ over all the `StandardizedLLMResponse` objects passed in from the front-end.
+
+ POST'd data should be in the form:
+ {
+ 'id': # a unique ID to refer to this information. Used when cache'ing responses.
+ 'code': str, # the body of the lambda function to evaluate, in form: lambda responses:
+ 'responses': List[StandardizedLLMResponse] # the responses to run on.
+ 'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch.
+ # If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt.
+ 'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated
+ }
+
+ NOTE: This should only be run on your server on code you trust.
+ There is no sandboxing; no safety. We assume you are the creator of the code.
+ """
+ data = request.get_json()
+
+ # Check that all required info is here:
+ if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
+ return jsonify({'error': 'POST data is improper format.'})
+ if not isinstance(data['id'], str) or len(data['id']) == 0:
+ return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
+ if data['scope'] not in ('response', 'batch'):
+ return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."})
+
+ # Check format of responses:
+ responses = data['responses']
+ if (isinstance(responses, str) or not isinstance(responses, list)) or (len(responses) > 0 and any([not isinstance(r, dict) for r in responses])):
+ return jsonify({'error': 'POST data responses is improper format.'})
+
+ # add the path to any scripts to the path:
+ try:
+ if 'script_paths' in data:
+ for script_path in data['script_paths']:
+ # get the folder of the script_path:
+ script_folder = os.path.dirname(script_path)
+ # check that the script_folder is valid, and it contains __init__.py
+ if not os.path.exists(script_folder):
+ print(script_folder, 'is not a valid script path.')
+ continue
+
+ # add it to the path:
+ sys.path.append(script_folder)
+ print(f'added {script_folder} to sys.path')
+ except Exception as e:
+ return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'})
+
+ # Create the evaluator function
+ # DANGER DANGER!
+ try:
+ exec(data['code'], globals())
+
+ # Double-check that there is an 'evaluate' method in our namespace.
+ # This will throw a NameError if not:
+ evaluate # noqa
+ except Exception as e:
+ return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'})
+
+ evald_responses = []
+ logs = []
+ try:
+ HIJACK_PYTHON_PRINT()
+ evald_responses = run_over_responses(evaluate, responses, scope=data['scope']) # noqa
+ logs = REVERT_PYTHON_PRINT()
+ except Exception as e:
+ logs = REVERT_PYTHON_PRINT()
+ return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}', 'logs': logs})
+
+ ret = jsonify({'responses': evald_responses, 'logs': logs})
+ ret.headers.add('Access-Control-Allow-Origin', '*')
+ return ret
+
@app.route('/app/execute', methods=['POST'])
def execute():
"""
@@ -1047,7 +1123,7 @@ def fetchEnvironAPIKeys():
keymap = {
'OPENAI_API_KEY': 'OpenAI',
'ANTHROPIC_API_KEY': 'Anthropic',
- 'GOOGLE_PALM_API_KEY': 'Google',
+ 'PALM_API_KEY': 'Google',
'AZURE_OPENAI_KEY': 'Azure_OpenAI',
'AZURE_OPENAI_ENDPOINT': 'Azure_OpenAI_Endpoint'
}
@@ -1056,6 +1132,39 @@ def fetchEnvironAPIKeys():
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
+@app.route('/app/makeFetchCall', methods=['POST'])
+def makeFetchCall():
+ """
+ Use in place of JavaScript's 'fetch' (with POST method), in cases where
+ cross-origin policy blocks responses from client-side fetches.
+
+ POST'd data should be in form:
+ {
+ url: # the url to fetch from
+ headers: # a JSON object of the headers
+ body: # the request payload, as JSON
+ }
+ """
+ # Verify post'd data
+ data = request.get_json()
+ if not set(data.keys()).issuperset({'url', 'headers', 'body'}):
+ return jsonify({'error': 'POST data is improper format.'})
+
+ url = data['url']
+ headers = data['headers']
+ body = data['body']
+
+ print(body)
+
+ response = py_requests.post(url, headers=headers, json=body)
+
+ if response.status_code == 200:
+ ret = jsonify({'response': response.json()})
+ ret.headers.add('Access-Control-Allow-Origin', '*')
+ return ret
+ else:
+ return jsonify({'error': 'API request to Anthropic failed'})
+
def run_server(host="", port=8000, cmd_args=None):
if cmd_args is not None and cmd_args.dummy_responses:
global PromptLLM
diff --git a/chainforge/react-server/craco.config.js b/chainforge/react-server/craco.config.js
index 67fbca8..9e23713 100644
--- a/chainforge/react-server/craco.config.js
+++ b/chainforge/react-server/craco.config.js
@@ -9,7 +9,7 @@ module.exports = {
resolve: {
fallback: {
"process": require.resolve("process/browser"),
- "buffer": require.resolve("buffer/"),
+ "buffer": require.resolve("buffer"),
"https": require.resolve("https-browserify"),
"querystring": require.resolve("querystring-es3"),
"url": require.resolve("url/"),
@@ -41,7 +41,13 @@ module.exports = {
add: [
new webpack.ProvidePlugin({
process: 'process/browser.js',
- })
+ }),
+
+ // Work around for Buffer is undefined:
+ // https://github.com/webpack/changelog-v5/issues/10
+ new webpack.ProvidePlugin({
+ Buffer: ['buffer', 'Buffer'],
+ }),
]
},
diff --git a/chainforge/react-server/package-lock.json b/chainforge/react-server/package-lock.json
index 0419a79..f70b7b2 100644
--- a/chainforge/react-server/package-lock.json
+++ b/chainforge/react-server/package-lock.json
@@ -61,6 +61,7 @@
"mathjs": "^11.8.2",
"net": "^1.0.2",
"net-browserify": "^0.2.4",
+ "node-fetch": "^2.6.11",
"openai": "^3.3.0",
"os-browserify": "^0.3.0",
"papaparse": "^5.4.1",
diff --git a/chainforge/react-server/package.json b/chainforge/react-server/package.json
index 9a71ced..65f7dc6 100644
--- a/chainforge/react-server/package.json
+++ b/chainforge/react-server/package.json
@@ -56,6 +56,7 @@
"mathjs": "^11.8.2",
"net": "^1.0.2",
"net-browserify": "^0.2.4",
+ "node-fetch": "^2.6.11",
"openai": "^3.3.0",
"os-browserify": "^0.3.0",
"papaparse": "^5.4.1",
diff --git a/chainforge/react-server/src/App.js b/chainforge/react-server/src/App.js
index fcd5453..04a852f 100644
--- a/chainforge/react-server/src/App.js
+++ b/chainforge/react-server/src/App.js
@@ -263,11 +263,11 @@ const App = () => {
// Save!
downloadJSON(flow_and_cache, `flow-${Date.now()}.cforge`);
- });
- }, [rfInstance, nodes]);
+ }).catch(handleError);
+ }, [rfInstance, nodes, handleError]);
// Import data to the cache stored on the local filesystem (in backend)
- const importCache = (cache_data) => {
+ const importCache = useCallback((cache_data) => {
return fetch_from_backend('importCache', {
'files': cache_data,
}, handleError).then(function(json) {
@@ -277,7 +277,7 @@ const App = () => {
throw new Error('Error importing cache data:' + json.error);
// Done!
}, handleError).catch(handleError);
- };
+ }, [handleError]);
const importFlowFromJSON = useCallback((flowJSON) => {
// Detect if there's no cache data
diff --git a/chainforge/react-server/src/EvaluatorNode.js b/chainforge/react-server/src/EvaluatorNode.js
index 6d714db..ff47301 100644
--- a/chainforge/react-server/src/EvaluatorNode.js
+++ b/chainforge/react-server/src/EvaluatorNode.js
@@ -109,21 +109,24 @@ const EvaluatorNode = ({ data, id }) => {
alertModal.current.trigger(err_msg);
};
- // Get all the script nodes, and get all the folder paths
- const script_nodes = nodes.filter(n => n.type === 'script');
- const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
+ // Get all the Python script nodes, and get all the folder paths
+ // NOTE: Python only!
+ let script_paths = [];
+ if (progLang === 'python') {
+ const script_nodes = nodes.filter(n => n.type === 'script');
+ script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat();
+ }
// Run evaluator in backend
const codeTextOnRun = codeText + '';
- const execute_route = (progLang === 'python') ? 'execute' : 'executejs';
+ const execute_route = (progLang === 'python') ? 'executepy' : 'executejs';
fetch_from_backend(execute_route, {
id: id,
code: codeTextOnRun,
responses: input_node_ids,
scope: mapScope,
- reduce_vars: [],
script_paths: script_paths,
- }, rejected).then(function(json) {
+ }).then(function(json) {
// Store any Python print output
if (json?.logs) {
let logs = json.logs;
@@ -154,7 +157,7 @@ const EvaluatorNode = ({ data, id }) => {
setCodeTextOnLastRun(codeTextOnRun);
setLastRunSuccess(true);
setStatus('ready');
- }, rejected);
+ }).catch((err) => rejected(err.message));
};
const handleOnMapScopeSelect = (event) => {
diff --git a/chainforge/react-server/src/VisNode.js b/chainforge/react-server/src/VisNode.js
index 5d1e92d..51c7614 100644
--- a/chainforge/react-server/src/VisNode.js
+++ b/chainforge/react-server/src/VisNode.js
@@ -577,7 +577,6 @@ const VisNode = ({ data, id }) => {
plot_simple_boxplot(get_llm, 'llm');
}
else if (varnames.length === 1) {
- console.log(varnames);
// 1 var; numeric eval
if (llm_names.length === 1) {
if (typeof_eval_res === 'Boolean')
diff --git a/chainforge/react-server/src/backend/__test__/template.test.ts b/chainforge/react-server/src/backend/__test__/template.test.ts
index 7479493..5cdbd08 100644
--- a/chainforge/react-server/src/backend/__test__/template.test.ts
+++ b/chainforge/react-server/src/backend/__test__/template.test.ts
@@ -25,18 +25,18 @@ test('string template escaped group', () => {
test('single template', () => {
let prompt_gen = new PromptPermutationGenerator('What is the {timeframe} when {person} was born?');
- let vars: {[key: string]: any} = {
- 'timeframe': ['year', 'decade', 'century'],
- 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']
- };
- let num_prompts = 0;
- for (const prompt of prompt_gen.generate(vars)) {
- // console.log(prompt.toString());
- expect(prompt.fill_history).toHaveProperty('timeframe');
- expect(prompt.fill_history).toHaveProperty('person');
- num_prompts += 1;
- }
- expect(num_prompts).toBe(9);
+ let vars: {[key: string]: any} = {
+ 'timeframe': ['year', 'decade', 'century'],
+ 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']
+ };
+ let num_prompts = 0;
+ for (const prompt of prompt_gen.generate(vars)) {
+ // console.log(prompt.toString());
+ expect(prompt.fill_history).toHaveProperty('timeframe');
+ expect(prompt.fill_history).toHaveProperty('person');
+ num_prompts += 1;
+ }
+ expect(num_prompts).toBe(9);
});
test('nested templates', () => {
diff --git a/chainforge/react-server/src/backend/__test__/utils.test.ts b/chainforge/react-server/src/backend/__test__/utils.test.ts
index e15b70b..5f95dcb 100644
--- a/chainforge/react-server/src/backend/__test__/utils.test.ts
+++ b/chainforge/react-server/src/backend/__test__/utils.test.ts
@@ -40,9 +40,9 @@ test('merge response objects', () => {
expect(merge_response_objs(undefined, B)).toBe(B);
})
-test('UNCOMMENT BELOW API CALL TESTS WHEN READY', () => {
- // NOTE: API CALL TESTS ASSUME YOUR ENVIRONMENT VARIABLE IS SET!
-});
+// test('UNCOMMENT BELOW API CALL TESTS WHEN READY', () => {
+// // NOTE: API CALL TESTS ASSUME YOUR ENVIRONMENT VARIABLE IS SET!
+// });
test('openai chat completions', async () => {
// Call ChatGPT with a basic question, and n=2
@@ -106,11 +106,3 @@ test('google palm2 models', async () => {
expect(typeof resps[0]).toBe('string');
console.log(JSON.stringify(resps));
}, 40000);
-
-// test('call_', async () => {
-// // Call Anthropic's Claude with a basic question
-// const [query, response] = await call_anthropic("Who invented modern playing cards? Keep your answer brief.", LLM.Claude_v1, 1, 1.0);
-// console.log(response);
-// expect(response).toHaveLength(1);
-// expect(query).toHaveProperty('max_tokens_to_sample');
-// }, 20000);
\ No newline at end of file
diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts
index 9a58e4c..dc6557b 100644
--- a/chainforge/react-server/src/backend/backend.ts
+++ b/chainforge/react-server/src/backend/backend.ts
@@ -10,10 +10,9 @@
// from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys
import { mean as __mean, std as __std, median as __median } from "mathjs";
-
import { Dict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse } from "./typing";
import { LLM } from "./models";
-import { APP_IS_RUNNING_LOCALLY, getEnumName, set_api_keys } from "./utils";
+import { APP_IS_RUNNING_LOCALLY, getEnumName, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils";
import StorageCache from "./cache";
import { PromptPipeline } from "./query";
import { PromptPermutationGenerator, PromptTemplate } from "./template";
@@ -23,9 +22,6 @@ import { PromptPermutationGenerator, PromptTemplate } from "./template";
// =================
// """
-/** Where the ChainForge Flask server is being hosted. */
-export const FLASK_BASE_URL = 'http://localhost:8000/';
-
let LLM_NAME_MAP = {};
Object.entries(LLM).forEach(([key, value]) => {
LLM_NAME_MAP[value] = key;
@@ -704,10 +700,6 @@ export async function executejs(id: string,
response_ids = [ response_ids ];
response_ids = response_ids as Array;
- console.log('executing js');
-
- // const iframe = document.createElement('iframe');
-
// Instantiate the evaluator function by eval'ing the passed code
// DANGER DANGER!!
let iframe: HTMLElement | undefined;
@@ -782,6 +774,64 @@ export async function executejs(id: string,
return {responses: all_evald_responses, logs: all_logs};
}
+/**
+ * Executes a Python 'evaluate' function over all cache'd responses with given id's.
+ * Requires user to be running on localhost, with Flask access.
+ *
+ * > **NOTE**: This should only be run on code you trust.
+ * There is no sandboxing; no safety. We assume you are the creator of the code.
+ *
+ * @param id a unique ID to refer to this information. Used when cache'ing evaluation results.
+ * @param code the code to evaluate. Must include an 'evaluate()' function that takes a 'response' of type ResponseInfo. Alternatively, can be the evaluate function itself.
+ * @param response_ids the cache'd response to run on, which must be a unique ID or list of unique IDs of cache'd data
+ * @param scope the scope of responses to run on --a single response, or all across each batch. (If batch, evaluate() func has access to 'responses'.)
+ */
+export async function executepy(id: string,
+ code: string | ((rinfo: ResponseInfo) => any),
+ response_ids: string | string[],
+ scope: 'response' | 'batch',
+ script_paths?: string[]): Promise {
+ if (!APP_IS_RUNNING_LOCALLY()) {
+ // We can't execute Python if we're not running the local Flask server. Error out:
+ throw new Error("Cannot evaluate Python code: ChainForge does not appear to be running on localhost.")
+ }
+
+ // Check format of response_ids
+ if (!Array.isArray(response_ids))
+ response_ids = [ response_ids ];
+ response_ids = response_ids as Array;
+
+ // Load cache'd responses for all response_ids:
+ const {responses, error} = await grabResponses(response_ids);
+
+ if (error !== undefined)
+ throw new Error(error);
+
+ // All responses loaded; call our Python server to execute the evaluation code across all responses:
+ const flask_response = await call_flask_backend('executepy', {
+ id: id,
+ code: code,
+ responses: responses,
+ scope: scope,
+ script_paths: script_paths,
+ }).catch(err => {
+ throw new Error(err.message);
+ });
+
+ if (!flask_response || flask_response.error !== undefined)
+ throw new Error(flask_response?.error || 'Empty response received from Flask server');
+
+ // Grab the responses and logs from the Flask result object:
+ const all_evald_responses = flask_response.responses;
+ const all_logs = flask_response.logs;
+
+ // Store the evaluated responses in a new cache json:
+ StorageCache.store(`${id}.json`, all_evald_responses);
+
+ return {responses: all_evald_responses, logs: all_logs};
+}
+
+
/**
* Returns all responses with the specified id(s).
* @param responses the ids to grab
@@ -805,7 +855,7 @@ export async function grabResponses(responses: Array): Promise {
grabbed_resps = grabbed_resps.concat(res);
}
- return {'responses': grabbed_resps};
+ return {responses: grabbed_resps};
}
/**
@@ -828,7 +878,7 @@ export async function exportCache(ids: string[]) {
export_data[key] = load_from_cache(key);
});
}
- return export_data;
+ return {files: export_data};
}
diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts
index 235549c..db806e7 100644
--- a/chainforge/react-server/src/backend/utils.ts
+++ b/chainforge/react-server/src/backend/utils.ts
@@ -11,9 +11,50 @@ import { StringTemplate } from './template';
/* LLM API SDKs */
import { Configuration as OpenAIConfig, OpenAIApi } from "openai";
import { OpenAIClient as AzureOpenAIClient, AzureKeyCredential } from "@azure/openai";
-import { AI_PROMPT, Client as AnthropicClient, HUMAN_PROMPT } from "@anthropic-ai/sdk";
-import { DiscussServiceClient, TextServiceClient } from "@google-ai/generativelanguage";
-import { GoogleAuth } from "google-auth-library";
+import { AI_PROMPT, HUMAN_PROMPT } from "@anthropic-ai/sdk";
+
+const fetch = require('node-fetch');
+
+/** Where the ChainForge Flask server is being hosted. */
+export const FLASK_BASE_URL = 'http://localhost:8000/';
+
+export async function call_flask_backend(route: string, params: Dict | string): Promise {
+ return fetch(`${FLASK_BASE_URL}app/${route}`, {
+ method: 'POST',
+ headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
+ body: JSON.stringify(params)
+ }).then(function(res) {
+ return res.json();
+ });
+}
+
+/**
+ * Equivalent to a 'fetch' call, but routes it to the backend Flask server in
+ * case we are running a local server and prefer to not deal with CORS issues making API calls client-side.
+ */
+async function route_fetch(url: string, method: string, headers: Dict, body: Dict) {
+ if (APP_IS_RUNNING_LOCALLY()) {
+ return call_flask_backend('makeFetchCall', {
+ url: url,
+ method: method,
+ headers: headers,
+ body: body,
+ }).then(res => {
+ if (!res || res.error)
+ throw new Error(res.error);
+ return res.response;
+ });
+ } else {
+ return fetch(url, {
+ method,
+ headers,
+ body: JSON.stringify(body),
+ }).then(res => res.json());
+ }
+}
+
+// import { DiscussServiceClient, TextServiceClient } from "@google-ai/generativelanguage";
+// import { GoogleAuth } from "google-auth-library";
function get_environ(key: string): string | undefined {
if (key in process_env)
@@ -35,7 +76,7 @@ let AZURE_OPENAI_ENDPOINT = get_environ("AZURE_OPENAI_ENDPOINT");
*/
export function set_api_keys(api_keys: StringDict): void {
function key_is_present(name: string): boolean {
- return name in api_keys && api_keys[name].trim().length > 0;
+ return name in api_keys && api_keys[name] && api_keys[name].trim().length > 0;
}
if (key_is_present('OpenAI'))
OPENAI_API_KEY= api_keys['OpenAI'];
@@ -228,9 +269,6 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number =
export async function call_anthropic(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> {
if (!ANTHROPIC_API_KEY)
throw Error("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local environment.");
-
- // Initialize Anthropic API client
- const client = new AnthropicClient(ANTHROPIC_API_KEY);
// Wrap the prompt in the provided template, or use the default Anthropic one
const custom_prompt_wrapper: string = params?.custom_prompt_wrapper || (HUMAN_PROMPT + " {prompt}" + AI_PROMPT);
@@ -239,6 +277,9 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
const prompt_wrapper_template = new StringTemplate(custom_prompt_wrapper);
const wrapped_prompt = prompt_wrapper_template.safe_substitute({prompt: prompt});
+ if (params?.custom_prompt_wrapper !== undefined)
+ delete params.custom_prompt_wrapper;
+
// Required non-standard params
const max_tokens_to_sample = params?.max_tokens_to_sample || 1024;
const stop_sequences = params?.stop_sequences || [HUMAN_PROMPT];
@@ -255,10 +296,19 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1,
console.log(`Calling Anthropic model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
+ // Make a REST call to Anthropic
+ const url = 'https://api.anthropic.com/v1/complete';
+ const headers = {
+ 'accept': 'application/json',
+ 'anthropic-version': '2023-06-01',
+ 'content-type': 'application/json',
+ 'x-api-key': ANTHROPIC_API_KEY,
+ };
+
// Repeat call n times, waiting for each response to come in:
let responses: Array = [];
while (responses.length < n) {
- const resp = await client.complete(query);
+ const resp = await route_fetch(url, 'POST', headers, query);
responses.push(resp);
console.log(`${model} response ${responses.length} of ${n}:\n${resp}`);
}
@@ -275,9 +325,6 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
throw Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment.");
const is_chat_model = model.toString().includes('chat');
- const client = new (is_chat_model ? DiscussServiceClient : TextServiceClient)({
- authClient: new GoogleAuth().fromAPIKey(GOOGLE_PALM_API_KEY),
- });
// Required non-standard params
const max_output_tokens = params?.max_output_tokens || 800;
@@ -317,18 +364,30 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
}
});
- console.log(`Calling Google PaLM model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
-
- // Call the correct model client
- let completion;
if (is_chat_model) {
// Chat completions
query.prompt = { messages: [{content: prompt}] };
- completion = await (client as DiscussServiceClient).generateMessage(query);
} else {
// Text completions
query.prompt = { text: prompt };
- completion = await (client as TextServiceClient).generateText(query);
+ }
+
+ console.log(`Calling Google PaLM model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`);
+
+ // Call the correct model client
+ const method = is_chat_model ? 'generateMessage' : 'generateText';
+ const url = `https://generativelanguage.googleapis.com/v1beta2/models/${model}:${method}?key=${GOOGLE_PALM_API_KEY}`;
+ const headers = {'Content-Type': 'application/json'};
+ let res = await fetch(url, {
+ method: 'POST',
+ headers,
+ body: JSON.stringify(query)
+ });
+ let completion: Dict = await res.json();
+
+ // Sometimes the REST call will give us an error; bubble this up the chain:
+ if (completion.error !== undefined) {
+ throw new Error(JSON.stringify(completion.error));
}
// Google PaLM, unlike other chat models, will output empty
@@ -336,10 +395,10 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
// API has a (relatively undocumented) 'safety_settings' parameter,
// the current chat completions API provides users no control over the blocking.
// We need to detect this and fill the response with the safety reasoning:
- if (completion[0].filters.length > 0) {
+ if (completion.filters && completion.filters.length > 0) {
// Request was blocked. Output why in the response text, repairing the candidate dict to mock up 'n' responses
const block_error_msg = `[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: ${JSON.stringify(completion.filters)}`
- completion[0].candidates = new Array(n).fill({'author': '1', 'content':block_error_msg});
+ completion.candidates = new Array(n).fill({'author': '1', 'content':block_error_msg});
}
// Weirdly, google ignores candidate_count if temperature is 0.
@@ -348,7 +407,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1
// copied_candidates = [completion_dict['candidates'][0]] * n
// completion_dict['candidates'] = copied_candidates
- return [query, completion[0]];
+ return [query, completion];
}
export async function call_dalai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> {
@@ -580,8 +639,14 @@ export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, r
}
export function APP_IS_RUNNING_LOCALLY(): boolean {
- const location = window.location;
- return location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === "";
+ try {
+ const location = window.location;
+ return location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === "";
+ } catch (e) {
+ // ReferenceError --window or location does not exist.
+ // We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server)
+ return false;
+ }
}
// def create_dir_if_not_exists(path: str) -> None:
diff --git a/chainforge/react-server/src/fetch_from_backend.js b/chainforge/react-server/src/fetch_from_backend.js
index f1f8847..640b7a3 100644
--- a/chainforge/react-server/src/fetch_from_backend.js
+++ b/chainforge/react-server/src/fetch_from_backend.js
@@ -1,8 +1,8 @@
-import { queryLLM, executejs, FLASK_BASE_URL,
+import { queryLLM, executejs, executepy, FLASK_BASE_URL,
fetchExampleFlow, fetchOpenAIEval, importCache,
exportCache, countQueries, grabResponses,
createProgressFile } from "./backend/backend";
-import { APP_IS_RUNNING_LOCALLY } from "./backend/utils";
+import { APP_IS_RUNNING_LOCALLY, call_flask_backend } from "./backend/utils";
const BACKEND_TYPES = {
FLASK: 'flask',
@@ -11,13 +11,7 @@ const BACKEND_TYPES = {
export let BACKEND_TYPE = BACKEND_TYPES.JAVASCRIPT;
function _route_to_flask_backend(route, params, rejected) {
- return fetch(`${FLASK_BASE_URL}app/${route}`, {
- method: 'POST',
- headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'},
- body: JSON.stringify(params)
- }, rejected).then(function(res) {
- return res.json();
- });
+ return call_flask_backend(route, params).catch(rejected);
}
async function _route_to_js_backend(route, params) {
@@ -32,6 +26,8 @@ async function _route_to_js_backend(route, params) {
return queryLLM(params.id, params.llm, params.n, params.prompt, params.vars, params.api_keys, params.no_cache, params.progress_listener);
case 'executejs':
return executejs(params.id, params.code, params.responses, params.scope);
+ case 'executepy':
+ return executepy(params.id, params.code, params.responses, params.scope, params.script_paths);
case 'importCache':
return importCache(params.files);
case 'exportCache':
@@ -55,24 +51,11 @@ async function _route_to_js_backend(route, params) {
export default function fetch_from_backend(route, params, rejected) {
rejected = rejected || ((err) => {throw new Error(err)});
- if (route === 'execute') { // executing Python code
- if (APP_IS_RUNNING_LOCALLY())
- return _route_to_flask_backend(route, params, rejected);
- else {
- // We can't execute Python if we're not running the local Flask server. Error out:
- return new Promise((resolve, reject) => {
- const msg = "Cannot run 'execute' route to evaluate Python code: ChainForge does not appear to be running on localhost.";
- rejected(new Error(msg));
- reject(msg);
- });
- }
- }
-
switch (BACKEND_TYPE) {
case BACKEND_TYPES.FLASK: // Fetch from Flask (python) backend
return _route_to_flask_backend(route, params, rejected);
case BACKEND_TYPES.JAVASCRIPT: // Fetch from client-side Javascript 'backend'
- return _route_to_js_backend(route, params);
+ return _route_to_js_backend(route, params).catch(rejected);
default:
console.error('Unsupported backend type:', BACKEND_TYPE);
break;