mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Add GPT4 support
This commit is contained in:
parent
bb13abbd18
commit
75a4efaa4f
@ -292,7 +292,9 @@ const PromptNode = ({ data, id }) => {
|
||||
<div className="nodrag">
|
||||
<input type="checkbox" id="gpt3.5" name="gpt3.5" value="gpt3.5" defaultChecked={true} onChange={handleLLMChecked} />
|
||||
<label htmlFor="gpt3.5">GPT3.5 </label>
|
||||
<input type="checkbox" id="alpaca.7B" name="alpaca.7B" value="alpaca.7B" onChange={handleLLMChecked} />
|
||||
<input type="checkbox" id="gpt4" name="gpt4" value="gpt4" defaultChecked={false} onChange={handleLLMChecked} />
|
||||
<label htmlFor="gpt4">GPT4 </label>
|
||||
<input type="checkbox" id="alpaca.7B" name="alpaca.7B" value="alpaca.7B" defaultChecked={false} onChange={handleLLMChecked} />
|
||||
<label htmlFor="alpaca.7B">Alpaca 7B</label>
|
||||
</div>
|
||||
<hr />
|
||||
|
@ -71,7 +71,7 @@ const TextFieldsNode = ({ data, id }) => {
|
||||
const val = data.fields ? data.fields[i] : '';
|
||||
return (
|
||||
<div className="input-field" key={i}>
|
||||
<textarea id={i} name={i} className="text-field-fixed nodrag" rows="3" cols="40" defaultValue={val} onChange={handleInputChange} />
|
||||
<textarea id={i} name={i} className="text-field-fixed nodrag" rows="2" cols="40" defaultValue={val} onChange={handleInputChange} />
|
||||
</div>
|
||||
)}));
|
||||
}, [data.fields, handleInputChange]);
|
||||
|
@ -13,6 +13,7 @@ CORS(app)
|
||||
LLM_NAME_MAP = {
|
||||
'gpt3.5': LLM.ChatGPT,
|
||||
'alpaca.7B': LLM.Alpaca7B,
|
||||
'gpt4': LLM.GPT4,
|
||||
}
|
||||
LLM_NAME_MAP_INVERSE = {val.name: key for key, val in LLM_NAME_MAP.items()}
|
||||
|
||||
|
@ -106,8 +106,8 @@ class PromptPipeline:
|
||||
self._cache_responses({})
|
||||
|
||||
def _prompt_llm(self, llm: LLM, prompt: str, n: int = 1, temperature: float = 1.0) -> Tuple[Dict, Dict]:
|
||||
if llm is LLM.ChatGPT:
|
||||
return call_chatgpt(prompt, n=n, temperature=temperature)
|
||||
if llm is LLM.ChatGPT or llm is LLM.GPT4:
|
||||
return call_chatgpt(prompt, model=llm, n=n, temperature=temperature)
|
||||
elif llm is LLM.Alpaca7B:
|
||||
return call_dalai(llm_name='alpaca.7B', port=4000, prompt=prompt, n=n, temperature=temperature)
|
||||
else:
|
||||
|
@ -10,15 +10,20 @@ DALAI_RESPONSE = None
|
||||
class LLM(Enum):
|
||||
ChatGPT = 0
|
||||
Alpaca7B = 1
|
||||
GPT4 = 2
|
||||
|
||||
def call_chatgpt(prompt: str, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
|
||||
def call_chatgpt(prompt: str, model: LLM, n: int = 1, temperature: float = 1.0, system_msg: Union[str, None]=None) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Calls GPT3.5 via OpenAI's API.
|
||||
Returns raw query and response JSON dicts.
|
||||
"""
|
||||
model_map = { LLM.ChatGPT: 'gpt-3.5-turbo', LLM.GPT4: 'gpt-4' }
|
||||
if model not in model_map:
|
||||
raise Exception(f"Could not find OpenAI chat model {model}")
|
||||
model = model_map[model]
|
||||
system_msg = "You are a helpful assistant." if system_msg is None else system_msg
|
||||
query = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": prompt},
|
||||
@ -129,7 +134,7 @@ def extract_responses(response: Union[list, dict], llm: LLM) -> List[dict]:
|
||||
Given a LLM and a response object from its API, extract the
|
||||
text response(s) part of the response object.
|
||||
"""
|
||||
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name:
|
||||
if llm is LLM.ChatGPT or llm == LLM.ChatGPT.name or llm is LLM.GPT4 or llm == LLM.GPT4.name:
|
||||
return _extract_chatgpt_responses(response)
|
||||
elif llm is LLM.Alpaca7B or llm == LLM.Alpaca7B.name:
|
||||
return response["response"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user