diff --git a/chain-forge/src/AlertModal.js b/chain-forge/src/AlertModal.js
index 53fa4d0..1d3bd53 100644
--- a/chain-forge/src/AlertModal.js
+++ b/chain-forge/src/AlertModal.js
@@ -10,6 +10,7 @@ const AlertModal = forwardRef((props, ref) => {
// This gives the parent access to triggering the modal alert
const trigger = (msg) => {
+ if (!msg) msg = "Unknown error.";
console.error(msg);
setAlertMsg(msg);
open();
diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js
index 433a0fc..6410e8f 100644
--- a/chain-forge/src/InspectorNode.js
+++ b/chain-forge/src/InspectorNode.js
@@ -1,4 +1,4 @@
-import React, { useState } from 'react';
+import React, { useState, useEffect } from 'react';
import { Handle } from 'react-flow-renderer';
import useStore from './store';
import NodeLabel from './NodeLabelComponent'
@@ -21,6 +21,7 @@ const InspectorNode = ({ data, id }) => {
const [varSelects, setVarSelects] = useState([]);
const [pastInputs, setPastInputs] = useState([]);
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
+ const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const handleVarValueSelect = () => {
}
@@ -115,6 +116,14 @@ const InspectorNode = ({ data, id }) => {
}
}
+ useEffect(() => {
+ if (data.refresh && data.refresh === true) {
+ // Recreate the visualization:
+ setDataPropsForNode(id, { refresh: false });
+ handleOnConnect();
+ }
+}, [data, id, handleOnConnect, setDataPropsForNode]);
+
return (
{
const edges = useStore((state) => state.edges);
const output = useStore((state) => state.output);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
+ const outputEdgesForNode = useStore((state) => state.outputEdgesForNode);
const getNode = useStore((state) => state.getNode);
const [templateVars, setTemplateVars] = useState(data.vars || []);
@@ -392,6 +393,15 @@ const PromptNode = ({ data, id }) => {
);
}));
+ // Ping any inspect nodes attached to this node to refresh their contents:
+ const output_nodes = outputEdgesForNode(id).map(e => e.target);
+ output_nodes.forEach(n => {
+ const node = getNode(n);
+ if (node && node.type === 'inspect') {
+ setDataPropsForNode(node.id, { refresh: true });
+ }
+ });
+
// Log responses for debugging:
console.log(json.responses);
} else {
diff --git a/python-backend/flask_app.py b/python-backend/flask_app.py
index e63fea0..619ffb3 100644
--- a/python-backend/flask_app.py
+++ b/python-backend/flask_app.py
@@ -1,4 +1,4 @@
-import json, os, asyncio, sys, argparse, threading
+import json, os, asyncio, sys, argparse, threading, traceback
from dataclasses import dataclass
from statistics import mean, median, stdev
from flask import Flask, request, jsonify, render_template, send_from_directory
@@ -21,12 +21,9 @@ cors = CORS(app, resources={r"/*": {"origins": "*"}})
def index():
return render_template("index.html")
-LLM_NAME_MAP = {
- 'gpt3.5': LLM.ChatGPT,
- 'alpaca.7B': LLM.Alpaca7B,
- 'gpt4': LLM.GPT4,
-}
-LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
+LLM_NAME_MAP = {}
+for model in LLM:
+ LLM_NAME_MAP[model.value] = model
@dataclass
class ResponseInfo:
@@ -40,7 +37,7 @@ class ResponseInfo:
return self.text
def to_standard_format(r: dict) -> list:
- llm = LLM_NAME_MAP_INVERSE[r['llm']]
+ llm = r['llm']
resp_obj = {
'vars': r['info'],
'llm': llm,
@@ -52,9 +49,6 @@ def to_standard_format(r: dict) -> list:
resp_obj['eval_res'] = r['eval_res']
return resp_obj
-def get_llm_of_response(response: dict) -> LLM:
- return LLM_NAME_MAP[response['llm']]
-
def get_filenames_with_id(filenames: list, id: str) -> list:
return [
c for c in filenames
@@ -82,7 +76,7 @@ def run_over_responses(eval_func, responses: dict, scope: str) -> list:
text=r,
prompt=prompt,
var=resp_obj['info'],
- llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
+ llm=resp_obj['llm'])
) for r in res
]
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
@@ -281,7 +275,8 @@ async def queryLLM():
with open(tempfilepath, 'w') as f:
json.dump(cur_data, f)
except Exception as e:
- print('error generating responses:', e)
+ print(f'error generating responses for {llm}:', e)
+ print(traceback.format_exc())
raise e
return {'llm': llm, 'responses': resps}
diff --git a/python-backend/promptengine/models.py b/python-backend/promptengine/models.py
new file mode 100644
index 0000000..c2ef9a3
--- /dev/null
+++ b/python-backend/promptengine/models.py
@@ -0,0 +1,39 @@
+"""
+ A list of all model APIs natively supported by ChainForge.
+"""
+from enum import Enum
+
+class LLM(str, Enum):
+ """ OpenAI Chat """
+ ChatGPT = "gpt-3.5-turbo"
+ GPT4 = "gpt-4"
+
+ """ Dalai-served models """
+ Alpaca7B = "alpaca.7B"
+
+ """ Anthropic """
+ # Our largest model, ideal for a wide range of more complex tasks. Using this model name
+ # will automatically switch you to newer versions of claude-v1 as they are released.
+ Claude_v1 = "claude-v1"
+
+ # An earlier version of claude-v1
+ Claude_v1_0 = "claude-v1.0"
+
+ # An improved version of claude-v1. It is slightly improved at general helpfulness,
+ # instruction following, coding, and other tasks. It is also considerably better with
+ # non-English languages. This model also has the ability to role play (in harmless ways)
+ # more consistently, and it defaults to writing somewhat longer and more thorough responses.
+ Claude_v1_2 = "claude-v1.2"
+
+ # A significantly improved version of claude-v1. Compared to claude-v1.2, it's more robust
+ # against red-team inputs, better at precise instruction-following, better at code, and better
+ # and non-English dialogue and writing.
+ Claude_v1_3 = "claude-v1.3"
+
+ # A smaller model with far lower latency, sampling at roughly 40 words/sec! Its output quality
+ # is somewhat lower than claude-v1 models, particularly for complex tasks. However, it is much
+ # less expensive and blazing fast. We believe that this model provides more than adequate performance
+ # on a range of tasks including text classification, summarization, and lightweight chat applications,
+ # as well as search result summarization. Using this model name will automatically switch you to newer
+ # versions of claude-instant-v1 as they are released.
+ Claude_v1_instant = "claude-instant-v1"
\ No newline at end of file
diff --git a/python-backend/promptengine/query.py b/python-backend/promptengine/query.py
index d0cfdcc..dde875c 100644
--- a/python-backend/promptengine/query.py
+++ b/python-backend/promptengine/query.py
@@ -1,7 +1,7 @@
from abc import abstractmethod
-from typing import List, Dict, Tuple, Iterator
+from typing import List, Dict, Tuple, Iterator, Union
import json, os, asyncio, random, string
-from promptengine.utils import LLM, call_chatgpt, call_dalai, is_valid_filepath, is_valid_json
+from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json
from promptengine.template import PromptTemplate, PromptPermutationGenerator
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
@@ -65,7 +65,7 @@ class PromptPipeline:
"prompt": prompt_str,
"query": responses[prompt_str]["query"],
"response": responses[prompt_str]["response"],
- "llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.name,
+ "llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.value,
"info": responses[prompt_str]["info"],
}
continue
@@ -86,7 +86,7 @@ class PromptPipeline:
responses[str(prompt)] = {
"query": query,
"response": response,
- "llm": llm.name,
+ "llm": llm.value,
"info": info,
}
self._cache_responses(responses)
@@ -96,7 +96,7 @@ class PromptPipeline:
"prompt":str(prompt),
"query":query,
"response":response,
- "llm": llm.name,
+ "llm": llm.value,
"info": info,
}
@@ -114,7 +114,7 @@ class PromptPipeline:
responses[str(prompt)] = {
"query": query,
"response": response,
- "llm": llm.name,
+ "llm": llm.value,
"info": info,
}
self._cache_responses(responses)
@@ -124,7 +124,7 @@ class PromptPipeline:
"prompt":str(prompt),
"query":query,
"response":response,
- "llm": llm.name,
+ "llm": llm.value,
"info": info,
}
@@ -147,11 +147,13 @@ class PromptPipeline:
def clear_cached_responses(self) -> None:
self._cache_responses({})
- async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
+ async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Union[List, Dict]]:
if llm is LLM.ChatGPT or llm is LLM.GPT4:
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
elif llm is LLM.Alpaca7B:
- query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
+ query, response = await call_dalai(model=llm, port=4000, prompt=str(prompt), n=n, temperature=temperature)
+ elif llm.value[:6] == 'claude':
+ query, response = await call_anthropic(prompt=str(prompt), model=llm, n=n, temperature=temperature)
else:
raise Exception(f"Language model {llm} is not supported.")
return prompt, query, response
diff --git a/python-backend/promptengine/utils.py b/python-backend/promptengine/utils.py
index d89b51e..31d2c45 100644
--- a/python-backend/promptengine/utils.py
+++ b/python-backend/promptengine/utils.py
@@ -1,28 +1,22 @@
-from typing import Dict, Tuple, List, Union
-from enum import Enum
-import openai
+from typing import Dict, Tuple, List, Union, Callable
import json, os, time, asyncio
+from promptengine.models import LLM
+
DALAI_MODEL = None
DALAI_RESPONSE = None
-openai.api_key = os.environ.get("OPENAI_API_KEY")
-
-""" Supported LLM coding assistants """
-class LLM(Enum):
- ChatGPT = 0
- Alpaca7B = 1
- GPT4 = 2
-
async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
"""
Calls GPT3.5 via OpenAI's API.
Returns raw query and response JSON dicts.
+
+ NOTE: It is recommended to set an environment variable OPENAI_API_KEY with your OpenAI API key
"""
- model_map = { LLM.ChatGPT: 'gpt-3.5-turbo', LLM.GPT4: 'gpt-4' }
- if model not in model_map:
- raise Exception(f"Could not find OpenAI chat model {model}")
- model = model_map[model]
+ import openai
+ if not openai.api_key:
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
+ model = model.value
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
query = {
@@ -37,13 +31,64 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
response = openai.ChatCompletion.create(**query)
return query, response
-async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
+async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
+ custom_prompt_wrapper: Union[Callable[[str], str], None]=None,
+ max_tokens_to_sample=1024,
+ stop_sequences: Union[List[str], str]=["\n\nHuman:"],
+ async_mode=False,
+ **params) -> Tuple[Dict, Dict]:
+ """
+ Calls Anthropic API with the given model, passing in params.
+ Returns raw query and response JSON dicts.
+
+ Unique parameters:
+ - custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
+ explore custom prompt wrappers that deviate, write a function that maps from 'prompt' to custom wrapper.
+ If set to None, defaults to Anthropic's suggested prompt wrapper.
+ - max_tokens_to_sample: A maximum number of tokens to generate before stopping.
+ - stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent.
+ - async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async.
+ If you want to send all 'n' requests at once, you can set async_mode to True.
+
+ NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
+ """
+ import anthropic
+ client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
+
+ # Format query
+ query = {
+ 'model': model.value,
+ 'prompt': f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" if not custom_prompt_wrapper else custom_prompt_wrapper(prompt),
+ 'max_tokens_to_sample': max_tokens_to_sample,
+ 'stop_sequences': stop_sequences,
+ 'temperature': temperature,
+ **params
+ }
+
+ print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...")
+
+ # Request responses using the passed async_mode
+ responses = []
+ if async_mode:
+ # Gather n responses by firing off all API requests at once
+ tasks = [client.acompletion(**query) for _ in range(n)]
+ responses = await asyncio.gather(*tasks)
+ else:
+ # Repeat call n times, waiting for each response to come in:
+ while len(responses) < n:
+ resp = await client.acompletion(**query)
+ responses.append(resp)
+ print(f'{model.value} response {len(responses)} of {n}:\n', resp)
+
+ return query, responses
+
+async def call_dalai(model: LLM, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
"""
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
Returns the raw query and response JSON dicts.
Parameters:
- - llm_name: The LLM's name as known by Dalai; e.g., 'alpaca.7b'
+ - model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
- port: The port of the local server where Dalai is running. Usually 3000.
- prompt: The prompt to pass to the LLM.
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
@@ -75,7 +120,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
# Create full query to Dalai
query = {
'prompt': prompt,
- 'model': llm_name,
+ 'model': model.value,
'id': str(round(time.time()*1000)),
'temp': temperature,
**def_params
@@ -132,15 +177,17 @@ def _extract_chatgpt_responses(response: dict) -> List[dict]:
for i, c in enumerate(choices)
]
-def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
+def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[dict]:
"""
Given a LLM and a response object from its API, extract the
text response(s) part of the response object.
"""
- if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name or llm is LLM.GPT4 or llm == LLM.GPT4.name:
+ if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value:
return _extract_chatgpt_responses(response)
- elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.name:
+ elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value:
return response["response"]
+ elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'):
+ return [r["completion"] for r in response["response"]]
else:
raise ValueError(f"LLM {llm} is unsupported.")