Add GPT4 support

This commit is contained in:
Ian Arawjo 2023-05-02 10:38:11 -04:00
parent bb13abbd18
commit 75a4efaa4f
5 changed files with 15 additions and 7 deletions

View File

@ -292,7 +292,9 @@ const PromptNode = ({ data, id }) => {
<div className="nodrag">
<input type="checkbox" id="gpt3.5" name="gpt3.5" value="gpt3.5" defaultChecked={true} onChange={handleLLMChecked} />
<label htmlFor="gpt3.5">GPT3.5 </label>
<input type="checkbox" id="alpaca.7B" name="alpaca.7B" value="alpaca.7B" onChange={handleLLMChecked} />
<input type="checkbox" id="gpt4" name="gpt4" value="gpt4" defaultChecked={false} onChange={handleLLMChecked} />
<label htmlFor="gpt4">GPT4 </label>
<input type="checkbox" id="alpaca.7B" name="alpaca.7B" value="alpaca.7B" defaultChecked={false} onChange={handleLLMChecked} />
<label htmlFor="alpaca.7B">Alpaca 7B</label>
</div>
<hr />

View File

@ -71,7 +71,7 @@ const TextFieldsNode = ({ data, id }) => {
const val = data.fields ? data.fields[i] : '';
return (
<div className="input-field" key={i}>
<textarea id={i} name={i} className="text-field-fixed nodrag" rows="3" cols="40" defaultValue={val} onChange={handleInputChange} />
<textarea id={i} name={i} className="text-field-fixed nodrag" rows="2" cols="40" defaultValue={val} onChange={handleInputChange} />
</div>
)}));
}, [data.fields, handleInputChange]);

View File

@ -13,6 +13,7 @@ CORS(app)
LLM_NAME_MAP = {
'gpt3.5': LLM.ChatGPT,
'alpaca.7B': LLM.Alpaca7B,
'gpt4': LLM.GPT4,
}
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}

View File

@ -106,8 +106,8 @@ class PromptPipeline:
self._cache_responses({})
def _prompt_llm(self, llm: LLM, prompt: str, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]:
if llm is LLM.ChatGPT:
return call_chatgpt(prompt, n=n, temperature=temperature)
if llm is LLM.ChatGPT or llm is LLM.GPT4:
return call_chatgpt(prompt, model=llm, n=n, temperature=temperature)
elif llm is LLM.Alpaca7B:
return call_dalai(llm_name='alpaca.7B', port=4000, prompt=prompt, n=n, temperature=temperature)
else:

View File

@ -10,15 +10,20 @@ DALAI_RESPONSE = None
class LLM(Enum):
ChatGPT = 0
Alpaca7B = 1
GPT4 = 2
def call_chatgpt(prompt: str, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
"""
Calls GPT3.5 via OpenAI's API.
Returns raw query and response JSON dicts.
"""
model_map = { LLM.ChatGPT: 'gpt-3.5-turbo', LLM.GPT4: 'gpt-4' }
if model not in model_map:
raise Exception(f"Could not find OpenAI chat model {model}")
model = model_map[model]
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
query = {
"model": "gpt-3.5-turbo",
"model": model,
"messages": [
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt},
@ -129,7 +134,7 @@ def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
Given a LLM and a response object from its API, extract the
text response(s) part of the response object.
"""
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name:
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name or llm is LLM.GPT4 or llm == LLM.GPT4.name:
return _extract_chatgpt_responses(response)
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.name:
return response["response"]