From 16135934f4d766c9a955807a6328053cfd2fcdc2 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Sun, 7 May 2023 12:13:25 -0400 Subject: [PATCH] Add Claude (Anthropic) support. --- chain-forge/src/AlertModal.js | 1 + chain-forge/src/InspectorNode.js | 11 +++- chain-forge/src/PromptNode.js | 16 ++++- python-backend/flask_app.py | 21 +++---- python-backend/promptengine/models.py | 39 ++++++++++++ python-backend/promptengine/query.py | 20 +++--- python-backend/promptengine/utils.py | 89 ++++++++++++++++++++------- 7 files changed, 150 insertions(+), 47 deletions(-) create mode 100644 python-backend/promptengine/models.py diff --git a/chain-forge/src/AlertModal.js b/chain-forge/src/AlertModal.js index 53fa4d0..1d3bd53 100644 --- a/chain-forge/src/AlertModal.js +++ b/chain-forge/src/AlertModal.js @@ -10,6 +10,7 @@ const AlertModal = forwardRef((props, ref) => { // This gives the parent access to triggering the modal alert const trigger = (msg) => { + if (!msg) msg = "Unknown error."; console.error(msg); setAlertMsg(msg); open(); diff --git a/chain-forge/src/InspectorNode.js b/chain-forge/src/InspectorNode.js index 433a0fc..6410e8f 100644 --- a/chain-forge/src/InspectorNode.js +++ b/chain-forge/src/InspectorNode.js @@ -1,4 +1,4 @@ -import React, { useState } from 'react'; +import React, { useState, useEffect } from 'react'; import { Handle } from 'react-flow-renderer'; import useStore from './store'; import NodeLabel from './NodeLabelComponent' @@ -21,6 +21,7 @@ const InspectorNode = ({ data, id }) => { const [varSelects, setVarSelects] = useState([]); const [pastInputs, setPastInputs] = useState([]); const inputEdgesForNode = useStore((state) => state.inputEdgesForNode); + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); const handleVarValueSelect = () => { } @@ -115,6 +116,14 @@ const InspectorNode = ({ data, id }) => { } } + useEffect(() => { + if (data.refresh && data.refresh === true) { + // Recreate the visualization: + setDataPropsForNode(id, { refresh: false }); + handleOnConnect(); + } +}, [data, id, handleOnConnect, setDataPropsForNode]); + return (
{ const edges = useStore((state) => state.edges); const output = useStore((state) => state.output); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + const outputEdgesForNode = useStore((state) => state.outputEdgesForNode); const getNode = useStore((state) => state.getNode); const [templateVars, setTemplateVars] = useState(data.vars || []); @@ -392,6 +393,15 @@ const PromptNode = ({ data, id }) => { ); })); + // Ping any inspect nodes attached to this node to refresh their contents: + const output_nodes = outputEdgesForNode(id).map(e => e.target); + output_nodes.forEach(n => { + const node = getNode(n); + if (node && node.type === 'inspect') { + setDataPropsForNode(node.id, { refresh: true }); + } + }); + // Log responses for debugging: console.log(json.responses); } else { diff --git a/python-backend/flask_app.py b/python-backend/flask_app.py index e63fea0..619ffb3 100644 --- a/python-backend/flask_app.py +++ b/python-backend/flask_app.py @@ -1,4 +1,4 @@ -import json, os, asyncio, sys, argparse, threading +import json, os, asyncio, sys, argparse, threading, traceback from dataclasses import dataclass from statistics import mean, median, stdev from flask import Flask, request, jsonify, render_template, send_from_directory @@ -21,12 +21,9 @@ cors = CORS(app, resources={r"/*": {"origins": "*"}}) def index(): return render_template("index.html") -LLM_NAME_MAP = { - 'gpt3.5': LLM.ChatGPT, - 'alpaca.7B': LLM.Alpaca7B, - 'gpt4': LLM.GPT4, -} -LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()} +LLM_NAME_MAP = {} +for model in LLM: + LLM_NAME_MAP[model.value] = model @dataclass class ResponseInfo: @@ -40,7 +37,7 @@ class ResponseInfo: return self.text def to_standard_format(r: dict) -> list: - llm = LLM_NAME_MAP_INVERSE[r['llm']] + llm = r['llm'] resp_obj = { 'vars': r['info'], 'llm': llm, @@ -52,9 +49,6 @@ def to_standard_format(r: dict) -> list: resp_obj['eval_res'] = r['eval_res'] return resp_obj -def get_llm_of_response(response: dict) -> LLM: - return LLM_NAME_MAP[response['llm']] - def get_filenames_with_id(filenames: list, id: str) -> list: return [ c for c in filenames @@ -82,7 +76,7 @@ def run_over_responses(eval_func, responses: dict, scope: str) -> list: text=r, prompt=prompt, var=resp_obj['info'], - llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']]) + llm=resp_obj['llm']) ) for r in res ] resp_obj['eval_res'] = { # NOTE: assumes this is numeric data @@ -281,7 +275,8 @@ async def queryLLM(): with open(tempfilepath, 'w') as f: json.dump(cur_data, f) except Exception as e: - print('error generating responses:', e) + print(f'error generating responses for {llm}:', e) + print(traceback.format_exc()) raise e return {'llm': llm, 'responses': resps} diff --git a/python-backend/promptengine/models.py b/python-backend/promptengine/models.py new file mode 100644 index 0000000..c2ef9a3 --- /dev/null +++ b/python-backend/promptengine/models.py @@ -0,0 +1,39 @@ +""" + A list of all model APIs natively supported by ChainForge. +""" +from enum import Enum + +class LLM(str, Enum): + """ OpenAI Chat """ + ChatGPT = "gpt-3.5-turbo" + GPT4 = "gpt-4" + + """ Dalai-served models """ + Alpaca7B = "alpaca.7B" + + """ Anthropic """ + # Our largest model, ideal for a wide range of more complex tasks. Using this model name + # will automatically switch you to newer versions of claude-v1 as they are released. + Claude_v1 = "claude-v1" + + # An earlier version of claude-v1 + Claude_v1_0 = "claude-v1.0" + + # An improved version of claude-v1. It is slightly improved at general helpfulness, + # instruction following, coding, and other tasks. It is also considerably better with + # non-English languages. This model also has the ability to role play (in harmless ways) + # more consistently, and it defaults to writing somewhat longer and more thorough responses. + Claude_v1_2 = "claude-v1.2" + + # A significantly improved version of claude-v1. Compared to claude-v1.2, it's more robust + # against red-team inputs, better at precise instruction-following, better at code, and better + # and non-English dialogue and writing. + Claude_v1_3 = "claude-v1.3" + + # A smaller model with far lower latency, sampling at roughly 40 words/sec! Its output quality + # is somewhat lower than claude-v1 models, particularly for complex tasks. However, it is much + # less expensive and blazing fast. We believe that this model provides more than adequate performance + # on a range of tasks including text classification, summarization, and lightweight chat applications, + # as well as search result summarization. Using this model name will automatically switch you to newer + # versions of claude-instant-v1 as they are released. + Claude_v1_instant = "claude-instant-v1" \ No newline at end of file diff --git a/python-backend/promptengine/query.py b/python-backend/promptengine/query.py index d0cfdcc..dde875c 100644 --- a/python-backend/promptengine/query.py +++ b/python-backend/promptengine/query.py @@ -1,7 +1,7 @@ from abc import abstractmethod -from typing import List, Dict, Tuple, Iterator +from typing import List, Dict, Tuple, Iterator, Union import json, os, asyncio, random, string -from promptengine.utils import LLM, call_chatgpt, call_dalai, is_valid_filepath, is_valid_json +from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json from promptengine.template import PromptTemplate, PromptPermutationGenerator # LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits @@ -65,7 +65,7 @@ class PromptPipeline: "prompt": prompt_str, "query": responses[prompt_str]["query"], "response": responses[prompt_str]["response"], - "llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.name, + "llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.value, "info": responses[prompt_str]["info"], } continue @@ -86,7 +86,7 @@ class PromptPipeline: responses[str(prompt)] = { "query": query, "response": response, - "llm": llm.name, + "llm": llm.value, "info": info, } self._cache_responses(responses) @@ -96,7 +96,7 @@ class PromptPipeline: "prompt":str(prompt), "query":query, "response":response, - "llm": llm.name, + "llm": llm.value, "info": info, } @@ -114,7 +114,7 @@ class PromptPipeline: responses[str(prompt)] = { "query": query, "response": response, - "llm": llm.name, + "llm": llm.value, "info": info, } self._cache_responses(responses) @@ -124,7 +124,7 @@ class PromptPipeline: "prompt":str(prompt), "query":query, "response":response, - "llm": llm.name, + "llm": llm.value, "info": info, } @@ -147,11 +147,13 @@ class PromptPipeline: def clear_cached_responses(self) -> None: self._cache_responses({}) - async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]: + async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Union[List, Dict]]: if llm is LLM.ChatGPT or llm is LLM.GPT4: query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature) elif llm is LLM.Alpaca7B: - query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature) + query, response = await call_dalai(model=llm, port=4000, prompt=str(prompt), n=n, temperature=temperature) + elif llm.value[:6] == 'claude': + query, response = await call_anthropic(prompt=str(prompt), model=llm, n=n, temperature=temperature) else: raise Exception(f"Language model {llm} is not supported.") return prompt, query, response diff --git a/python-backend/promptengine/utils.py b/python-backend/promptengine/utils.py index d89b51e..31d2c45 100644 --- a/python-backend/promptengine/utils.py +++ b/python-backend/promptengine/utils.py @@ -1,28 +1,22 @@ -from typing import Dict, Tuple, List, Union -from enum import Enum -import openai +from typing import Dict, Tuple, List, Union, Callable import json, os, time, asyncio +from promptengine.models import LLM + DALAI_MODEL = None DALAI_RESPONSE = None -openai.api_key = os.environ.get("OPENAI_API_KEY") - -""" Supported LLM coding assistants """ -class LLM(Enum): - ChatGPT = 0 - Alpaca7B = 1 - GPT4 = 2 - async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]: """ Calls GPT3.5 via OpenAI's API. Returns raw query and response JSON dicts. + + NOTE: It is recommended to set an environment variable OPENAI_API_KEY with your OpenAI API key """ - model_map = { LLM.ChatGPT: 'gpt-3.5-turbo', LLM.GPT4: 'gpt-4' } - if model not in model_map: - raise Exception(f"Could not find OpenAI chat model {model}") - model = model_map[model] + import openai + if not openai.api_key: + openai.api_key = os.environ.get("OPENAI_API_KEY") + model = model.value print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...") system_msg = "You are a helpful assistant." if system_msg is None else system_msg query = { @@ -37,13 +31,64 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = response = openai.ChatCompletion.create(**query) return query, response -async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]: +async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0, + custom_prompt_wrapper: Union[Callable[[str], str], None]=None, + max_tokens_to_sample=1024, + stop_sequences: Union[List[str], str]=["\n\nHuman:"], + async_mode=False, + **params) -> Tuple[Dict, Dict]: + """ + Calls Anthropic API with the given model, passing in params. + Returns raw query and response JSON dicts. + + Unique parameters: + - custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to + explore custom prompt wrappers that deviate, write a function that maps from 'prompt' to custom wrapper. + If set to None, defaults to Anthropic's suggested prompt wrapper. + - max_tokens_to_sample: A maximum number of tokens to generate before stopping. + - stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent. + - async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async. + If you want to send all 'n' requests at once, you can set async_mode to True. + + NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key + """ + import anthropic + client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"]) + + # Format query + query = { + 'model': model.value, + 'prompt': f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" if not custom_prompt_wrapper else custom_prompt_wrapper(prompt), + 'max_tokens_to_sample': max_tokens_to_sample, + 'stop_sequences': stop_sequences, + 'temperature': temperature, + **params + } + + print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...") + + # Request responses using the passed async_mode + responses = [] + if async_mode: + # Gather n responses by firing off all API requests at once + tasks = [client.acompletion(**query) for _ in range(n)] + responses = await asyncio.gather(*tasks) + else: + # Repeat call n times, waiting for each response to come in: + while len(responses) < n: + resp = await client.acompletion(**query) + responses.append(resp) + print(f'{model.value} response {len(responses)} of {n}:\n', resp) + + return query, responses + +async def call_dalai(model: LLM, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]: """ Calls a Dalai server running LLMs Alpaca, Llama, etc locally. Returns the raw query and response JSON dicts. Parameters: - - llm_name: The LLM's name as known by Dalai; e.g., 'alpaca.7b' + - model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b' - port: The port of the local server where Dalai is running. Usually 3000. - prompt: The prompt to pass to the LLM. - n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses. @@ -75,7 +120,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat # Create full query to Dalai query = { 'prompt': prompt, - 'model': llm_name, + 'model': model.value, 'id': str(round(time.time()*1000)), 'temp': temperature, **def_params @@ -132,15 +177,17 @@ def _extract_chatgpt_responses(response: dict) -> List[dict]: for i, c in enumerate(choices) ] -def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]: +def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[dict]: """ Given a LLM and a response object from its API, extract the text response(s) part of the response object. """ - if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name or llm is LLM.GPT4 or llm == LLM.GPT4.name: + if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value: return _extract_chatgpt_responses(response) - elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.name: + elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value: return response["response"] + elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'): + return [r["completion"] for r in response["response"]] else: raise ValueError(f"LLM {llm} is unsupported.")