From 1564c83ae9d3f326ca4477bdda1a20eb692ac39f Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Wed, 28 Jun 2023 18:05:09 -0400 Subject: [PATCH] Route Anthropic calls through Flask when running locally --- chainforge/flask_app.py | 113 +++++++++++++++++- chainforge/react-server/craco.config.js | 10 +- chainforge/react-server/package-lock.json | 1 + chainforge/react-server/package.json | 1 + chainforge/react-server/src/App.js | 8 +- chainforge/react-server/src/EvaluatorNode.js | 17 +-- chainforge/react-server/src/VisNode.js | 1 - .../src/backend/__test__/template.test.ts | 24 ++-- .../src/backend/__test__/utils.test.ts | 14 +-- .../react-server/src/backend/backend.ts | 72 +++++++++-- chainforge/react-server/src/backend/utils.ts | 109 +++++++++++++---- .../react-server/src/fetch_from_backend.js | 29 +---- 12 files changed, 304 insertions(+), 95 deletions(-) diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py index ec337c7..1bddb50 100644 --- a/chainforge/flask_app.py +++ b/chainforge/flask_app.py @@ -8,7 +8,7 @@ from flask_cors import CORS from chainforge.promptengine.query import PromptLLM, PromptLLMDummy, LLMResponseException from chainforge.promptengine.template import PromptTemplate, PromptPermutationGenerator from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys - +import requests as py_requests """ ================= SETUP AND GLOBALS @@ -679,6 +679,82 @@ async def queryLLM(): ret.headers.add('Access-Control-Allow-Origin', '*') return ret +@app.route('/app/executepy', methods=['POST']) +def executepy(): + """ + Executes a Python function sent from JavaScript, + over all the `StandardizedLLMResponse` objects passed in from the front-end. + + POST'd data should be in the form: + { + 'id': # a unique ID to refer to this information. Used when cache'ing responses. + 'code': str, # the body of the lambda function to evaluate, in form: lambda responses: + 'responses': List[StandardizedLLMResponse] # the responses to run on. + 'scope': 'response' | 'batch' # the scope of responses to run on --a single response, or all across each batch. + # If batch, evaluator has access to 'responses'. Only matters if n > 1 for each prompt. + 'script_paths': unspecified | List[str] # the paths to scripts to be added to the path before the lambda function is evaluated + } + + NOTE: This should only be run on your server on code you trust. + There is no sandboxing; no safety. We assume you are the creator of the code. + """ + data = request.get_json() + + # Check that all required info is here: + if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}): + return jsonify({'error': 'POST data is improper format.'}) + if not isinstance(data['id'], str) or len(data['id']) == 0: + return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'}) + if data['scope'] not in ('response', 'batch'): + return jsonify({'error': "POST data scope is unknown. Must be either 'response' or 'batch'."}) + + # Check format of responses: + responses = data['responses'] + if (isinstance(responses, str) or not isinstance(responses, list)) or (len(responses) > 0 and any([not isinstance(r, dict) for r in responses])): + return jsonify({'error': 'POST data responses is improper format.'}) + + # add the path to any scripts to the path: + try: + if 'script_paths' in data: + for script_path in data['script_paths']: + # get the folder of the script_path: + script_folder = os.path.dirname(script_path) + # check that the script_folder is valid, and it contains __init__.py + if not os.path.exists(script_folder): + print(script_folder, 'is not a valid script path.') + continue + + # add it to the path: + sys.path.append(script_folder) + print(f'added {script_folder} to sys.path') + except Exception as e: + return jsonify({'error': f'Could not add script path to sys.path. Error message:\n{str(e)}'}) + + # Create the evaluator function + # DANGER DANGER! + try: + exec(data['code'], globals()) + + # Double-check that there is an 'evaluate' method in our namespace. + # This will throw a NameError if not: + evaluate # noqa + except Exception as e: + return jsonify({'error': f'Could not compile evaluator code. Error message:\n{str(e)}'}) + + evald_responses = [] + logs = [] + try: + HIJACK_PYTHON_PRINT() + evald_responses = run_over_responses(evaluate, responses, scope=data['scope']) # noqa + logs = REVERT_PYTHON_PRINT() + except Exception as e: + logs = REVERT_PYTHON_PRINT() + return jsonify({'error': f'Error encountered while trying to run "evaluate" method:\n{str(e)}', 'logs': logs}) + + ret = jsonify({'responses': evald_responses, 'logs': logs}) + ret.headers.add('Access-Control-Allow-Origin', '*') + return ret + @app.route('/app/execute', methods=['POST']) def execute(): """ @@ -1047,7 +1123,7 @@ def fetchEnvironAPIKeys(): keymap = { 'OPENAI_API_KEY': 'OpenAI', 'ANTHROPIC_API_KEY': 'Anthropic', - 'GOOGLE_PALM_API_KEY': 'Google', + 'PALM_API_KEY': 'Google', 'AZURE_OPENAI_KEY': 'Azure_OpenAI', 'AZURE_OPENAI_ENDPOINT': 'Azure_OpenAI_Endpoint' } @@ -1056,6 +1132,39 @@ def fetchEnvironAPIKeys(): ret.headers.add('Access-Control-Allow-Origin', '*') return ret +@app.route('/app/makeFetchCall', methods=['POST']) +def makeFetchCall(): + """ + Use in place of JavaScript's 'fetch' (with POST method), in cases where + cross-origin policy blocks responses from client-side fetches. + + POST'd data should be in form: + { + url: # the url to fetch from + headers: # a JSON object of the headers + body: # the request payload, as JSON + } + """ + # Verify post'd data + data = request.get_json() + if not set(data.keys()).issuperset({'url', 'headers', 'body'}): + return jsonify({'error': 'POST data is improper format.'}) + + url = data['url'] + headers = data['headers'] + body = data['body'] + + print(body) + + response = py_requests.post(url, headers=headers, json=body) + + if response.status_code == 200: + ret = jsonify({'response': response.json()}) + ret.headers.add('Access-Control-Allow-Origin', '*') + return ret + else: + return jsonify({'error': 'API request to Anthropic failed'}) + def run_server(host="", port=8000, cmd_args=None): if cmd_args is not None and cmd_args.dummy_responses: global PromptLLM diff --git a/chainforge/react-server/craco.config.js b/chainforge/react-server/craco.config.js index 67fbca8..9e23713 100644 --- a/chainforge/react-server/craco.config.js +++ b/chainforge/react-server/craco.config.js @@ -9,7 +9,7 @@ module.exports = { resolve: { fallback: { "process": require.resolve("process/browser"), - "buffer": require.resolve("buffer/"), + "buffer": require.resolve("buffer"), "https": require.resolve("https-browserify"), "querystring": require.resolve("querystring-es3"), "url": require.resolve("url/"), @@ -41,7 +41,13 @@ module.exports = { add: [ new webpack.ProvidePlugin({ process: 'process/browser.js', - }) + }), + + // Work around for Buffer is undefined: + // https://github.com/webpack/changelog-v5/issues/10 + new webpack.ProvidePlugin({ + Buffer: ['buffer', 'Buffer'], + }), ] }, diff --git a/chainforge/react-server/package-lock.json b/chainforge/react-server/package-lock.json index 0419a79..f70b7b2 100644 --- a/chainforge/react-server/package-lock.json +++ b/chainforge/react-server/package-lock.json @@ -61,6 +61,7 @@ "mathjs": "^11.8.2", "net": "^1.0.2", "net-browserify": "^0.2.4", + "node-fetch": "^2.6.11", "openai": "^3.3.0", "os-browserify": "^0.3.0", "papaparse": "^5.4.1", diff --git a/chainforge/react-server/package.json b/chainforge/react-server/package.json index 9a71ced..65f7dc6 100644 --- a/chainforge/react-server/package.json +++ b/chainforge/react-server/package.json @@ -56,6 +56,7 @@ "mathjs": "^11.8.2", "net": "^1.0.2", "net-browserify": "^0.2.4", + "node-fetch": "^2.6.11", "openai": "^3.3.0", "os-browserify": "^0.3.0", "papaparse": "^5.4.1", diff --git a/chainforge/react-server/src/App.js b/chainforge/react-server/src/App.js index fcd5453..04a852f 100644 --- a/chainforge/react-server/src/App.js +++ b/chainforge/react-server/src/App.js @@ -263,11 +263,11 @@ const App = () => { // Save! downloadJSON(flow_and_cache, `flow-${Date.now()}.cforge`); - }); - }, [rfInstance, nodes]); + }).catch(handleError); + }, [rfInstance, nodes, handleError]); // Import data to the cache stored on the local filesystem (in backend) - const importCache = (cache_data) => { + const importCache = useCallback((cache_data) => { return fetch_from_backend('importCache', { 'files': cache_data, }, handleError).then(function(json) { @@ -277,7 +277,7 @@ const App = () => { throw new Error('Error importing cache data:' + json.error); // Done! }, handleError).catch(handleError); - }; + }, [handleError]); const importFlowFromJSON = useCallback((flowJSON) => { // Detect if there's no cache data diff --git a/chainforge/react-server/src/EvaluatorNode.js b/chainforge/react-server/src/EvaluatorNode.js index 6d714db..ff47301 100644 --- a/chainforge/react-server/src/EvaluatorNode.js +++ b/chainforge/react-server/src/EvaluatorNode.js @@ -109,21 +109,24 @@ const EvaluatorNode = ({ data, id }) => { alertModal.current.trigger(err_msg); }; - // Get all the script nodes, and get all the folder paths - const script_nodes = nodes.filter(n => n.type === 'script'); - const script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat(); + // Get all the Python script nodes, and get all the folder paths + // NOTE: Python only! + let script_paths = []; + if (progLang === 'python') { + const script_nodes = nodes.filter(n => n.type === 'script'); + script_paths = script_nodes.map(n => Object.values(n.data.scriptFiles).filter(f => f !== '')).flat(); + } // Run evaluator in backend const codeTextOnRun = codeText + ''; - const execute_route = (progLang === 'python') ? 'execute' : 'executejs'; + const execute_route = (progLang === 'python') ? 'executepy' : 'executejs'; fetch_from_backend(execute_route, { id: id, code: codeTextOnRun, responses: input_node_ids, scope: mapScope, - reduce_vars: [], script_paths: script_paths, - }, rejected).then(function(json) { + }).then(function(json) { // Store any Python print output if (json?.logs) { let logs = json.logs; @@ -154,7 +157,7 @@ const EvaluatorNode = ({ data, id }) => { setCodeTextOnLastRun(codeTextOnRun); setLastRunSuccess(true); setStatus('ready'); - }, rejected); + }).catch((err) => rejected(err.message)); }; const handleOnMapScopeSelect = (event) => { diff --git a/chainforge/react-server/src/VisNode.js b/chainforge/react-server/src/VisNode.js index 5d1e92d..51c7614 100644 --- a/chainforge/react-server/src/VisNode.js +++ b/chainforge/react-server/src/VisNode.js @@ -577,7 +577,6 @@ const VisNode = ({ data, id }) => { plot_simple_boxplot(get_llm, 'llm'); } else if (varnames.length === 1) { - console.log(varnames); // 1 var; numeric eval if (llm_names.length === 1) { if (typeof_eval_res === 'Boolean') diff --git a/chainforge/react-server/src/backend/__test__/template.test.ts b/chainforge/react-server/src/backend/__test__/template.test.ts index 7479493..5cdbd08 100644 --- a/chainforge/react-server/src/backend/__test__/template.test.ts +++ b/chainforge/react-server/src/backend/__test__/template.test.ts @@ -25,18 +25,18 @@ test('string template escaped group', () => { test('single template', () => { let prompt_gen = new PromptPermutationGenerator('What is the {timeframe} when {person} was born?'); - let vars: {[key: string]: any} = { - 'timeframe': ['year', 'decade', 'century'], - 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding'] - }; - let num_prompts = 0; - for (const prompt of prompt_gen.generate(vars)) { - // console.log(prompt.toString()); - expect(prompt.fill_history).toHaveProperty('timeframe'); - expect(prompt.fill_history).toHaveProperty('person'); - num_prompts += 1; - } - expect(num_prompts).toBe(9); + let vars: {[key: string]: any} = { + 'timeframe': ['year', 'decade', 'century'], + 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding'] + }; + let num_prompts = 0; + for (const prompt of prompt_gen.generate(vars)) { + // console.log(prompt.toString()); + expect(prompt.fill_history).toHaveProperty('timeframe'); + expect(prompt.fill_history).toHaveProperty('person'); + num_prompts += 1; + } + expect(num_prompts).toBe(9); }); test('nested templates', () => { diff --git a/chainforge/react-server/src/backend/__test__/utils.test.ts b/chainforge/react-server/src/backend/__test__/utils.test.ts index e15b70b..5f95dcb 100644 --- a/chainforge/react-server/src/backend/__test__/utils.test.ts +++ b/chainforge/react-server/src/backend/__test__/utils.test.ts @@ -40,9 +40,9 @@ test('merge response objects', () => { expect(merge_response_objs(undefined, B)).toBe(B); }) -test('UNCOMMENT BELOW API CALL TESTS WHEN READY', () => { - // NOTE: API CALL TESTS ASSUME YOUR ENVIRONMENT VARIABLE IS SET! -}); +// test('UNCOMMENT BELOW API CALL TESTS WHEN READY', () => { +// // NOTE: API CALL TESTS ASSUME YOUR ENVIRONMENT VARIABLE IS SET! +// }); test('openai chat completions', async () => { // Call ChatGPT with a basic question, and n=2 @@ -106,11 +106,3 @@ test('google palm2 models', async () => { expect(typeof resps[0]).toBe('string'); console.log(JSON.stringify(resps)); }, 40000); - -// test('call_', async () => { -// // Call Anthropic's Claude with a basic question -// const [query, response] = await call_anthropic("Who invented modern playing cards? Keep your answer brief.", LLM.Claude_v1, 1, 1.0); -// console.log(response); -// expect(response).toHaveLength(1); -// expect(query).toHaveProperty('max_tokens_to_sample'); -// }, 20000); \ No newline at end of file diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts index 9a58e4c..dc6557b 100644 --- a/chainforge/react-server/src/backend/backend.ts +++ b/chainforge/react-server/src/backend/backend.ts @@ -10,10 +10,9 @@ // from chainforge.promptengine.utils import LLM, is_valid_filepath, get_files_at_dir, create_dir_if_not_exists, set_api_keys import { mean as __mean, std as __std, median as __median } from "mathjs"; - import { Dict, LLMResponseError, LLMResponseObject, StandardizedLLMResponse } from "./typing"; import { LLM } from "./models"; -import { APP_IS_RUNNING_LOCALLY, getEnumName, set_api_keys } from "./utils"; +import { APP_IS_RUNNING_LOCALLY, getEnumName, set_api_keys, FLASK_BASE_URL, call_flask_backend } from "./utils"; import StorageCache from "./cache"; import { PromptPipeline } from "./query"; import { PromptPermutationGenerator, PromptTemplate } from "./template"; @@ -23,9 +22,6 @@ import { PromptPermutationGenerator, PromptTemplate } from "./template"; // ================= // """ -/** Where the ChainForge Flask server is being hosted. */ -export const FLASK_BASE_URL = 'http://localhost:8000/'; - let LLM_NAME_MAP = {}; Object.entries(LLM).forEach(([key, value]) => { LLM_NAME_MAP[value] = key; @@ -704,10 +700,6 @@ export async function executejs(id: string, response_ids = [ response_ids ]; response_ids = response_ids as Array; - console.log('executing js'); - - // const iframe = document.createElement('iframe'); - // Instantiate the evaluator function by eval'ing the passed code // DANGER DANGER!! let iframe: HTMLElement | undefined; @@ -782,6 +774,64 @@ export async function executejs(id: string, return {responses: all_evald_responses, logs: all_logs}; } +/** + * Executes a Python 'evaluate' function over all cache'd responses with given id's. + * Requires user to be running on localhost, with Flask access. + * + * > **NOTE**: This should only be run on code you trust. + * There is no sandboxing; no safety. We assume you are the creator of the code. + * + * @param id a unique ID to refer to this information. Used when cache'ing evaluation results. + * @param code the code to evaluate. Must include an 'evaluate()' function that takes a 'response' of type ResponseInfo. Alternatively, can be the evaluate function itself. + * @param response_ids the cache'd response to run on, which must be a unique ID or list of unique IDs of cache'd data + * @param scope the scope of responses to run on --a single response, or all across each batch. (If batch, evaluate() func has access to 'responses'.) + */ +export async function executepy(id: string, + code: string | ((rinfo: ResponseInfo) => any), + response_ids: string | string[], + scope: 'response' | 'batch', + script_paths?: string[]): Promise { + if (!APP_IS_RUNNING_LOCALLY()) { + // We can't execute Python if we're not running the local Flask server. Error out: + throw new Error("Cannot evaluate Python code: ChainForge does not appear to be running on localhost.") + } + + // Check format of response_ids + if (!Array.isArray(response_ids)) + response_ids = [ response_ids ]; + response_ids = response_ids as Array; + + // Load cache'd responses for all response_ids: + const {responses, error} = await grabResponses(response_ids); + + if (error !== undefined) + throw new Error(error); + + // All responses loaded; call our Python server to execute the evaluation code across all responses: + const flask_response = await call_flask_backend('executepy', { + id: id, + code: code, + responses: responses, + scope: scope, + script_paths: script_paths, + }).catch(err => { + throw new Error(err.message); + }); + + if (!flask_response || flask_response.error !== undefined) + throw new Error(flask_response?.error || 'Empty response received from Flask server'); + + // Grab the responses and logs from the Flask result object: + const all_evald_responses = flask_response.responses; + const all_logs = flask_response.logs; + + // Store the evaluated responses in a new cache json: + StorageCache.store(`${id}.json`, all_evald_responses); + + return {responses: all_evald_responses, logs: all_logs}; +} + + /** * Returns all responses with the specified id(s). * @param responses the ids to grab @@ -805,7 +855,7 @@ export async function grabResponses(responses: Array): Promise { grabbed_resps = grabbed_resps.concat(res); } - return {'responses': grabbed_resps}; + return {responses: grabbed_resps}; } /** @@ -828,7 +878,7 @@ export async function exportCache(ids: string[]) { export_data[key] = load_from_cache(key); }); } - return export_data; + return {files: export_data}; } diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index 235549c..db806e7 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -11,9 +11,50 @@ import { StringTemplate } from './template'; /* LLM API SDKs */ import { Configuration as OpenAIConfig, OpenAIApi } from "openai"; import { OpenAIClient as AzureOpenAIClient, AzureKeyCredential } from "@azure/openai"; -import { AI_PROMPT, Client as AnthropicClient, HUMAN_PROMPT } from "@anthropic-ai/sdk"; -import { DiscussServiceClient, TextServiceClient } from "@google-ai/generativelanguage"; -import { GoogleAuth } from "google-auth-library"; +import { AI_PROMPT, HUMAN_PROMPT } from "@anthropic-ai/sdk"; + +const fetch = require('node-fetch'); + +/** Where the ChainForge Flask server is being hosted. */ +export const FLASK_BASE_URL = 'http://localhost:8000/'; + +export async function call_flask_backend(route: string, params: Dict | string): Promise { + return fetch(`${FLASK_BASE_URL}app/${route}`, { + method: 'POST', + headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, + body: JSON.stringify(params) + }).then(function(res) { + return res.json(); + }); +} + +/** + * Equivalent to a 'fetch' call, but routes it to the backend Flask server in + * case we are running a local server and prefer to not deal with CORS issues making API calls client-side. + */ +async function route_fetch(url: string, method: string, headers: Dict, body: Dict) { + if (APP_IS_RUNNING_LOCALLY()) { + return call_flask_backend('makeFetchCall', { + url: url, + method: method, + headers: headers, + body: body, + }).then(res => { + if (!res || res.error) + throw new Error(res.error); + return res.response; + }); + } else { + return fetch(url, { + method, + headers, + body: JSON.stringify(body), + }).then(res => res.json()); + } +} + +// import { DiscussServiceClient, TextServiceClient } from "@google-ai/generativelanguage"; +// import { GoogleAuth } from "google-auth-library"; function get_environ(key: string): string | undefined { if (key in process_env) @@ -35,7 +76,7 @@ let AZURE_OPENAI_ENDPOINT = get_environ("AZURE_OPENAI_ENDPOINT"); */ export function set_api_keys(api_keys: StringDict): void { function key_is_present(name: string): boolean { - return name in api_keys && api_keys[name].trim().length > 0; + return name in api_keys && api_keys[name] && api_keys[name].trim().length > 0; } if (key_is_present('OpenAI')) OPENAI_API_KEY= api_keys['OpenAI']; @@ -228,9 +269,6 @@ export async function call_azure_openai(prompt: string, model: LLM, n: number = export async function call_anthropic(prompt: string, model: LLM, n: number = 1, temperature: number = 1.0, params?: Dict): Promise<[Dict, Dict]> { if (!ANTHROPIC_API_KEY) throw Error("Could not find an API key for Anthropic models. Double-check that your API key is set in Settings or in your local environment."); - - // Initialize Anthropic API client - const client = new AnthropicClient(ANTHROPIC_API_KEY); // Wrap the prompt in the provided template, or use the default Anthropic one const custom_prompt_wrapper: string = params?.custom_prompt_wrapper || (HUMAN_PROMPT + " {prompt}" + AI_PROMPT); @@ -239,6 +277,9 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1, const prompt_wrapper_template = new StringTemplate(custom_prompt_wrapper); const wrapped_prompt = prompt_wrapper_template.safe_substitute({prompt: prompt}); + if (params?.custom_prompt_wrapper !== undefined) + delete params.custom_prompt_wrapper; + // Required non-standard params const max_tokens_to_sample = params?.max_tokens_to_sample || 1024; const stop_sequences = params?.stop_sequences || [HUMAN_PROMPT]; @@ -255,10 +296,19 @@ export async function call_anthropic(prompt: string, model: LLM, n: number = 1, console.log(`Calling Anthropic model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`); + // Make a REST call to Anthropic + const url = 'https://api.anthropic.com/v1/complete'; + const headers = { + 'accept': 'application/json', + 'anthropic-version': '2023-06-01', + 'content-type': 'application/json', + 'x-api-key': ANTHROPIC_API_KEY, + }; + // Repeat call n times, waiting for each response to come in: let responses: Array = []; while (responses.length < n) { - const resp = await client.complete(query); + const resp = await route_fetch(url, 'POST', headers, query); responses.push(resp); console.log(`${model} response ${responses.length} of ${n}:\n${resp}`); } @@ -275,9 +325,6 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1 throw Error("Could not find an API key for Google PaLM models. Double-check that your API key is set in Settings or in your local environment."); const is_chat_model = model.toString().includes('chat'); - const client = new (is_chat_model ? DiscussServiceClient : TextServiceClient)({ - authClient: new GoogleAuth().fromAPIKey(GOOGLE_PALM_API_KEY), - }); // Required non-standard params const max_output_tokens = params?.max_output_tokens || 800; @@ -317,18 +364,30 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1 } }); - console.log(`Calling Google PaLM model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`); - - // Call the correct model client - let completion; if (is_chat_model) { // Chat completions query.prompt = { messages: [{content: prompt}] }; - completion = await (client as DiscussServiceClient).generateMessage(query); } else { // Text completions query.prompt = { text: prompt }; - completion = await (client as TextServiceClient).generateText(query); + } + + console.log(`Calling Google PaLM model '${model}' with prompt '${prompt}' (n=${n}). Please be patient...`); + + // Call the correct model client + const method = is_chat_model ? 'generateMessage' : 'generateText'; + const url = `https://generativelanguage.googleapis.com/v1beta2/models/${model}:${method}?key=${GOOGLE_PALM_API_KEY}`; + const headers = {'Content-Type': 'application/json'}; + let res = await fetch(url, { + method: 'POST', + headers, + body: JSON.stringify(query) + }); + let completion: Dict = await res.json(); + + // Sometimes the REST call will give us an error; bubble this up the chain: + if (completion.error !== undefined) { + throw new Error(JSON.stringify(completion.error)); } // Google PaLM, unlike other chat models, will output empty @@ -336,10 +395,10 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1 // API has a (relatively undocumented) 'safety_settings' parameter, // the current chat completions API provides users no control over the blocking. // We need to detect this and fill the response with the safety reasoning: - if (completion[0].filters.length > 0) { + if (completion.filters && completion.filters.length > 0) { // Request was blocked. Output why in the response text, repairing the candidate dict to mock up 'n' responses const block_error_msg = `[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: ${JSON.stringify(completion.filters)}` - completion[0].candidates = new Array(n).fill({'author': '1', 'content':block_error_msg}); + completion.candidates = new Array(n).fill({'author': '1', 'content':block_error_msg}); } // Weirdly, google ignores candidate_count if temperature is 0. @@ -348,7 +407,7 @@ export async function call_google_palm(prompt: string, model: LLM, n: number = 1 // copied_candidates = [completion_dict['candidates'][0]] * n // completion_dict['candidates'] = copied_candidates - return [query, completion[0]]; + return [query, completion]; } export async function call_dalai(prompt: string, model: LLM, n: number = 1, temperature: number = 0.7, params?: Dict): Promise<[Dict, Dict]> { @@ -580,8 +639,14 @@ export function merge_response_objs(resp_obj_A: LLMResponseObject | undefined, r } export function APP_IS_RUNNING_LOCALLY(): boolean { - const location = window.location; - return location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === ""; + try { + const location = window.location; + return location.hostname === "localhost" || location.hostname === "127.0.0.1" || location.hostname === ""; + } catch (e) { + // ReferenceError --window or location does not exist. + // We must not be running client-side in a browser, in this case (e.g., we are running a Node.js server) + return false; + } } // def create_dir_if_not_exists(path: str) -> None: diff --git a/chainforge/react-server/src/fetch_from_backend.js b/chainforge/react-server/src/fetch_from_backend.js index f1f8847..640b7a3 100644 --- a/chainforge/react-server/src/fetch_from_backend.js +++ b/chainforge/react-server/src/fetch_from_backend.js @@ -1,8 +1,8 @@ -import { queryLLM, executejs, FLASK_BASE_URL, +import { queryLLM, executejs, executepy, FLASK_BASE_URL, fetchExampleFlow, fetchOpenAIEval, importCache, exportCache, countQueries, grabResponses, createProgressFile } from "./backend/backend"; -import { APP_IS_RUNNING_LOCALLY } from "./backend/utils"; +import { APP_IS_RUNNING_LOCALLY, call_flask_backend } from "./backend/utils"; const BACKEND_TYPES = { FLASK: 'flask', @@ -11,13 +11,7 @@ const BACKEND_TYPES = { export let BACKEND_TYPE = BACKEND_TYPES.JAVASCRIPT; function _route_to_flask_backend(route, params, rejected) { - return fetch(`${FLASK_BASE_URL}app/${route}`, { - method: 'POST', - headers: {'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*'}, - body: JSON.stringify(params) - }, rejected).then(function(res) { - return res.json(); - }); + return call_flask_backend(route, params).catch(rejected); } async function _route_to_js_backend(route, params) { @@ -32,6 +26,8 @@ async function _route_to_js_backend(route, params) { return queryLLM(params.id, params.llm, params.n, params.prompt, params.vars, params.api_keys, params.no_cache, params.progress_listener); case 'executejs': return executejs(params.id, params.code, params.responses, params.scope); + case 'executepy': + return executepy(params.id, params.code, params.responses, params.scope, params.script_paths); case 'importCache': return importCache(params.files); case 'exportCache': @@ -55,24 +51,11 @@ async function _route_to_js_backend(route, params) { export default function fetch_from_backend(route, params, rejected) { rejected = rejected || ((err) => {throw new Error(err)}); - if (route === 'execute') { // executing Python code - if (APP_IS_RUNNING_LOCALLY()) - return _route_to_flask_backend(route, params, rejected); - else { - // We can't execute Python if we're not running the local Flask server. Error out: - return new Promise((resolve, reject) => { - const msg = "Cannot run 'execute' route to evaluate Python code: ChainForge does not appear to be running on localhost."; - rejected(new Error(msg)); - reject(msg); - }); - } - } - switch (BACKEND_TYPE) { case BACKEND_TYPES.FLASK: // Fetch from Flask (python) backend return _route_to_flask_backend(route, params, rejected); case BACKEND_TYPES.JAVASCRIPT: // Fetch from client-side Javascript 'backend' - return _route_to_js_backend(route, params); + return _route_to_js_backend(route, params).catch(rejected); default: console.error('Unsupported backend type:', BACKEND_TYPE); break;