mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Compare commits
14 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
2c7447d5f7 | ||
|
3b929880dc | ||
|
6ed42fe518 | ||
|
98a8184a6a | ||
|
9d7c458b7a | ||
|
723022fb31 | ||
|
6b7e7935d0 | ||
|
1e215f4238 | ||
|
1206d62b7b | ||
|
0174b53aff | ||
|
0e96fa2e1c | ||
|
9ec7a3a4fc | ||
|
f5882768ba | ||
|
ff813c7255 |
20
README.md
20
README.md
@ -67,7 +67,9 @@ Now you can open the browser of your choice and open `http://127.0.0.1:8000`.
|
||||
- OpenAI
|
||||
- Anthropic
|
||||
- Google (Gemini, PaLM2)
|
||||
- DeepSeek
|
||||
- HuggingFace (Inference and Endpoints)
|
||||
- Together.ai
|
||||
- [Ollama](https://github.com/jmorganca/ollama) (locally-hosted models)
|
||||
- Microsoft Azure OpenAI Endpoints
|
||||
- [AlephAlpha](https://app.aleph-alpha.com/)
|
||||
@ -133,7 +135,7 @@ For more specific details, see our [documentation](https://chainforge.ai/docs/no
|
||||
|
||||
# Development
|
||||
|
||||
ChainForge was created by [Ian Arawjo](http://ianarawjo.com/index.html), a postdoctoral scholar in Harvard HCI's [Glassman Lab](http://glassmanlab.seas.harvard.edu/) with support from the Harvard HCI community. Collaborators include PhD students [Priyan Vaithilingam](https://priyan.info) and [Chelse Swoopes](https://seas.harvard.edu/person/chelse-swoopes), Harvard undergraduate [Sean Yang](https://shawsean.com), and faculty members [Elena Glassman](http://glassmanlab.seas.harvard.edu/glassman.html) and [Martin Wattenberg](https://www.bewitched.com/about.html).
|
||||
ChainForge was created by [Ian Arawjo](http://ianarawjo.com/index.html), a postdoctoral scholar in Harvard HCI's [Glassman Lab](http://glassmanlab.seas.harvard.edu/) with support from the Harvard HCI community. Collaborators include PhD students [Priyan Vaithilingam](https://priyan.info) and [Chelse Swoopes](https://seas.harvard.edu/person/chelse-swoopes), Harvard undergraduate [Sean Yang](https://shawsean.com), and faculty members [Elena Glassman](http://glassmanlab.seas.harvard.edu/glassman.html) and [Martin Wattenberg](https://www.bewitched.com/about.html). Additional collaborators include UC Berkeley PhD student Shreya Shankar and Université de Montréal undergraduate Cassandre Hamel.
|
||||
|
||||
This work was partially funded by the NSF grants IIS-2107391, IIS-2040880, and IIS-1955699. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation.
|
||||
|
||||
@ -156,17 +158,15 @@ We welcome open-source collaborators. If you want to report a bug or request a f
|
||||
|
||||
# Cite Us
|
||||
|
||||
If you use ChainForge for research purposes, or build upon the source code, we ask that you cite our [arXiv pre-print](https://arxiv.org/abs/2309.09128) in any related publications.
|
||||
The BibTeX you can use is:
|
||||
If you use ChainForge for research purposes, whether by building upon the source code or investigating LLM behavior using the tool, we ask that you cite our [CHI research paper](https://dl.acm.org/doi/full/10.1145/3613904.3642016) in any related publications. The BibTeX you can use is:
|
||||
|
||||
```bibtex
|
||||
@misc{arawjo2023chainforge,
|
||||
title={ChainForge: A Visual Toolkit for Prompt Engineering and LLM Hypothesis Testing},
|
||||
author={Ian Arawjo and Chelse Swoopes and Priyan Vaithilingam and Martin Wattenberg and Elena Glassman},
|
||||
year={2023},
|
||||
eprint={2309.09128},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.HC}
|
||||
@inproceedings{arawjo2024chainforge,
|
||||
title={ChainForge: A Visual Toolkit for Prompt Engineering and LLM Hypothesis Testing},
|
||||
author={Arawjo, Ian and Swoopes, Chelse and Vaithilingam, Priyan and Wattenberg, Martin and Glassman, Elena L},
|
||||
booktitle={Proceedings of the CHI Conference on Human Factors in Computing Systems},
|
||||
pages={1--18},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
|
1060
chainforge/examples/animal-images.cforge
Normal file
1060
chainforge/examples/animal-images.cforge
Normal file
File diff suppressed because one or more lines are too long
18701
chainforge/examples/audit-bias.cforge
Normal file
18701
chainforge/examples/audit-bias.cforge
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
3844
chainforge/examples/book-beginnings.cforge
Normal file
3844
chainforge/examples/book-beginnings.cforge
Normal file
File diff suppressed because it is too large
Load Diff
32424
chainforge/examples/chat-sycophancy.cforge
Normal file
32424
chainforge/examples/chat-sycophancy.cforge
Normal file
File diff suppressed because it is too large
Load Diff
26835
chainforge/examples/compare-prompts.cforge
Normal file
26835
chainforge/examples/compare-prompts.cforge
Normal file
File diff suppressed because it is too large
Load Diff
7207
chainforge/examples/comparing-formats.cforge
Normal file
7207
chainforge/examples/comparing-formats.cforge
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1677
chainforge/examples/mosquito-knowledge.cforge
Normal file
1677
chainforge/examples/mosquito-knowledge.cforge
Normal file
File diff suppressed because it is too large
Load Diff
124439
chainforge/examples/python-coding-eval.cforge
Normal file
124439
chainforge/examples/python-coding-eval.cforge
Normal file
File diff suppressed because it is too large
Load Diff
63762
chainforge/examples/red-team-stereotypes.cforge
Normal file
63762
chainforge/examples/red-team-stereotypes.cforge
Normal file
File diff suppressed because it is too large
Load Diff
1287
chainforge/examples/structured-outputs.cforge
Normal file
1287
chainforge/examples/structured-outputs.cforge
Normal file
File diff suppressed because it is too large
Load Diff
41255
chainforge/examples/tweet-multi-eval.cforge
Normal file
41255
chainforge/examples/tweet-multi-eval.cforge
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,11 +3,13 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from statistics import mean, median, stdev
|
||||
from datetime import datetime
|
||||
from flask import Flask, request, jsonify, render_template
|
||||
from flask_cors import CORS
|
||||
from chainforge.providers.dalai import call_dalai
|
||||
from chainforge.providers import ProviderRegistry
|
||||
import requests as py_requests
|
||||
from platformdirs import user_data_dir
|
||||
|
||||
""" =================
|
||||
SETUP AND GLOBALS
|
||||
@ -26,6 +28,7 @@ app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
|
||||
cors = CORS(app, resources={r"/*": {"origins": "*"}})
|
||||
|
||||
# The cache and examples files base directories
|
||||
FLOWS_DIR = user_data_dir("chainforge") # platform-agnostic local storage that persists outside the package install location
|
||||
CACHE_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache')
|
||||
EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'examples')
|
||||
|
||||
@ -472,6 +475,7 @@ def fetchEnvironAPIKeys():
|
||||
'AWS_REGION': 'AWS_Region',
|
||||
'AWS_SESSION_TOKEN': 'AWS_Session_Token',
|
||||
'TOGETHER_API_KEY': 'Together',
|
||||
'DEEPSEEK_API_KEY': 'DeepSeek',
|
||||
}
|
||||
d = { alias: os.environ.get(key) for key, alias in keymap.items() }
|
||||
ret = jsonify(d)
|
||||
@ -508,7 +512,7 @@ def makeFetchCall():
|
||||
ret.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return ret
|
||||
else:
|
||||
err_msg = "API request to Anthropic failed"
|
||||
err_msg = "API request failed"
|
||||
ret = response.json()
|
||||
if "error" in ret and "message" in ret["error"]:
|
||||
err_msg += ": " + ret["error"]["message"]
|
||||
@ -720,6 +724,109 @@ async def callCustomProvider():
|
||||
# Return the response
|
||||
return jsonify({'response': response})
|
||||
|
||||
"""
|
||||
LOCALLY SAVED FLOWS
|
||||
"""
|
||||
@app.route('/api/flows', methods=['GET'])
|
||||
def get_flows():
|
||||
"""Return a list of all saved flows. If the directory does not exist, try to create it."""
|
||||
os.makedirs(FLOWS_DIR, exist_ok=True) # Creates the directory if it doesn't exist
|
||||
flows = [
|
||||
{
|
||||
"name": f,
|
||||
"last_modified": datetime.fromtimestamp(os.path.getmtime(os.path.join(FLOWS_DIR, f))).isoformat()
|
||||
}
|
||||
for f in os.listdir(FLOWS_DIR)
|
||||
if f.endswith('.cforge') and f != "__autosave.cforge" # ignore the special autosave file
|
||||
]
|
||||
|
||||
# Sort the flow files by last modified date in descending order (most recent first)
|
||||
flows.sort(key=lambda x: x["last_modified"], reverse=True)
|
||||
|
||||
return jsonify({
|
||||
"flow_dir": FLOWS_DIR,
|
||||
"flows": flows
|
||||
})
|
||||
|
||||
@app.route('/api/flows/<filename>', methods=['GET'])
|
||||
def get_flow(filename):
|
||||
"""Return the content of a specific flow"""
|
||||
if not filename.endswith('.cforge'):
|
||||
filename += '.cforge'
|
||||
try:
|
||||
with open(os.path.join(FLOWS_DIR, filename), 'r') as f:
|
||||
return jsonify(json.load(f))
|
||||
except FileNotFoundError:
|
||||
return jsonify({"error": "Flow not found"}), 404
|
||||
|
||||
@app.route('/api/flows/<filename>', methods=['DELETE'])
|
||||
def delete_flow(filename):
|
||||
"""Delete a flow"""
|
||||
if not filename.endswith('.cforge'):
|
||||
filename += '.cforge'
|
||||
try:
|
||||
os.remove(os.path.join(FLOWS_DIR, filename))
|
||||
return jsonify({"message": f"Flow {filename} deleted successfully"})
|
||||
except FileNotFoundError:
|
||||
return jsonify({"error": "Flow not found"}), 404
|
||||
|
||||
@app.route('/api/flows/<filename>', methods=['PUT'])
|
||||
def save_or_rename_flow(filename):
|
||||
"""Save or rename a flow"""
|
||||
data = request.json
|
||||
|
||||
if not filename.endswith('.cforge'):
|
||||
filename += '.cforge'
|
||||
|
||||
if data.get('flow'):
|
||||
# Save flow (overwriting any existing flow file with the same name)
|
||||
flow_data = data.get('flow')
|
||||
|
||||
try:
|
||||
filepath = os.path.join(FLOWS_DIR, filename)
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(flow_data, f)
|
||||
return jsonify({"message": f"Flow '{filename}' saved!"})
|
||||
except FileNotFoundError:
|
||||
return jsonify({"error": f"Could not save flow '{filename}' to local filesystem. See terminal for more details."}), 404
|
||||
|
||||
elif data.get('newName'):
|
||||
# Rename flow
|
||||
new_name = data.get('newName')
|
||||
|
||||
if not new_name.endswith('.cforge'):
|
||||
new_name += '.cforge'
|
||||
|
||||
try:
|
||||
# Check for name clashes (if a flow already exists with the new name)
|
||||
if os.path.isfile(os.path.join(FLOWS_DIR, new_name)):
|
||||
raise Exception("A flow with that name already exists.")
|
||||
os.rename(os.path.join(FLOWS_DIR, filename), os.path.join(FLOWS_DIR, new_name))
|
||||
return jsonify({"message": f"Flow renamed from {filename} to {new_name}"})
|
||||
except Exception as error:
|
||||
return jsonify({"error": str(error)}), 404
|
||||
|
||||
@app.route('/api/getUniqueFlowFilename', methods=['PUT'])
|
||||
def get_unique_flow_name():
|
||||
"""Return a non-name-clashing filename to store in the local disk."""
|
||||
data = request.json
|
||||
filename = data.get("name")
|
||||
|
||||
try:
|
||||
base, ext = os.path.splitext(filename)
|
||||
if ext is None or len(ext) == 0:
|
||||
ext = ".cforge"
|
||||
unique_filename = base + ext
|
||||
i = 1
|
||||
|
||||
# Find the first non-clashing filename of the form <filename>(i).cforge where i=1,2,3 etc
|
||||
while os.path.isfile(os.path.join(FLOWS_DIR, unique_filename)):
|
||||
unique_filename = f"{base}({i}){ext}"
|
||||
i += 1
|
||||
|
||||
return jsonify(unique_filename.replace(".cforge", ""))
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 404
|
||||
|
||||
def run_server(host="", port=8000, cmd_args=None):
|
||||
global HOSTNAME, PORT
|
||||
|
52
chainforge/react-server/package-lock.json
generated
52
chainforge/react-server/package-lock.json
generated
@ -16,7 +16,7 @@
|
||||
"@emoji-mart/react": "^1.1.1",
|
||||
"@fontsource/geist-mono": "^5.0.1",
|
||||
"@google-ai/generativelanguage": "^0.2.0",
|
||||
"@google/generative-ai": "^0.1.3",
|
||||
"@google/generative-ai": "^0.21.0",
|
||||
"@mantine/core": "^6.0.9",
|
||||
"@mantine/dates": "^6.0.13",
|
||||
"@mantine/dropzone": "^6.0.19",
|
||||
@ -65,7 +65,7 @@
|
||||
"lodash": "^4.17.21",
|
||||
"lz-string": "^1.5.0",
|
||||
"mantine-contextmenu": "^6.1.0",
|
||||
"mantine-react-table": "^1.0.0-beta.8",
|
||||
"mantine-react-table": "^1.3.4",
|
||||
"markdown-it": "^13.0.1",
|
||||
"mathjs": "^11.8.2",
|
||||
"mdast-util-from-markdown": "^2.0.0",
|
||||
@ -112,7 +112,7 @@
|
||||
"@types/react-edit-text": "^5.0.4",
|
||||
"@types/react-plotly.js": "^2.6.3",
|
||||
"@types/styled-components": "^5.1.34",
|
||||
"eslint": "^8.56.0",
|
||||
"eslint": "<9.0.0",
|
||||
"eslint-config-prettier": "^9.1.0",
|
||||
"eslint-config-semistandard": "^17.0.0",
|
||||
"eslint-config-standard": "^17.1.0",
|
||||
@ -3785,9 +3785,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@google/generative-ai": {
|
||||
"version": "0.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.1.3.tgz",
|
||||
"integrity": "sha512-Cm4uJX1sKarpm1mje/MiOIinM7zdUUrQp/5/qGPAgznbdd/B9zup5ehT6c1qGqycFcSopTA1J1HpqHS5kJR8hQ==",
|
||||
"version": "0.21.0",
|
||||
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.21.0.tgz",
|
||||
"integrity": "sha512-7XhUbtnlkSEZK15kN3t+tzIMxsbKm/dSkKBFalj+20NvPKe1kBY7mR2P7vuijEn+f06z5+A8bVGKO0v39cr6Wg==",
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
@ -6191,11 +6191,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/react-table": {
|
||||
"version": "8.10.0",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.10.0.tgz",
|
||||
"integrity": "sha512-FNhKE3525hryvuWw90xRbP16qNiq7OLJkDZopOKcwyktErLi1ibJzAN9DFwA/gR1br9SK4StXZh9JPvp9izrrQ==",
|
||||
"version": "8.10.6",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.10.6.tgz",
|
||||
"integrity": "sha512-D0VEfkIYnIKdy6SHiBNEaMc4SxO+MV7ojaPhRu8jP933/gbMi367+Wul2LxkdovJ5cq6awm0L1+jgxdS/unzIg==",
|
||||
"dependencies": {
|
||||
"@tanstack/table-core": "8.10.0"
|
||||
"@tanstack/table-core": "8.10.6"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
@ -6210,11 +6210,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/react-virtual": {
|
||||
"version": "3.0.0-beta.60",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.0.0-beta.60.tgz",
|
||||
"integrity": "sha512-F0wL9+byp7lf/tH6U5LW0ZjBqs+hrMXJrj5xcIGcklI0pggvjzMNW9DdIBcyltPNr6hmHQ0wt8FDGe1n1ZAThA==",
|
||||
"version": "3.0.0-beta.63",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.0.0-beta.63.tgz",
|
||||
"integrity": "sha512-n4aaZs3g9U2oZjFp8dAeT1C2g4rr/3lbCo2qWbD9NquajKnGx7R+EfLBAHJ6pVMmfsTMZ0XCBwkIs7U74R/s0A==",
|
||||
"dependencies": {
|
||||
"@tanstack/virtual-core": "3.0.0-beta.60"
|
||||
"@tanstack/virtual-core": "3.0.0-beta.63"
|
||||
},
|
||||
"funding": {
|
||||
"type": "github",
|
||||
@ -6225,9 +6225,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/table-core": {
|
||||
"version": "8.10.0",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.10.0.tgz",
|
||||
"integrity": "sha512-e701yAJ18aGDP6mzVworlFAmQ+gi3Wtqx5mGZUe2BUv4W4D80dJxUfkHdtEGJ6GryAnlCCNTej7eaJiYmPhyYg==",
|
||||
"version": "8.10.6",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.10.6.tgz",
|
||||
"integrity": "sha512-9t8brthhAmCBIjzk7fCDa/kPKoLQTtA31l9Ir76jYxciTlHU61r/6gYm69XF9cbg9n88gVL5y7rNpeJ2dc1AFA==",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@ -6237,9 +6237,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/virtual-core": {
|
||||
"version": "3.0.0-beta.60",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.0.0-beta.60.tgz",
|
||||
"integrity": "sha512-QlCdhsV1+JIf0c0U6ge6SQmpwsyAT0oQaOSZk50AtEeAyQl9tQrd6qCHAslxQpgphrfe945abvKG8uYvw3hIGA==",
|
||||
"version": "3.0.0-beta.63",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.0.0-beta.63.tgz",
|
||||
"integrity": "sha512-KhhfRYSoQpl0y+2axEw+PJZd/e/9p87PDpPompxcXnweNpt9ZHCT/HuNx7MKM9PVY/xzg9xJSWxwnSCrO+d6PQ==",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
@ -17692,13 +17692,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/mantine-react-table": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/mantine-react-table/-/mantine-react-table-1.3.0.tgz",
|
||||
"integrity": "sha512-ljAd9ZI7S89glI8OGbM5DsNHzInZPfigbeDXboR5gLGPuuseLVgezlVjqV1h8MUQkY06++d9ah9zsMLr2VNxlQ==",
|
||||
"version": "1.3.4",
|
||||
"resolved": "https://registry.npmjs.org/mantine-react-table/-/mantine-react-table-1.3.4.tgz",
|
||||
"integrity": "sha512-rD0CaeC4RCU7k/ZKvfj5ijFFMd4clGpeg/EwMcogYFioZjj8aNfD78osTNNYr90AnOAFGnd7ZnderLK89+W1ZQ==",
|
||||
"dependencies": {
|
||||
"@tanstack/match-sorter-utils": "8.8.4",
|
||||
"@tanstack/react-table": "8.10.0",
|
||||
"@tanstack/react-virtual": "3.0.0-beta.60"
|
||||
"@tanstack/react-table": "8.10.6",
|
||||
"@tanstack/react-virtual": "3.0.0-beta.63"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
@ -17712,7 +17712,7 @@
|
||||
"@mantine/core": "^6.0",
|
||||
"@mantine/dates": "^6.0",
|
||||
"@mantine/hooks": "^6.0",
|
||||
"@tabler/icons-react": ">=2.23.0",
|
||||
"@tabler/icons-react": ">=2.23",
|
||||
"react": ">=18.0",
|
||||
"react-dom": ">=18.0"
|
||||
}
|
||||
|
@ -14,7 +14,7 @@
|
||||
"@emoji-mart/react": "^1.1.1",
|
||||
"@fontsource/geist-mono": "^5.0.1",
|
||||
"@google-ai/generativelanguage": "^0.2.0",
|
||||
"@google/generative-ai": "^0.1.3",
|
||||
"@google/generative-ai": "^0.21.0",
|
||||
"@mantine/core": "^6.0.9",
|
||||
"@mantine/dates": "^6.0.13",
|
||||
"@mantine/dropzone": "^6.0.19",
|
||||
@ -63,7 +63,7 @@
|
||||
"lodash": "^4.17.21",
|
||||
"lz-string": "^1.5.0",
|
||||
"mantine-contextmenu": "^6.1.0",
|
||||
"mantine-react-table": "^1.0.0-beta.8",
|
||||
"mantine-react-table": "^1.3.4",
|
||||
"markdown-it": "^13.0.1",
|
||||
"mathjs": "^11.8.2",
|
||||
"mdast-util-from-markdown": "^2.0.0",
|
||||
|
@ -42,6 +42,7 @@ import {
|
||||
VarsContext,
|
||||
} from "./backend/typing";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { StringLookup } from "./backend/cache";
|
||||
|
||||
const zeroGap = { gap: "0rem" };
|
||||
const popoverShadow = "rgb(38, 57, 77) 0px 10px 30px -14px";
|
||||
@ -298,8 +299,9 @@ export function AIGenReplaceTablePopover({
|
||||
// Check if there are any non-empty rows
|
||||
const nonEmptyRows = useMemo(
|
||||
() =>
|
||||
values.filter((row) => Object.values(row).some((val) => val?.trim()))
|
||||
.length,
|
||||
values.filter((row) =>
|
||||
Object.values(row).some((val) => StringLookup.get(val)?.trim()),
|
||||
).length,
|
||||
[values],
|
||||
);
|
||||
|
||||
@ -368,7 +370,9 @@ export function AIGenReplaceTablePopover({
|
||||
const tableRows = values
|
||||
.slice(0, -1) // Remove the last empty row
|
||||
.map((row) =>
|
||||
tableColumns.map((col) => row[col]?.trim() || "").join(" | "),
|
||||
tableColumns
|
||||
.map((col) => StringLookup.get(row[col])?.trim() || "")
|
||||
.join(" | "),
|
||||
);
|
||||
|
||||
const tableInput = {
|
||||
@ -418,7 +422,9 @@ export function AIGenReplaceTablePopover({
|
||||
const tableRows = values
|
||||
.slice(0, emptyLastRow ? -1 : values.length)
|
||||
.map((row) =>
|
||||
tableColumns.map((col) => row[col.key]?.trim() || "").join(" | "),
|
||||
tableColumns
|
||||
.map((col) => StringLookup.get(row[col.key])?.trim() || "")
|
||||
.join(" | "),
|
||||
);
|
||||
|
||||
const tableInput = {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -58,11 +58,10 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
|
||||
// Remove the node, after user confirmation dialog
|
||||
const handleRemoveNode = useCallback(() => {
|
||||
// Open the 'are you sure' modal:
|
||||
if (deleteConfirmModal && deleteConfirmModal.current)
|
||||
deleteConfirmModal.current.trigger();
|
||||
deleteConfirmModal?.current?.trigger();
|
||||
}, [deleteConfirmModal]);
|
||||
|
||||
const handleOpenContextMenu = (e: Dict) => {
|
||||
const handleOpenContextMenu = useCallback((e: Dict) => {
|
||||
// Ignore all right-clicked elements that aren't children of the parent,
|
||||
// and that aren't divs (for instance, textfields should still have normal right-click)
|
||||
if (e.target?.localName !== "div") return;
|
||||
@ -91,23 +90,22 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
|
||||
},
|
||||
});
|
||||
setContextMenuOpened(true);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// A BaseNode is just a div with "cfnode" as a class, and optional other className(s) for the specific node.
|
||||
// It adds a context menu to all nodes upon right-click of the node itself (the div), to duplicate or delete the node.
|
||||
return (
|
||||
<div
|
||||
className={classes}
|
||||
onPointerDown={() => setContextMenuOpened(false)}
|
||||
onContextMenu={handleOpenContextMenu}
|
||||
style={style}
|
||||
>
|
||||
const areYouSureModal = useMemo(
|
||||
() => (
|
||||
<AreYouSureModal
|
||||
ref={deleteConfirmModal}
|
||||
title="Delete node"
|
||||
message="Are you sure you want to delete this node? This action is irreversible."
|
||||
onConfirm={() => removeNode(nodeId)}
|
||||
/>
|
||||
),
|
||||
[removeNode, nodeId, deleteConfirmModal],
|
||||
);
|
||||
|
||||
const contextMenu = useMemo(
|
||||
() => (
|
||||
<Menu
|
||||
opened={contextMenuOpened}
|
||||
withinPortal={true}
|
||||
@ -132,6 +130,29 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
|
||||
</Menu.Item>
|
||||
</Menu.Dropdown>
|
||||
</Menu>
|
||||
),
|
||||
[
|
||||
handleDuplicateNode,
|
||||
handleRemoveNode,
|
||||
contextMenuExts,
|
||||
children,
|
||||
contextMenuStyle,
|
||||
contextMenuOpened,
|
||||
setContextMenuOpened,
|
||||
],
|
||||
);
|
||||
|
||||
// A BaseNode is just a div with "cfnode" as a class, and optional other className(s) for the specific node.
|
||||
// It adds a context menu to all nodes upon right-click of the node itself (the div), to duplicate or delete the node.
|
||||
return (
|
||||
<div
|
||||
className={classes}
|
||||
onPointerDown={() => setContextMenuOpened(false)}
|
||||
onContextMenu={handleOpenContextMenu}
|
||||
style={style}
|
||||
>
|
||||
{areYouSureModal}
|
||||
{contextMenu}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
@ -53,6 +53,7 @@ import {
|
||||
import { Status } from "./StatusIndicatorComponent";
|
||||
import { executejs, executepy, grabResponses } from "./backend/backend";
|
||||
import { AlertModalContext } from "./AlertModal";
|
||||
import { StringLookup } from "./backend/cache";
|
||||
|
||||
// Whether we are running on localhost or not, and hence whether
|
||||
// we have access to the Flask backend for, e.g., Python code evaluation.
|
||||
@ -540,7 +541,12 @@ The Python interpeter in the browser is Pyodide. You may not be able to run some
|
||||
resp_obj.responses.map((r) => {
|
||||
// Carry over the response text, prompt, prompt fill history (vars), and llm data
|
||||
const o: TemplateVarInfo = {
|
||||
text: typeof r === "string" ? escapeBraces(r) : undefined,
|
||||
text:
|
||||
typeof r === "number"
|
||||
? escapeBraces(StringLookup.get(r)!)
|
||||
: typeof r === "string"
|
||||
? escapeBraces(r)
|
||||
: undefined,
|
||||
image:
|
||||
typeof r === "object" && r.t === "img" ? r.d : undefined,
|
||||
prompt: resp_obj.prompt,
|
||||
@ -550,6 +556,11 @@ The Python interpeter in the browser is Pyodide. You may not be able to run some
|
||||
uid: resp_obj.uid,
|
||||
};
|
||||
|
||||
o.text =
|
||||
o.text !== undefined
|
||||
? StringLookup.intern(o.text as string)
|
||||
: undefined;
|
||||
|
||||
// Carry over any chat history
|
||||
if (resp_obj.chat_history)
|
||||
o.chat_history = resp_obj.chat_history;
|
||||
|
@ -1,5 +1,13 @@
|
||||
import React, { forwardRef, useImperativeHandle } from "react";
|
||||
import { SimpleGrid, Card, Modal, Text, Button, Tabs } from "@mantine/core";
|
||||
import {
|
||||
SimpleGrid,
|
||||
Card,
|
||||
Modal,
|
||||
Text,
|
||||
Button,
|
||||
Tabs,
|
||||
Stack,
|
||||
} from "@mantine/core";
|
||||
import { useDisclosure } from "@mantine/hooks";
|
||||
import { IconChartDots3 } from "@tabler/icons-react";
|
||||
import { Dict } from "./backend/typing";
|
||||
@ -331,28 +339,38 @@ const ExampleFlowCard: React.FC<ExampleFlowCardProps> = ({
|
||||
onSelect,
|
||||
}) => {
|
||||
return (
|
||||
<Card shadow="sm" padding="lg" radius="md" withBorder>
|
||||
<Text mb="xs" weight={500}>
|
||||
{title}
|
||||
</Text>
|
||||
<Card
|
||||
shadow="sm"
|
||||
radius="md"
|
||||
withBorder
|
||||
style={{ padding: "16px 10px 16px 10px" }}
|
||||
>
|
||||
<Stack justify="space-between" spacing="sm" h={160}>
|
||||
<div>
|
||||
<Text mb="xs" weight={500} lh={1.1} align="center">
|
||||
{title}
|
||||
</Text>
|
||||
|
||||
<Text size="sm" color="dimmed" lh={1.3}>
|
||||
{description}
|
||||
</Text>
|
||||
<Text size="sm" color="dimmed" lh={1.1} align="center">
|
||||
{description}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (onSelect) onSelect(filename);
|
||||
}}
|
||||
variant="light"
|
||||
color="blue"
|
||||
fullWidth
|
||||
size="sm"
|
||||
mt="md"
|
||||
radius="md"
|
||||
>
|
||||
{buttonText ?? "Try me"}
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (onSelect) onSelect(filename);
|
||||
}}
|
||||
variant="light"
|
||||
color="blue"
|
||||
h={32}
|
||||
mih={32}
|
||||
fullWidth
|
||||
size="sm"
|
||||
radius="md"
|
||||
>
|
||||
{buttonText ?? "Try me"}
|
||||
</Button>
|
||||
</Stack>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
@ -414,39 +432,105 @@ const ExampleFlowsModal = forwardRef<
|
||||
<Tabs.Panel value="examples" pt="xs">
|
||||
<SimpleGrid cols={3} spacing="sm" verticalSpacing="sm">
|
||||
<ExampleFlowCard
|
||||
title="Compare length of responses across LLMs"
|
||||
title="📑 Compare between prompt templates"
|
||||
description="Compare between prompt templates using template chaining. Visualize response quality across models."
|
||||
filename="compare-prompts"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="📊 Compare prompt across models"
|
||||
description="A simple evaluation with a prompt template, some inputs, and three models to prompt. Visualizes variability in response length."
|
||||
filename="basic-comparison"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="Robustness to prompt injection attacks"
|
||||
description="Get a sense of different model's robustness against prompt injection attacks."
|
||||
filename="prompt-injection-test"
|
||||
title="🤖 Compare system prompts"
|
||||
description="Compares response quality across different system prompts. Visualizes how well it sticks to the instructions to only print Racket code."
|
||||
filename="comparing-system-msg"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="Chain prompts together"
|
||||
title="📗 Testing knowledge of book beginnings"
|
||||
description="Test whether different LLMs know the first sentences of famous books."
|
||||
filename="book-beginnings"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="⛓️ Extract data with prompt chaining"
|
||||
description="Chain one prompt into another to extract entities from a text response. Plots number of entities."
|
||||
filename="chaining-prompts"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="Measure impact of system message on response"
|
||||
description="Compares response quality across different ChatGPT system prompts. Visualizes how well it sticks to the instructions to only print Racket code."
|
||||
filename="comparing-system-msg"
|
||||
title="💬🙇 Estimate chat model sycophancy"
|
||||
description="Estimate how sycophantic a chat model is: ask it for a well-known fact, then tell it it's wrong, and check whether it apologizes or changes its answer."
|
||||
filename="chat-sycophancy"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="Ground truth evaluation for math problems"
|
||||
description="Uses a tabular data node to evaluate LLM performance on basic math problems. Compares responses to expected answer and plots performance across LLMs."
|
||||
title="🧪 Audit models for gender bias"
|
||||
description="Asks an LLM to estimate the gender of a person, given a profession and salary."
|
||||
filename="audit-bias"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="🛑 Red-teaming of stereotypes about nationalities"
|
||||
description="Check for whether models refuse to generate stereotypes about people from different countries."
|
||||
filename="red-team-stereotypes"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="🐦 Multi-evals of prompt to extract structured data from tweets"
|
||||
description="Extracts named entities from a dataset of tweets, and double-checks the output against multiple eval criteria."
|
||||
filename="tweet-multi-eval"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="🧮 Produce structured outputs"
|
||||
description="Extract information from a dataset and output it in a structured JSON format using OpenAI's structured outputs feature."
|
||||
filename="structured-outputs"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="🔨 Detect whether tool is triggered"
|
||||
description="Basic example showing whether a given prompt triggered tool usage."
|
||||
filename="basic-function-calls"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="📑 Compare output format"
|
||||
description="Check whether asking for a different format (YAML, XML, JSON, etc.) changes the content."
|
||||
filename="comparing-formats"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="🧑💻️ HumanEvals Python coding benchmark"
|
||||
description="Run the HumanEvals Python coding benchmark to evaluate LLMs on Python code completion, entirely in your browser. A classic!"
|
||||
filename="python-coding-eval"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="🗯 Check robustness to prompt injection attacks"
|
||||
description="Get a sense of different model's robustness against prompt injection attacks."
|
||||
filename="prompt-injection-test"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="🔢 Ground truth evaluation for math problems"
|
||||
description="Uses a Tabular Data Node to evaluate LLM performance on basic math problems. Compares responses to expected answer and plots performance."
|
||||
filename="basic-math"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="Detect whether OpenAI function call was triggered"
|
||||
description="Basic example showing whether a given prompt triggered an OpenAI function call. Also shows difference between ChatGPT prior to function calls, and function call version."
|
||||
filename="basic-function-calls"
|
||||
title="🦟 Test knowledge of mosquitos"
|
||||
description="Uses an LLM scorer to test whether LLMs know the difference between lifetimes of male and female mosquitos."
|
||||
filename="mosquito-knowledge"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
<ExampleFlowCard
|
||||
title="🖼 Generate images of animals"
|
||||
description="Shows images of a fox, sparrow, and a pig as a computer scientist and a gamer, using Dall-E2."
|
||||
filename="animal-images"
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
</SimpleGrid>
|
||||
@ -462,10 +546,10 @@ const ExampleFlowsModal = forwardRef<
|
||||
>
|
||||
OpenAI evals
|
||||
</a>
|
||||
{`benchmarking package. We currently load evals with a common system
|
||||
{` benchmarking package. We currently load evals with a common system
|
||||
message, a single 'turn' (prompt), and evaluation types of
|
||||
'includes', 'match', and 'fuzzy match', and a reasonable number of
|
||||
prompts. `}
|
||||
prompts. `}
|
||||
<i>
|
||||
Warning: some evals include tables with 1000 prompts or more.{" "}
|
||||
</i>
|
||||
|
312
chainforge/react-server/src/FlowSidebar.tsx
Normal file
312
chainforge/react-server/src/FlowSidebar.tsx
Normal file
@ -0,0 +1,312 @@
|
||||
import React, { useState, useEffect, useContext } from "react";
|
||||
import {
|
||||
IconEdit,
|
||||
IconTrash,
|
||||
IconMenu2,
|
||||
IconX,
|
||||
IconCheck,
|
||||
} from "@tabler/icons-react";
|
||||
import axios from "axios";
|
||||
import { AlertModalContext } from "./AlertModal";
|
||||
import { Dict } from "./backend/typing";
|
||||
import {
|
||||
ActionIcon,
|
||||
Box,
|
||||
Drawer,
|
||||
Group,
|
||||
Stack,
|
||||
TextInput,
|
||||
Text,
|
||||
Flex,
|
||||
Divider,
|
||||
ScrollArea,
|
||||
} from "@mantine/core";
|
||||
import { FLASK_BASE_URL } from "./backend/utils";
|
||||
|
||||
interface FlowFile {
|
||||
name: string;
|
||||
last_modified: string;
|
||||
}
|
||||
|
||||
interface FlowSidebarProps {
|
||||
/** The name of flow that's currently loaded in the front-end, if defined. */
|
||||
currentFlow?: string;
|
||||
onLoadFlow: (flowFile?: Dict<any>, flowName?: string) => void;
|
||||
}
|
||||
|
||||
const FlowSidebar: React.FC<FlowSidebarProps> = ({
|
||||
onLoadFlow,
|
||||
currentFlow,
|
||||
}) => {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [savedFlows, setSavedFlows] = useState<FlowFile[]>([]);
|
||||
const [editName, setEditName] = useState<string | null>(null);
|
||||
const [newEditName, setNewEditName] = useState<string>("newName");
|
||||
|
||||
// The name of the local directory where flows are stored
|
||||
const [flowDir, setFlowDir] = useState<string | undefined>(undefined);
|
||||
|
||||
// For displaying alerts
|
||||
const showAlert = useContext(AlertModalContext);
|
||||
|
||||
// Fetch saved flows from the Flask backend
|
||||
const fetchSavedFlowList = async () => {
|
||||
try {
|
||||
const response = await axios.get(`${FLASK_BASE_URL}api/flows`);
|
||||
const flows = response.data.flows as FlowFile[];
|
||||
setFlowDir(response.data.flow_dir);
|
||||
setSavedFlows(
|
||||
flows.map((item) => ({
|
||||
name: item.name.replace(".cforge", ""),
|
||||
last_modified: new Date(item.last_modified).toLocaleString(),
|
||||
})),
|
||||
);
|
||||
} catch (error) {
|
||||
console.error("Error fetching saved flows:", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Load a flow when clicked, and push it to the caller
|
||||
const handleLoadFlow = async (filename: string) => {
|
||||
try {
|
||||
// Fetch the flow
|
||||
const response = await axios.get(
|
||||
`${FLASK_BASE_URL}api/flows/${filename}`,
|
||||
);
|
||||
|
||||
// Push the flow to the ReactFlow UI. We also pass the filename
|
||||
// so that the caller can use that info to save the right flow when the user presses save.
|
||||
onLoadFlow(response.data, filename);
|
||||
|
||||
setIsOpen(false); // Close sidebar after loading
|
||||
} catch (error) {
|
||||
console.error(`Error loading flow ${filename}:`, error);
|
||||
if (showAlert) showAlert(error as Error);
|
||||
}
|
||||
};
|
||||
|
||||
// Delete a flow
|
||||
const handleDeleteFlow = async (
|
||||
filename: string,
|
||||
event: React.MouseEvent<HTMLButtonElement, MouseEvent>,
|
||||
) => {
|
||||
event.stopPropagation(); // Prevent triggering the parent click
|
||||
if (window.confirm(`Are you sure you want to delete "${filename}"?`)) {
|
||||
try {
|
||||
await axios.delete(`${FLASK_BASE_URL}api/flows/${filename}`);
|
||||
fetchSavedFlowList(); // Refresh the list
|
||||
} catch (error) {
|
||||
console.error(`Error deleting flow ${filename}:`, error);
|
||||
if (showAlert) showAlert(error as Error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Start editing a flow name
|
||||
const handleEditClick = (
|
||||
flowFile: string,
|
||||
event: React.MouseEvent<HTMLButtonElement, MouseEvent>,
|
||||
) => {
|
||||
event.stopPropagation(); // Prevent triggering the parent click
|
||||
setEditName(flowFile);
|
||||
setNewEditName(flowFile);
|
||||
};
|
||||
|
||||
// Cancel editing
|
||||
const handleCancelEdit = (
|
||||
event: React.MouseEvent<HTMLButtonElement, MouseEvent>,
|
||||
) => {
|
||||
event.stopPropagation(); // Prevent triggering the parent click
|
||||
setEditName(null);
|
||||
};
|
||||
|
||||
// Save the edited flow name
|
||||
const handleSaveEdit = async (
|
||||
oldFilename: string,
|
||||
newFilename: string,
|
||||
event: React.MouseEvent<HTMLButtonElement, MouseEvent>,
|
||||
) => {
|
||||
event?.stopPropagation(); // Prevent triggering the parent click
|
||||
if (newFilename && newFilename !== oldFilename) {
|
||||
await axios
|
||||
.put(`${FLASK_BASE_URL}api/flows/${oldFilename}`, {
|
||||
newName: newFilename,
|
||||
})
|
||||
.then(() => {
|
||||
onLoadFlow(undefined, newFilename); // Tell the parent that the filename has changed. This won't replace the flow.
|
||||
fetchSavedFlowList(); // Refresh the list
|
||||
})
|
||||
.catch((error) => {
|
||||
let msg: string;
|
||||
if (error.response) {
|
||||
msg = `404 Error: ${error.response.status === 404 ? error.response.data?.error ?? "Not Found" : error.response.data}`;
|
||||
} else if (error.request) {
|
||||
// Request was made but no response was received
|
||||
msg = "No response received from server.";
|
||||
} else {
|
||||
// Something else happened in setting up the request
|
||||
msg = `Unknown Error: ${error.message}`;
|
||||
}
|
||||
console.error(msg);
|
||||
if (showAlert) showAlert(msg);
|
||||
});
|
||||
}
|
||||
|
||||
// No longer editing
|
||||
setEditName(null);
|
||||
setNewEditName("newName");
|
||||
};
|
||||
|
||||
// Load flows when component mounts
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
fetchSavedFlowList();
|
||||
}
|
||||
}, [isOpen]);
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
{/* <RenameValueModal title="Rename flow" label="Edit name" initialValue="" onSubmit={handleEditName} /> */}
|
||||
|
||||
{/* Toggle Button */}
|
||||
<ActionIcon
|
||||
variant="gradient"
|
||||
size="1.625rem"
|
||||
style={{
|
||||
position: "absolute",
|
||||
top: "10px",
|
||||
left: "10px",
|
||||
// left: isOpen ? "250px" : "10px",
|
||||
// transition: "left 0.3s ease-in-out",
|
||||
zIndex: 10,
|
||||
}}
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
>
|
||||
{isOpen ? <IconX /> : <IconMenu2 />}
|
||||
</ActionIcon>
|
||||
|
||||
{/* Sidebar */}
|
||||
<Drawer
|
||||
opened={isOpen}
|
||||
onClose={() => setIsOpen(false)}
|
||||
title="Saved Flows"
|
||||
position="left"
|
||||
size="250px" // Adjust sidebar width
|
||||
padding="md"
|
||||
withCloseButton={true}
|
||||
scrollAreaComponent={ScrollArea.Autosize}
|
||||
>
|
||||
<Divider />
|
||||
<Stack spacing="4px" mt="0px" mb="120px">
|
||||
{savedFlows.length === 0 ? (
|
||||
<Text color="dimmed">No saved flows found</Text>
|
||||
) : (
|
||||
savedFlows.map((flow) => (
|
||||
<Box
|
||||
key={flow.name}
|
||||
p="6px"
|
||||
sx={(theme) => ({
|
||||
borderRadius: theme.radius.sm,
|
||||
cursor: "pointer",
|
||||
"&:hover": {
|
||||
backgroundColor:
|
||||
theme.colorScheme === "dark"
|
||||
? theme.colors.dark[6]
|
||||
: theme.colors.gray[0],
|
||||
},
|
||||
})}
|
||||
onClick={() => {
|
||||
if (editName !== flow.name) handleLoadFlow(flow.name);
|
||||
}}
|
||||
>
|
||||
{editName === flow.name ? (
|
||||
<Group spacing="xs">
|
||||
<TextInput
|
||||
value={newEditName}
|
||||
onChange={(e) => setNewEditName(e.target.value)}
|
||||
style={{ flex: 1 }}
|
||||
autoFocus
|
||||
/>
|
||||
<ActionIcon
|
||||
color="green"
|
||||
onClick={(e) => handleSaveEdit(editName, newEditName, e)}
|
||||
>
|
||||
<IconCheck size={18} />
|
||||
</ActionIcon>
|
||||
<ActionIcon color="gray" onClick={handleCancelEdit}>
|
||||
<IconX size={18} />
|
||||
</ActionIcon>
|
||||
</Group>
|
||||
) : (
|
||||
<>
|
||||
<Flex
|
||||
justify="space-between"
|
||||
align="center"
|
||||
gap="0px"
|
||||
h="auto"
|
||||
>
|
||||
{currentFlow === flow.name ? (
|
||||
<Box
|
||||
ml="-15px"
|
||||
mr="5px"
|
||||
bg="green"
|
||||
w="10px"
|
||||
h="10px"
|
||||
style={{ borderRadius: "50%" }}
|
||||
></Box>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
<Text size="sm" mr="auto">
|
||||
{flow.name}
|
||||
</Text>
|
||||
<Flex gap="0px">
|
||||
<ActionIcon
|
||||
color="blue"
|
||||
onClick={(e) => handleEditClick(flow.name, e)}
|
||||
>
|
||||
<IconEdit size={18} />
|
||||
</ActionIcon>
|
||||
<ActionIcon
|
||||
color="red"
|
||||
onClick={(e) => handleDeleteFlow(flow.name, e)}
|
||||
>
|
||||
<IconTrash size={18} />
|
||||
</ActionIcon>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Text size="xs" color="gray">
|
||||
{flow.last_modified}
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
<Divider />
|
||||
</Box>
|
||||
))
|
||||
)}
|
||||
</Stack>
|
||||
|
||||
{/* Sticky footer */}
|
||||
<div
|
||||
style={{
|
||||
position: "fixed",
|
||||
bottom: 0,
|
||||
background: "white",
|
||||
padding: "10px",
|
||||
borderTop: "1px solid #ddd",
|
||||
}}
|
||||
>
|
||||
{flowDir ? (
|
||||
<Text size="xs" color="gray">
|
||||
Local flows are saved at: {flowDir}
|
||||
</Text>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
</Drawer>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default FlowSidebar;
|
@ -271,7 +271,7 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
|
||||
|
||||
// Load image compression settings from cache (if possible)
|
||||
const cachedImgCompr = (
|
||||
StorageCache.loadFromLocalStorage("imageCompression") as Dict
|
||||
StorageCache.loadFromLocalStorage("imageCompression", false) as Dict
|
||||
)?.value as boolean | undefined;
|
||||
if (typeof cachedImgCompr === "boolean") {
|
||||
setImageCompression(cachedImgCompr);
|
||||
@ -400,6 +400,12 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
|
||||
{...form.getInputProps("Google")}
|
||||
/>
|
||||
<br />
|
||||
<TextInput
|
||||
label="DeepSeek API Key"
|
||||
placeholder="Paste your DeepSeek API key here"
|
||||
{...form.getInputProps("DeepSeek")}
|
||||
/>
|
||||
<br />
|
||||
<TextInput
|
||||
label="Aleph Alpha API Key"
|
||||
placeholder="Paste your Aleph Alpha API key here"
|
||||
|
@ -94,6 +94,7 @@ const InspectorNode: React.FC<InspectorNodeProps> = ({ data, id }) => {
|
||||
>
|
||||
<LLMResponseInspector
|
||||
jsonResponses={jsonResponses ?? []}
|
||||
isOpen={true}
|
||||
wideFormat={false}
|
||||
/>
|
||||
</div>
|
||||
|
@ -27,7 +27,7 @@ import {
|
||||
getVarsAndMetavars,
|
||||
cleanMetavarsFilterFunc,
|
||||
} from "./backend/utils";
|
||||
import StorageCache from "./backend/cache";
|
||||
import StorageCache, { StringLookup } from "./backend/cache";
|
||||
import { ResponseBox } from "./ResponseBoxes";
|
||||
import {
|
||||
Dict,
|
||||
@ -78,9 +78,9 @@ const displayJoinedTexts = (
|
||||
return textInfos.map((info, idx) => {
|
||||
const llm_name =
|
||||
typeof info !== "string"
|
||||
? typeof info.llm === "string"
|
||||
? info.llm
|
||||
: info.llm?.name
|
||||
? typeof info.llm === "string" || typeof info.llm === "number"
|
||||
? StringLookup.get(info.llm)
|
||||
: StringLookup.get(info.llm?.name)
|
||||
: "";
|
||||
const ps = (
|
||||
<pre className="small-response">
|
||||
@ -267,7 +267,12 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
|
||||
if (groupByVar !== "A") vars[varname] = var_val;
|
||||
return {
|
||||
text: joinTexts(
|
||||
resp_objs.map((r) => (typeof r === "string" ? r : r.text ?? "")),
|
||||
resp_objs.map(
|
||||
(r) =>
|
||||
(typeof r === "string" || typeof r === "number"
|
||||
? StringLookup.get(r)
|
||||
: StringLookup.get(r.text)) ?? "",
|
||||
),
|
||||
formatting,
|
||||
),
|
||||
fill_history: isMetavar ? {} : vars,
|
||||
@ -284,7 +289,12 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
|
||||
countNumLLMs(unspecGroup) > 1 ? undefined : unspecGroup[0].llm;
|
||||
joined_texts.push({
|
||||
text: joinTexts(
|
||||
unspecGroup.map((u) => (typeof u === "string" ? u : u.text ?? "")),
|
||||
unspecGroup.map(
|
||||
(u) =>
|
||||
(typeof u === "string" || typeof u === "number"
|
||||
? StringLookup.get(u)
|
||||
: StringLookup.get(u.text)) ?? "",
|
||||
),
|
||||
formatting,
|
||||
),
|
||||
fill_history: {},
|
||||
@ -326,7 +336,10 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
|
||||
let joined_texts: (TemplateVarInfo | string)[] = [];
|
||||
const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(
|
||||
resp_objs,
|
||||
(r) => (typeof r.llm === "string" ? r.llm : r.llm?.key),
|
||||
(r) =>
|
||||
typeof r.llm === "string" || typeof r.llm === "number"
|
||||
? StringLookup.get(r.llm)
|
||||
: r.llm?.key,
|
||||
);
|
||||
// eslint-disable-next-line
|
||||
Object.entries(groupedRespsByLLM).forEach(([llm_key, resp_objs]) => {
|
||||
@ -337,7 +350,7 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
|
||||
if (nonLLMRespGroup.length > 0)
|
||||
joined_texts.push(
|
||||
joinTexts(
|
||||
nonLLMRespGroup.map((t) => t.text ?? ""),
|
||||
nonLLMRespGroup.map((t) => StringLookup.get(t.text) ?? ""),
|
||||
formatting,
|
||||
),
|
||||
);
|
||||
@ -353,7 +366,12 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
|
||||
setDataPropsForNode(id, { fields: joined_texts });
|
||||
} else {
|
||||
let joined_texts: string | TemplateVarInfo = joinTexts(
|
||||
resp_objs.map((r) => (typeof r === "string" ? r : r.text ?? "")),
|
||||
resp_objs.map(
|
||||
(r) =>
|
||||
(typeof r === "string" || typeof r === "number"
|
||||
? StringLookup.get(r)
|
||||
: StringLookup.get(r.text)) ?? "",
|
||||
),
|
||||
formatting,
|
||||
);
|
||||
|
||||
|
@ -21,11 +21,27 @@ import LLMResponseInspectorModal, {
|
||||
} from "./LLMResponseInspectorModal";
|
||||
import InspectFooter from "./InspectFooter";
|
||||
import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer";
|
||||
import { genDebounceFunc, stripLLMDetailsFromResponses } from "./backend/utils";
|
||||
import {
|
||||
extractSettingsVars,
|
||||
genDebounceFunc,
|
||||
stripLLMDetailsFromResponses,
|
||||
} from "./backend/utils";
|
||||
import { AlertModalContext } from "./AlertModal";
|
||||
import { Dict, LLMResponse, LLMSpec, QueryProgress } from "./backend/typing";
|
||||
import {
|
||||
Dict,
|
||||
LLMResponse,
|
||||
LLMResponseData,
|
||||
LLMSpec,
|
||||
QueryProgress,
|
||||
} from "./backend/typing";
|
||||
import { Status } from "./StatusIndicatorComponent";
|
||||
import { evalWithLLM, grabResponses } from "./backend/backend";
|
||||
import { evalWithLLM, generatePrompts, grabResponses } from "./backend/backend";
|
||||
import { UserForcedPrematureExit } from "./backend/errors";
|
||||
import CancelTracker from "./backend/canceler";
|
||||
import { PromptInfo, PromptListModal, PromptListPopover } from "./PromptNode";
|
||||
import { useDisclosure } from "@mantine/hooks";
|
||||
import { PromptTemplate } from "./backend/template";
|
||||
import { StringLookup } from "./backend/cache";
|
||||
|
||||
// The default prompt shown in gray highlights to give people a good example of an evaluation prompt.
|
||||
const PLACEHOLDER_PROMPT =
|
||||
@ -57,7 +73,9 @@ const DEFAULT_LLM_ITEM = (() => {
|
||||
const item = [initLLMProviders.find((i) => i.base_model === "gpt-4")].map(
|
||||
(i) => ({
|
||||
key: uuid(),
|
||||
settings: getDefaultModelSettings((i as LLMSpec).base_model),
|
||||
settings: getDefaultModelSettings(
|
||||
StringLookup.get(i?.base_model) as string,
|
||||
),
|
||||
...i,
|
||||
}),
|
||||
)[0];
|
||||
@ -69,12 +87,15 @@ export interface LLMEvaluatorComponentRef {
|
||||
run: (
|
||||
input_node_ids: string[],
|
||||
onProgressChange?: (progress: QueryProgress) => void,
|
||||
cancelId?: string | number,
|
||||
) => Promise<LLMResponse[]>;
|
||||
cancel: (cancelId: string | number, cancelProgress: () => void) => void;
|
||||
serialize: () => {
|
||||
prompt: string;
|
||||
format: string;
|
||||
grader?: LLMSpec;
|
||||
};
|
||||
getPromptTemplate: () => string;
|
||||
}
|
||||
|
||||
export interface LLMEvaluatorComponentProps {
|
||||
@ -149,20 +170,26 @@ export const LLMEvaluatorComponent = forwardRef<
|
||||
[setExpectedFormat, onFormatChange],
|
||||
);
|
||||
|
||||
const getPromptTemplate = () => {
|
||||
const formatting_instr = OUTPUT_FORMAT_PROMPTS[expectedFormat] ?? "";
|
||||
return (
|
||||
"You are evaluating text that will be pasted below. " +
|
||||
promptText +
|
||||
" " +
|
||||
formatting_instr +
|
||||
"\n```\n{__input}\n```"
|
||||
);
|
||||
};
|
||||
|
||||
// Runs the LLM evaluator over the inputs, returning the results in a Promise.
|
||||
// Errors are raised as a rejected Promise.
|
||||
const run = (
|
||||
input_node_ids: string[],
|
||||
onProgressChange?: (progress: QueryProgress) => void,
|
||||
cancelId?: string | number,
|
||||
) => {
|
||||
// Create prompt template to wrap user-specified scorer prompt and input data
|
||||
const formatting_instr = OUTPUT_FORMAT_PROMPTS[expectedFormat] ?? "";
|
||||
const template =
|
||||
"You are evaluating text that will be pasted below. " +
|
||||
promptText +
|
||||
" " +
|
||||
formatting_instr +
|
||||
"\n```\n{input}\n```";
|
||||
const template = getPromptTemplate();
|
||||
const llm_key = llmScorers[0].key ?? "";
|
||||
|
||||
// Fetch info about the number of queries we'll need to make
|
||||
@ -176,12 +203,16 @@ export const LLMEvaluatorComponent = forwardRef<
|
||||
);
|
||||
return onProgressChange
|
||||
? (progress_by_llm: Dict<QueryProgress>) =>
|
||||
onProgressChange({
|
||||
success:
|
||||
(100 * progress_by_llm[llm_key].success) / num_resps_required,
|
||||
error:
|
||||
(100 * progress_by_llm[llm_key].error) / num_resps_required,
|
||||
})
|
||||
// Debounce the progress bars UI update to ensure we don't re-render too often:
|
||||
debounce(() => {
|
||||
onProgressChange({
|
||||
success:
|
||||
(100 * progress_by_llm[llm_key].success) /
|
||||
num_resps_required,
|
||||
error:
|
||||
(100 * progress_by_llm[llm_key].error) / num_resps_required,
|
||||
});
|
||||
}, 30)()
|
||||
: undefined;
|
||||
})
|
||||
.then((progress_listener) => {
|
||||
@ -193,9 +224,13 @@ export const LLMEvaluatorComponent = forwardRef<
|
||||
input_node_ids,
|
||||
apiKeys ?? {},
|
||||
progress_listener,
|
||||
cancelId,
|
||||
);
|
||||
})
|
||||
.then(function (res) {
|
||||
// eslint-disable-next-line
|
||||
debounce(() => {}, 1)(); // erase any pending debounces
|
||||
|
||||
// Check if there's an error; if so, bubble it up to user and exit:
|
||||
if (res.errors && res.errors.length > 0) throw new Error(res.errors[0]);
|
||||
else if (res.responses === undefined)
|
||||
@ -208,6 +243,12 @@ export const LLMEvaluatorComponent = forwardRef<
|
||||
});
|
||||
};
|
||||
|
||||
const cancel = (cancelId: string | number, cancelProgress: () => void) => {
|
||||
CancelTracker.add(cancelId);
|
||||
// eslint-disable-next-line
|
||||
debounce(cancelProgress, 1)(); // erase any pending debounces
|
||||
};
|
||||
|
||||
// Export the current internal state as JSON
|
||||
const serialize = () => ({
|
||||
prompt: promptText,
|
||||
@ -218,7 +259,9 @@ export const LLMEvaluatorComponent = forwardRef<
|
||||
// Define functions accessible from the parent component
|
||||
useImperativeHandle(ref, () => ({
|
||||
run,
|
||||
cancel,
|
||||
serialize,
|
||||
getPromptTemplate,
|
||||
}));
|
||||
|
||||
return (
|
||||
@ -289,11 +332,20 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
|
||||
const [status, setStatus] = useState<Status>(Status.NONE);
|
||||
const showAlert = useContext(AlertModalContext);
|
||||
|
||||
// Cancelation of pending queries
|
||||
const [cancelId, setCancelId] = useState(Date.now());
|
||||
const refreshCancelId = () => setCancelId(Date.now());
|
||||
|
||||
const inspectModal = useRef<LLMResponseInspectorModalRef>(null);
|
||||
// eslint-disable-next-line
|
||||
const [uninspectedResponses, setUninspectedResponses] = useState(false);
|
||||
const [showDrawer, setShowDrawer] = useState(false);
|
||||
|
||||
// For an info pop-up that shows all the prompts that will be sent off
|
||||
// NOTE: This is the 'full' version of the PromptListPopover that activates on hover.
|
||||
const [infoModalOpened, { open: openInfoModal, close: closeInfoModal }] =
|
||||
useDisclosure(false);
|
||||
|
||||
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
|
||||
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
|
||||
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
|
||||
@ -306,6 +358,56 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
|
||||
undefined,
|
||||
);
|
||||
|
||||
// On hover over the 'info' button, to preview the prompts that will be sent out
|
||||
const [promptPreviews, setPromptPreviews] = useState<PromptInfo[]>([]);
|
||||
const handlePreviewHover = () => {
|
||||
// Get the ids from the connected input nodes:
|
||||
const input_node_ids = inputEdgesForNode(id).map((e) => e.source);
|
||||
if (input_node_ids.length === 0) {
|
||||
console.warn("No inputs for evaluator node.");
|
||||
return;
|
||||
}
|
||||
|
||||
const promptText = llmEvaluatorRef?.current?.getPromptTemplate();
|
||||
if (!promptText) return;
|
||||
|
||||
// Pull input data
|
||||
try {
|
||||
grabResponses(input_node_ids)
|
||||
.then(function (resp_objs) {
|
||||
const inputs = resp_objs
|
||||
.map((obj: LLMResponse) =>
|
||||
obj.responses.map((r: LLMResponseData) => ({
|
||||
text:
|
||||
typeof r === "string" || typeof r === "number"
|
||||
? r
|
||||
: undefined,
|
||||
image: typeof r === "object" && r.t === "img" ? r.d : undefined,
|
||||
fill_history: obj.vars,
|
||||
metavars: obj.metavars,
|
||||
})),
|
||||
)
|
||||
.flat();
|
||||
return generatePrompts(promptText, { __input: inputs });
|
||||
})
|
||||
.then(function (prompts) {
|
||||
setPromptPreviews(
|
||||
prompts.map(
|
||||
(p: PromptTemplate) =>
|
||||
new PromptInfo(
|
||||
p.toString(),
|
||||
extractSettingsVars(p.fill_history),
|
||||
),
|
||||
),
|
||||
);
|
||||
});
|
||||
} catch (err) {
|
||||
// soft fail
|
||||
console.error(err);
|
||||
setPromptPreviews([]);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRunClick = useCallback(() => {
|
||||
// Get the ids from the connected input nodes:
|
||||
const input_node_ids = inputEdgesForNode(id).map((e) => e.source);
|
||||
@ -318,15 +420,23 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
|
||||
setProgress({ success: 2, error: 0 });
|
||||
|
||||
const handleError = (err: Error | string) => {
|
||||
setStatus(Status.ERROR);
|
||||
setProgress(undefined);
|
||||
if (typeof err !== "string") console.error(err);
|
||||
if (showAlert) showAlert(typeof err === "string" ? err : err?.message);
|
||||
if (
|
||||
err instanceof UserForcedPrematureExit ||
|
||||
CancelTracker.has(cancelId)
|
||||
) {
|
||||
// Handle a premature cancelation
|
||||
console.log("Canceled.");
|
||||
setStatus(Status.NONE);
|
||||
} else {
|
||||
setStatus(Status.ERROR);
|
||||
if (showAlert) showAlert(typeof err === "string" ? err : err?.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Run LLM evaluator
|
||||
llmEvaluatorRef?.current
|
||||
?.run(input_node_ids, setProgress)
|
||||
?.run(input_node_ids, setProgress, cancelId)
|
||||
.then(function (evald_resps) {
|
||||
// Ping any vis + inspect nodes attached to this node to refresh their contents:
|
||||
pingOutputNodes(id);
|
||||
@ -347,8 +457,15 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
|
||||
setStatus,
|
||||
showDrawer,
|
||||
showAlert,
|
||||
cancelId,
|
||||
]);
|
||||
|
||||
const handleStopClick = useCallback(() => {
|
||||
llmEvaluatorRef?.current?.cancel(cancelId, () => setProgress(undefined));
|
||||
refreshCancelId();
|
||||
setStatus(Status.NONE);
|
||||
}, [cancelId, refreshCancelId]);
|
||||
|
||||
const showResponseInspector = useCallback(() => {
|
||||
if (inspectModal && inspectModal.current && lastResponses) {
|
||||
setUninspectedResponses(false);
|
||||
@ -384,13 +501,28 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
|
||||
nodeId={id}
|
||||
icon={<IconRobot size="16px" />}
|
||||
status={status}
|
||||
isRunning={status === Status.LOADING}
|
||||
handleRunClick={handleRunClick}
|
||||
handleStopClick={handleStopClick}
|
||||
runButtonTooltip="Run scorer over inputs"
|
||||
customButtons={[
|
||||
<PromptListPopover
|
||||
key="prompt-previews"
|
||||
promptInfos={promptPreviews}
|
||||
onHover={handlePreviewHover}
|
||||
onClick={openInfoModal}
|
||||
/>,
|
||||
]}
|
||||
/>
|
||||
<LLMResponseInspectorModal
|
||||
ref={inspectModal}
|
||||
jsonResponses={lastResponses}
|
||||
/>
|
||||
<PromptListModal
|
||||
promptPreviews={promptPreviews}
|
||||
infoModalOpened={infoModalOpened}
|
||||
closeInfoModal={closeInfoModal}
|
||||
/>
|
||||
|
||||
<div className="llm-scorer-container">
|
||||
<LLMEvaluatorComponent
|
||||
|
@ -389,6 +389,7 @@ export const LLMListContainer = forwardRef<
|
||||
// Together models have a substring "together/" that we need to strip:
|
||||
if (item.base_model === "together")
|
||||
item.formData.model = item.model.substring(9);
|
||||
else item.formData.model = item.model;
|
||||
|
||||
let new_items: LLMSpec[] = [];
|
||||
if (selectModelAction === "add" || selectModelAction === undefined) {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -22,6 +22,7 @@ export default function LLMResponseInspectorDrawer({
|
||||
>
|
||||
<LLMResponseInspector
|
||||
jsonResponses={jsonResponses}
|
||||
isOpen={showDrawer}
|
||||
wideFormat={false}
|
||||
/>
|
||||
</div>
|
||||
|
@ -80,6 +80,7 @@ const LLMResponseInspectorModal = forwardRef<
|
||||
<Suspense fallback={<LoadingOverlay visible={true} />}>
|
||||
<LLMResponseInspector
|
||||
jsonResponses={props.jsonResponses}
|
||||
isOpen={opened}
|
||||
wideFormat={true}
|
||||
/>
|
||||
</Suspense>
|
||||
|
@ -156,11 +156,11 @@ const ChatGPTSettings: ModelSettingsDict = {
|
||||
"If specified, the OpenAI API will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed.",
|
||||
allow_empty_str: true,
|
||||
},
|
||||
max_tokens: {
|
||||
max_completion_tokens: {
|
||||
type: "integer",
|
||||
title: "max_tokens",
|
||||
title: "max_completion_tokens",
|
||||
description:
|
||||
"The maximum number of tokens to generate in the chat completion. (The total length of input tokens and generated tokens is limited by the model's context length.)",
|
||||
"An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.",
|
||||
},
|
||||
presence_penalty: {
|
||||
type: "number",
|
||||
@ -198,6 +198,7 @@ const ChatGPTSettings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to gpt-3.5-turbo.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
system_msg: {
|
||||
"ui:widget": "textarea",
|
||||
@ -243,7 +244,7 @@ const ChatGPTSettings: ModelSettingsDict = {
|
||||
"ui:widget": "textarea",
|
||||
"ui:help": "Defaults to empty.",
|
||||
},
|
||||
max_tokens: {
|
||||
max_completion_tokens: {
|
||||
"ui:help": "Defaults to infinity.",
|
||||
},
|
||||
seed: {
|
||||
@ -326,6 +327,121 @@ const GPT4Settings: ModelSettingsDict = {
|
||||
postprocessors: ChatGPTSettings.postprocessors,
|
||||
};
|
||||
|
||||
const DeepSeekSettings: ModelSettingsDict = {
|
||||
fullName: "DeepSeek",
|
||||
schema: {
|
||||
type: "object",
|
||||
required: ["shortname"],
|
||||
properties: {
|
||||
shortname: {
|
||||
type: "string",
|
||||
title: "Nickname",
|
||||
description:
|
||||
"Unique identifier to appear in ChainForge. Keep it short.",
|
||||
default: "Deep Seek",
|
||||
},
|
||||
model: {
|
||||
type: "string",
|
||||
title: "Model Version",
|
||||
description:
|
||||
"Select a DeepSeek model to query. For more details on the differences, see the DeepSeek API documentation.",
|
||||
enum: ["deepseek-chat", "deepseek-reasoner"],
|
||||
default: "deepseek-chat",
|
||||
},
|
||||
system_msg: {
|
||||
type: "string",
|
||||
title: "system_msg",
|
||||
description:
|
||||
"Many conversations begin with a system message to gently instruct the assistant. By default, ChainForge includes the suggested 'You are a helpful assistant.'",
|
||||
default: "You are a helpful assistant.",
|
||||
allow_empty_str: true,
|
||||
},
|
||||
temperature: {
|
||||
type: "number",
|
||||
title: "temperature",
|
||||
description:
|
||||
"What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.",
|
||||
default: 1,
|
||||
minimum: 0,
|
||||
maximum: 2,
|
||||
multipleOf: 0.01,
|
||||
},
|
||||
response_format: {
|
||||
type: "string",
|
||||
title: "response_format",
|
||||
description:
|
||||
"An object specifying the format that the model must output. Can be 'text' or 'json_object' or (late 2024) can be a JSON schema specifying structured outputs. In ChainForge, you should only specify text, json_object, or the verbatim JSON schema---do not add a JSON object with a 'type' parameter surrounding these values. JSON modes only works with newest GPT models. IMPORTANT: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message.",
|
||||
default: "text",
|
||||
},
|
||||
tools: {
|
||||
type: "string",
|
||||
title: "tools",
|
||||
description:
|
||||
"A list of JSON schema objects, each with 'name', 'description', and 'parameters' keys, which describe functions the model may generate JSON inputs for. For more info, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_with_chat_models.ipynb",
|
||||
default: "",
|
||||
},
|
||||
tool_choice: {
|
||||
type: "string",
|
||||
title: "tool_choice",
|
||||
description:
|
||||
"Controls how the model responds to function calls. 'none' means the model does not call a function, and responds to the end-user. 'auto' means the model can pick between an end-user or calling a function. 'required' means the model must call one or more tools. Specifying a particular function name forces the model to call only that function. Leave blank for default behavior.",
|
||||
default: "",
|
||||
},
|
||||
top_p: {
|
||||
type: "number",
|
||||
title: "top_p",
|
||||
description:
|
||||
"An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
||||
default: 1,
|
||||
minimum: 0,
|
||||
maximum: 1,
|
||||
multipleOf: 0.005,
|
||||
},
|
||||
stop: {
|
||||
type: "string",
|
||||
title: "stop sequences",
|
||||
description:
|
||||
'Up to 4 sequences where the API will stop generating further tokens. Enclose stop sequences in double-quotes "" and use whitespace to separate them.',
|
||||
default: "",
|
||||
},
|
||||
max_tokens: {
|
||||
type: "integer",
|
||||
title: "max_tokens",
|
||||
description:
|
||||
"The maximum number of tokens to generate in the chat completion. (The total length of input tokens and generated tokens is limited by the model's context length.)",
|
||||
},
|
||||
presence_penalty: {
|
||||
type: "number",
|
||||
title: "presence_penalty",
|
||||
description:
|
||||
"Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
|
||||
default: 0,
|
||||
minimum: -2,
|
||||
maximum: 2,
|
||||
multipleOf: 0.005,
|
||||
},
|
||||
frequency_penalty: {
|
||||
type: "number",
|
||||
title: "frequency_penalty",
|
||||
description:
|
||||
"Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
||||
default: 0,
|
||||
minimum: -2,
|
||||
maximum: 2,
|
||||
multipleOf: 0.005,
|
||||
},
|
||||
},
|
||||
},
|
||||
uiSchema: {
|
||||
...ChatGPTSettings.uiSchema,
|
||||
model: {
|
||||
"ui:help": "Defaults to deepseek-chat.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
},
|
||||
postprocessors: ChatGPTSettings.postprocessors,
|
||||
};
|
||||
|
||||
const DalleSettings: ModelSettingsDict = {
|
||||
fullName: "Dall-E Image Models (OpenAI)",
|
||||
schema: {
|
||||
@ -381,6 +497,7 @@ const DalleSettings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to dalle-2.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
size: {
|
||||
"ui:help": "Defaults to 256x256.",
|
||||
@ -542,6 +659,7 @@ const ClaudeSettings: ModelSettingsDict = {
|
||||
model: {
|
||||
"ui:help":
|
||||
"Defaults to claude-2.1. Note that Anthropic models are subject to change. Model names prior to Claude 2, including 100k context window, are no longer listed on the Anthropic site, so they may or may not work.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
system_msg: {
|
||||
"ui:widget": "textarea",
|
||||
@ -606,7 +724,7 @@ const ClaudeSettings: ModelSettingsDict = {
|
||||
};
|
||||
|
||||
const PaLM2Settings: ModelSettingsDict = {
|
||||
fullName: "Google AI Models (Gemini & PaLM)",
|
||||
fullName: "Google AI Models (Gemini)",
|
||||
schema: {
|
||||
type: "object",
|
||||
required: ["shortname"],
|
||||
@ -623,14 +741,33 @@ const PaLM2Settings: ModelSettingsDict = {
|
||||
title: "Model",
|
||||
description:
|
||||
"Select a PaLM model to query. For more details on the differences, see the Google PaLM API documentation.",
|
||||
enum: ["gemini-pro", "text-bison-001", "chat-bison-001"],
|
||||
default: "gemini-pro",
|
||||
enum: [
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-flash-8b",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.0-pro",
|
||||
"gemini-pro",
|
||||
"text-bison-001",
|
||||
"chat-bison-001",
|
||||
],
|
||||
default: "gemini-1.5-flash",
|
||||
shortname_map: {
|
||||
"text-bison-001": "PaLM2-text",
|
||||
"chat-bison-001": "PaLM2-chat",
|
||||
"gemini-pro": "Gemini",
|
||||
"gemini-pro": "Gemini 1.0",
|
||||
"gemini-1.5-pro": "Gemini 1.5",
|
||||
"gemini-1.0-pro": "Gemini 1.0",
|
||||
"gemini-1.5-flash": "Gemini Flash",
|
||||
"gemini-1.5-flash-8b": "Gemini Flash 8B",
|
||||
},
|
||||
},
|
||||
system_msg: {
|
||||
type: "string",
|
||||
title: "system_msg",
|
||||
description:
|
||||
"Enter your system message here, to be passed to the systemInstructions parameter.",
|
||||
default: "",
|
||||
},
|
||||
temperature: {
|
||||
type: "number",
|
||||
title: "temperature",
|
||||
@ -684,6 +821,10 @@ const PaLM2Settings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to gemini-pro.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
system_msg: {
|
||||
"ui:widget": "textarea",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 0.5.",
|
||||
@ -831,6 +972,7 @@ const DalaiModelSettings: ModelSettingsDict = {
|
||||
model: {
|
||||
"ui:help":
|
||||
"NOTE: You must have installed the selected model and have Dalai be running and accessible on the local environment with which you are running the ChainForge server.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 0.5.",
|
||||
@ -956,13 +1098,6 @@ const HuggingFaceTextInferenceSettings: ModelSettingsDict = {
|
||||
"bigcode/starcoder": "starcoder",
|
||||
},
|
||||
},
|
||||
custom_model: {
|
||||
type: "string",
|
||||
title: "Custom HF model endpoint",
|
||||
description:
|
||||
"(Only used if you select 'Other' above.) Enter the HuggingFace id of the text generation model you wish to query via the inference API. Alternatively, if you have hosted a model on HF Inference Endpoints, you can enter the full URL of the endpoint here.",
|
||||
default: "",
|
||||
},
|
||||
model_type: {
|
||||
type: "string",
|
||||
title: "Model Type (Text or Chat)",
|
||||
@ -1052,6 +1187,7 @@ const HuggingFaceTextInferenceSettings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to Falcon.7B.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -1194,6 +1330,7 @@ const AlephAlphaLuminousSettings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to Luminous Base.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 0.0.",
|
||||
@ -1546,6 +1683,7 @@ const BedrockClaudeSettings: ModelSettingsDict = {
|
||||
model: {
|
||||
"ui:help":
|
||||
"Defaults to claude-2. Note that Anthropic models in particular are subject to change. Model names prior to Claude 2, including 100k context window, are no longer listed on the Anthropic site, so they may or may not work.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -1679,7 +1817,8 @@ const BedrockJurassic2Settings: ModelSettingsDict = {
|
||||
"ui:autofocus": true,
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to Jurassic 2 Ultra. ",
|
||||
"ui:help": "Defaults to Jurassic 2 Ultra.",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -1781,6 +1920,7 @@ const BedrockTitanSettings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to Titan Large",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -1891,6 +2031,7 @@ const BedrockCommandTextSettings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to Command Text",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -2000,6 +2141,7 @@ const MistralSettings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to Mistral",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -2118,6 +2260,7 @@ const BedrockLlama2ChatSettings: ModelSettingsDict = {
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to LlamaChat 13B",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -2168,9 +2311,13 @@ export const TogetherChatSettings: ModelSettingsDict = {
|
||||
"Austism/chronos-hermes-13b",
|
||||
"cognitivecomputations/dolphin-2.5-mixtral-8x7b",
|
||||
"databricks/dbrx-instruct",
|
||||
"deepseek-ai/DeepSeek-V3",
|
||||
"deepseek-ai/DeepSeek-R1",
|
||||
"deepseek-ai/deepseek-coder-33b-instruct",
|
||||
"deepseek-ai/deepseek-llm-67b-chat",
|
||||
"garage-bAInd/Platypus2-70B-instruct",
|
||||
"google/gemma-2-27b-it",
|
||||
"google/gemma-2-9b-it",
|
||||
"google/gemma-2b-it",
|
||||
"google/gemma-7b-it",
|
||||
"Gryphe/MythoMax-L2-13b",
|
||||
@ -2180,11 +2327,22 @@ export const TogetherChatSettings: ModelSettingsDict = {
|
||||
"codellama/CodeLlama-34b-Instruct-hf",
|
||||
"codellama/CodeLlama-70b-Instruct-hf",
|
||||
"codellama/CodeLlama-7b-Instruct-hf",
|
||||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
|
||||
"meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
|
||||
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct-Lite",
|
||||
"meta-llama/Meta-Llama-3-70B-Instruct-Lite",
|
||||
"meta-llama/Llama-2-70b-chat-hf",
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
"meta-llama/Llama-3-8b-chat-hf",
|
||||
"meta-llama/Llama-3-70b-chat-hf",
|
||||
"microsoft/WizardLM-2-8x22B",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||
"mistralai/Mistral-7B-Instruct-v0.2",
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
@ -2196,8 +2354,16 @@ export const TogetherChatSettings: ModelSettingsDict = {
|
||||
"NousResearch/Nous-Hermes-llama-2-7b",
|
||||
"NousResearch/Nous-Hermes-Llama2-13b",
|
||||
"NousResearch/Nous-Hermes-2-Yi-34B",
|
||||
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
|
||||
"openchat/openchat-3.5-1210",
|
||||
"Open-Orca/Mistral-7B-OpenOrca",
|
||||
"Qwen/Qwen2.5-7B-Instruct-Turbo",
|
||||
"Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
"Qwen/Qwen2-72B-Instruct",
|
||||
"Qwen/Qwen2-VL-72B-Instruct",
|
||||
"Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
"Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
"Qwen/QwQ-32B-Preview",
|
||||
"Qwen/Qwen1.5-0.5B-Chat",
|
||||
"Qwen/Qwen1.5-1.8B-Chat",
|
||||
"Qwen/Qwen1.5-4B-Chat",
|
||||
@ -2220,7 +2386,7 @@ export const TogetherChatSettings: ModelSettingsDict = {
|
||||
"WizardLM/WizardLM-13B-V1.2",
|
||||
"upstage/SOLAR-10.7B-Instruct-v1.0",
|
||||
],
|
||||
default: "meta-llama/Llama-2-7b-chat-hf",
|
||||
default: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
},
|
||||
temperature: {
|
||||
type: "number",
|
||||
@ -2267,7 +2433,8 @@ export const TogetherChatSettings: ModelSettingsDict = {
|
||||
"ui:autofocus": true,
|
||||
},
|
||||
model: {
|
||||
"ui:help": "Defaults to LlamaChat 13B",
|
||||
"ui:help": "Defaults to Llama-3.3-70B",
|
||||
"ui:widget": "datalist",
|
||||
},
|
||||
temperature: {
|
||||
"ui:help": "Defaults to 1.0.",
|
||||
@ -2342,8 +2509,38 @@ export const ModelSettings: Dict<ModelSettingsDict> = {
|
||||
"br.meta.llama2": BedrockLlama2ChatSettings,
|
||||
"br.meta.llama3": BedrockLlama3Settings,
|
||||
together: TogetherChatSettings,
|
||||
deepseek: DeepSeekSettings,
|
||||
};
|
||||
|
||||
// A lookup that converts the base_model names into LLMProviders.
|
||||
// Used for backwards compatibility.
|
||||
// TODO in future: Deprecate base_model and migrate fully to using LLMProvider type throughout.
|
||||
export function baseModelToProvider(base_model: string): LLMProvider {
|
||||
const lookup: Record<string, LLMProvider> = {
|
||||
"gpt-3.5-turbo": LLMProvider.OpenAI,
|
||||
"gpt-4": LLMProvider.OpenAI,
|
||||
"dall-e": LLMProvider.OpenAI,
|
||||
"claude-v1": LLMProvider.Anthropic,
|
||||
"palm2-bison": LLMProvider.Google,
|
||||
dalai: LLMProvider.Dalai,
|
||||
"azure-openai": LLMProvider.Azure_OpenAI,
|
||||
hf: LLMProvider.HuggingFace,
|
||||
"luminous-base": LLMProvider.Aleph_Alpha,
|
||||
ollama: LLMProvider.Ollama,
|
||||
"br.anthropic.claude": LLMProvider.Bedrock,
|
||||
"br.ai21.j2": LLMProvider.Bedrock,
|
||||
"br.amazon.titan": LLMProvider.Bedrock,
|
||||
"br.cohere.command": LLMProvider.Bedrock,
|
||||
"br.mistral.mistral": LLMProvider.Bedrock,
|
||||
"br.mistral.mixtral": LLMProvider.Bedrock,
|
||||
"br.meta.llama2": LLMProvider.Bedrock,
|
||||
"br.meta.llama3": LLMProvider.Bedrock,
|
||||
together: LLMProvider.Together,
|
||||
deepseek: LLMProvider.DeepSeek,
|
||||
};
|
||||
return lookup[base_model] ?? LLMProvider.Custom;
|
||||
}
|
||||
|
||||
export function getSettingsSchemaForLLM(
|
||||
llm_name: string,
|
||||
): ModelSettingsDict | undefined {
|
||||
@ -2361,6 +2558,7 @@ export function getSettingsSchemaForLLM(
|
||||
[LLMProvider.Aleph_Alpha]: AlephAlphaLuminousSettings,
|
||||
[LLMProvider.Ollama]: OllamaSettings,
|
||||
[LLMProvider.Together]: TogetherChatSettings,
|
||||
[LLMProvider.DeepSeek]: DeepSeekSettings,
|
||||
};
|
||||
|
||||
if (llm_provider === LLMProvider.Custom) return ModelSettings[llm_name];
|
||||
|
@ -4,14 +4,16 @@ import React, {
|
||||
forwardRef,
|
||||
useImperativeHandle,
|
||||
useEffect,
|
||||
useMemo,
|
||||
} from "react";
|
||||
import { Button, Modal, Popover } from "@mantine/core";
|
||||
import { Button, Modal, Popover, Select } from "@mantine/core";
|
||||
import { useDisclosure } from "@mantine/hooks";
|
||||
import emojidata from "@emoji-mart/data";
|
||||
import Picker from "@emoji-mart/react";
|
||||
// react-jsonschema-form
|
||||
import validator from "@rjsf/validator-ajv8";
|
||||
import Form from "@rjsf/core";
|
||||
import { WidgetProps } from "@rjsf/utils";
|
||||
import {
|
||||
ModelSettings,
|
||||
getDefaultModelFormData,
|
||||
@ -24,6 +26,46 @@ import {
|
||||
ModelSettingsDict,
|
||||
} from "./backend/typing";
|
||||
|
||||
// Custom UI widgets for react-jsonschema-form
|
||||
const DatalistWidget = (props: WidgetProps) => {
|
||||
const [data, setData] = useState(
|
||||
(
|
||||
props.options.enumOptions?.map((option, index) => ({
|
||||
value: option.value,
|
||||
label: option.value,
|
||||
})) ?? []
|
||||
).concat(
|
||||
props.options.enumOptions?.find((o) => o.value === props.value)
|
||||
? []
|
||||
: { value: props.value, label: props.value },
|
||||
),
|
||||
);
|
||||
|
||||
return (
|
||||
<Select
|
||||
data={data}
|
||||
defaultValue={props.value ?? ""}
|
||||
onChange={(newVal) => props.onChange(newVal ?? "")}
|
||||
size="sm"
|
||||
placeholder="Select items"
|
||||
nothingFound="Nothing found"
|
||||
searchable
|
||||
creatable
|
||||
getCreateLabel={(query) => `+ Create ${query}`}
|
||||
onCreate={(query) => {
|
||||
const item = { value: query, label: query };
|
||||
setData((current) => [...current, item]);
|
||||
console.log(item);
|
||||
return item;
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const widgets = {
|
||||
datalist: DatalistWidget,
|
||||
};
|
||||
|
||||
export interface ModelSettingsModalRef {
|
||||
trigger: () => void;
|
||||
}
|
||||
@ -84,9 +126,16 @@ const ModelSettingsModal = forwardRef<
|
||||
setSchema(schema);
|
||||
setUISchema(settingsSpec.uiSchema);
|
||||
setBaseModelName(settingsSpec.fullName);
|
||||
|
||||
// If the user has already saved custom settings...
|
||||
if (model.formData) {
|
||||
setFormData(model.formData);
|
||||
setInitShortname(model.formData.shortname as string | undefined);
|
||||
|
||||
// If the "custom_model" field is set, use that as the initial model name, overriding "model".
|
||||
// if (string_exists(model.formData.custom_model))
|
||||
// setInitModelName(model.formData.custom_model as string);
|
||||
// else
|
||||
setInitModelName(model.formData.model as string | undefined);
|
||||
} else {
|
||||
// Create settings from schema
|
||||
@ -244,8 +293,9 @@ const ModelSettingsModal = forwardRef<
|
||||
<Form
|
||||
schema={schema}
|
||||
uiSchema={uiSchema}
|
||||
widgets={widgets} // Custom UI widgets
|
||||
formData={formData}
|
||||
// @ts-expect-error This is literally the example code from react-json-schema; no idea why it wouldn't typecheck correctly.
|
||||
// // @ts-expect-error This is literally the example code from react-json-schema; no idea why it wouldn't typecheck correctly.
|
||||
validator={validator}
|
||||
// @ts-expect-error Expect format is LLMSpec.
|
||||
onChange={onFormDataChange}
|
||||
|
@ -52,6 +52,7 @@ import {
|
||||
QueryProgress,
|
||||
LLMResponse,
|
||||
TemplateVarInfo,
|
||||
StringOrHash,
|
||||
} from "./backend/typing";
|
||||
import { AlertModalContext } from "./AlertModal";
|
||||
import { Status } from "./StatusIndicatorComponent";
|
||||
@ -62,6 +63,7 @@ import {
|
||||
grabResponses,
|
||||
queryLLM,
|
||||
} from "./backend/backend";
|
||||
import { StringLookup } from "./backend/cache";
|
||||
|
||||
const getUniqueLLMMetavarKey = (responses: LLMResponse[]) => {
|
||||
const metakeys = new Set(
|
||||
@ -81,7 +83,7 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => {
|
||||
return chats_by_llm;
|
||||
};
|
||||
|
||||
class PromptInfo {
|
||||
export class PromptInfo {
|
||||
prompt: string;
|
||||
settings: Dict;
|
||||
|
||||
@ -118,7 +120,7 @@ export interface PromptListPopoverProps {
|
||||
onClick: () => void;
|
||||
}
|
||||
|
||||
const PromptListPopover: React.FC<PromptListPopoverProps> = ({
|
||||
export const PromptListPopover: React.FC<PromptListPopoverProps> = ({
|
||||
promptInfos,
|
||||
onHover,
|
||||
onClick,
|
||||
@ -176,6 +178,39 @@ const PromptListPopover: React.FC<PromptListPopoverProps> = ({
|
||||
);
|
||||
};
|
||||
|
||||
export interface PromptListModalProps {
|
||||
promptPreviews: PromptInfo[];
|
||||
infoModalOpened: boolean;
|
||||
closeInfoModal: () => void;
|
||||
}
|
||||
|
||||
export const PromptListModal: React.FC<PromptListModalProps> = ({
|
||||
promptPreviews,
|
||||
infoModalOpened,
|
||||
closeInfoModal,
|
||||
}) => {
|
||||
return (
|
||||
<Modal
|
||||
title={
|
||||
"List of prompts that will be sent to LLMs (" +
|
||||
promptPreviews.length +
|
||||
" total)"
|
||||
}
|
||||
size="xl"
|
||||
opened={infoModalOpened}
|
||||
onClose={closeInfoModal}
|
||||
styles={{
|
||||
header: { backgroundColor: "#FFD700" },
|
||||
root: { position: "relative", left: "-5%" },
|
||||
}}
|
||||
>
|
||||
<Box m="lg" mt="xl">
|
||||
{displayPromptInfos(promptPreviews, true)}
|
||||
</Box>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export interface PromptNodeProps {
|
||||
data: {
|
||||
title: string;
|
||||
@ -376,24 +411,36 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
[setTemplateVars, templateVars, pullInputData, id],
|
||||
);
|
||||
|
||||
const handleInputChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const value = event.target.value;
|
||||
const handleInputChange = useCallback(
|
||||
(event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const value = event.target.value;
|
||||
const updateStatus =
|
||||
promptTextOnLastRun !== null &&
|
||||
status !== Status.WARNING &&
|
||||
value !== promptTextOnLastRun;
|
||||
|
||||
// Store prompt text
|
||||
setPromptText(value);
|
||||
data.prompt = value;
|
||||
// Store prompt text
|
||||
data.prompt = value;
|
||||
|
||||
// Update status icon, if need be:
|
||||
if (
|
||||
promptTextOnLastRun !== null &&
|
||||
status !== Status.WARNING &&
|
||||
value !== promptTextOnLastRun
|
||||
)
|
||||
setStatus(Status.WARNING);
|
||||
// Debounce the global state change to happen only after 500ms, as it forces a costly rerender:
|
||||
debounce((_value, _updateStatus) => {
|
||||
setPromptText(_value);
|
||||
setDataPropsForNode(id, { prompt: _value });
|
||||
refreshTemplateHooks(_value);
|
||||
if (_updateStatus) setStatus(Status.WARNING);
|
||||
}, 300)(value, updateStatus);
|
||||
|
||||
// Debounce refreshing the template hooks so we don't annoy the user
|
||||
debounce((_value) => refreshTemplateHooks(_value), 500)(value);
|
||||
};
|
||||
// Debounce refreshing the template hooks so we don't annoy the user
|
||||
// debounce((_value) => refreshTemplateHooks(_value), 500)(value);
|
||||
},
|
||||
[
|
||||
promptTextOnLastRun,
|
||||
status,
|
||||
refreshTemplateHooks,
|
||||
setDataPropsForNode,
|
||||
debounceTimeoutRef,
|
||||
],
|
||||
);
|
||||
|
||||
// On initialization
|
||||
useEffect(() => {
|
||||
@ -432,7 +479,7 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
|
||||
// Chat nodes only. Pulls input data attached to the 'past conversations' handle.
|
||||
// Returns a tuple (past_chat_llms, __past_chats), where both are undefined if nothing is connected.
|
||||
const pullInputChats = () => {
|
||||
const pullInputChats = useCallback(() => {
|
||||
const pulled_data = pullInputData(["__past_chats"], id);
|
||||
if (!("__past_chats" in pulled_data)) return [undefined, undefined];
|
||||
|
||||
@ -454,6 +501,7 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
// Add to unique LLMs list, if necessary
|
||||
if (
|
||||
typeof info?.llm !== "string" &&
|
||||
typeof info?.llm !== "number" &&
|
||||
info?.llm?.name !== undefined &&
|
||||
!llm_names.has(info.llm.name)
|
||||
) {
|
||||
@ -464,8 +512,8 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
// Create revised chat_history on the TemplateVarInfo object,
|
||||
// with the prompt and text of the pulled data as the 2nd-to-last, and last, messages:
|
||||
const last_messages = [
|
||||
{ role: "user", content: info.prompt ?? "" },
|
||||
{ role: "assistant", content: info.text ?? "" },
|
||||
{ role: "user", content: StringLookup.get(info.prompt) ?? "" },
|
||||
{ role: "assistant", content: StringLookup.get(info.text) ?? "" },
|
||||
];
|
||||
let updated_chat_hist =
|
||||
info.chat_history !== undefined
|
||||
@ -475,6 +523,7 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
// Append any present system message retroactively as the first message in the chat history:
|
||||
if (
|
||||
typeof info?.llm !== "string" &&
|
||||
typeof info?.llm !== "number" &&
|
||||
typeof info?.llm?.settings?.system_msg === "string" &&
|
||||
updated_chat_hist[0].role !== "system"
|
||||
)
|
||||
@ -487,7 +536,10 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
messages: updated_chat_hist,
|
||||
fill_history: info.fill_history ?? {},
|
||||
metavars: info.metavars ?? {},
|
||||
llm: typeof info?.llm === "string" ? info.llm : info?.llm?.name,
|
||||
llm:
|
||||
typeof info?.llm === "string" || typeof info?.llm === "number"
|
||||
? StringLookup.get(info.llm) ?? "(LLM lookup failed)"
|
||||
: StringLookup.get(info?.llm?.name),
|
||||
uid: uuid(),
|
||||
};
|
||||
},
|
||||
@ -495,36 +547,46 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
|
||||
// Returns [list of LLM specs, list of ChatHistoryInfo]
|
||||
return [past_chat_llms, past_chats];
|
||||
};
|
||||
}, [id, pullInputData]);
|
||||
|
||||
// Ask the backend how many responses it needs to collect, given the input data:
|
||||
const fetchResponseCounts = (
|
||||
prompt: string,
|
||||
vars: Dict,
|
||||
llms: (string | Dict)[],
|
||||
chat_histories?:
|
||||
| (ChatHistoryInfo | undefined)[]
|
||||
| Dict<(ChatHistoryInfo | undefined)[]>,
|
||||
) => {
|
||||
return countQueries(
|
||||
prompt,
|
||||
vars,
|
||||
llms,
|
||||
const fetchResponseCounts = useCallback(
|
||||
(
|
||||
prompt: string,
|
||||
vars: Dict,
|
||||
llms: (StringOrHash | LLMSpec)[],
|
||||
chat_histories?:
|
||||
| (ChatHistoryInfo | undefined)[]
|
||||
| Dict<(ChatHistoryInfo | undefined)[]>,
|
||||
) => {
|
||||
return countQueries(
|
||||
prompt,
|
||||
vars,
|
||||
llms,
|
||||
numGenerations,
|
||||
chat_histories,
|
||||
id,
|
||||
node_type !== "chat" ? showContToggle && contWithPriorLLMs : undefined,
|
||||
).then(function (results) {
|
||||
return [results.counts, results.total_num_responses] as [
|
||||
Dict<Dict<number>>,
|
||||
Dict<number>,
|
||||
];
|
||||
});
|
||||
},
|
||||
[
|
||||
countQueries,
|
||||
numGenerations,
|
||||
chat_histories,
|
||||
showContToggle,
|
||||
contWithPriorLLMs,
|
||||
id,
|
||||
node_type !== "chat" ? showContToggle && contWithPriorLLMs : undefined,
|
||||
).then(function (results) {
|
||||
return [results.counts, results.total_num_responses] as [
|
||||
Dict<Dict<number>>,
|
||||
Dict<number>,
|
||||
];
|
||||
});
|
||||
};
|
||||
node_type,
|
||||
],
|
||||
);
|
||||
|
||||
// On hover over the 'info' button, to preview the prompts that will be sent out
|
||||
const [promptPreviews, setPromptPreviews] = useState<PromptInfo[]>([]);
|
||||
const handlePreviewHover = () => {
|
||||
const handlePreviewHover = useCallback(() => {
|
||||
// Pull input data and prompt
|
||||
try {
|
||||
const pulled_vars = pullInputData(templateVars, id);
|
||||
@ -545,10 +607,18 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
console.error(err);
|
||||
setPromptPreviews([]);
|
||||
}
|
||||
};
|
||||
}, [
|
||||
pullInputData,
|
||||
templateVars,
|
||||
id,
|
||||
updateShowContToggle,
|
||||
generatePrompts,
|
||||
promptText,
|
||||
pullInputChats,
|
||||
]);
|
||||
|
||||
// On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails.
|
||||
const handleRunHover = () => {
|
||||
const handleRunHover = useCallback(() => {
|
||||
// Check if the PromptNode is not already waiting for a response...
|
||||
if (status === "loading") {
|
||||
setRunTooltip("Fetching responses...");
|
||||
@ -679,9 +749,17 @@ const PromptNode: React.FC<PromptNodeProps> = ({
|
||||
console.error(err); // soft fail
|
||||
setRunTooltip("Could not reach backend server.");
|
||||
});
|
||||
};
|
||||
}, [
|
||||
status,
|
||||
llmItemsCurrState,
|
||||
pullInputChats,
|
||||
contWithPriorLLMs,
|
||||
pullInputData,
|
||||
fetchResponseCounts,
|
||||
promptText,
|
||||
]);
|
||||
|
||||
const handleRunClick = () => {
|
||||
const handleRunClick = useCallback(() => {
|
||||
// Go through all template hooks (if any) and check they're connected:
|
||||
const is_fully_connected = templateVars.every((varname) => {
|
||||
// Check that some edge has, as its target, this node and its template hook:
|
||||
@ -912,7 +990,12 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
resp_obj.responses.map((r) => {
|
||||
// Carry over the response text, prompt, prompt fill history (vars), and llm nickname:
|
||||
const o: TemplateVarInfo = {
|
||||
text: typeof r === "string" ? escapeBraces(r) : undefined,
|
||||
text:
|
||||
typeof r === "number"
|
||||
? escapeBraces(StringLookup.get(r)!)
|
||||
: typeof r === "string"
|
||||
? escapeBraces(r)
|
||||
: undefined,
|
||||
image:
|
||||
typeof r === "object" && r.t === "img" ? r.d : undefined,
|
||||
prompt: resp_obj.prompt,
|
||||
@ -923,6 +1006,11 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
uid: resp_obj.uid,
|
||||
};
|
||||
|
||||
o.text =
|
||||
o.text !== undefined
|
||||
? StringLookup.intern(o.text as string)
|
||||
: undefined;
|
||||
|
||||
// Carry over any metavars
|
||||
o.metavars = resp_obj.metavars ?? {};
|
||||
|
||||
@ -935,9 +1023,11 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
|
||||
// Add a meta var to keep track of which LLM produced this response
|
||||
o.metavars[llm_metavar_key] =
|
||||
typeof resp_obj.llm === "string"
|
||||
? resp_obj.llm
|
||||
typeof resp_obj.llm === "string" ||
|
||||
typeof resp_obj.llm === "number"
|
||||
? StringLookup.get(resp_obj.llm) ?? "(LLM lookup failed)"
|
||||
: resp_obj.llm.name;
|
||||
|
||||
return o;
|
||||
}),
|
||||
)
|
||||
@ -1006,7 +1096,31 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
.then(open_progress_listener)
|
||||
.then(query_llms)
|
||||
.catch(rejected);
|
||||
};
|
||||
}, [
|
||||
templateVars,
|
||||
triggerAlert,
|
||||
pullInputChats,
|
||||
pullInputData,
|
||||
updateShowContToggle,
|
||||
llmItemsCurrState,
|
||||
contWithPriorLLMs,
|
||||
showAlert,
|
||||
fetchResponseCounts,
|
||||
numGenerations,
|
||||
promptText,
|
||||
apiKeys,
|
||||
showContToggle,
|
||||
cancelId,
|
||||
refreshCancelId,
|
||||
node_type,
|
||||
id,
|
||||
setDataPropsForNode,
|
||||
llmListContainer,
|
||||
responsesWillChange,
|
||||
showDrawer,
|
||||
pingOutputNodes,
|
||||
debounceTimeoutRef,
|
||||
]);
|
||||
|
||||
const handleStopClick = useCallback(() => {
|
||||
CancelTracker.add(cancelId);
|
||||
@ -1024,7 +1138,7 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
setStatus(Status.NONE);
|
||||
setContChatToggleDisabled(false);
|
||||
llmListContainer?.current?.resetLLMItemsProgress();
|
||||
}, [cancelId, refreshCancelId]);
|
||||
}, [cancelId, refreshCancelId, debounceTimeoutRef]);
|
||||
|
||||
const handleNumGenChange = useCallback(
|
||||
(event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
@ -1124,24 +1238,11 @@ Soft failing by replacing undefined with empty strings.`,
|
||||
ref={inspectModal}
|
||||
jsonResponses={jsonResponses ?? []}
|
||||
/>
|
||||
<Modal
|
||||
title={
|
||||
"List of prompts that will be sent to LLMs (" +
|
||||
promptPreviews.length +
|
||||
" total)"
|
||||
}
|
||||
size="xl"
|
||||
opened={infoModalOpened}
|
||||
onClose={closeInfoModal}
|
||||
styles={{
|
||||
header: { backgroundColor: "#FFD700" },
|
||||
root: { position: "relative", left: "-5%" },
|
||||
}}
|
||||
>
|
||||
<Box m="lg" mt="xl">
|
||||
{displayPromptInfos(promptPreviews, true)}
|
||||
</Box>
|
||||
</Modal>
|
||||
<PromptListModal
|
||||
promptPreviews={promptPreviews}
|
||||
infoModalOpened={infoModalOpened}
|
||||
closeInfoModal={closeInfoModal}
|
||||
/>
|
||||
|
||||
{node_type === "chat" ? (
|
||||
<div ref={setRef}>
|
||||
|
@ -1,13 +1,15 @@
|
||||
import React, { Suspense, useMemo, lazy } from "react";
|
||||
import { Collapse, Flex, Stack } from "@mantine/core";
|
||||
import { useDisclosure } from "@mantine/hooks";
|
||||
import { truncStr } from "./backend/utils";
|
||||
import { llmResponseDataToString, truncStr } from "./backend/utils";
|
||||
import {
|
||||
Dict,
|
||||
EvaluationScore,
|
||||
LLMResponse,
|
||||
LLMResponseData,
|
||||
StringOrHash,
|
||||
} from "./backend/typing";
|
||||
import { StringLookup } from "./backend/cache";
|
||||
|
||||
// Lazy load the response toolbars
|
||||
const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar"));
|
||||
@ -15,25 +17,53 @@ const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar"));
|
||||
/* HELPER FUNCTIONS */
|
||||
const SUCCESS_EVAL_SCORES = new Set(["true", "yes"]);
|
||||
const FAILURE_EVAL_SCORES = new Set(["false", "no"]);
|
||||
/**
|
||||
* Returns an array of JSX elements, and the searchable text underpinning them,
|
||||
* that represents a concrete version of the Evaluation Scores passed in.
|
||||
* @param eval_item The evaluation result to visualize.
|
||||
* @param hide_prefix Whether to hide 'score: ' or '{key}: ' prefixes when printing.
|
||||
* @param onlyString Whether to only return string values.
|
||||
* @returns An array [JSX.Element, string] where the latter is a string representation of the eval score, to enable search
|
||||
*/
|
||||
export const getEvalResultStr = (
|
||||
eval_item: EvaluationScore,
|
||||
hide_prefix: boolean,
|
||||
) => {
|
||||
onlyString?: boolean,
|
||||
): [JSX.Element | string, string] => {
|
||||
if (Array.isArray(eval_item)) {
|
||||
return (hide_prefix ? "" : "scores: ") + eval_item.join(", ");
|
||||
const items_str = (hide_prefix ? "" : "scores: ") + eval_item.join(", ");
|
||||
return [items_str, items_str];
|
||||
} else if (typeof eval_item === "object") {
|
||||
const strs = Object.keys(eval_item).map((key, j) => {
|
||||
let val = eval_item[key];
|
||||
if (typeof val === "number" && val.toString().indexOf(".") > -1)
|
||||
val = val.toFixed(4); // truncate floats to 4 decimal places
|
||||
return (
|
||||
<div key={`${key}-${j}`}>
|
||||
<span>{key}: </span>
|
||||
<span>{getEvalResultStr(val, true)}</span>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
return <Stack spacing={0}>{strs}</Stack>;
|
||||
const strs: [JSX.Element | string, string][] = Object.keys(eval_item).map(
|
||||
(key, j) => {
|
||||
const innerKey = `${key}-${j}`;
|
||||
let val = eval_item[key];
|
||||
if (typeof val === "number" && val.toString().indexOf(".") > -1)
|
||||
val = val.toFixed(4); // truncate floats to 4 decimal places
|
||||
const [recurs_res, recurs_str] = getEvalResultStr(val, true);
|
||||
if (onlyString) return [`${key}: ${recurs_str}`, recurs_str];
|
||||
else
|
||||
return [
|
||||
<div key={innerKey}>
|
||||
<span key={0}>{key}: </span>
|
||||
<span key={1}>{recurs_res}</span>
|
||||
</div>,
|
||||
recurs_str,
|
||||
];
|
||||
},
|
||||
);
|
||||
const joined_strs = strs.map((s) => s[1]).join("\n");
|
||||
if (onlyString) {
|
||||
return [joined_strs, joined_strs];
|
||||
} else
|
||||
return [
|
||||
<Stack key={1} spacing={0}>
|
||||
{strs.map((s, i) => (
|
||||
<span key={i}>s</span>
|
||||
))}
|
||||
</Stack>,
|
||||
joined_strs,
|
||||
];
|
||||
} else {
|
||||
const eval_str = eval_item.toString().trim().toLowerCase();
|
||||
const color = SUCCESS_EVAL_SCORES.has(eval_str)
|
||||
@ -41,12 +71,15 @@ export const getEvalResultStr = (
|
||||
: FAILURE_EVAL_SCORES.has(eval_str)
|
||||
? "red"
|
||||
: "black";
|
||||
return (
|
||||
<>
|
||||
{!hide_prefix && <span style={{ color: "gray" }}>{"score: "}</span>}
|
||||
<span style={{ color }}>{eval_str}</span>
|
||||
</>
|
||||
);
|
||||
if (onlyString) return [eval_str, eval_str];
|
||||
else
|
||||
return [
|
||||
<>
|
||||
{!hide_prefix && <span style={{ color: "gray" }}>{"score: "}</span>}
|
||||
<span style={{ color }}>{eval_str}</span>
|
||||
</>,
|
||||
eval_str,
|
||||
];
|
||||
}
|
||||
};
|
||||
|
||||
@ -113,7 +146,7 @@ export const ResponseGroup: React.FC<ResponseGroupProps> = ({
|
||||
*/
|
||||
interface ResponseBoxProps {
|
||||
children: React.ReactNode; // For components, HTML elements, text, etc.
|
||||
vars?: Dict<string>;
|
||||
vars?: Dict<StringOrHash>;
|
||||
truncLenForVars?: number;
|
||||
llmName?: string;
|
||||
boxColor?: string;
|
||||
@ -131,7 +164,10 @@ export const ResponseBox: React.FC<ResponseBoxProps> = ({
|
||||
const var_tags = useMemo(() => {
|
||||
if (vars === undefined) return [];
|
||||
return Object.entries(vars).map(([varname, val]) => {
|
||||
const v = truncStr(val.trim(), truncLenForVars ?? 18);
|
||||
const v = truncStr(
|
||||
(StringLookup.get(val) ?? "").trim(),
|
||||
truncLenForVars ?? 18,
|
||||
);
|
||||
return (
|
||||
<div key={varname} className="response-var-inline">
|
||||
<span className="response-var-name">{varname} = </span>
|
||||
@ -191,16 +227,15 @@ export const genResponseTextsDisplay = (
|
||||
const resp_str_to_eval_res: Dict<EvaluationScore> = {};
|
||||
if (eval_res_items)
|
||||
responses.forEach((r, idx) => {
|
||||
resp_str_to_eval_res[typeof r === "string" ? r : r.d] =
|
||||
eval_res_items[idx];
|
||||
resp_str_to_eval_res[llmResponseDataToString(r)] = eval_res_items[idx];
|
||||
});
|
||||
|
||||
const same_resp_text_counts = countResponsesBy(responses, (r) =>
|
||||
typeof r === "string" ? r : r.d,
|
||||
llmResponseDataToString(r),
|
||||
);
|
||||
const resp_special_type_map: Dict<string> = {};
|
||||
responses.forEach((r) => {
|
||||
const key = typeof r === "string" ? r : r.d;
|
||||
const key = llmResponseDataToString(r);
|
||||
if (typeof r === "object") resp_special_type_map[key] = r.t;
|
||||
});
|
||||
const same_resp_keys = Object.keys(same_resp_text_counts).sort(
|
||||
@ -240,6 +275,7 @@ export const genResponseTextsDisplay = (
|
||||
uid={res_obj.uid}
|
||||
innerIdxs={origIdxs}
|
||||
wideFormat={wideFormat}
|
||||
responseData={r}
|
||||
/>
|
||||
</Suspense>
|
||||
{llmName !== undefined &&
|
||||
@ -259,7 +295,7 @@ export const genResponseTextsDisplay = (
|
||||
)}
|
||||
{eval_res_items ? (
|
||||
<p className="small-response-metrics">
|
||||
{getEvalResultStr(resp_str_to_eval_res[r], true)}
|
||||
{getEvalResultStr(resp_str_to_eval_res[r], true)[0]}
|
||||
</p>
|
||||
) : (
|
||||
<></>
|
||||
|
@ -5,11 +5,17 @@ import React, {
|
||||
useMemo,
|
||||
useState,
|
||||
} from "react";
|
||||
import { Button, Flex, Popover, Stack, Textarea } from "@mantine/core";
|
||||
import { IconMessage2, IconThumbDown, IconThumbUp } from "@tabler/icons-react";
|
||||
import { Button, Flex, Popover, Stack, Textarea, Tooltip } from "@mantine/core";
|
||||
import {
|
||||
IconCopy,
|
||||
IconMessage2,
|
||||
IconThumbDown,
|
||||
IconThumbUp,
|
||||
} from "@tabler/icons-react";
|
||||
import StorageCache from "./backend/cache";
|
||||
import useStore from "./store";
|
||||
import { deepcopy } from "./backend/utils";
|
||||
import { LLMResponseData } from "./backend/typing";
|
||||
|
||||
type RatingDict = Record<number, boolean | string | undefined>;
|
||||
|
||||
@ -63,14 +69,14 @@ export interface ResponseRatingToolbarProps {
|
||||
uid: string;
|
||||
wideFormat?: boolean;
|
||||
innerIdxs: number[];
|
||||
onUpdateResponses?: () => void;
|
||||
responseData?: string;
|
||||
}
|
||||
|
||||
const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
|
||||
uid,
|
||||
wideFormat,
|
||||
innerIdxs,
|
||||
onUpdateResponses,
|
||||
responseData,
|
||||
}) => {
|
||||
// The cache keys storing the ratings for this response object
|
||||
const gradeKey = getRatingKeyForResponse(uid, "grade");
|
||||
@ -108,6 +114,9 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
|
||||
const [noteText, setNoteText] = useState("");
|
||||
const [notePopoverOpened, setNotePopoverOpened] = useState(false);
|
||||
|
||||
// Text state
|
||||
const [copied, setCopied] = useState(false);
|
||||
|
||||
// Override the text in the internal textarea whenever upstream annotation changes.
|
||||
useEffect(() => {
|
||||
setNoteText(note !== undefined ? note.toString() : "");
|
||||
@ -133,7 +142,6 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
|
||||
new_grades[idx] = grade;
|
||||
});
|
||||
setRating(uid, "grade", new_grades);
|
||||
if (onUpdateResponses) onUpdateResponses();
|
||||
};
|
||||
|
||||
const onAnnotate = (label?: string) => {
|
||||
@ -145,7 +153,6 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
|
||||
new_notes[idx] = label;
|
||||
});
|
||||
setRating(uid, "note", new_notes);
|
||||
if (onUpdateResponses) onUpdateResponses();
|
||||
};
|
||||
|
||||
const handleSaveAnnotation = useCallback(() => {
|
||||
@ -175,6 +182,33 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
|
||||
>
|
||||
<IconThumbDown size={size} />
|
||||
</ToolbarButton>
|
||||
<Tooltip
|
||||
label={copied ? "Copied!" : "Copy"}
|
||||
withArrow
|
||||
arrowPosition="center"
|
||||
>
|
||||
<ToolbarButton
|
||||
selected={copied}
|
||||
onClick={() => {
|
||||
if (responseData) {
|
||||
navigator.clipboard
|
||||
.writeText(responseData)
|
||||
.then(() => {
|
||||
console.log("Text copied to clipboard");
|
||||
setCopied(() => true);
|
||||
setTimeout(() => {
|
||||
setCopied(() => false);
|
||||
}, 1000);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error("Failed to copy text: ", err);
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<IconCopy size={size} />
|
||||
</ToolbarButton>
|
||||
</Tooltip>
|
||||
<Popover
|
||||
opened={notePopoverOpened}
|
||||
onChange={setNotePopoverOpened}
|
||||
|
@ -28,7 +28,7 @@ import {
|
||||
} from "./backend/utils";
|
||||
|
||||
import { fromMarkdown } from "mdast-util-from-markdown";
|
||||
import StorageCache from "./backend/cache";
|
||||
import StorageCache, { StringLookup } from "./backend/cache";
|
||||
import { ResponseBox } from "./ResponseBoxes";
|
||||
import { Root, RootContent } from "mdast";
|
||||
import { Dict, TemplateVarInfo } from "./backend/typing";
|
||||
@ -130,11 +130,13 @@ const displaySplitTexts = (
|
||||
} else {
|
||||
const llm_color =
|
||||
typeof info.llm === "object" && "name" in info.llm
|
||||
? color_for_llm(info.llm?.name)
|
||||
? color_for_llm(
|
||||
StringLookup.get(info.llm?.name) ?? "(string lookup failed)",
|
||||
)
|
||||
: "#ddd";
|
||||
const llm_name =
|
||||
typeof info.llm === "object" && "name" in info.llm
|
||||
? info.llm?.name
|
||||
? StringLookup.get(info.llm?.name)
|
||||
: "";
|
||||
return (
|
||||
<ResponseBox
|
||||
@ -289,7 +291,11 @@ const SplitNode: React.FC<SplitNodeProps> = ({ data, id }) => {
|
||||
.map((resp_obj: TemplateVarInfo | string) => {
|
||||
if (typeof resp_obj === "string")
|
||||
return splitText(resp_obj, formatting, true);
|
||||
const texts = splitText(resp_obj?.text ?? "", formatting, true);
|
||||
const texts = splitText(
|
||||
StringLookup.get(resp_obj?.text) ?? "",
|
||||
formatting,
|
||||
true,
|
||||
);
|
||||
if (texts !== undefined && texts.length >= 1)
|
||||
return texts.map(
|
||||
(t: string) =>
|
||||
|
@ -26,6 +26,7 @@ import { Dict, TabularDataRowType, TabularDataColType } from "./backend/typing";
|
||||
import { Position } from "reactflow";
|
||||
import { AIGenReplaceTablePopover } from "./AiPopover";
|
||||
import { parseTableData } from "./backend/tableUtils";
|
||||
import { StringLookup } from "./backend/cache";
|
||||
|
||||
const defaultRows: TabularDataRowType[] = [
|
||||
{
|
||||
@ -469,7 +470,7 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const [rowValues, setRowValues] = useState<string[]>(
|
||||
tableData.map((row) => row.value || ""),
|
||||
tableData.map((row) => StringLookup.get(row.value) ?? ""),
|
||||
);
|
||||
|
||||
// Function to add new columns to the right of the existing columns (with optional row values)
|
||||
|
@ -7,7 +7,7 @@ import React, {
|
||||
MouseEventHandler,
|
||||
} from "react";
|
||||
import { Handle, Node, Position } from "reactflow";
|
||||
import { Textarea, Tooltip, Skeleton } from "@mantine/core";
|
||||
import { Textarea, Tooltip, Skeleton, ScrollArea } from "@mantine/core";
|
||||
import {
|
||||
IconTextPlus,
|
||||
IconEye,
|
||||
@ -74,6 +74,15 @@ const TextFieldsNode: React.FC<TextFieldsNodeProps> = ({ data, id }) => {
|
||||
data.fields_visibility || {},
|
||||
);
|
||||
|
||||
// For when textfields exceed the TextFields Node max height,
|
||||
// when we add a new field, this gives us a way to scroll to the bottom. Better UX.
|
||||
const viewport = useRef<HTMLDivElement>(null);
|
||||
const scrollToBottom = () =>
|
||||
viewport?.current?.scrollTo({
|
||||
top: viewport.current.scrollHeight,
|
||||
behavior: "smooth",
|
||||
});
|
||||
|
||||
// Whether the text fields should be in a loading state
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
@ -163,6 +172,10 @@ const TextFieldsNode: React.FC<TextFieldsNodeProps> = ({ data, id }) => {
|
||||
setDataPropsForNode(id, { fields: new_fields });
|
||||
pingOutputNodes(id);
|
||||
|
||||
setTimeout(() => {
|
||||
scrollToBottom();
|
||||
}, 10);
|
||||
|
||||
// Cycle suggestions when new field is created
|
||||
// aiSuggestionsManager.cycleSuggestions();
|
||||
|
||||
@ -493,7 +506,11 @@ const TextFieldsNode: React.FC<TextFieldsNodeProps> = ({ data, id }) => {
|
||||
}
|
||||
/>
|
||||
<Skeleton visible={isLoading}>
|
||||
<div ref={setRef}>{textFields}</div>
|
||||
<div ref={setRef} className="nodrag nowheel">
|
||||
<ScrollArea.Autosize mah={580} type="hover" viewportRef={viewport}>
|
||||
{textFields}
|
||||
</ScrollArea.Autosize>
|
||||
</div>
|
||||
</Skeleton>
|
||||
<Handle
|
||||
type="source"
|
||||
|
@ -13,9 +13,11 @@ import {
|
||||
EvaluationScore,
|
||||
JSONCompatible,
|
||||
LLMResponse,
|
||||
LLMResponseData,
|
||||
} from "./backend/typing";
|
||||
import { Status } from "./StatusIndicatorComponent";
|
||||
import { grabResponses } from "./backend/backend";
|
||||
import { StringLookup } from "./backend/cache";
|
||||
|
||||
// Helper funcs
|
||||
const splitAndAddBreaks = (s: string, chunkSize: number) => {
|
||||
@ -223,8 +225,9 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
|
||||
|
||||
const get_llm = (resp_obj: LLMResponse) => {
|
||||
if (selectedLLMGroup === "LLM")
|
||||
return typeof resp_obj.llm === "string"
|
||||
? resp_obj.llm
|
||||
return typeof resp_obj.llm === "string" ||
|
||||
typeof resp_obj.llm === "number"
|
||||
? StringLookup.get(resp_obj.llm) ?? "(LLM lookup failed)"
|
||||
: resp_obj.llm?.name;
|
||||
else return resp_obj.metavars[selectedLLMGroup] as string;
|
||||
};
|
||||
@ -332,7 +335,7 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
|
||||
? resp_obj.metavars[varname.slice("__meta_".length)]
|
||||
: resp_obj.vars[varname];
|
||||
if (v === undefined && empty_str_if_undefined) return "";
|
||||
return v;
|
||||
return StringLookup.get(v) ?? "";
|
||||
};
|
||||
|
||||
const get_var_and_trim = (
|
||||
@ -345,6 +348,11 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
|
||||
else return v;
|
||||
};
|
||||
|
||||
const castData = (v: LLMResponseData) =>
|
||||
typeof v === "string" || typeof v === "number"
|
||||
? StringLookup.get(v) ?? "(unknown lookup error)"
|
||||
: v.d;
|
||||
|
||||
const get_items = (eval_res_obj?: EvaluationResults) => {
|
||||
if (eval_res_obj === undefined) return [];
|
||||
if (typeof_eval_res.includes("KeyValue"))
|
||||
@ -478,9 +486,7 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
|
||||
if (resp_to_x(r) !== name) return;
|
||||
x_items = x_items.concat(get_items(r.eval_res));
|
||||
text_items = text_items.concat(
|
||||
createHoverTexts(
|
||||
r.responses.map((v) => (typeof v === "string" ? v : v.d)),
|
||||
),
|
||||
createHoverTexts(r.responses.map(castData)),
|
||||
);
|
||||
});
|
||||
}
|
||||
@ -573,11 +579,7 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
|
||||
if (resp_to_x(r) !== name) return;
|
||||
x_items = x_items.concat(get_items(r.eval_res)).flat();
|
||||
text_items = text_items
|
||||
.concat(
|
||||
createHoverTexts(
|
||||
r.responses.map((v) => (typeof v === "string" ? v : v.d)),
|
||||
),
|
||||
)
|
||||
.concat(createHoverTexts(r.responses.map(castData)))
|
||||
.flat();
|
||||
y_items = y_items
|
||||
.concat(
|
||||
|
@ -10,7 +10,7 @@ import {
|
||||
ResponseInfo,
|
||||
grabResponses,
|
||||
} from "../backend";
|
||||
import { LLMResponse, Dict } from "../typing";
|
||||
import { LLMResponse, Dict, StringOrHash, LLMSpec } from "../typing";
|
||||
import StorageCache from "../cache";
|
||||
|
||||
test("count queries required", async () => {
|
||||
@ -22,7 +22,10 @@ test("count queries required", async () => {
|
||||
};
|
||||
|
||||
// Double-check the queries required (not loading from cache)
|
||||
const test_count_queries = async (llms: Array<string | Dict>, n: number) => {
|
||||
const test_count_queries = async (
|
||||
llms: Array<StringOrHash | LLMSpec>,
|
||||
n: number,
|
||||
) => {
|
||||
const { counts, total_num_responses } = await countQueries(
|
||||
prompt,
|
||||
vars,
|
||||
|
@ -22,7 +22,7 @@ test("saving and loading cache data from localStorage", () => {
|
||||
expect(d).toBeUndefined();
|
||||
|
||||
// Load cache from localStorage
|
||||
StorageCache.loadFromLocalStorage("test");
|
||||
StorageCache.loadFromLocalStorage("test", false);
|
||||
|
||||
// Verify stored data:
|
||||
d = StorageCache.get("hello");
|
||||
|
@ -2,11 +2,11 @@
|
||||
* @jest-environment node
|
||||
*/
|
||||
import { PromptPipeline } from "../query";
|
||||
import { LLM, NativeLLM } from "../models";
|
||||
import { LLM, LLMProvider, NativeLLM } from "../models";
|
||||
import { expect, test } from "@jest/globals";
|
||||
import { LLMResponseError, RawLLMResponseObject } from "../typing";
|
||||
|
||||
async function prompt_model(model: LLM): Promise<void> {
|
||||
async function prompt_model(model: LLM, provider: LLMProvider): Promise<void> {
|
||||
const pipeline = new PromptPipeline(
|
||||
"What is the oldest {thing} in the world? Keep your answer brief.",
|
||||
model.toString(),
|
||||
@ -15,6 +15,7 @@ async function prompt_model(model: LLM): Promise<void> {
|
||||
for await (const response of pipeline.gen_responses(
|
||||
{ thing: ["bar", "tree", "book"] },
|
||||
model,
|
||||
provider,
|
||||
1,
|
||||
1.0,
|
||||
)) {
|
||||
@ -35,6 +36,7 @@ async function prompt_model(model: LLM): Promise<void> {
|
||||
for await (const response of pipeline.gen_responses(
|
||||
{ thing: ["bar", "tree", "book"] },
|
||||
model,
|
||||
provider,
|
||||
2,
|
||||
1.0,
|
||||
)) {
|
||||
@ -54,7 +56,6 @@ async function prompt_model(model: LLM): Promise<void> {
|
||||
`Prompt: ${prompt}\nResponses: ${JSON.stringify(resp_obj.responses)}`,
|
||||
);
|
||||
expect(resp_obj.responses).toHaveLength(2);
|
||||
expect(resp_obj.raw_response).toHaveLength(2); // these should've been merged
|
||||
});
|
||||
expect(Object.keys(cache)).toHaveLength(3); // still expect 3 prompts
|
||||
|
||||
@ -63,6 +64,7 @@ async function prompt_model(model: LLM): Promise<void> {
|
||||
for await (const response of pipeline.gen_responses(
|
||||
{ thing: ["bar", "tree", "book"] },
|
||||
model,
|
||||
provider,
|
||||
2,
|
||||
1.0,
|
||||
)) {
|
||||
@ -79,20 +81,19 @@ async function prompt_model(model: LLM): Promise<void> {
|
||||
Object.entries(cache).forEach(([prompt, response]) => {
|
||||
const resp_obj = Array.isArray(response) ? response[0] : response;
|
||||
expect(resp_obj.responses).toHaveLength(2);
|
||||
expect(resp_obj.raw_response).toHaveLength(2); // these should've been merged
|
||||
});
|
||||
expect(Object.keys(cache)).toHaveLength(3); // still expect 3 prompts
|
||||
}
|
||||
|
||||
test("basic prompt pipeline with chatgpt", async () => {
|
||||
// Setup a simple pipeline with a prompt template, 1 variable and 3 input values
|
||||
await prompt_model(NativeLLM.OpenAI_ChatGPT);
|
||||
await prompt_model(NativeLLM.OpenAI_ChatGPT, LLMProvider.OpenAI);
|
||||
}, 20000);
|
||||
|
||||
test("basic prompt pipeline with anthropic", async () => {
|
||||
await prompt_model(NativeLLM.Claude_v1);
|
||||
await prompt_model(NativeLLM.Claude_v1, LLMProvider.Anthropic);
|
||||
}, 40000);
|
||||
|
||||
test("basic prompt pipeline with google palm2", async () => {
|
||||
await prompt_model(NativeLLM.PaLM2_Chat_Bison);
|
||||
await prompt_model(NativeLLM.PaLM2_Chat_Bison, LLMProvider.Google);
|
||||
}, 40000);
|
||||
|
@ -9,7 +9,7 @@ import {
|
||||
extract_responses,
|
||||
merge_response_objs,
|
||||
} from "../utils";
|
||||
import { NativeLLM } from "../models";
|
||||
import { LLMProvider, NativeLLM } from "../models";
|
||||
import { expect, test } from "@jest/globals";
|
||||
import { RawLLMResponseObject } from "../typing";
|
||||
|
||||
@ -17,7 +17,6 @@ test("merge response objects", () => {
|
||||
// Merging two response objects
|
||||
const A: RawLLMResponseObject = {
|
||||
responses: ["x", "y", "z"],
|
||||
raw_response: ["x", "y", "z"],
|
||||
prompt: "this is a test",
|
||||
query: {},
|
||||
llm: NativeLLM.OpenAI_ChatGPT,
|
||||
@ -27,7 +26,6 @@ test("merge response objects", () => {
|
||||
};
|
||||
const B: RawLLMResponseObject = {
|
||||
responses: ["a", "b", "c"],
|
||||
raw_response: { B: "B" },
|
||||
prompt: "this is a test 2",
|
||||
query: {},
|
||||
llm: NativeLLM.OpenAI_ChatGPT,
|
||||
@ -40,7 +38,6 @@ test("merge response objects", () => {
|
||||
expect(JSON.stringify(C.responses)).toBe(
|
||||
JSON.stringify(["x", "y", "z", "a", "b", "c"]),
|
||||
);
|
||||
expect(C.raw_response).toHaveLength(4);
|
||||
expect(Object.keys(C.vars)).toHaveLength(2);
|
||||
expect(Object.keys(C.vars)).toContain("varB1");
|
||||
expect(Object.keys(C.metavars)).toHaveLength(1);
|
||||
@ -68,7 +65,11 @@ test("openai chat completions", async () => {
|
||||
expect(query).toHaveProperty("temperature");
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, NativeLLM.OpenAI_ChatGPT);
|
||||
const resps = extract_responses(
|
||||
response,
|
||||
NativeLLM.OpenAI_ChatGPT,
|
||||
LLMProvider.OpenAI,
|
||||
);
|
||||
expect(resps).toHaveLength(2);
|
||||
expect(typeof resps[0]).toBe("string");
|
||||
}, 20000);
|
||||
@ -86,7 +87,11 @@ test("openai text completions", async () => {
|
||||
expect(query).toHaveProperty("n");
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, NativeLLM.OpenAI_Davinci003);
|
||||
const resps = extract_responses(
|
||||
response,
|
||||
NativeLLM.OpenAI_Davinci003,
|
||||
LLMProvider.OpenAI,
|
||||
);
|
||||
expect(resps).toHaveLength(2);
|
||||
expect(typeof resps[0]).toBe("string");
|
||||
}, 20000);
|
||||
@ -104,7 +109,11 @@ test("anthropic models", async () => {
|
||||
expect(query).toHaveProperty("max_tokens_to_sample");
|
||||
|
||||
// Extract responses, check their type
|
||||
const resps = extract_responses(response, NativeLLM.Claude_v1);
|
||||
const resps = extract_responses(
|
||||
response,
|
||||
NativeLLM.Claude_v1,
|
||||
LLMProvider.Anthropic,
|
||||
);
|
||||
expect(resps).toHaveLength(1);
|
||||
expect(typeof resps[0]).toBe("string");
|
||||
}, 20000);
|
||||
@ -121,7 +130,11 @@ test("google palm2 models", async () => {
|
||||
expect(query).toHaveProperty("candidateCount");
|
||||
|
||||
// Extract responses, check their type
|
||||
let resps = extract_responses(response, NativeLLM.PaLM2_Chat_Bison);
|
||||
let resps = extract_responses(
|
||||
response,
|
||||
NativeLLM.PaLM2_Chat_Bison,
|
||||
LLMProvider.Google,
|
||||
);
|
||||
expect(resps).toHaveLength(3);
|
||||
expect(typeof resps[0]).toBe("string");
|
||||
console.log(JSON.stringify(resps));
|
||||
@ -137,7 +150,11 @@ test("google palm2 models", async () => {
|
||||
expect(query).toHaveProperty("maxOutputTokens");
|
||||
|
||||
// Extract responses, check their type
|
||||
resps = extract_responses(response, NativeLLM.PaLM2_Chat_Bison);
|
||||
resps = extract_responses(
|
||||
response,
|
||||
NativeLLM.PaLM2_Chat_Bison,
|
||||
LLMProvider.Google,
|
||||
);
|
||||
expect(resps).toHaveLength(3);
|
||||
expect(typeof resps[0]).toBe("string");
|
||||
console.log(JSON.stringify(resps));
|
||||
@ -153,7 +170,11 @@ test("aleph alpha model", async () => {
|
||||
expect(response).toHaveLength(3);
|
||||
|
||||
// Extract responses, check their type
|
||||
let resps = extract_responses(response, NativeLLM.Aleph_Alpha_Luminous_Base);
|
||||
let resps = extract_responses(
|
||||
response,
|
||||
NativeLLM.Aleph_Alpha_Luminous_Base,
|
||||
LLMProvider.Aleph_Alpha,
|
||||
);
|
||||
expect(resps).toHaveLength(3);
|
||||
expect(typeof resps[0]).toBe("string");
|
||||
console.log(JSON.stringify(resps));
|
||||
@ -169,7 +190,11 @@ test("aleph alpha model", async () => {
|
||||
expect(response).toHaveLength(3);
|
||||
|
||||
// Extract responses, check their type
|
||||
resps = extract_responses(response, NativeLLM.Aleph_Alpha_Luminous_Base);
|
||||
resps = extract_responses(
|
||||
response,
|
||||
NativeLLM.Aleph_Alpha_Luminous_Base,
|
||||
LLMProvider.Aleph_Alpha,
|
||||
);
|
||||
expect(resps).toHaveLength(3);
|
||||
expect(typeof resps[0]).toBe("string");
|
||||
console.log(JSON.stringify(resps));
|
||||
|
@ -9,7 +9,7 @@ import {
|
||||
} from "./template";
|
||||
import { ChatHistoryInfo, Dict, TabularDataColType } from "./typing";
|
||||
import { fromMarkdown } from "mdast-util-from-markdown";
|
||||
import { sampleRandomElements } from "./utils";
|
||||
import { llmResponseDataToString, sampleRandomElements } from "./utils";
|
||||
|
||||
export class AIError extends Error {
|
||||
constructor(message: string) {
|
||||
@ -314,7 +314,7 @@ export async function autofill(
|
||||
if (result.errors && Object.keys(result.errors).length > 0)
|
||||
throw new Error(Object.values(result.errors)[0].toString());
|
||||
|
||||
const output = result.responses[0].responses[0] as string;
|
||||
const output = llmResponseDataToString(result.responses[0].responses[0]);
|
||||
|
||||
console.log("LLM said: ", output);
|
||||
|
||||
@ -381,7 +381,7 @@ export async function autofillTable(
|
||||
throw new Error(Object.values(result.errors)[0].toString());
|
||||
|
||||
// Extract the output from the LLM response
|
||||
const output = result.responses[0].responses[0] as string;
|
||||
const output = llmResponseDataToString(result.responses[0].responses[0]);
|
||||
console.log("LLM said: ", output);
|
||||
const newRows = decodeTable(output).rows;
|
||||
|
||||
@ -451,7 +451,7 @@ ${prompt}: ?`;
|
||||
throw new AIError(Object.values(result.errors)[0].toString());
|
||||
}
|
||||
|
||||
const output = result.responses[0].responses[0] as string;
|
||||
const output = llmResponseDataToString(result.responses[0].responses[0]);
|
||||
return output.trim();
|
||||
}
|
||||
|
||||
@ -484,7 +484,10 @@ export async function generateColumn(
|
||||
apiKeys,
|
||||
true,
|
||||
);
|
||||
colName = (result.responses[0].responses[0] as string).replace("_", " ");
|
||||
colName = llmResponseDataToString(result.responses[0].responses[0]).replace(
|
||||
"_",
|
||||
" ",
|
||||
);
|
||||
}
|
||||
|
||||
// Remove any leading/trailing whitespace from the column name as well as any double quotes
|
||||
@ -573,9 +576,10 @@ export async function generateAndReplace(
|
||||
if (result.errors && Object.keys(result.errors).length > 0)
|
||||
throw new Error(Object.values(result.errors)[0].toString());
|
||||
|
||||
console.log("LLM said: ", result.responses[0].responses[0]);
|
||||
const resp = llmResponseDataToString(result.responses[0].responses[0]);
|
||||
console.log("LLM said: ", resp);
|
||||
|
||||
const new_items = decode(result.responses[0].responses[0] as string);
|
||||
const new_items = decode(resp);
|
||||
return new_items.slice(0, n);
|
||||
}
|
||||
|
||||
@ -633,7 +637,7 @@ export async function generateAndReplaceTable(
|
||||
console.log("LLM said: ", result.responses[0].responses[0]);
|
||||
|
||||
const { cols: new_cols, rows: new_rows } = decodeTable(
|
||||
result.responses[0].responses[0] as string,
|
||||
llmResponseDataToString(result.responses[0].responses[0]),
|
||||
);
|
||||
|
||||
// Return the generated table with "n" number of rows
|
||||
|
@ -1,4 +1,5 @@
|
||||
import MarkdownIt from "markdown-it";
|
||||
import axios from "axios";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import {
|
||||
Dict,
|
||||
@ -12,11 +13,12 @@ import {
|
||||
EvaluationScore,
|
||||
LLMSpec,
|
||||
EvaluatedResponsesResults,
|
||||
TemplateVarInfo,
|
||||
CustomLLMProviderSpec,
|
||||
LLMResponseData,
|
||||
PromptVarType,
|
||||
StringOrHash,
|
||||
} from "./typing";
|
||||
import { LLM, getEnumName } from "./models";
|
||||
import { LLM, LLMProvider, getEnumName, getProvider } from "./models";
|
||||
import {
|
||||
APP_IS_RUNNING_LOCALLY,
|
||||
set_api_keys,
|
||||
@ -26,8 +28,9 @@ import {
|
||||
areEqualVarsDicts,
|
||||
repairCachedResponses,
|
||||
deepcopy,
|
||||
llmResponseDataToString,
|
||||
} from "./utils";
|
||||
import StorageCache from "./cache";
|
||||
import StorageCache, { StringLookup } from "./cache";
|
||||
import { PromptPipeline } from "./query";
|
||||
import {
|
||||
PromptPermutationGenerator,
|
||||
@ -38,6 +41,7 @@ import {
|
||||
import { UserForcedPrematureExit } from "./errors";
|
||||
import CancelTracker from "./canceler";
|
||||
import { execPy } from "./pyodide/exec-py";
|
||||
import { baseModelToProvider } from "../ModelSettingSchemas";
|
||||
|
||||
// """ =================
|
||||
// SETUP AND GLOBALS
|
||||
@ -215,19 +219,30 @@ function gen_unique_cache_filename(
|
||||
return `${cache_id}_${idx}.json`;
|
||||
}
|
||||
|
||||
function extract_llm_nickname(llm_spec: Dict | string) {
|
||||
function extract_llm_nickname(llm_spec: StringOrHash | LLMSpec): string {
|
||||
if (typeof llm_spec === "object" && llm_spec.name !== undefined)
|
||||
return llm_spec.name;
|
||||
else return llm_spec;
|
||||
else
|
||||
return (
|
||||
StringLookup.get(llm_spec as StringOrHash) ?? "(string lookup failed)"
|
||||
);
|
||||
}
|
||||
|
||||
function extract_llm_name(llm_spec: Dict | string): string {
|
||||
if (typeof llm_spec === "string") return llm_spec;
|
||||
function extract_llm_name(llm_spec: StringOrHash | LLMSpec): string {
|
||||
if (typeof llm_spec === "string" || typeof llm_spec === "number")
|
||||
return StringLookup.get(llm_spec) ?? "(string lookup failed)";
|
||||
else return llm_spec.model;
|
||||
}
|
||||
|
||||
function extract_llm_key(llm_spec: Dict | string): string {
|
||||
if (typeof llm_spec === "string") return llm_spec;
|
||||
function extract_llm_provider(llm_spec: StringOrHash | LLMSpec): LLMProvider {
|
||||
if (typeof llm_spec === "string" || typeof llm_spec === "number")
|
||||
return getProvider(StringLookup.get(llm_spec) ?? "") ?? LLMProvider.Custom;
|
||||
else return baseModelToProvider(llm_spec.base_model);
|
||||
}
|
||||
|
||||
function extract_llm_key(llm_spec: StringOrHash | LLMSpec): string {
|
||||
if (typeof llm_spec === "string" || typeof llm_spec === "number")
|
||||
return StringLookup.get(llm_spec) ?? "(string lookup failed)";
|
||||
else if (llm_spec.key !== undefined) return llm_spec.key;
|
||||
else
|
||||
throw new Error(
|
||||
@ -237,7 +252,7 @@ function extract_llm_key(llm_spec: Dict | string): string {
|
||||
);
|
||||
}
|
||||
|
||||
function extract_llm_params(llm_spec: Dict | string): Dict {
|
||||
function extract_llm_params(llm_spec: StringOrHash | LLMSpec): Dict {
|
||||
if (typeof llm_spec === "object" && llm_spec.settings !== undefined)
|
||||
return llm_spec.settings;
|
||||
else return {};
|
||||
@ -250,8 +265,10 @@ function filterVarsByLLM(vars: PromptVarsDict, llm_key: string): Dict {
|
||||
_vars[key] = vs.filter(
|
||||
(v) =>
|
||||
typeof v === "string" ||
|
||||
typeof v === "number" ||
|
||||
v?.llm === undefined ||
|
||||
typeof v.llm === "string" ||
|
||||
typeof v.llm === "number" ||
|
||||
v.llm.key === llm_key,
|
||||
);
|
||||
});
|
||||
@ -294,8 +311,8 @@ function isLooselyEqual(value1: any, value2: any): boolean {
|
||||
* determines whether the response query used the same parameters.
|
||||
*/
|
||||
function matching_settings(
|
||||
cache_llm_spec: Dict | string,
|
||||
llm_spec: Dict | string,
|
||||
cache_llm_spec: StringOrHash | LLMSpec,
|
||||
llm_spec: StringOrHash | LLMSpec,
|
||||
): boolean {
|
||||
if (extract_llm_name(cache_llm_spec) !== extract_llm_name(llm_spec))
|
||||
return false;
|
||||
@ -398,10 +415,7 @@ async function run_over_responses(
|
||||
const evald_resps: Promise<LLMResponse>[] = responses.map(
|
||||
async (_resp_obj: LLMResponse) => {
|
||||
// Deep clone the response object
|
||||
const resp_obj = JSON.parse(JSON.stringify(_resp_obj));
|
||||
|
||||
// Clean up any escaped braces
|
||||
resp_obj.responses = resp_obj.responses.map(cleanEscapedBraces);
|
||||
const resp_obj: LLMResponse = JSON.parse(JSON.stringify(_resp_obj));
|
||||
|
||||
// Whether the processor function is async or not
|
||||
const async_processor =
|
||||
@ -410,12 +424,12 @@ async function run_over_responses(
|
||||
// Map the processor func over every individual response text in each response object
|
||||
const res = resp_obj.responses;
|
||||
const llm_name = extract_llm_nickname(resp_obj.llm);
|
||||
let processed = res.map((r: string) => {
|
||||
let processed = res.map((r: LLMResponseData) => {
|
||||
const r_info = new ResponseInfo(
|
||||
r,
|
||||
resp_obj.prompt,
|
||||
resp_obj.vars,
|
||||
resp_obj.metavars || {},
|
||||
cleanEscapedBraces(llmResponseDataToString(r)),
|
||||
StringLookup.get(resp_obj.prompt) ?? "",
|
||||
StringLookup.concretizeDict(resp_obj.vars),
|
||||
StringLookup.concretizeDict(resp_obj.metavars) || {},
|
||||
llm_name,
|
||||
);
|
||||
|
||||
@ -455,7 +469,8 @@ async function run_over_responses(
|
||||
// Store items with summary of mean, median, etc
|
||||
resp_obj.eval_res = {
|
||||
items: processed,
|
||||
dtype: getEnumName(MetricType, eval_res_type),
|
||||
dtype: (getEnumName(MetricType, eval_res_type) ??
|
||||
"Unknown") as keyof typeof MetricType,
|
||||
};
|
||||
} else if (
|
||||
[MetricType.Unknown, MetricType.Empty].includes(eval_res_type)
|
||||
@ -467,7 +482,8 @@ async function run_over_responses(
|
||||
// Categorical, KeyValue, etc, we just store the items:
|
||||
resp_obj.eval_res = {
|
||||
items: processed,
|
||||
dtype: getEnumName(MetricType, eval_res_type),
|
||||
dtype: (getEnumName(MetricType, eval_res_type) ??
|
||||
"Unknown") as keyof typeof MetricType,
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -492,7 +508,7 @@ async function run_over_responses(
|
||||
*/
|
||||
export async function generatePrompts(
|
||||
root_prompt: string,
|
||||
vars: Dict<(TemplateVarInfo | string)[]>,
|
||||
vars: Dict<PromptVarType[]>,
|
||||
): Promise<PromptTemplate[]> {
|
||||
const gen_prompts = new PromptPermutationGenerator(root_prompt);
|
||||
const all_prompt_permutations = Array.from(
|
||||
@ -517,7 +533,7 @@ export async function generatePrompts(
|
||||
export async function countQueries(
|
||||
prompt: string,
|
||||
vars: PromptVarsDict,
|
||||
llms: Array<Dict | string>,
|
||||
llms: Array<StringOrHash | LLMSpec>,
|
||||
n: number,
|
||||
chat_histories?:
|
||||
| (ChatHistoryInfo | undefined)[]
|
||||
@ -591,7 +607,9 @@ export async function countQueries(
|
||||
found_cache = true;
|
||||
|
||||
// Load the cache file
|
||||
const cache_llm_responses = load_from_cache(cache_filename);
|
||||
const cache_llm_responses: Dict<
|
||||
RawLLMResponseObject[] | RawLLMResponseObject
|
||||
> = load_from_cache(cache_filename);
|
||||
|
||||
// Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt
|
||||
_all_prompt_perms.forEach((prompt) => {
|
||||
@ -680,6 +698,40 @@ export async function fetchEnvironAPIKeys(): Promise<Dict<string>> {
|
||||
}).then((res) => res.json());
|
||||
}
|
||||
|
||||
export async function saveFlowToLocalFilesystem(
|
||||
flowJSON: Dict,
|
||||
filename: string,
|
||||
): Promise<void> {
|
||||
try {
|
||||
await axios.put(`${FLASK_BASE_URL}api/flows/${filename}`, {
|
||||
flow: flowJSON,
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Error saving flow with name ${filename}: ${(error as Error).toString()}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export async function ensureUniqueFlowFilename(
|
||||
filename: string,
|
||||
): Promise<string> {
|
||||
try {
|
||||
const response = await axios.put(
|
||||
`${FLASK_BASE_URL}api/getUniqueFlowFilename`,
|
||||
{
|
||||
name: filename,
|
||||
},
|
||||
);
|
||||
return response.data as string;
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error contact Flask to ensure unique filename for imported flow. Defaulting to passed filename (warning: risk this overrides an existing flow.) Error: ${(error as Error).toString()}`,
|
||||
);
|
||||
return filename;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Queries LLM(s) with root prompt template `prompt` and prompt input variables `vars`, `n` times per prompt.
|
||||
* Soft-fails if API calls fail, and collects the errors in `errors` property of the return object.
|
||||
@ -766,7 +818,7 @@ export async function queryLLM(
|
||||
// Create a new cache JSON object
|
||||
cache = { cache_files: {}, responses_last_run: [] };
|
||||
const prev_filenames: Array<string> = [];
|
||||
llms.forEach((llm_spec: string | Dict) => {
|
||||
llms.forEach((llm_spec) => {
|
||||
const fname = gen_unique_cache_filename(id, prev_filenames);
|
||||
llm_to_cache_filename[extract_llm_key(llm_spec)] = fname;
|
||||
cache.cache_files[fname] = llm_spec;
|
||||
@ -800,9 +852,12 @@ export async function queryLLM(
|
||||
const responses: { [key: string]: Array<RawLLMResponseObject> } = {};
|
||||
const all_errors: Dict<string[]> = {};
|
||||
const num_generations = n ?? 1;
|
||||
async function query(llm_spec: string | Dict): Promise<LLMPrompterResults> {
|
||||
async function query(
|
||||
llm_spec: StringOrHash | LLMSpec,
|
||||
): Promise<LLMPrompterResults> {
|
||||
// Get LLM model name and any params
|
||||
const llm_str = extract_llm_name(llm_spec);
|
||||
const llm_provider = extract_llm_provider(llm_spec);
|
||||
const llm_nickname = extract_llm_nickname(llm_spec);
|
||||
const llm_params = extract_llm_params(llm_spec);
|
||||
const llm_key = extract_llm_key(llm_spec);
|
||||
@ -842,6 +897,7 @@ export async function queryLLM(
|
||||
for await (const response of prompter.gen_responses(
|
||||
_vars,
|
||||
llm_str as LLM,
|
||||
llm_provider,
|
||||
num_generations,
|
||||
temperature,
|
||||
llm_params,
|
||||
@ -1218,11 +1274,12 @@ export async function executepy(
|
||||
*/
|
||||
export async function evalWithLLM(
|
||||
id: string,
|
||||
llm: string | LLMSpec,
|
||||
llm: LLMSpec,
|
||||
root_prompt: string,
|
||||
response_ids: string | string[],
|
||||
api_keys?: Dict,
|
||||
progress_listener?: (progress: { [key: symbol]: any }) => void,
|
||||
cancel_id?: string | number,
|
||||
): Promise<{ responses?: LLMResponse[]; errors: string[] }> {
|
||||
// Check format of response_ids
|
||||
if (!Array.isArray(response_ids)) response_ids = [response_ids];
|
||||
@ -1242,17 +1299,27 @@ export async function evalWithLLM(
|
||||
const resp_objs = (load_cache_responses(fname) as LLMResponse[]).map((r) =>
|
||||
JSON.parse(JSON.stringify(r)),
|
||||
) as LLMResponse[];
|
||||
|
||||
if (resp_objs.length === 0) continue;
|
||||
|
||||
console.log(resp_objs);
|
||||
|
||||
// We need to keep track of the index of each response in the response object.
|
||||
// We can generate var dicts with metadata to store the indices:
|
||||
const inputs = resp_objs
|
||||
.map((obj, __i) =>
|
||||
obj.responses.map((r: LLMResponseData, __j: number) => ({
|
||||
text: typeof r === "string" ? escapeBraces(r) : undefined,
|
||||
text:
|
||||
typeof r === "string" || typeof r === "number"
|
||||
? escapeBraces(StringLookup.get(r) ?? "(string lookup failed)")
|
||||
: undefined,
|
||||
image: typeof r === "object" && r.t === "img" ? r.d : undefined,
|
||||
fill_history: obj.vars,
|
||||
metavars: { ...obj.metavars, __i, __j },
|
||||
metavars: {
|
||||
...obj.metavars,
|
||||
__i: __i.toString(),
|
||||
__j: __j.toString(),
|
||||
},
|
||||
})),
|
||||
)
|
||||
.flat();
|
||||
@ -1263,11 +1330,13 @@ export async function evalWithLLM(
|
||||
[llm],
|
||||
1,
|
||||
root_prompt,
|
||||
{ input: inputs },
|
||||
{ __input: inputs },
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
progress_listener,
|
||||
false,
|
||||
cancel_id,
|
||||
);
|
||||
|
||||
const err_vals: string[] = Object.values(errors).flat();
|
||||
@ -1276,15 +1345,17 @@ export async function evalWithLLM(
|
||||
// Now we need to apply each response as an eval_res (a score) back to each response object,
|
||||
// using the aforementioned mapping metadata:
|
||||
responses.forEach((r: LLMResponse) => {
|
||||
const resp_obj = resp_objs[r.metavars.__i];
|
||||
const __i = parseInt(StringLookup.get(r.metavars.__i) ?? "");
|
||||
const __j = parseInt(StringLookup.get(r.metavars.__j) ?? "");
|
||||
const resp_obj = resp_objs[__i];
|
||||
if (resp_obj.eval_res !== undefined)
|
||||
resp_obj.eval_res.items[r.metavars.__j] = r.responses[0];
|
||||
resp_obj.eval_res.items[__j] = llmResponseDataToString(r.responses[0]);
|
||||
else {
|
||||
resp_obj.eval_res = {
|
||||
items: [],
|
||||
dtype: "Categorical",
|
||||
};
|
||||
resp_obj.eval_res.items[r.metavars.__j] = r.responses[0];
|
||||
resp_obj.eval_res.items[__j] = llmResponseDataToString(r.responses[0]);
|
||||
}
|
||||
});
|
||||
|
||||
@ -1412,8 +1483,8 @@ export async function exportCache(ids: string[]): Promise<Dict<Dict>> {
|
||||
}
|
||||
// Bundle up specific other state in StorageCache, which
|
||||
// includes things like human ratings for responses:
|
||||
const cache_state = StorageCache.getAllMatching((key) =>
|
||||
key.startsWith("r."),
|
||||
const cache_state = StorageCache.getAllMatching(
|
||||
(key) => key.startsWith("r.") || key === "__s",
|
||||
);
|
||||
return { ...cache_files, ...cache_state };
|
||||
}
|
||||
@ -1438,6 +1509,9 @@ export async function importCache(files: {
|
||||
Object.entries(files).forEach(([filename, data]) => {
|
||||
StorageCache.store(filename, data);
|
||||
});
|
||||
|
||||
// Load StringLookup table from cache
|
||||
StringLookup.restoreFrom(StorageCache.get("__s"));
|
||||
} catch (err) {
|
||||
throw new Error("Error importing from cache:" + (err as Error).message);
|
||||
}
|
||||
|
@ -72,6 +72,7 @@ export default class StorageCache {
|
||||
*/
|
||||
public static clear(key?: string): void {
|
||||
StorageCache.getInstance().clearCache(key);
|
||||
if (key === undefined) StringLookup.restoreFrom([]);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -117,11 +118,12 @@ export default class StorageCache {
|
||||
* Performs lz-string decompression from UTF16 encoding.
|
||||
*
|
||||
* @param localStorageKey The key that will be used in localStorage (default='chainforge')
|
||||
* @param replaceStorageCacheWithLoadedData Whether the data in the StorageCache should be saved with the loaded data. Erases all current memory. Only set this to true if you are replacing the ChainForge flow state entirely.
|
||||
* @returns Loaded data if succeeded, undefined if failure (e.g., key not found).
|
||||
*/
|
||||
public static loadFromLocalStorage(
|
||||
localStorageKey = "chainforge",
|
||||
setStorageCacheData = true,
|
||||
replaceStorageCacheWithLoadedData = false,
|
||||
): JSONCompatible | undefined {
|
||||
const compressed = localStorage.getItem(localStorageKey);
|
||||
if (!compressed) {
|
||||
@ -132,7 +134,12 @@ export default class StorageCache {
|
||||
}
|
||||
try {
|
||||
const data = JSON.parse(LZString.decompressFromUTF16(compressed));
|
||||
if (setStorageCacheData) StorageCache.getInstance().data = data;
|
||||
if (replaceStorageCacheWithLoadedData) {
|
||||
// Replaces the current cache data with the loaded data
|
||||
StorageCache.getInstance().data = data;
|
||||
// Restores the current StringLookup table with the contents of the loaded data, if the __s key is present.
|
||||
StringLookup.restoreFrom(data.__s);
|
||||
}
|
||||
console.log("loaded", data);
|
||||
return data;
|
||||
} catch (error) {
|
||||
@ -141,3 +148,171 @@ export default class StorageCache {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Global string intern table for efficient storage of repeated strings */
|
||||
export class StringLookup {
|
||||
// eslint-disable-next-line no-use-before-define
|
||||
private static instance: StringLookup;
|
||||
private stringToIndex: Map<string, number> = new Map();
|
||||
private indexToString: string[] = [];
|
||||
|
||||
/** Gets the string intern lookup table. Initializes it if the singleton instance does not yet exist. */
|
||||
public static getInstance(): StringLookup {
|
||||
if (!StringLookup.instance) {
|
||||
StringLookup.instance = new StringLookup();
|
||||
}
|
||||
return StringLookup.instance;
|
||||
}
|
||||
|
||||
/** Adds a string to the table and returns its index */
|
||||
public static intern(str: string): number {
|
||||
const s = StringLookup.getInstance();
|
||||
if (s.stringToIndex.has(str)) {
|
||||
return s.stringToIndex.get(str)!; // Return existing index
|
||||
}
|
||||
|
||||
// Add new string to the table
|
||||
const index = s.indexToString.length;
|
||||
s.indexToString.push(str);
|
||||
s.stringToIndex.set(str, index);
|
||||
|
||||
// Save to cache
|
||||
StorageCache.store("__s", s.indexToString);
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
// Overloaded signatures
|
||||
// This tells TypeScript that a number or string will always produce a string or undefined,
|
||||
// whereas any other type T will return the same type.
|
||||
public static get(index: number | string | undefined): string | undefined;
|
||||
public static get<T>(index: T): T;
|
||||
|
||||
/**
|
||||
* Retrieves the string in the lookup table, given its index.
|
||||
* - **Note**: This function soft fails: if index is not a number, returns index unchanged.
|
||||
*/
|
||||
public static get<T>(index: T | number): T | string {
|
||||
if (typeof index !== "number") return index;
|
||||
const s = StringLookup.getInstance();
|
||||
return s.indexToString[index]; // O(1) lookup
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms a Dict by interning all strings encountered, up to 1 level of depth,
|
||||
* and returning the modified Dict with the strings as hash indexes instead.
|
||||
*
|
||||
* NOTE: This ignores recursing into any key "llm" that has a dict component.
|
||||
*/
|
||||
public static internDict(
|
||||
d: Dict,
|
||||
inplace?: boolean,
|
||||
depth = 1,
|
||||
ignoreKey = ["llm", "uid", "eval_res"],
|
||||
): Dict {
|
||||
const newDict = inplace ? d : ({} as Dict);
|
||||
const entries = Object.entries(d);
|
||||
|
||||
for (const [key, value] of entries) {
|
||||
if (ignoreKey.includes(key)) {
|
||||
// Keep the ignored key the same
|
||||
if (!inplace) newDict[key] = value;
|
||||
continue;
|
||||
}
|
||||
if (typeof value === "string") {
|
||||
newDict[key] = StringLookup.intern(value);
|
||||
} else if (
|
||||
Array.isArray(value) &&
|
||||
value.every((v) => typeof v === "string")
|
||||
) {
|
||||
newDict[key] = value.map((v) => StringLookup.intern(v));
|
||||
} else if (depth > 0 && typeof value === "object" && value !== null) {
|
||||
newDict[key] = StringLookup.internDict(
|
||||
value as Dict,
|
||||
inplace,
|
||||
depth - 1,
|
||||
);
|
||||
} else {
|
||||
if (!inplace) newDict[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return newDict as Map<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Treats all numberic values in the dictionary as hashes, and maps them to strings.
|
||||
* Leaves the rest of the dict unchanged. (Only operates 1 level deep.)
|
||||
* @param d The dictionary to operate over
|
||||
*/
|
||||
public static concretizeDict<T>(
|
||||
d: Dict<T | number>,
|
||||
inplace = false,
|
||||
depth = 1,
|
||||
ignoreKey = ["llm", "uid", "eval_res"],
|
||||
): Dict<T | string> {
|
||||
const newDict = inplace ? d : ({} as Dict);
|
||||
const entries = Object.entries(d);
|
||||
for (const [key, value] of entries) {
|
||||
const ignore = ignoreKey.includes(key);
|
||||
if (!ignore && typeof value === "number")
|
||||
newDict[key] = StringLookup.get(value);
|
||||
else if (
|
||||
!ignore &&
|
||||
Array.isArray(value) &&
|
||||
value.every((v) => typeof v === "number")
|
||||
)
|
||||
newDict[key] = value.map((v) => StringLookup.get(v));
|
||||
else if (
|
||||
!ignore &&
|
||||
depth > 0 &&
|
||||
typeof value === "object" &&
|
||||
value !== null
|
||||
) {
|
||||
newDict[key] = StringLookup.concretizeDict(
|
||||
value as Dict<unknown>,
|
||||
false,
|
||||
0,
|
||||
);
|
||||
} else if (!inplace) newDict[key] = value;
|
||||
}
|
||||
return newDict;
|
||||
}
|
||||
|
||||
public static restoreFrom(savedIndexToString?: string[]): void {
|
||||
const s = StringLookup.getInstance();
|
||||
s.stringToIndex = new Map<string, number>();
|
||||
if (savedIndexToString === undefined || savedIndexToString.length === 0) {
|
||||
// Reset
|
||||
s.indexToString = [];
|
||||
return;
|
||||
} else if (!Array.isArray(savedIndexToString)) {
|
||||
// Reset, but warn user
|
||||
console.error(
|
||||
"String lookup table could not be loaded: data.__s is not an array.",
|
||||
);
|
||||
s.indexToString = [];
|
||||
return;
|
||||
}
|
||||
|
||||
// Recreate from the index array
|
||||
s.indexToString = savedIndexToString;
|
||||
savedIndexToString.forEach((v, i) => {
|
||||
s.stringToIndex.set(v, i);
|
||||
});
|
||||
}
|
||||
|
||||
/** Serializes interned strings and their mappings */
|
||||
public static toJSON() {
|
||||
const s = StringLookup.getInstance();
|
||||
return s.indexToString;
|
||||
}
|
||||
|
||||
/** Restores from JSON */
|
||||
static fromJSON(data: { dictionary: string[] }) {
|
||||
const table = new StringLookup();
|
||||
table.indexToString = data.dictionary;
|
||||
table.stringToIndex = new Map(data.dictionary.map((str, i) => [str, i]));
|
||||
StringLookup.instance = table;
|
||||
}
|
||||
}
|
||||
|
@ -74,6 +74,15 @@ export enum NativeLLM {
|
||||
PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing
|
||||
PaLM2_Chat_Bison = "chat-bison-001",
|
||||
GEMINI_PRO = "gemini-pro",
|
||||
GEMINI_v2_flash = "gemini-2.0-flash-exp",
|
||||
GEMINI_v1_5_flash = "gemini-1.5-flash",
|
||||
GEMINI_v1_5_flash_8B = "gemini-1.5-flash-8b",
|
||||
GEMINI_v1_5_pro = "gemini-1.5-pro",
|
||||
GEMINI_v1_pro = "gemini-1.0-pro",
|
||||
|
||||
// DeepSeek
|
||||
DeepSeek_Chat = "deepseek-chat",
|
||||
DeepSeek_Reasoner = "deepseek-reasoner",
|
||||
|
||||
// Aleph Alpha
|
||||
Aleph_Alpha_Luminous_Extended = "luminous-extended",
|
||||
@ -92,12 +101,10 @@ export enum NativeLLM {
|
||||
HF_DIALOGPT_LARGE = "microsoft/DialoGPT-large", // chat model
|
||||
HF_GPT2 = "gpt2",
|
||||
HF_BLOOM_560M = "bigscience/bloom-560m",
|
||||
// HF_GPTJ_6B = "EleutherAI/gpt-j-6b",
|
||||
// HF_LLAMA_7B = "decapoda-research/llama-7b-hf",
|
||||
|
||||
// A special flag for a user-defined HuggingFace model endpoint.
|
||||
// The actual model name will be passed as a param to the LLM call function.
|
||||
HF_OTHER = "Other (HuggingFace)",
|
||||
|
||||
Ollama = "ollama",
|
||||
|
||||
Bedrock_Claude_2_1 = "anthropic.claude-v2:1",
|
||||
@ -146,6 +153,29 @@ export enum NativeLLM {
|
||||
Together_Meta_LLaMA2_Chat_7B = "together/meta-llama/Llama-2-7b-chat-hf",
|
||||
Together_Meta_LLaMA3_Chat_8B = "together/meta-llama/Llama-3-8b-chat-hf",
|
||||
Together_Meta_LLaMA3_Chat_70B = "together/meta-llama/Llama-3-70b-chat-hf",
|
||||
Together_Meta_LLaMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
Together_Meta_LLaMA3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
Together_Meta_LLaMA3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
Together_Meta_LLaMA3_1_405B = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
Together_Meta_LLaMA3_8B = "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
|
||||
Together_Meta_LLaMA3_70B = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
|
||||
Together_Meta_LLaMA3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
Together_Meta_LLaMA3_8B_Lite = "meta-llama/Meta-Llama-3-8B-Instruct-Lite",
|
||||
Together_Meta_LLaMA3_70B_Lite = "meta-llama/Meta-Llama-3-70B-Instruct-Lite",
|
||||
Together_Nvidia_LLaMA3_1_Nemotron_70B = "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
|
||||
Together_Qwen_Qwen2_5_Coder_32B = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
Together_Qwen_QwQ_32B_Preview = "Qwen/QwQ-32B-Preview",
|
||||
Together_Microsoft_WizardLM_2_8x22B = "microsoft/WizardLM-2-8x22B",
|
||||
Together_Google_Gemma2_27B = "google/gemma-2-27b-it",
|
||||
Together_Google_Gemma2_9B = "google/gemma-2-9b-it",
|
||||
Together_DeepSeek_3 = "deepseek-ai/DeepSeek-V3",
|
||||
Together_DeepSeek_R1 = "deepseek-ai/DeepSeek-R1",
|
||||
Together_mistralai_Mistral_7B_Instruct_v0_3 = "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
Together_Qwen_Qwen2_5_7B_Turbo = "Qwen/Qwen2.5-7B-Instruct-Turbo",
|
||||
Together_Qwen_Qwen2_5_72B_Turbo = "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
Together_Qwen_Qwen2_5_72B = "Qwen/Qwen2-72B-Instruct",
|
||||
Together_Qwen_Qwen2_VL_72B = "Qwen/Qwen2-VL-72B-Instruct",
|
||||
Together_Qwen_Qwen2_5_32B_Coder = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
Together_mistralai_Mistral_7B_Instruct = "together/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
Together_mistralai_Mistral_7B_Instruct_v0_2 = "together/mistralai/Mistral-7B-Instruct-v0.2",
|
||||
Together_mistralai_Mixtral8x7B_Instruct_46_7B = "together/mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
@ -207,6 +237,7 @@ export enum LLMProvider {
|
||||
Ollama = "ollama",
|
||||
Bedrock = "bedrock",
|
||||
Together = "together",
|
||||
DeepSeek = "deepseek",
|
||||
Custom = "__custom",
|
||||
}
|
||||
|
||||
@ -228,6 +259,7 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
|
||||
else if (llm_name?.startsWith("Ollama")) return LLMProvider.Ollama;
|
||||
else if (llm_name?.startsWith("Bedrock")) return LLMProvider.Bedrock;
|
||||
else if (llm_name?.startsWith("Together")) return LLMProvider.Together;
|
||||
else if (llm_name?.startsWith("DeepSeek")) return LLMProvider.DeepSeek;
|
||||
else if (llm.toString().startsWith("__custom/")) return LLMProvider.Custom;
|
||||
|
||||
return undefined;
|
||||
@ -283,10 +315,12 @@ export const RATE_LIMIT_BY_MODEL: { [key in LLM]?: number } = {
|
||||
export const RATE_LIMIT_BY_PROVIDER: { [key in LLMProvider]?: number } = {
|
||||
[LLMProvider.Anthropic]: 25, // Tier 1 pricing limit is 50 per minute, across all models; we halve this, to be safe.
|
||||
[LLMProvider.Together]: 30, // Paid tier limit is 60 per minute, across all models; we halve this, to be safe.
|
||||
[LLMProvider.Google]: 1000, // RPM for Google Gemini models 1.5 is quite generous; at base it is 1000 RPM. If you are using the free version it's 15 RPM, but we can expect most CF users to be using paid (and anyway you can just re-run prompt node until satisfied).
|
||||
[LLMProvider.DeepSeek]: 1000, // DeepSeek does not constrain users atm but they might in the future. To be safe we are limiting it to 1000 queries per minute.
|
||||
};
|
||||
|
||||
// Max concurrent requests. Add to this to further constrain the rate limiter.
|
||||
export const MAX_CONCURRENT: { [key in LLM]?: number } = {};
|
||||
export const MAX_CONCURRENT: { [key in string | NativeLLM]?: number } = {};
|
||||
|
||||
const DEFAULT_RATE_LIMIT = 100; // RPM for any models not listed above
|
||||
|
||||
@ -312,14 +346,14 @@ export class RateLimiter {
|
||||
}
|
||||
|
||||
/** Get the Bottleneck limiter for the given model. If it doesn't already exist, instantiates it dynamically. */
|
||||
private getLimiter(model: LLM): Bottleneck {
|
||||
private getLimiter(model: LLM, provider: LLMProvider): Bottleneck {
|
||||
// Find if there's an existing limiter for this model
|
||||
if (!(model in this.limiters)) {
|
||||
// If there isn't, make one:
|
||||
// Find the RPM. First search if the model is present in predefined rate limits; then search for pre-defined RLs by provider; then set to default.
|
||||
const rpm =
|
||||
RATE_LIMIT_BY_MODEL[model] ??
|
||||
RATE_LIMIT_BY_PROVIDER[getProvider(model) ?? LLMProvider.Custom] ??
|
||||
RATE_LIMIT_BY_PROVIDER[provider ?? LLMProvider.Custom] ??
|
||||
DEFAULT_RATE_LIMIT;
|
||||
this.limiters[model] = new Bottleneck({
|
||||
reservoir: rpm, // max requests per minute
|
||||
@ -340,12 +374,13 @@ export class RateLimiter {
|
||||
*/
|
||||
public static throttle<T>(
|
||||
model: LLM,
|
||||
provider: LLMProvider,
|
||||
func: () => PromiseLike<T>,
|
||||
should_cancel?: () => boolean,
|
||||
): Promise<T> {
|
||||
// Rate limit per model, and abort if the API request takes 3 minutes or more.
|
||||
return this.getInstance()
|
||||
.getLimiter(model)
|
||||
.getLimiter(model, provider)
|
||||
.schedule({}, () => {
|
||||
if (should_cancel && should_cancel())
|
||||
throw new UserForcedPrematureExit();
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { v4 as uuid } from "uuid";
|
||||
import { PromptTemplate, PromptPermutationGenerator } from "./template";
|
||||
import { LLM, NativeLLM, RateLimiter } from "./models";
|
||||
import { LLM, LLMProvider, NativeLLM, RateLimiter } from "./models";
|
||||
import {
|
||||
Dict,
|
||||
LLMResponseError,
|
||||
@ -21,7 +21,7 @@ import {
|
||||
repairCachedResponses,
|
||||
compressBase64Image,
|
||||
} from "./utils";
|
||||
import StorageCache from "./cache";
|
||||
import StorageCache, { StringLookup } from "./cache";
|
||||
import { UserForcedPrematureExit } from "./errors";
|
||||
import { typecastSettingsDict } from "../ModelSettingSchemas";
|
||||
|
||||
@ -75,6 +75,7 @@ export class PromptPipeline {
|
||||
private async collect_LLM_response(
|
||||
result: _IntermediateLLMResponseType,
|
||||
llm: LLM,
|
||||
provider: LLMProvider,
|
||||
cached_responses: Dict,
|
||||
): Promise<RawLLMResponseObject | LLMResponseError> {
|
||||
const {
|
||||
@ -97,7 +98,7 @@ export class PromptPipeline {
|
||||
const metavars = prompt.metavars;
|
||||
|
||||
// Extract and format the responses into `LLMResponseData`
|
||||
const extracted_resps = extract_responses(response, llm);
|
||||
const extracted_resps = extract_responses(response, llm, provider);
|
||||
|
||||
// Detect any images and downrez them if the user has approved of automatic compression.
|
||||
// This saves a lot of performance and storage. We also need to disable storing the raw response here, to save space.
|
||||
@ -130,7 +131,6 @@ export class PromptPipeline {
|
||||
query: query ?? {},
|
||||
uid: uuid(),
|
||||
responses: extracted_resps,
|
||||
raw_response: contains_imgs ? {} : response ?? {}, // don't double-store images
|
||||
llm,
|
||||
vars: mergeDicts(info, chat_history?.fill_history) ?? {},
|
||||
metavars: mergeDicts(metavars, chat_history?.metavars) ?? {},
|
||||
@ -140,6 +140,9 @@ export class PromptPipeline {
|
||||
if (chat_history !== undefined)
|
||||
resp_obj.chat_history = chat_history.messages;
|
||||
|
||||
// Hash strings present in the response object, to improve performance
|
||||
StringLookup.internDict(resp_obj, true);
|
||||
|
||||
// Merge the response obj with the past one, if necessary
|
||||
if (past_resp_obj)
|
||||
resp_obj = merge_response_objs(
|
||||
@ -149,14 +152,14 @@ export class PromptPipeline {
|
||||
|
||||
// Save the current state of cache'd responses to a JSON file
|
||||
// NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
|
||||
if (!(resp_obj.prompt in cached_responses))
|
||||
cached_responses[resp_obj.prompt] = [];
|
||||
else if (!Array.isArray(cached_responses[resp_obj.prompt]))
|
||||
cached_responses[resp_obj.prompt] = [cached_responses[resp_obj.prompt]];
|
||||
const prompt_str = prompt.toString();
|
||||
if (!(prompt_str in cached_responses)) cached_responses[prompt_str] = [];
|
||||
else if (!Array.isArray(cached_responses[prompt_str]))
|
||||
cached_responses[prompt_str] = [cached_responses[prompt_str]];
|
||||
|
||||
if (past_resp_obj_cache_idx !== undefined && past_resp_obj_cache_idx > -1)
|
||||
cached_responses[resp_obj.prompt][past_resp_obj_cache_idx] = resp_obj;
|
||||
else cached_responses[resp_obj.prompt].push(resp_obj);
|
||||
cached_responses[prompt_str][past_resp_obj_cache_idx] = resp_obj;
|
||||
else cached_responses[prompt_str].push(resp_obj);
|
||||
|
||||
this._cache_responses(cached_responses);
|
||||
|
||||
@ -183,6 +186,7 @@ export class PromptPipeline {
|
||||
|
||||
* @param vars The 'vars' dict to fill variables in the root prompt template. For instance, for 'Who is {person}?', vars might = { person: ['TJ', 'MJ', 'AD'] }.
|
||||
* @param llm The specific LLM model to call. See the LLM enum for supported models.
|
||||
* @param provider The specific LLM provider to call. See the LLMProvider enum for supported providers.
|
||||
* @param n How many generations per prompt sent to the LLM.
|
||||
* @param temperature The temperature to use when querying the LLM.
|
||||
* @param llm_params Optional. The model-specific settings to pass into the LLM API call. Varies by LLM.
|
||||
@ -195,6 +199,7 @@ export class PromptPipeline {
|
||||
async *gen_responses(
|
||||
vars: Dict,
|
||||
llm: LLM,
|
||||
provider: LLMProvider,
|
||||
n = 1,
|
||||
temperature = 1.0,
|
||||
llm_params?: Dict,
|
||||
@ -284,11 +289,10 @@ export class PromptPipeline {
|
||||
if (cached_resp && extracted_resps.length >= n) {
|
||||
// console.log(` - Found cache'd response for prompt ${prompt_str}. Using...`);
|
||||
const resp: RawLLMResponseObject = {
|
||||
prompt: prompt_str,
|
||||
prompt: cached_resp.prompt,
|
||||
query: cached_resp.query,
|
||||
uid: cached_resp.uid ?? uuid(),
|
||||
responses: extracted_resps.slice(0, n),
|
||||
raw_response: cached_resp.raw_response,
|
||||
llm: cached_resp.llm || NativeLLM.OpenAI_ChatGPT,
|
||||
// We want to use the new info, since 'vars' could have changed even though
|
||||
// the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
|
||||
@ -297,6 +301,7 @@ export class PromptPipeline {
|
||||
};
|
||||
if (chat_history !== undefined)
|
||||
resp.chat_history = chat_history.messages;
|
||||
|
||||
yield resp;
|
||||
continue;
|
||||
}
|
||||
@ -305,6 +310,7 @@ export class PromptPipeline {
|
||||
tasks.push(
|
||||
this._prompt_llm(
|
||||
llm,
|
||||
provider,
|
||||
prompt,
|
||||
n,
|
||||
temperature,
|
||||
@ -319,7 +325,9 @@ export class PromptPipeline {
|
||||
},
|
||||
chat_history,
|
||||
should_cancel,
|
||||
).then((result) => this.collect_LLM_response(result, llm, responses)),
|
||||
).then((result) =>
|
||||
this.collect_LLM_response(result, llm, provider, responses),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -358,6 +366,7 @@ export class PromptPipeline {
|
||||
|
||||
async _prompt_llm(
|
||||
llm: LLM,
|
||||
provider: LLMProvider,
|
||||
prompt: PromptTemplate,
|
||||
n = 1,
|
||||
temperature = 1.0,
|
||||
@ -393,9 +402,11 @@ export class PromptPipeline {
|
||||
// It's not perfect, but it's simpler than throttling at the call-specific level.
|
||||
[query, response] = await RateLimiter.throttle(
|
||||
llm,
|
||||
provider,
|
||||
() =>
|
||||
call_llm(
|
||||
llm,
|
||||
provider,
|
||||
prompt.toString(),
|
||||
n,
|
||||
temperature,
|
||||
|
@ -1,4 +1,12 @@
|
||||
import { StringLookup } from "./cache";
|
||||
import { isEqual } from "./setUtils";
|
||||
import {
|
||||
Dict,
|
||||
PromptVarsDict,
|
||||
PromptVarType,
|
||||
StringOrHash,
|
||||
TemplateVarInfo,
|
||||
} from "./typing";
|
||||
|
||||
function len(o: object | string | Array<any>): number {
|
||||
// Acts akin to Python's builtin 'len' method
|
||||
@ -193,8 +201,8 @@ export class PromptTemplate {
|
||||
print(partial_prompt)
|
||||
*/
|
||||
template: string;
|
||||
fill_history: { [key: string]: any };
|
||||
metavars: { [key: string]: any };
|
||||
fill_history: Dict<StringOrHash>;
|
||||
metavars: Dict<StringOrHash>;
|
||||
|
||||
constructor(templateStr: string) {
|
||||
/**
|
||||
@ -234,7 +242,8 @@ export class PromptTemplate {
|
||||
has_unfilled_settings_var(varname: string): boolean {
|
||||
return Object.entries(this.fill_history).some(
|
||||
([key, val]) =>
|
||||
key.startsWith("=") && new StringTemplate(val).has_vars([varname]),
|
||||
key.startsWith("=") &&
|
||||
new StringTemplate(StringLookup.get(val) ?? "").has_vars([varname]),
|
||||
);
|
||||
}
|
||||
|
||||
@ -257,33 +266,47 @@ export class PromptTemplate {
|
||||
"PL": "Python"
|
||||
});
|
||||
*/
|
||||
fill(paramDict: { [key: string]: any }): PromptTemplate {
|
||||
fill(paramDict: Dict<PromptVarType>): PromptTemplate {
|
||||
// Check for special 'past fill history' format:
|
||||
let past_fill_history = {};
|
||||
let past_metavars = {};
|
||||
let past_fill_history: Dict<StringOrHash> = {};
|
||||
let past_metavars: Dict<StringOrHash> = {};
|
||||
const some_key = Object.keys(paramDict).pop();
|
||||
const some_val = some_key ? paramDict[some_key] : undefined;
|
||||
if (len(paramDict) > 0 && isDict(some_val)) {
|
||||
// Transfer over the fill history and metavars
|
||||
Object.values(paramDict).forEach((obj) => {
|
||||
if ("fill_history" in obj)
|
||||
past_fill_history = { ...obj.fill_history, ...past_fill_history };
|
||||
if ("metavars" in obj)
|
||||
past_metavars = { ...obj.metavars, ...past_metavars };
|
||||
if ("fill_history" in (obj as TemplateVarInfo))
|
||||
past_fill_history = {
|
||||
...(obj as TemplateVarInfo).fill_history,
|
||||
...past_fill_history,
|
||||
};
|
||||
if ("metavars" in (obj as TemplateVarInfo))
|
||||
past_metavars = {
|
||||
...(obj as TemplateVarInfo).metavars,
|
||||
...past_metavars,
|
||||
};
|
||||
});
|
||||
|
||||
past_fill_history = StringLookup.concretizeDict(past_fill_history);
|
||||
past_metavars = StringLookup.concretizeDict(past_metavars);
|
||||
|
||||
// Recreate the param dict from just the 'text' property of the fill object
|
||||
const newParamDict: { [key: string]: any } = {};
|
||||
const newParamDict: Dict<StringOrHash> = {};
|
||||
Object.entries(paramDict).forEach(([param, obj]) => {
|
||||
newParamDict[param] = obj.text;
|
||||
newParamDict[param] = (obj as TemplateVarInfo).text as StringOrHash;
|
||||
});
|
||||
paramDict = newParamDict;
|
||||
}
|
||||
|
||||
// Concretize the params
|
||||
paramDict = StringLookup.concretizeDict(paramDict) as Dict<
|
||||
string | TemplateVarInfo
|
||||
>;
|
||||
|
||||
// For 'settings' template vars of form {=system_msg}, we use the same logic of storing param
|
||||
// values as before -- the only difference is that, when it comes to the actual substitution of
|
||||
// the string, we *don't fill the template with anything* --it vanishes.
|
||||
let params_wo_settings = paramDict;
|
||||
let params_wo_settings = paramDict as Dict<string>;
|
||||
// To improve performance, we first check if there's a settings var present at all before deep cloning:
|
||||
if (Object.keys(paramDict).some((key) => key?.charAt(0) === "=")) {
|
||||
// A settings var is present; deep clone the param dict and replace it with the empty string:
|
||||
@ -305,9 +328,9 @@ export class PromptTemplate {
|
||||
// Perform the fill inside any and all 'settings' template vars
|
||||
Object.entries(filled_pt.fill_history).forEach(([key, val]) => {
|
||||
if (!key.startsWith("=")) return;
|
||||
filled_pt.fill_history[key] = new StringTemplate(val).safe_substitute(
|
||||
params_wo_settings,
|
||||
);
|
||||
filled_pt.fill_history[key] = new StringTemplate(
|
||||
StringLookup.get(val) ?? "",
|
||||
).safe_substitute(params_wo_settings);
|
||||
});
|
||||
|
||||
// Append any past history passed as vars:
|
||||
@ -325,7 +348,7 @@ export class PromptTemplate {
|
||||
});
|
||||
|
||||
// Add the new fill history using the passed parameters that we just filled in
|
||||
Object.entries(paramDict).forEach(([key, val]) => {
|
||||
Object.entries(paramDict as Dict<string>).forEach(([key, val]) => {
|
||||
if (key in filled_pt.fill_history)
|
||||
console.log(
|
||||
`Warning: PromptTemplate already has fill history for key ${key}.`,
|
||||
@ -341,11 +364,11 @@ export class PromptTemplate {
|
||||
* Modifies the prompt template in place.
|
||||
* @param fill_history A fill history dict.
|
||||
*/
|
||||
fill_special_vars(fill_history: { [key: string]: any }): void {
|
||||
fill_special_vars(fill_history: Dict<StringOrHash>): void {
|
||||
// Special variables {#...} denotes filling a variable from a matching var in fill_history or metavars.
|
||||
// Find any special variables:
|
||||
const unfilled_vars = new StringTemplate(this.template).get_vars();
|
||||
const special_vars_to_fill: { [key: string]: string } = {};
|
||||
const special_vars_to_fill: Dict<StringOrHash> = {};
|
||||
for (const v of unfilled_vars) {
|
||||
if (v.length > 0 && v[0] === "#") {
|
||||
// special template variables must begin with #
|
||||
@ -360,7 +383,7 @@ export class PromptTemplate {
|
||||
// Fill any special variables, using the fill history of the template in question:
|
||||
if (Object.keys(special_vars_to_fill).length > 0)
|
||||
this.template = new StringTemplate(this.template).safe_substitute(
|
||||
special_vars_to_fill,
|
||||
StringLookup.concretizeDict(special_vars_to_fill),
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -389,7 +412,7 @@ export class PromptPermutationGenerator {
|
||||
*_gen_perm(
|
||||
template: PromptTemplate,
|
||||
params_to_fill: Array<string>,
|
||||
paramDict: { [key: string]: any },
|
||||
paramDict: PromptVarsDict,
|
||||
): Generator<PromptTemplate, boolean, undefined> {
|
||||
if (len(params_to_fill) === 0) return true;
|
||||
|
||||
@ -417,24 +440,31 @@ export class PromptPermutationGenerator {
|
||||
val.forEach((v) => {
|
||||
if (param === undefined) return;
|
||||
|
||||
const param_fill_dict: { [key: string]: any } = {};
|
||||
const param_fill_dict: Dict<PromptVarType> = {};
|
||||
param_fill_dict[param] = v;
|
||||
|
||||
/* If this var has an "associate_id", then it wants to "carry with"
|
||||
values of other prompt parameters with the same id.
|
||||
We have to find any parameters with values of the same id,
|
||||
and fill them in alongside the initial parameter v: */
|
||||
if (isDict(v) && "associate_id" in v) {
|
||||
const v_associate_id = v.associate_id;
|
||||
if (isDict(v) && "associate_id" in (v as object)) {
|
||||
const v_associate_id = (v as TemplateVarInfo).associate_id;
|
||||
params_left.forEach((other_param) => {
|
||||
if (
|
||||
(template.has_var(other_param) ||
|
||||
template.has_unfilled_settings_var(other_param)) &&
|
||||
Array.isArray(paramDict[other_param])
|
||||
) {
|
||||
for (let i = 0; i < paramDict[other_param].length; i++) {
|
||||
const ov = paramDict[other_param][i];
|
||||
if (isDict(ov) && ov.associate_id === v_associate_id) {
|
||||
for (
|
||||
let i = 0;
|
||||
i < (paramDict[other_param] as PromptVarType[]).length;
|
||||
i++
|
||||
) {
|
||||
const ov = (paramDict[other_param] as PromptVarType[])[i];
|
||||
if (
|
||||
isDict(ov) &&
|
||||
(ov as TemplateVarInfo).associate_id === v_associate_id
|
||||
) {
|
||||
// This is a match. We should add the val to our param_fill_dict:
|
||||
param_fill_dict[other_param] = ov;
|
||||
break;
|
||||
@ -447,8 +477,8 @@ export class PromptPermutationGenerator {
|
||||
// Fill the template with the param values and append it to the list
|
||||
new_prompt_temps.push(template.fill(param_fill_dict));
|
||||
});
|
||||
} else if (typeof val === "string") {
|
||||
const sub_dict: { [key: string]: any } = {};
|
||||
} else if (typeof val === "string" || typeof val === "number") {
|
||||
const sub_dict: Dict<StringOrHash> = {};
|
||||
sub_dict[param] = val;
|
||||
new_prompt_temps = [template.fill(sub_dict)];
|
||||
} else
|
||||
@ -470,9 +500,9 @@ export class PromptPermutationGenerator {
|
||||
}
|
||||
|
||||
// Generator class method to yield permutations of a root prompt template
|
||||
*generate(paramDict: {
|
||||
[key: string]: any;
|
||||
}): Generator<PromptTemplate, boolean, undefined> {
|
||||
*generate(
|
||||
paramDict: PromptVarsDict,
|
||||
): Generator<PromptTemplate, boolean, undefined> {
|
||||
const template =
|
||||
typeof this.template === "string"
|
||||
? new PromptTemplate(this.template)
|
||||
|
@ -13,6 +13,9 @@ export interface Dict<T = any> {
|
||||
[key: string]: T;
|
||||
}
|
||||
|
||||
/** A string or a number representing the index to a hash table (`StringLookup`). */
|
||||
export type StringOrHash = string | number;
|
||||
|
||||
// Function types
|
||||
export type Func<T = void> = (...args: any[]) => T;
|
||||
|
||||
@ -52,7 +55,7 @@ export interface PaLMChatContext {
|
||||
|
||||
export interface GeminiChatMessage {
|
||||
role: string;
|
||||
parts: string;
|
||||
parts: [{ text: string }];
|
||||
}
|
||||
|
||||
export interface GeminiChatContext {
|
||||
@ -68,8 +71,8 @@ export interface HuggingFaceChatHistory {
|
||||
// Chat history with 'carried' variable metadata
|
||||
export interface ChatHistoryInfo {
|
||||
messages: ChatHistory;
|
||||
fill_history: Dict;
|
||||
metavars?: Dict;
|
||||
fill_history: Dict<StringOrHash>;
|
||||
metavars?: Dict<StringOrHash>;
|
||||
llm?: string;
|
||||
}
|
||||
|
||||
@ -164,7 +167,7 @@ export type LLMResponseData =
|
||||
t: "img"; // type
|
||||
d: string; // payload
|
||||
}
|
||||
| string;
|
||||
| StringOrHash;
|
||||
|
||||
export function isImageResponseData(
|
||||
r: LLMResponseData,
|
||||
@ -177,13 +180,13 @@ export interface BaseLLMResponseObject {
|
||||
/** A unique ID to refer to this response */
|
||||
uid: ResponseUID;
|
||||
/** The concrete prompt that led to this response. */
|
||||
prompt: string;
|
||||
prompt: StringOrHash;
|
||||
/** The variables fed into the prompt. */
|
||||
vars: Dict;
|
||||
vars: Dict<StringOrHash>;
|
||||
/** Any associated metavariables. */
|
||||
metavars: Dict;
|
||||
metavars: Dict<StringOrHash>;
|
||||
/** The LLM to query (usually a dict of settings) */
|
||||
llm: string | LLMSpec;
|
||||
llm: StringOrHash | LLMSpec;
|
||||
/** Optional: The chat history to pass the LLM */
|
||||
chat_history?: ChatHistory;
|
||||
}
|
||||
@ -193,7 +196,8 @@ export interface RawLLMResponseObject extends BaseLLMResponseObject {
|
||||
// A snapshot of the exact query (payload) sent to the LLM's API
|
||||
query: Dict;
|
||||
// The raw JSON response from the LLM
|
||||
raw_response: Dict;
|
||||
// NOTE: This is now deprecated since it wastes precious storage space.
|
||||
// raw_response: Dict;
|
||||
// Extracted responses (1 or more) from raw_response
|
||||
responses: LLMResponseData[];
|
||||
// Token lengths (if given)
|
||||
@ -240,19 +244,19 @@ export type EvaluatedResponsesResults = {
|
||||
/** The outputs of prompt nodes, text fields or other data passed internally in the front-end and to the PromptTemplate backend.
|
||||
* Used to populate prompt templates and carry variables/metavariables along the chain. */
|
||||
export interface TemplateVarInfo {
|
||||
text?: string;
|
||||
image?: string; // base-64 encoding
|
||||
fill_history?: Dict<string>;
|
||||
metavars?: Dict<string>;
|
||||
associate_id?: string;
|
||||
prompt?: string;
|
||||
text?: StringOrHash;
|
||||
image?: StringOrHash; // base-64 encoding
|
||||
fill_history?: Dict<StringOrHash>;
|
||||
metavars?: Dict<StringOrHash>;
|
||||
associate_id?: StringOrHash;
|
||||
prompt?: StringOrHash;
|
||||
uid?: ResponseUID;
|
||||
llm?: string | LLMSpec;
|
||||
llm?: StringOrHash | LLMSpec;
|
||||
chat_history?: ChatHistory;
|
||||
}
|
||||
|
||||
export type LLMResponsesByVarDict = Dict<
|
||||
(BaseLLMResponseObject | LLMResponse | TemplateVarInfo | string)[]
|
||||
(BaseLLMResponseObject | LLMResponse | TemplateVarInfo | StringOrHash)[]
|
||||
>;
|
||||
|
||||
export type VarsContext = {
|
||||
@ -260,12 +264,12 @@ export type VarsContext = {
|
||||
metavars: string[];
|
||||
};
|
||||
|
||||
export type PromptVarType = string | TemplateVarInfo;
|
||||
export type PromptVarType = StringOrHash | TemplateVarInfo;
|
||||
export type PromptVarsDict = {
|
||||
[key: string]: PromptVarType[];
|
||||
[key: string]: PromptVarType[] | StringOrHash;
|
||||
};
|
||||
|
||||
export type TabularDataRowType = Dict<string>;
|
||||
export type TabularDataRowType = Dict<StringOrHash>;
|
||||
export type TabularDataColType = {
|
||||
key: string;
|
||||
header: string;
|
||||
|
@ -26,6 +26,8 @@ import {
|
||||
EvaluationScore,
|
||||
LLMResponseData,
|
||||
isImageResponseData,
|
||||
StringOrHash,
|
||||
PromptVarsDict,
|
||||
} from "./typing";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import { StringTemplate } from "./template";
|
||||
@ -47,9 +49,9 @@ import {
|
||||
fromModelId,
|
||||
ChatMessage as BedrockChatMessage,
|
||||
} from "@mirai73/bedrock-fm";
|
||||
import StorageCache from "./cache";
|
||||
import StorageCache, { StringLookup } from "./cache";
|
||||
import Compressor from "compressorjs";
|
||||
import { Models } from "@mirai73/bedrock-fm/lib/bedrock";
|
||||
// import { Models } from "@mirai73/bedrock-fm/lib/bedrock";
|
||||
|
||||
const ANTHROPIC_HUMAN_PROMPT = "\n\nHuman:";
|
||||
const ANTHROPIC_AI_PROMPT = "\n\nAssistant:";
|
||||
@ -156,6 +158,7 @@ let AWS_SECRET_ACCESS_KEY = get_environ("AWS_SECRET_ACCESS_KEY");
|
||||
let AWS_SESSION_TOKEN = get_environ("AWS_SESSION_TOKEN");
|
||||
let AWS_REGION = get_environ("AWS_REGION");
|
||||
let TOGETHER_API_KEY = get_environ("TOGETHER_API_KEY");
|
||||
let DEEPSEEK_API_KEY = get_environ("DEEPSEEK_API_KEY");
|
||||
|
||||
/**
|
||||
* Sets the local API keys for the revelant LLM API(s).
|
||||
@ -188,6 +191,7 @@ export function set_api_keys(api_keys: Dict<string>): void {
|
||||
AWS_SESSION_TOKEN = api_keys.AWS_Session_Token;
|
||||
if (key_is_present("AWS_Region")) AWS_REGION = api_keys.AWS_Region;
|
||||
if (key_is_present("Together")) TOGETHER_API_KEY = api_keys.Together;
|
||||
if (key_is_present("DeepSeek")) DEEPSEEK_API_KEY = api_keys.DeepSeek;
|
||||
}
|
||||
|
||||
export function get_azure_openai_api_keys(): [
|
||||
@ -207,12 +211,16 @@ function construct_openai_chat_history(
|
||||
prompt: string,
|
||||
chat_history?: ChatHistory,
|
||||
system_msg?: string,
|
||||
system_role_name?: string,
|
||||
): ChatHistory {
|
||||
const sys_role_name = system_role_name ?? "system";
|
||||
const prompt_msg: ChatMessage = { role: "user", content: prompt };
|
||||
const sys_msg: ChatMessage[] =
|
||||
system_msg !== undefined ? [{ role: "system", content: system_msg }] : [];
|
||||
system_msg !== undefined
|
||||
? [{ role: sys_role_name, content: system_msg }]
|
||||
: [];
|
||||
if (chat_history !== undefined && chat_history.length > 0) {
|
||||
if (chat_history[0].role === "system") {
|
||||
if (chat_history[0].role === sys_role_name) {
|
||||
// In this case, the system_msg is ignored because the prior history already contains one.
|
||||
return chat_history.concat([prompt_msg]);
|
||||
} else {
|
||||
@ -234,6 +242,8 @@ export async function call_chatgpt(
|
||||
temperature = 1.0,
|
||||
params?: Dict,
|
||||
should_cancel?: () => boolean,
|
||||
BASE_URL?: string,
|
||||
API_KEY?: string,
|
||||
): Promise<[Dict, Dict]> {
|
||||
if (!OPENAI_API_KEY)
|
||||
throw new Error(
|
||||
@ -241,8 +251,8 @@ export async function call_chatgpt(
|
||||
);
|
||||
|
||||
const configuration = new OpenAIConfig({
|
||||
apiKey: OPENAI_API_KEY,
|
||||
basePath: OPENAI_BASE_URL ?? undefined,
|
||||
apiKey: API_KEY ?? OPENAI_API_KEY,
|
||||
basePath: BASE_URL ?? OPENAI_BASE_URL ?? undefined,
|
||||
});
|
||||
|
||||
// Since we are running client-side, we need to remove the user-agent header:
|
||||
@ -284,17 +294,21 @@ export async function call_chatgpt(
|
||||
if (params?.tools === undefined && params?.parallel_tool_calls !== undefined)
|
||||
delete params?.parallel_tool_calls;
|
||||
|
||||
console.log(`Querying OpenAI model '${model}' with prompt '${prompt}'...`);
|
||||
if (!BASE_URL)
|
||||
console.log(`Querying OpenAI model '${model}' with prompt '${prompt}'...`);
|
||||
|
||||
// Determine the system message and whether there's chat history to continue:
|
||||
const chat_history: ChatHistory | undefined = params?.chat_history;
|
||||
const system_msg: string =
|
||||
params?.system_msg !== undefined
|
||||
? params.system_msg
|
||||
: "You are a helpful assistant.";
|
||||
let system_msg: string | undefined =
|
||||
params?.system_msg !== undefined ? params.system_msg : undefined;
|
||||
delete params?.system_msg;
|
||||
delete params?.chat_history;
|
||||
|
||||
// The o1 and later OpenAI models, for whatever reason, block the system message from being sent in the chat history.
|
||||
// The official API states that "developer" works, but it doesn't for some models, so until OpenAI fixes this
|
||||
// and fully supports system messages, we have to block them from being sent in the chat history.
|
||||
if (model.startsWith("o")) system_msg = undefined;
|
||||
|
||||
const query: Dict = {
|
||||
model: modelname,
|
||||
n,
|
||||
@ -339,6 +353,36 @@ export async function call_chatgpt(
|
||||
return [query, response];
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls DeepSeek models via DeepSeek's API.
|
||||
*/
|
||||
export async function call_deepseek(
|
||||
prompt: string,
|
||||
model: LLM,
|
||||
n = 1,
|
||||
temperature = 1.0,
|
||||
params?: Dict,
|
||||
should_cancel?: () => boolean,
|
||||
): Promise<[Dict, Dict]> {
|
||||
if (!DEEPSEEK_API_KEY)
|
||||
throw new Error(
|
||||
"Could not find a DeepSeek API key. Double-check that your API key is set in Settings or in your local environment.",
|
||||
);
|
||||
|
||||
console.log(`Querying DeepSeek model '${model}' with prompt '${prompt}'...`);
|
||||
|
||||
return await call_chatgpt(
|
||||
prompt,
|
||||
model,
|
||||
n,
|
||||
temperature,
|
||||
params,
|
||||
should_cancel,
|
||||
"https://api.deepseek.com",
|
||||
DEEPSEEK_API_KEY,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls OpenAI Image models via OpenAI's API.
|
||||
@returns raw query and response JSON dicts.
|
||||
@ -700,26 +744,27 @@ export async function call_google_ai(
|
||||
params?: Dict,
|
||||
should_cancel?: () => boolean,
|
||||
): Promise<[Dict, Dict]> {
|
||||
switch (model) {
|
||||
case NativeLLM.GEMINI_PRO:
|
||||
return call_google_gemini(
|
||||
prompt,
|
||||
model,
|
||||
n,
|
||||
temperature,
|
||||
params,
|
||||
should_cancel,
|
||||
);
|
||||
default:
|
||||
return call_google_palm(
|
||||
prompt,
|
||||
model,
|
||||
n,
|
||||
temperature,
|
||||
params,
|
||||
should_cancel,
|
||||
);
|
||||
}
|
||||
if (
|
||||
model === NativeLLM.PaLM2_Chat_Bison ||
|
||||
model === NativeLLM.PaLM2_Text_Bison
|
||||
)
|
||||
return call_google_palm(
|
||||
prompt,
|
||||
model,
|
||||
n,
|
||||
temperature,
|
||||
params,
|
||||
should_cancel,
|
||||
);
|
||||
else
|
||||
return call_google_gemini(
|
||||
prompt,
|
||||
model,
|
||||
n,
|
||||
temperature,
|
||||
params,
|
||||
should_cancel,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -861,18 +906,21 @@ export async function call_google_gemini(
|
||||
"Could not find an API key for Google Gemini models. Double-check that your API key is set in Settings or in your local environment.",
|
||||
);
|
||||
|
||||
// calling the correct model client
|
||||
model = NativeLLM.GEMINI_PRO;
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GOOGLE_PALM_API_KEY);
|
||||
const gemini_model = genAI.getGenerativeModel({ model: model.toString() });
|
||||
|
||||
// removing chat for now. by default chat is supported
|
||||
|
||||
// Required non-standard params
|
||||
const max_output_tokens = params?.max_output_tokens || 1000;
|
||||
const chat_history = params?.chat_history;
|
||||
const chat_history: ChatHistory = params?.chat_history;
|
||||
const system_msg = params?.system_msg;
|
||||
delete params?.chat_history;
|
||||
delete params?.system_msg;
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GOOGLE_PALM_API_KEY);
|
||||
const gemini_model = genAI.getGenerativeModel({
|
||||
model: model.toString(),
|
||||
systemInstruction:
|
||||
typeof system_msg === "string" && chat_history === undefined
|
||||
? system_msg
|
||||
: undefined,
|
||||
});
|
||||
|
||||
const query: Dict = {
|
||||
model: `models/${model}`,
|
||||
@ -924,10 +972,20 @@ export async function call_google_gemini(
|
||||
for (const chat_msg of chat_history) {
|
||||
if (chat_msg.role === "system") {
|
||||
// Carry the system message over as PaLM's chat 'context':
|
||||
gemini_messages.push({ role: "model", parts: chat_msg.content });
|
||||
gemini_messages.push({
|
||||
role: "model",
|
||||
parts: [{ text: chat_msg.content }],
|
||||
});
|
||||
} else if (chat_msg.role === "user") {
|
||||
gemini_messages.push({ role: "user", parts: chat_msg.content });
|
||||
} else gemini_messages.push({ role: "model", parts: chat_msg.content });
|
||||
gemini_messages.push({
|
||||
role: "user",
|
||||
parts: [{ text: chat_msg.content }],
|
||||
});
|
||||
} else
|
||||
gemini_messages.push({
|
||||
role: "model",
|
||||
parts: [{ text: chat_msg.content }],
|
||||
});
|
||||
}
|
||||
gemini_chat_context.history = gemini_messages;
|
||||
}
|
||||
@ -1362,7 +1420,7 @@ export async function call_bedrock(
|
||||
temperature,
|
||||
};
|
||||
|
||||
const fm = fromModelId(modelName as Models, {
|
||||
const fm = fromModelId(modelName, {
|
||||
region: bedrockConfig.region,
|
||||
credentials: bedrockConfig.credentials,
|
||||
...query,
|
||||
@ -1561,6 +1619,7 @@ async function call_custom_provider(
|
||||
*/
|
||||
export async function call_llm(
|
||||
llm: LLM,
|
||||
provider: LLMProvider,
|
||||
prompt: string,
|
||||
n: number,
|
||||
temperature: number,
|
||||
@ -1569,7 +1628,7 @@ export async function call_llm(
|
||||
): Promise<[Dict, Dict]> {
|
||||
// Get the correct API call for the given LLM:
|
||||
let call_api: LLMAPICall | undefined;
|
||||
const llm_provider: LLMProvider | undefined = getProvider(llm);
|
||||
const llm_provider: LLMProvider | undefined = provider ?? getProvider(llm); // backwards compatibility if there's no explicit provider
|
||||
|
||||
if (llm_provider === undefined)
|
||||
throw new Error(`Language model ${llm} is not supported.`);
|
||||
@ -1590,6 +1649,7 @@ export async function call_llm(
|
||||
else if (llm_provider === LLMProvider.Custom) call_api = call_custom_provider;
|
||||
else if (llm_provider === LLMProvider.Bedrock) call_api = call_bedrock;
|
||||
else if (llm_provider === LLMProvider.Together) call_api = call_together;
|
||||
else if (llm_provider === LLMProvider.DeepSeek) call_api = call_deepseek;
|
||||
if (call_api === undefined)
|
||||
throw new Error(
|
||||
`Adapter for Language model ${llm} and ${llm_provider} not found`,
|
||||
@ -1682,10 +1742,12 @@ function _extract_google_ai_responses(
|
||||
llm: LLM | string,
|
||||
): Array<string> {
|
||||
switch (llm) {
|
||||
case NativeLLM.GEMINI_PRO:
|
||||
return _extract_gemini_responses(response as Array<Dict>);
|
||||
default:
|
||||
case NativeLLM.PaLM2_Chat_Bison:
|
||||
return _extract_palm_responses(response);
|
||||
case NativeLLM.PaLM2_Text_Bison:
|
||||
return _extract_palm_responses(response);
|
||||
default:
|
||||
return _extract_gemini_responses(response as Array<Dict>);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1775,8 +1837,10 @@ function _extract_ollama_responses(
|
||||
export function extract_responses(
|
||||
response: Array<string | Dict> | Dict,
|
||||
llm: LLM | string,
|
||||
provider: LLMProvider,
|
||||
): Array<LLMResponseData> {
|
||||
const llm_provider: LLMProvider | undefined = getProvider(llm as LLM);
|
||||
const llm_provider: LLMProvider | undefined =
|
||||
provider ?? getProvider(llm as LLM);
|
||||
const llm_name = llm.toString().toLowerCase();
|
||||
switch (llm_provider) {
|
||||
case LLMProvider.OpenAI:
|
||||
@ -1807,6 +1871,8 @@ export function extract_responses(
|
||||
return response as Array<string>;
|
||||
case LLMProvider.Together:
|
||||
return _extract_openai_responses(response as Dict[]);
|
||||
case LLMProvider.DeepSeek:
|
||||
return _extract_openai_responses(response as Dict[]);
|
||||
default:
|
||||
if (
|
||||
Array.isArray(response) &&
|
||||
@ -1839,13 +1905,8 @@ export function merge_response_objs(
|
||||
else if (!resp_obj_A && resp_obj_B) return resp_obj_B;
|
||||
resp_obj_A = resp_obj_A as RawLLMResponseObject; // required by typescript
|
||||
resp_obj_B = resp_obj_B as RawLLMResponseObject;
|
||||
let raw_resp_A = resp_obj_A.raw_response;
|
||||
let raw_resp_B = resp_obj_B.raw_response;
|
||||
if (!Array.isArray(raw_resp_A)) raw_resp_A = [raw_resp_A];
|
||||
if (!Array.isArray(raw_resp_B)) raw_resp_B = [raw_resp_B];
|
||||
const res: RawLLMResponseObject = {
|
||||
responses: resp_obj_A.responses.concat(resp_obj_B.responses),
|
||||
raw_response: raw_resp_A.concat(raw_resp_B),
|
||||
prompt: resp_obj_B.prompt,
|
||||
query: resp_obj_B.query,
|
||||
llm: resp_obj_B.llm,
|
||||
@ -1901,7 +1962,7 @@ export const transformDict = (
|
||||
*
|
||||
* Returns empty dict {} if no settings vars found.
|
||||
*/
|
||||
export const extractSettingsVars = (vars?: Dict) => {
|
||||
export const extractSettingsVars = (vars?: PromptVarsDict) => {
|
||||
if (
|
||||
vars !== undefined &&
|
||||
Object.keys(vars).some((k) => k.charAt(0) === "=")
|
||||
@ -1918,8 +1979,8 @@ export const extractSettingsVars = (vars?: Dict) => {
|
||||
* Given two info vars dicts, detects whether any + all vars (keys) match values.
|
||||
*/
|
||||
export const areEqualVarsDicts = (
|
||||
A: Dict | undefined,
|
||||
B: Dict | undefined,
|
||||
A: PromptVarsDict | undefined,
|
||||
B: PromptVarsDict | undefined,
|
||||
): boolean => {
|
||||
if (A === undefined || B === undefined) {
|
||||
if (A === undefined && B === undefined) return true;
|
||||
@ -2000,12 +2061,15 @@ export const stripLLMDetailsFromResponses = (
|
||||
): LLMResponse[] =>
|
||||
resps.map((r) => ({
|
||||
...r,
|
||||
llm: typeof r?.llm === "string" ? r?.llm : r?.llm?.name ?? "undefined",
|
||||
llm:
|
||||
(typeof r?.llm === "string" || typeof r?.llm === "number"
|
||||
? StringLookup.get(r?.llm)
|
||||
: r?.llm?.name) ?? "undefined",
|
||||
}));
|
||||
|
||||
// NOTE: The typing is purposefully general since we are trying to cast to an expected format.
|
||||
export const toStandardResponseFormat = (r: Dict | string) => {
|
||||
if (typeof r === "string")
|
||||
if (typeof r === "string" || typeof r === "number")
|
||||
return {
|
||||
vars: {},
|
||||
metavars: {},
|
||||
@ -2020,7 +2084,7 @@ export const toStandardResponseFormat = (r: Dict | string) => {
|
||||
uid: r?.uid ?? r?.batch_id ?? uuid(),
|
||||
llm: r?.llm ?? undefined,
|
||||
prompt: r?.prompt ?? "",
|
||||
responses: [typeof r === "string" ? r : r?.text],
|
||||
responses: [typeof r === "string" || typeof r === "number" ? r : r?.text],
|
||||
tokens: r?.raw_response?.usage ?? {},
|
||||
};
|
||||
if (r?.eval_res !== undefined) resp_obj.eval_res = r.eval_res;
|
||||
@ -2046,8 +2110,10 @@ export const tagMetadataWithLLM = (input_data: LLMResponsesByVarDict) => {
|
||||
if (
|
||||
!r ||
|
||||
typeof r === "string" ||
|
||||
typeof r === "number" ||
|
||||
!r?.llm ||
|
||||
typeof r.llm === "string" ||
|
||||
typeof r.llm === "number" ||
|
||||
!r.llm.key
|
||||
)
|
||||
return r;
|
||||
@ -2061,20 +2127,21 @@ export const tagMetadataWithLLM = (input_data: LLMResponsesByVarDict) => {
|
||||
|
||||
export const extractLLMLookup = (
|
||||
input_data: Dict<
|
||||
(string | TemplateVarInfo | BaseLLMResponseObject | LLMResponse)[]
|
||||
(StringOrHash | TemplateVarInfo | BaseLLMResponseObject | LLMResponse)[]
|
||||
>,
|
||||
) => {
|
||||
const llm_lookup: Dict<string | LLMSpec> = {};
|
||||
const llm_lookup: Dict<StringOrHash | LLMSpec> = {};
|
||||
Object.values(input_data).forEach((resp_objs) => {
|
||||
resp_objs.forEach((r) => {
|
||||
const llm_name =
|
||||
typeof r === "string"
|
||||
typeof r === "string" || typeof r === "number"
|
||||
? undefined
|
||||
: !r.llm || typeof r.llm === "string"
|
||||
? r.llm
|
||||
: !r.llm || typeof r.llm === "string" || typeof r.llm === "number"
|
||||
? StringLookup.get(r.llm)
|
||||
: r.llm.key;
|
||||
if (
|
||||
typeof r === "string" ||
|
||||
typeof r === "number" ||
|
||||
!r.llm ||
|
||||
!llm_name ||
|
||||
llm_name in llm_lookup
|
||||
@ -2154,6 +2221,13 @@ export const batchResponsesByUID = (
|
||||
.concat(unspecified_id_group);
|
||||
};
|
||||
|
||||
export function llmResponseDataToString(data: LLMResponseData): string {
|
||||
if (typeof data === "string") return data;
|
||||
else if (typeof data === "number")
|
||||
return StringLookup.get(data) ?? "(string lookup failed)";
|
||||
else return data.d;
|
||||
}
|
||||
|
||||
/**
|
||||
* Naive method to sample N items at random from an array.
|
||||
* @param arr an array of items
|
||||
|
@ -22,13 +22,13 @@ import {
|
||||
LLMGroup,
|
||||
LLMSpec,
|
||||
PromptVarType,
|
||||
PromptVarsDict,
|
||||
TemplateVarInfo,
|
||||
TabularDataColType,
|
||||
TabularDataRowType,
|
||||
} from "./backend/typing";
|
||||
import { TogetherChatSettings } from "./ModelSettingSchemas";
|
||||
import { NativeLLM } from "./backend/models";
|
||||
import { StringLookup } from "./backend/cache";
|
||||
|
||||
// Initial project settings
|
||||
const initialAPIKeys = {};
|
||||
@ -133,18 +133,85 @@ export const initLLMProviderMenu: (LLMSpec | LLMGroup)[] = [
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "Claude",
|
||||
group: "Claude",
|
||||
emoji: "📚",
|
||||
model: "claude-2",
|
||||
base_model: "claude-v1",
|
||||
temp: 0.5,
|
||||
items: [
|
||||
{
|
||||
name: "Claude 3.5 Sonnet",
|
||||
emoji: "📚",
|
||||
model: "claude-3-5-sonnet-latest",
|
||||
base_model: "claude-v1",
|
||||
temp: 0.5,
|
||||
},
|
||||
{
|
||||
name: "Claude 3.5 Haiku",
|
||||
emoji: "📗",
|
||||
model: "claude-3-5-haiku-latest",
|
||||
base_model: "claude-v1",
|
||||
temp: 0.5,
|
||||
},
|
||||
{
|
||||
name: "Claude 3 Opus",
|
||||
emoji: "📙",
|
||||
model: "claude-3-opus-latest",
|
||||
base_model: "claude-v1",
|
||||
temp: 0.5,
|
||||
},
|
||||
{
|
||||
name: "Claude 2",
|
||||
emoji: "📓",
|
||||
model: "claude-2",
|
||||
base_model: "claude-v1",
|
||||
temp: 0.5,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "Gemini",
|
||||
group: "Gemini",
|
||||
emoji: "♊",
|
||||
model: "gemini-pro",
|
||||
base_model: "palm2-bison",
|
||||
temp: 0.7,
|
||||
items: [
|
||||
{
|
||||
name: "Gemini 1.5",
|
||||
emoji: "♊",
|
||||
model: "gemini-1.5-pro",
|
||||
base_model: "palm2-bison",
|
||||
temp: 0.7,
|
||||
},
|
||||
{
|
||||
name: "Gemini 1.5 Flash",
|
||||
emoji: "📸",
|
||||
model: "gemini-1.5-flash",
|
||||
base_model: "palm2-bison",
|
||||
temp: 0.7,
|
||||
},
|
||||
{
|
||||
name: "Gemini 1.5 Flash 8B",
|
||||
emoji: "⚡️",
|
||||
model: "gemini-1.5-flash-8b",
|
||||
base_model: "palm2-bison",
|
||||
temp: 0.7,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
group: "DeepSeek",
|
||||
emoji: "🐋",
|
||||
items: [
|
||||
{
|
||||
name: "DeepSeek Chat",
|
||||
emoji: "🐋",
|
||||
model: "deepseek-chat",
|
||||
base_model: "deepseek",
|
||||
temp: 1.0,
|
||||
}, // The base_model designates what settings form will be used, and must be unique.
|
||||
{
|
||||
name: "DeepSeek Reasoner",
|
||||
emoji: "🐳",
|
||||
model: "deepseek-reasoner",
|
||||
base_model: "deepseek",
|
||||
temp: 1.0,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
group: "HuggingFace",
|
||||
@ -557,7 +624,9 @@ const useStore = create<StoreHandles>((set, get) => ({
|
||||
(key) =>
|
||||
key === "__uid" ||
|
||||
!row[key] ||
|
||||
(typeof row[key] === "string" && row[key].trim() === ""),
|
||||
((typeof row[key] === "string" ||
|
||||
typeof row[key] === "number") &&
|
||||
StringLookup.get(row[key])?.trim() === ""),
|
||||
)
|
||||
)
|
||||
return undefined;
|
||||
@ -565,14 +634,20 @@ const useStore = create<StoreHandles>((set, get) => ({
|
||||
const row_excluding_col: Dict<string> = {};
|
||||
row_keys.forEach((key) => {
|
||||
if (key !== src_col.key && key !== "__uid")
|
||||
row_excluding_col[col_header_lookup[key]] =
|
||||
row[key].toString();
|
||||
row_excluding_col[col_header_lookup[key]] = (
|
||||
StringLookup.get(row[key]) ?? "(string lookup failed)"
|
||||
).toString();
|
||||
});
|
||||
return {
|
||||
// We escape any braces in the source text before they're passed downstream.
|
||||
// This is a special property of tabular data nodes: we don't want their text to be treated as prompt templates.
|
||||
text: escapeBraces(
|
||||
src_col.key in row ? row[src_col.key].toString() : "",
|
||||
src_col.key in row
|
||||
? (
|
||||
StringLookup.get(row[src_col.key]) ??
|
||||
"(string lookup failed)"
|
||||
).toString()
|
||||
: "",
|
||||
),
|
||||
metavars: row_excluding_col,
|
||||
associate_id: row.__uid, // this is used by the backend to 'carry' certain values together
|
||||
@ -641,7 +716,7 @@ const useStore = create<StoreHandles>((set, get) => ({
|
||||
const store_data = (
|
||||
_texts: PromptVarType[],
|
||||
_varname: string,
|
||||
_data: PromptVarsDict,
|
||||
_data: Dict<PromptVarType[]>,
|
||||
) => {
|
||||
if (_varname in _data) _data[_varname] = _data[_varname].concat(_texts);
|
||||
else _data[_varname] = _texts;
|
||||
|
@ -5,6 +5,5 @@ requests
|
||||
openai
|
||||
dalaipy==2.0.2
|
||||
urllib3==1.26.6
|
||||
anthropic
|
||||
google-generativeai
|
||||
mistune>=2.0
|
||||
mistune>=2.0
|
||||
platformdirs
|
5
setup.py
5
setup.py
@ -6,7 +6,7 @@ def readme():
|
||||
|
||||
setup(
|
||||
name="chainforge",
|
||||
version="0.3.2.5",
|
||||
version="0.3.4.3",
|
||||
packages=find_packages(),
|
||||
author="Ian Arawjo",
|
||||
description="A Visual Programming Environment for Prompt Engineering",
|
||||
@ -21,10 +21,9 @@ setup(
|
||||
"flask[async]",
|
||||
"flask_cors",
|
||||
"requests",
|
||||
"platformdirs",
|
||||
"urllib3==1.26.6",
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google-generativeai",
|
||||
"dalaipy>=2.0.2",
|
||||
"mistune>=2.0", # for LLM response markdown parsing
|
||||
],
|
||||
|
Loading…
x
Reference in New Issue
Block a user