Add Claude (Anthropic) support.

This commit is contained in:
Ian Arawjo 2023-05-07 12:13:25 -04:00
parent 1c367d3080
commit 16135934f4
7 changed files with 150 additions and 47 deletions

View File

@ -10,6 +10,7 @@ const AlertModal = forwardRef((props, ref) => {
// This gives the parent access to triggering the modal alert
const trigger = (msg) => {
if (!msg) msg = "Unknown error.";
console.error(msg);
setAlertMsg(msg);
open();

View File

@ -1,4 +1,4 @@
import React, { useState } from 'react';
import React, { useState, useEffect } from 'react';
import { Handle } from 'react-flow-renderer';
import useStore from './store';
import NodeLabel from './NodeLabelComponent'
@ -21,6 +21,7 @@ const InspectorNode = ({ data, id }) => {
const [varSelects, setVarSelects] = useState([]);
const [pastInputs, setPastInputs] = useState([]);
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const handleVarValueSelect = () => {
}
@ -115,6 +116,14 @@ const InspectorNode = ({ data, id }) => {
}
}
useEffect(() => {
if (data.refresh && data.refresh === true) {
// Recreate the visualization:
setDataPropsForNode(id, { refresh: false });
handleOnConnect();
}
}, [data, id, handleOnConnect, setDataPropsForNode]);
return (
<div className="inspector-node cfnode">
<NodeLabel title={data.title || 'Inspect Node'}

View File

@ -11,10 +11,10 @@ import io from 'socket.io-client';
// Available LLMs
const allLLMs = [
{ name: "GPT3.5", emoji: "🙂", model: "gpt3.5", temp: 1.0 },
{ name: "GPT4", emoji: "🥵", model: "gpt4", temp: 1.0 },
{ name: "GPT3.5", emoji: "🙂", model: "gpt-3.5-turbo", temp: 1.0 },
{ name: "GPT4", emoji: "🥵", model: "gpt-4", temp: 1.0 },
{ name: "Alpaca 7B", emoji: "🦙", model: "alpaca.7B", temp: 0.5 },
{ name: "Claude v1", emoji: "📚", model: "claude.v1", temp: 0.5 },
{ name: "Claude v1", emoji: "📚", model: "claude-v1", temp: 0.5 },
{ name: "Ian Chatbot", emoji: "💩", model: "test", temp: 0.5 }
];
const initLLMs = [allLLMs[0]];
@ -50,6 +50,7 @@ const PromptNode = ({ data, id }) => {
const edges = useStore((state) => state.edges);
const output = useStore((state) => state.output);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const outputEdgesForNode = useStore((state) => state.outputEdgesForNode);
const getNode = useStore((state) => state.getNode);
const [templateVars, setTemplateVars] = useState(data.vars || []);
@ -392,6 +393,15 @@ const PromptNode = ({ data, id }) => {
);
}));
// Ping any inspect nodes attached to this node to refresh their contents:
const output_nodes = outputEdgesForNode(id).map(e => e.target);
output_nodes.forEach(n => {
const node = getNode(n);
if (node && node.type === 'inspect') {
setDataPropsForNode(node.id, { refresh: true });
}
});
// Log responses for debugging:
console.log(json.responses);
} else {

View File

@ -1,4 +1,4 @@
import json, os, asyncio, sys, argparse, threading
import json, os, asyncio, sys, argparse, threading, traceback
from dataclasses import dataclass
from statistics import mean, median, stdev
from flask import Flask, request, jsonify, render_template, send_from_directory
@ -21,12 +21,9 @@ cors = CORS(app, resources={r"/*": {"origins": "*"}})
def index():
return render_template("index.html")
LLM_NAME_MAP = {
'gpt3.5': LLM.ChatGPT,
'alpaca.7B': LLM.Alpaca7B,
'gpt4': LLM.GPT4,
}
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
LLM_NAME_MAP = {}
for model in LLM:
LLM_NAME_MAP[model.value] = model
@dataclass
class ResponseInfo:
@ -40,7 +37,7 @@ class ResponseInfo:
return self.text
def to_standard_format(r: dict) -> list:
llm = LLM_NAME_MAP_INVERSE[r['llm']]
llm = r['llm']
resp_obj = {
'vars': r['info'],
'llm': llm,
@ -52,9 +49,6 @@ def to_standard_format(r: dict) -> list:
resp_obj['eval_res'] = r['eval_res']
return resp_obj
def get_llm_of_response(response: dict) -> LLM:
return LLM_NAME_MAP[response['llm']]
def get_filenames_with_id(filenames: list, id: str) -> list:
return [
c for c in filenames
@ -82,7 +76,7 @@ def run_over_responses(eval_func, responses: dict, scope: str) -> list:
text=r,
prompt=prompt,
var=resp_obj['info'],
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
llm=resp_obj['llm'])
) for r in res
]
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
@ -281,7 +275,8 @@ async def queryLLM():
with open(tempfilepath, 'w') as f:
json.dump(cur_data, f)
except Exception as e:
print('error generating responses:', e)
print(f'error generating responses for {llm}:', e)
print(traceback.format_exc())
raise e
return {'llm': llm, 'responses': resps}

View File

@ -0,0 +1,39 @@
"""
A list of all model APIs natively supported by ChainForge.
"""
from enum import Enum
class LLM(str, Enum):
""" OpenAI Chat """
ChatGPT = "gpt-3.5-turbo"
GPT4 = "gpt-4"
""" Dalai-served models """
Alpaca7B = "alpaca.7B"
""" Anthropic """
# Our largest model, ideal for a wide range of more complex tasks. Using this model name
# will automatically switch you to newer versions of claude-v1 as they are released.
Claude_v1 = "claude-v1"
# An earlier version of claude-v1
Claude_v1_0 = "claude-v1.0"
# An improved version of claude-v1. It is slightly improved at general helpfulness,
# instruction following, coding, and other tasks. It is also considerably better with
# non-English languages. This model also has the ability to role play (in harmless ways)
# more consistently, and it defaults to writing somewhat longer and more thorough responses.
Claude_v1_2 = "claude-v1.2"
# A significantly improved version of claude-v1. Compared to claude-v1.2, it's more robust
# against red-team inputs, better at precise instruction-following, better at code, and better
# and non-English dialogue and writing.
Claude_v1_3 = "claude-v1.3"
# A smaller model with far lower latency, sampling at roughly 40 words/sec! Its output quality
# is somewhat lower than claude-v1 models, particularly for complex tasks. However, it is much
# less expensive and blazing fast. We believe that this model provides more than adequate performance
# on a range of tasks including text classification, summarization, and lightweight chat applications,
# as well as search result summarization. Using this model name will automatically switch you to newer
# versions of claude-instant-v1 as they are released.
Claude_v1_instant = "claude-instant-v1"

View File

@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import List, Dict, Tuple, Iterator
from typing import List, Dict, Tuple, Iterator, Union
import json, os, asyncio, random, string
from promptengine.utils import LLM, call_chatgpt, call_dalai, is_valid_filepath, is_valid_json
from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json
from promptengine.template import PromptTemplate, PromptPermutationGenerator
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
@ -65,7 +65,7 @@ class PromptPipeline:
"prompt": prompt_str,
"query": responses[prompt_str]["query"],
"response": responses[prompt_str]["response"],
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.name,
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.value,
"info": responses[prompt_str]["info"],
}
continue
@ -86,7 +86,7 @@ class PromptPipeline:
responses[str(prompt)] = {
"query": query,
"response": response,
"llm": llm.name,
"llm": llm.value,
"info": info,
}
self._cache_responses(responses)
@ -96,7 +96,7 @@ class PromptPipeline:
"prompt":str(prompt),
"query":query,
"response":response,
"llm": llm.name,
"llm": llm.value,
"info": info,
}
@ -114,7 +114,7 @@ class PromptPipeline:
responses[str(prompt)] = {
"query": query,
"response": response,
"llm": llm.name,
"llm": llm.value,
"info": info,
}
self._cache_responses(responses)
@ -124,7 +124,7 @@ class PromptPipeline:
"prompt":str(prompt),
"query":query,
"response":response,
"llm": llm.name,
"llm": llm.value,
"info": info,
}
@ -147,11 +147,13 @@ class PromptPipeline:
def clear_cached_responses(self) -> None:
self._cache_responses({})
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Union[List, Dict]]:
if llm is LLM.ChatGPT or llm is LLM.GPT4:
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
elif llm is LLM.Alpaca7B:
query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
query, response = await call_dalai(model=llm, port=4000, prompt=str(prompt), n=n, temperature=temperature)
elif llm.value[:6] == 'claude':
query, response = await call_anthropic(prompt=str(prompt), model=llm, n=n, temperature=temperature)
else:
raise Exception(f"Language model {llm} is not supported.")
return prompt, query, response

View File

@ -1,28 +1,22 @@
from typing import Dict, Tuple, List, Union
from enum import Enum
import openai
from typing import Dict, Tuple, List, Union, Callable
import json, os, time, asyncio
from promptengine.models import LLM
DALAI_MODEL = None
DALAI_RESPONSE = None
openai.api_key = os.environ.get("OPENAI_API_KEY")
""" Supported LLM coding assistants """
class LLM(Enum):
ChatGPT = 0
Alpaca7B = 1
GPT4 = 2
async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
"""
Calls GPT3.5 via OpenAI's API.
Returns raw query and response JSON dicts.
NOTE: It is recommended to set an environment variable OPENAI_API_KEY with your OpenAI API key
"""
model_map = { LLM.ChatGPT: 'gpt-3.5-turbo', LLM.GPT4: 'gpt-4' }
if model not in model_map:
raise Exception(f"Could not find OpenAI chat model {model}")
model = model_map[model]
import openai
if not openai.api_key:
openai.api_key = os.environ.get("OPENAI_API_KEY")
model = model.value
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
query = {
@ -37,13 +31,64 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
response = openai.ChatCompletion.create(**query)
return query, response
async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
custom_prompt_wrapper: Union[Callable[[str], str], None]=None,
max_tokens_to_sample=1024,
stop_sequences: Union[List[str], str]=["\n\nHuman:"],
async_mode=False,
**params) -> Tuple[Dict, Dict]:
"""
Calls Anthropic API with the given model, passing in params.
Returns raw query and response JSON dicts.
Unique parameters:
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
explore custom prompt wrappers that deviate, write a function that maps from 'prompt' to custom wrapper.
If set to None, defaults to Anthropic's suggested prompt wrapper.
- max_tokens_to_sample: A maximum number of tokens to generate before stopping.
- stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent.
- async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async.
If you want to send all 'n' requests at once, you can set async_mode to True.
NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
"""
import anthropic
client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
# Format query
query = {
'model': model.value,
'prompt': f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" if not custom_prompt_wrapper else custom_prompt_wrapper(prompt),
'max_tokens_to_sample': max_tokens_to_sample,
'stop_sequences': stop_sequences,
'temperature': temperature,
**params
}
print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...")
# Request responses using the passed async_mode
responses = []
if async_mode:
# Gather n responses by firing off all API requests at once
tasks = [client.acompletion(**query) for _ in range(n)]
responses = await asyncio.gather(*tasks)
else:
# Repeat call n times, waiting for each response to come in:
while len(responses) < n:
resp = await client.acompletion(**query)
responses.append(resp)
print(f'{model.value} response {len(responses)} of {n}:\n', resp)
return query, responses
async def call_dalai(model: LLM, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
"""
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
Returns the raw query and response JSON dicts.
Parameters:
- llm_name: The LLM's name as known by Dalai; e.g., 'alpaca.7b'
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
- port: The port of the local server where Dalai is running. Usually 3000.
- prompt: The prompt to pass to the LLM.
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
@ -75,7 +120,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
# Create full query to Dalai
query = {
'prompt': prompt,
'model': llm_name,
'model': model.value,
'id': str(round(time.time()*1000)),
'temp': temperature,
**def_params
@ -132,15 +177,17 @@ def _extract_chatgpt_responses(response: dict) -> List[dict]:
for i, c in enumerate(choices)
]
def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[dict]:
"""
Given a LLM and a response object from its API, extract the
text response(s) part of the response object.
"""
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name or llm is LLM.GPT4 or llm == LLM.GPT4.name:
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value:
return _extract_chatgpt_responses(response)
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.name:
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value:
return response["response"]
elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'):
return [r["completion"] for r in response["response"]]
else:
raise ValueError(f"LLM {llm} is unsupported.")