mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Add Claude (Anthropic) support.
This commit is contained in:
parent
1c367d3080
commit
16135934f4
@ -10,6 +10,7 @@ const AlertModal = forwardRef((props, ref) => {
|
||||
|
||||
// This gives the parent access to triggering the modal alert
|
||||
const trigger = (msg) => {
|
||||
if (!msg) msg = "Unknown error.";
|
||||
console.error(msg);
|
||||
setAlertMsg(msg);
|
||||
open();
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { useState } from 'react';
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Handle } from 'react-flow-renderer';
|
||||
import useStore from './store';
|
||||
import NodeLabel from './NodeLabelComponent'
|
||||
@ -21,6 +21,7 @@ const InspectorNode = ({ data, id }) => {
|
||||
const [varSelects, setVarSelects] = useState([]);
|
||||
const [pastInputs, setPastInputs] = useState([]);
|
||||
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
|
||||
const handleVarValueSelect = () => {
|
||||
}
|
||||
@ -115,6 +116,14 @@ const InspectorNode = ({ data, id }) => {
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (data.refresh && data.refresh === true) {
|
||||
// Recreate the visualization:
|
||||
setDataPropsForNode(id, { refresh: false });
|
||||
handleOnConnect();
|
||||
}
|
||||
}, [data, id, handleOnConnect, setDataPropsForNode]);
|
||||
|
||||
return (
|
||||
<div className="inspector-node cfnode">
|
||||
<NodeLabel title={data.title || 'Inspect Node'}
|
||||
|
@ -11,10 +11,10 @@ import io from 'socket.io-client';
|
||||
|
||||
// Available LLMs
|
||||
const allLLMs = [
|
||||
{ name: "GPT3.5", emoji: "🙂", model: "gpt3.5", temp: 1.0 },
|
||||
{ name: "GPT4", emoji: "🥵", model: "gpt4", temp: 1.0 },
|
||||
{ name: "GPT3.5", emoji: "🙂", model: "gpt-3.5-turbo", temp: 1.0 },
|
||||
{ name: "GPT4", emoji: "🥵", model: "gpt-4", temp: 1.0 },
|
||||
{ name: "Alpaca 7B", emoji: "🦙", model: "alpaca.7B", temp: 0.5 },
|
||||
{ name: "Claude v1", emoji: "📚", model: "claude.v1", temp: 0.5 },
|
||||
{ name: "Claude v1", emoji: "📚", model: "claude-v1", temp: 0.5 },
|
||||
{ name: "Ian Chatbot", emoji: "💩", model: "test", temp: 0.5 }
|
||||
];
|
||||
const initLLMs = [allLLMs[0]];
|
||||
@ -50,6 +50,7 @@ const PromptNode = ({ data, id }) => {
|
||||
const edges = useStore((state) => state.edges);
|
||||
const output = useStore((state) => state.output);
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const outputEdgesForNode = useStore((state) => state.outputEdgesForNode);
|
||||
const getNode = useStore((state) => state.getNode);
|
||||
|
||||
const [templateVars, setTemplateVars] = useState(data.vars || []);
|
||||
@ -392,6 +393,15 @@ const PromptNode = ({ data, id }) => {
|
||||
);
|
||||
}));
|
||||
|
||||
// Ping any inspect nodes attached to this node to refresh their contents:
|
||||
const output_nodes = outputEdgesForNode(id).map(e => e.target);
|
||||
output_nodes.forEach(n => {
|
||||
const node = getNode(n);
|
||||
if (node && node.type === 'inspect') {
|
||||
setDataPropsForNode(node.id, { refresh: true });
|
||||
}
|
||||
});
|
||||
|
||||
// Log responses for debugging:
|
||||
console.log(json.responses);
|
||||
} else {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import json, os, asyncio, sys, argparse, threading
|
||||
import json, os, asyncio, sys, argparse, threading, traceback
|
||||
from dataclasses import dataclass
|
||||
from statistics import mean, median, stdev
|
||||
from flask import Flask, request, jsonify, render_template, send_from_directory
|
||||
@ -21,12 +21,9 @@ cors = CORS(app, resources={r"/*": {"origins": "*"}})
|
||||
def index():
|
||||
return render_template("index.html")
|
||||
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
'alpaca.7B': LLM.Alpaca7B,
|
||||
'gpt4': LLM.GPT4,
|
||||
}
|
||||
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
|
||||
LLM_NAME_MAP = {}
|
||||
for model in LLM:
|
||||
LLM_NAME_MAP[model.value] = model
|
||||
|
||||
@dataclass
|
||||
class ResponseInfo:
|
||||
@ -40,7 +37,7 @@ class ResponseInfo:
|
||||
return self.text
|
||||
|
||||
def to_standard_format(r: dict) -> list:
|
||||
llm = LLM_NAME_MAP_INVERSE[r['llm']]
|
||||
llm = r['llm']
|
||||
resp_obj = {
|
||||
'vars': r['info'],
|
||||
'llm': llm,
|
||||
@ -52,9 +49,6 @@ def to_standard_format(r: dict) -> list:
|
||||
resp_obj['eval_res'] = r['eval_res']
|
||||
return resp_obj
|
||||
|
||||
def get_llm_of_response(response: dict) -> LLM:
|
||||
return LLM_NAME_MAP[response['llm']]
|
||||
|
||||
def get_filenames_with_id(filenames: list, id: str) -> list:
|
||||
return [
|
||||
c for c in filenames
|
||||
@ -82,7 +76,7 @@ def run_over_responses(eval_func, responses: dict, scope: str) -> list:
|
||||
text=r,
|
||||
prompt=prompt,
|
||||
var=resp_obj['info'],
|
||||
llm=LLM_NAME_MAP_INVERSE[resp_obj['llm']])
|
||||
llm=resp_obj['llm'])
|
||||
) for r in res
|
||||
]
|
||||
resp_obj['eval_res'] = { # NOTE: assumes this is numeric data
|
||||
@ -281,7 +275,8 @@ async def queryLLM():
|
||||
with open(tempfilepath, 'w') as f:
|
||||
json.dump(cur_data, f)
|
||||
except Exception as e:
|
||||
print('error generating responses:', e)
|
||||
print(f'error generating responses for {llm}:', e)
|
||||
print(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
return {'llm': llm, 'responses': resps}
|
||||
|
39
python-backend/promptengine/models.py
Normal file
39
python-backend/promptengine/models.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
A list of all model APIs natively supported by ChainForge.
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
class LLM(str, Enum):
|
||||
""" OpenAI Chat """
|
||||
ChatGPT = "gpt-3.5-turbo"
|
||||
GPT4 = "gpt-4"
|
||||
|
||||
""" Dalai-served models """
|
||||
Alpaca7B = "alpaca.7B"
|
||||
|
||||
""" Anthropic """
|
||||
# Our largest model, ideal for a wide range of more complex tasks. Using this model name
|
||||
# will automatically switch you to newer versions of claude-v1 as they are released.
|
||||
Claude_v1 = "claude-v1"
|
||||
|
||||
# An earlier version of claude-v1
|
||||
Claude_v1_0 = "claude-v1.0"
|
||||
|
||||
# An improved version of claude-v1. It is slightly improved at general helpfulness,
|
||||
# instruction following, coding, and other tasks. It is also considerably better with
|
||||
# non-English languages. This model also has the ability to role play (in harmless ways)
|
||||
# more consistently, and it defaults to writing somewhat longer and more thorough responses.
|
||||
Claude_v1_2 = "claude-v1.2"
|
||||
|
||||
# A significantly improved version of claude-v1. Compared to claude-v1.2, it's more robust
|
||||
# against red-team inputs, better at precise instruction-following, better at code, and better
|
||||
# and non-English dialogue and writing.
|
||||
Claude_v1_3 = "claude-v1.3"
|
||||
|
||||
# A smaller model with far lower latency, sampling at roughly 40 words/sec! Its output quality
|
||||
# is somewhat lower than claude-v1 models, particularly for complex tasks. However, it is much
|
||||
# less expensive and blazing fast. We believe that this model provides more than adequate performance
|
||||
# on a range of tasks including text classification, summarization, and lightweight chat applications,
|
||||
# as well as search result summarization. Using this model name will automatically switch you to newer
|
||||
# versions of claude-instant-v1 as they are released.
|
||||
Claude_v1_instant = "claude-instant-v1"
|
@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Tuple, Iterator
|
||||
from typing import List, Dict, Tuple, Iterator, Union
|
||||
import json, os, asyncio, random, string
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, is_valid_filepath, is_valid_json
|
||||
from promptengine.utils import LLM, call_chatgpt, call_dalai, call_anthropic, is_valid_filepath, is_valid_json
|
||||
from promptengine.template import PromptTemplate, PromptPermutationGenerator
|
||||
|
||||
# LLM APIs often have rate limits, which control number of requests. E.g., OpenAI: https://platform.openai.com/account/rate-limits
|
||||
@ -65,7 +65,7 @@ class PromptPipeline:
|
||||
"prompt": prompt_str,
|
||||
"query": responses[prompt_str]["query"],
|
||||
"response": responses[prompt_str]["response"],
|
||||
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.name,
|
||||
"llm": responses[prompt_str]["llm"] if "llm" in responses[prompt_str] else LLM.ChatGPT.value,
|
||||
"info": responses[prompt_str]["info"],
|
||||
}
|
||||
continue
|
||||
@ -86,7 +86,7 @@ class PromptPipeline:
|
||||
responses[str(prompt)] = {
|
||||
"query": query,
|
||||
"response": response,
|
||||
"llm": llm.name,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
self._cache_responses(responses)
|
||||
@ -96,7 +96,7 @@ class PromptPipeline:
|
||||
"prompt":str(prompt),
|
||||
"query":query,
|
||||
"response":response,
|
||||
"llm": llm.name,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ class PromptPipeline:
|
||||
responses[str(prompt)] = {
|
||||
"query": query,
|
||||
"response": response,
|
||||
"llm": llm.name,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
self._cache_responses(responses)
|
||||
@ -124,7 +124,7 @@ class PromptPipeline:
|
||||
"prompt":str(prompt),
|
||||
"query":query,
|
||||
"response":response,
|
||||
"llm": llm.name,
|
||||
"llm": llm.value,
|
||||
"info": info,
|
||||
}
|
||||
|
||||
@ -147,11 +147,13 @@ class PromptPipeline:
|
||||
def clear_cached_responses(self) -> None:
|
||||
self._cache_responses({})
|
||||
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Dict]:
|
||||
async def _prompt_llm(self, llm: LLM, prompt: PromptTemplate, n: int = 1, temperature: float = 1.0) -> Tuple[str, Dict, Union[List, Dict]]:
|
||||
if llm is LLM.ChatGPT or llm is LLM.GPT4:
|
||||
query, response = await call_chatgpt(str(prompt), model=llm, n=n, temperature=temperature)
|
||||
elif llm is LLM.Alpaca7B:
|
||||
query, response = await call_dalai(llm_name='alpaca.7B', port=4000, prompt=str(prompt), n=n, temperature=temperature)
|
||||
query, response = await call_dalai(model=llm, port=4000, prompt=str(prompt), n=n, temperature=temperature)
|
||||
elif llm.value[:6] == 'claude':
|
||||
query, response = await call_anthropic(prompt=str(prompt), model=llm, n=n, temperature=temperature)
|
||||
else:
|
||||
raise Exception(f"Language model {llm} is not supported.")
|
||||
return prompt, query, response
|
||||
|
@ -1,28 +1,22 @@
|
||||
from typing import Dict, Tuple, List, Union
|
||||
from enum import Enum
|
||||
import openai
|
||||
from typing import Dict, Tuple, List, Union, Callable
|
||||
import json, os, time, asyncio
|
||||
|
||||
from promptengine.models import LLM
|
||||
|
||||
DALAI_MODEL = None
|
||||
DALAI_RESPONSE = None
|
||||
|
||||
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
""" Supported LLM coding assistants """
|
||||
class LLM(Enum):
|
||||
ChatGPT = 0
|
||||
Alpaca7B = 1
|
||||
GPT4 = 2
|
||||
|
||||
async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls GPT3.5 via OpenAI's API.
|
||||
Returns raw query and response JSON dicts.
|
||||
|
||||
NOTE: It is recommended to set an environment variable OPENAI_API_KEY with your OpenAI API key
|
||||
"""
|
||||
model_map = { LLM.ChatGPT: 'gpt-3.5-turbo', LLM.GPT4: 'gpt-4' }
|
||||
if model not in model_map:
|
||||
raise Exception(f"Could not find OpenAI chat model {model}")
|
||||
model = model_map[model]
|
||||
import openai
|
||||
if not openai.api_key:
|
||||
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||
model = model.value
|
||||
print(f"Querying OpenAI model '{model}' with prompt '{prompt}'...")
|
||||
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
|
||||
query = {
|
||||
@ -37,13 +31,64 @@ async def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float =
|
||||
response = openai.ChatCompletion.create(**query)
|
||||
return query, response
|
||||
|
||||
async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
async def call_anthropic(prompt: str, model: LLM, n: int = 1, temperature: float= 1.0,
|
||||
custom_prompt_wrapper: Union[Callable[[str], str], None]=None,
|
||||
max_tokens_to_sample=1024,
|
||||
stop_sequences: Union[List[str], str]=["\n\nHuman:"],
|
||||
async_mode=False,
|
||||
**params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls Anthropic API with the given model, passing in params.
|
||||
Returns raw query and response JSON dicts.
|
||||
|
||||
Unique parameters:
|
||||
- custom_prompt_wrapper: Anthropic models expect prompts in form "\n\nHuman: ${prompt}\n\nAssistant". If you wish to
|
||||
explore custom prompt wrappers that deviate, write a function that maps from 'prompt' to custom wrapper.
|
||||
If set to None, defaults to Anthropic's suggested prompt wrapper.
|
||||
- max_tokens_to_sample: A maximum number of tokens to generate before stopping.
|
||||
- stop_sequences: A list of strings upon which to stop generating. Defaults to ["\n\nHuman:"], the cue for the next turn in the dialog agent.
|
||||
- async_mode: Evaluation access to Claude limits calls to 1 at a time, meaning we can't take advantage of async.
|
||||
If you want to send all 'n' requests at once, you can set async_mode to True.
|
||||
|
||||
NOTE: It is recommended to set an environment variable ANTHROPIC_API_KEY with your Anthropic API key
|
||||
"""
|
||||
import anthropic
|
||||
client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
|
||||
|
||||
# Format query
|
||||
query = {
|
||||
'model': model.value,
|
||||
'prompt': f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" if not custom_prompt_wrapper else custom_prompt_wrapper(prompt),
|
||||
'max_tokens_to_sample': max_tokens_to_sample,
|
||||
'stop_sequences': stop_sequences,
|
||||
'temperature': temperature,
|
||||
**params
|
||||
}
|
||||
|
||||
print(f"Calling Anthropic model '{model.value}' with prompt '{prompt}' (n={n}). Please be patient...")
|
||||
|
||||
# Request responses using the passed async_mode
|
||||
responses = []
|
||||
if async_mode:
|
||||
# Gather n responses by firing off all API requests at once
|
||||
tasks = [client.acompletion(**query) for _ in range(n)]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
else:
|
||||
# Repeat call n times, waiting for each response to come in:
|
||||
while len(responses) < n:
|
||||
resp = await client.acompletion(**query)
|
||||
responses.append(resp)
|
||||
print(f'{model.value} response {len(responses)} of {n}:\n', resp)
|
||||
|
||||
return query, responses
|
||||
|
||||
async def call_dalai(model: LLM, port: int, prompt: str, n: int = 1, temperature: float = 0.5, **params) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls a Dalai server running LLMs Alpaca, Llama, etc locally.
|
||||
Returns the raw query and response JSON dicts.
|
||||
|
||||
Parameters:
|
||||
- llm_name: The LLM's name as known by Dalai; e.g., 'alpaca.7b'
|
||||
- model: The LLM model, whose value is the name known byt Dalai; e.g. 'alpaca.7b'
|
||||
- port: The port of the local server where Dalai is running. Usually 3000.
|
||||
- prompt: The prompt to pass to the LLM.
|
||||
- n: How many times to query. If n > 1, this will continue to query the LLM 'n' times and collect all responses.
|
||||
@ -75,7 +120,7 @@ async def call_dalai(llm_name: str, port: int, prompt: str, n: int = 1, temperat
|
||||
# Create full query to Dalai
|
||||
query = {
|
||||
'prompt': prompt,
|
||||
'model': llm_name,
|
||||
'model': model.value,
|
||||
'id': str(round(time.time()*1000)),
|
||||
'temp': temperature,
|
||||
**def_params
|
||||
@ -132,15 +177,17 @@ def _extract_chatgpt_responses(response: dict) -> List[dict]:
|
||||
for i, c in enumerate(choices)
|
||||
]
|
||||
|
||||
def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
|
||||
def extract_responses(response: Union[list, dict], llm: Union[LLM, str]) -> List[dict]:
|
||||
"""
|
||||
Given a LLM and a response object from its API, extract the
|
||||
text response(s) part of the response object.
|
||||
"""
|
||||
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name or llm is LLM.GPT4 or llm == LLM.GPT4.name:
|
||||
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.value or llm is LLM.GPT4 or llm == LLM.GPT4.value:
|
||||
return _extract_chatgpt_responses(response)
|
||||
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.name:
|
||||
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.value:
|
||||
return response["response"]
|
||||
elif (isinstance(llm, LLM) and llm.value[:6] == 'claude') or (isinstance(llm, str) and llm[:6] == 'claude'):
|
||||
return [r["completion"] for r in response["response"]]
|
||||
else:
|
||||
raise ValueError(f"LLM {llm} is unsupported.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user