Compare commits

...

31 Commits
v0.3.1 ... main

Author SHA1 Message Date
ianarawjo
2c7447d5f7
Saved flows sidebar and save button when running locally (#332)
* Add save to local disk/cache button

* Autosave to local filesystem but hide it

* Use imported filename when saving to local disk after import. Ensure no name clashes.

* Remove unneeded dependencies in pip and update package version

* Hide sidebar button and sidebar when hosted on web
2025-03-02 17:13:40 -05:00
Ian Arawjo
3b929880dc Update package version 2025-03-02 10:44:48 -05:00
Ian Arawjo
6ed42fe518 Bug fix: prompt node data saving 2025-03-02 10:43:48 -05:00
ianarawjo
98a8184a6a
Add copy button to response toolbars, and improve UI rendering performance (#331)
* Add copy button to toolbar.

* Debounce text update on prompt node

* Refactored App.tsx and BaseNode to useMemo for many places such as the toolbar

* Mantine React Table was rerendering cells even when hidden. This change ensures that the inner views of Response Inspectors do not display when the inspector is hidden.

* Add buffered before displaying the inner response inspector view, so that the modal immediately pops up 

* Update package version
2025-03-01 10:28:31 -05:00
ianarawjo
9d7c458b7a
Better Table Tiew, optimized string storage (#328)
* Add loading spinner overlay to Response Inspector when it hangs on UI reload

* Use Mantine React Table in Response Inspector

* Merge toolbar into Mantine React Table toolbar, to save space

* Disable counting num of matches in search bar, since this is non-trivial now. We might merge later, but the user can use the num of items in the table view as a proxy.

* Adds `StringLookup` table. Change types related to `LLMResponse`s with a `StringOrHash` type, which can be a string or a number.

* Import/export the lookup table when importing/exporting the flow state. 

* Update package version
2025-02-28 14:55:31 -05:00
ianarawjo
723022fb31
Give TextFields a maxHeight and wrap in ScrollArea (#327) 2025-02-18 23:17:13 -05:00
ianarawjo
6b7e7935d0
Custom model names to all providers (#324)
* Add custom widget using datalist to set model not present in dropdown enum list

* Require `LLMProvider` explicitly passed when querying LLMs.

* Finished UI datalist widget for `react-jsonschema-form`. Tested custom model endpoints in settings windows. 

* Added o1+ model hack to strip system message, since OpenAI o1+ models do not yet support system messages (and the `developer` command on their API does not currently work...)
2025-02-12 15:54:28 -05:00
ianarawjo
1e215f4238
Add DeepSeek to supported providers list 2025-01-27 22:57:22 -05:00
ianarawjo
1206d62b7b
Add Deepseek API (#322)
* Add DeepSeek Model API and settings

* Update Together model list
2025-01-27 22:48:37 -05:00
ianarawjo
0174b53aff
Update README.md to point to CHI paper 2024-12-30 12:25:24 -05:00
ianarawjo
0e96fa2e1c
Update README.md to add collaborators
Added Shreya Shankar (EvalGen) and Cassandre Hamel (GenAI support in Tabular Data Nodes).
2024-12-30 12:23:06 -05:00
ianarawjo
9ec7a3a4fc
Add prompt previews and cancel button to LLM scorer (#319)
Fix bug with default model not showing the selected one when adding new LLM model from the dropdown list.
2024-12-29 11:31:47 -05:00
Ian Arawjo
f5882768ba Update Google Gemini models 2024-12-28 17:56:51 -05:00
ianarawjo
ff813c7255
Revamped Example Flows (#316)
* Fix Try Me button spacing

* Added new examples

* Updated package version
2024-12-27 18:47:32 -05:00
ianarawjo
7e86f19aac
GenAI Data Synthesis for Tabular Data Node (#315)
* TabularDataNode supports Replace and Extend for AiGen (#312)

* Testing Values

* Fixed typing issue with Models in fromModelId

* TabularDataNode now supports table generation.
modified:   src/AiPopover.tsx
            Added support for table replacement
            and future support for extension.
modified:   src/TabularDataNode.tsx
            Added the AiPopover button and
            functionality for table replacement.
modified:   src/backend/ai.ts
            Added specific prompts and decoding
            for markdown table generation.
new file:   src/backend/tableUtils.ts
            Seperated the parsing for tables into
            a seperate utility file for better
            organization and future extensibility.

* Fixed typing issue with Models in fromModelId

* TabularDataNode now supports table generation.
modified:   src/AiPopover.tsx
            Added support for table replacement
            and future support for extension.
modified:   src/TabularDataNode.tsx
            Added the AiPopover button and
            functionality for table replacement.
modified:   src/backend/ai.ts
            Added specific prompts and decoding
            for markdown table generation.
new file:   src/backend/tableUtils.ts
            Seperated the parsing for tables into
            a seperate utility file for better
            organization and future extensibility.

Testing Values

* Added Extend Functionality to Table Popover.
modified:   src/AiPopover.tsx
            Removed unnecessary import.
            Changed handleCommandFill to work with
            autofillTable function in ai.ts.
modified:   src/TabularDataNode.tsx
            Removed Skeleton from Popover.
            Changed addMultipleRows such that it
            now renders the new rows correctly
            and removes the blank row.
modified:   src/backend/ai.ts
            Added autofillTable function and
            changed decodeTable so that they
            are flexible with both proper and
            improper markdown tables.
            Added new system message prompt
            specific to autofillTable.
            Removed unneccessary log statements.
removed:    src/backend/utils.ts
            Removed change.

* Added "add column" prompt & button in TablePopover

modified:   src/AiPopover.tsx
            Added handleGenerateColumn so that
            a column can be generated given
            a prompt.
            Added changes to the TablePopover UI
            Now extend is diveded into AddRow
            and AddColumn sections.
modified:   src/TabularDataNode.tsx
            Modified addColumns so that its safer.
            Added optional pass of rowValue to
            support generateColumn.
modified:   src/backend/ai.ts
            Added generateColumn and it's
            corresponding system message.
Cleaned up some comments and added missing commas.

* Generate Columns now considers item-by-item
processing of the rows for generating the
new column values.

modified:   src/AiPopover.tsx
            Fixed the key issue for onAddColumn.
modified:   src/TabularDataNode.tsx
            Changed addColumns to filter out
            previously added columns.
modified:   src/backend/ai.ts
            Changed generateColumns to process
            item-by-item to generate new columns.

* Fix bugs. Change OpenAI small model for GenAI features to GPT-4o.

* Update package version

* Remove gen diverse outputs switch in genAI for table

---------

Co-authored-by: Kraft-Cheese <114844630+Kraft-Cheese@users.noreply.github.com>
2024-12-19 15:38:23 -05:00
ianarawjo
1641abe975
Structured outputs support for Ollama, OpenAI, and Anthropic models (#313)
* Add structured outputs support for OpenAI and Ollama

* Extract outputs from tool_calls and refusal in OpenAI API responses

* Add tool use for Anthropic API. Add new Anthropic models.

* Add num_ctx to Ollama API call

* Update package version

* Update function calling example
2024-12-16 16:24:55 -05:00
Ian Arawjo
dd28754959 Remove rate limiting ceiling in bottleneck. Fix eslint to <9.0. 2024-12-11 10:17:04 -05:00
Sam
f6565537fa
chore(deps): bump deps for Docker, python versions (#305) 2024-10-29 13:33:01 -04:00
Ian Arawjo
98b140b5fa Add newest OpenAI and Anthropic models 2024-08-15 22:06:13 -04:00
ianarawjo
f6e1bfa38a
Add Together.ai and update Bedrock (#283)
* feat(bedrock_llama3): added support for Llama3 (#270)

- added also Claude 3 Opus to the list of models
- replaced hardcoded model Id strings with refs to NativeLLM enum

* chore: bump @mirai73/bedrock-fm library (#277)

- the new version adds source code to facilitate debugging

Co-authored-by: ianarawjo <fatso784@gmail.com>

* Adding together.ai support (#280)


---------

Co-authored-by: ianarawjo <fatso784@gmail.com>

* Add Together.ai and update Bedrock models

---------

Co-authored-by: Massimiliano Angelino <angmas@amazon.com>
Co-authored-by: Can Bal <canbal@users.noreply.github.com>
2024-05-17 20:17:18 -10:00
Ian Arawjo
e3259ecc1b Add new OpenAI models 2024-05-14 07:21:48 -10:00
Ian Arawjo
735268e331 Fix Claude carrying system message issue and bug with OpenAI_BaseURL loading 2024-04-29 21:17:49 -04:00
Ian Arawjo
af7f53f76e Fix bug loading OPENAI_BASE_URL from environ var 2024-04-28 14:29:58 -04:00
Ian Arawjo
4fa4b7bcc0 Escape braces in LLM scorer. Add OpenAI_BaseURL setting. 2024-04-26 07:24:31 -04:00
ianarawjo
6fa3092cd0
Add Multi-Eval node (#265)
* Port over and type MultiEvalNode code from the `multi-eval` branch

* Merge css changes from `multi-eval`

* Merge changes to inspector table view from `multi-eval`

* Criteria progress rings

* Debounce renders on text edits

* Add sandbox toggle to Python evals inside MultiEval

* Add uids to evals in MultiEval, for correct cache ids not dependent on name

* <Stack> scores

* Add debounce to editing code or prompts in eval UI

* Update package version
2024-04-25 13:51:25 -04:00
Ian Arawjo
2998c99f08 Bug fix for loading example flows in web version 2024-04-19 19:50:32 -04:00
Ian Arawjo
7126f4f4d4 Fix typing error and update package vers 2024-04-17 19:19:54 -04:00
Massimiliano Angelino
ffd947e636
Update to Bedrock integration (#258)
* fix(aws credentials): correct check for credentials

* chore(bedrock): bump @mirai73/ bedrock-fm library

* feat(bedrock): updating library and adding new mistral large model

- fix stop_sequences
2024-04-15 12:55:17 -04:00
Zigelboim Misha
4c56928cb9
Running ChainForge inside a Docker Container (#254)
* Create a Dockerfile

* Edit README.md to contain information about using Chainforge inside a container

* Update Dockerfile to use python 3.10 as the base image

---------

Co-authored-by: Rob-Powell <7034920+Rob-Powell@users.noreply.github.com>
2024-04-08 19:00:10 -04:00
Ian Arawjo
6b65d96369 Fix bug w/ non-updating custom providers in model list 2024-04-02 12:16:35 -04:00
yipengfei
5d4d196260
Fixed a bug that prevented custom models from appearing in model list. (#255) 2024-04-02 12:03:37 -04:00
60 changed files with 352297 additions and 29088 deletions

9
Dockerfile Normal file
View File

@ -0,0 +1,9 @@
FROM python:3.12-slim AS builder
RUN pip install --upgrade pip
RUN pip install chainforge --no-cache-dir
WORKDIR /chainforge
EXPOSE 8000
ENTRYPOINT [ "chainforge", "serve", "--host", "0.0.0.0" ]

View File

@ -46,12 +46,30 @@ Open [localhost:8000](http://localhost:8000/) in a Google Chrome, Firefox, Micro
You can set your API keys by clicking the Settings icon in the top-right corner. If you prefer to not worry about this everytime you open ChainForge, we recommend that save your OpenAI, Anthropic, Google PaLM API keys and/or Amazon AWS credentials to your local environment. For more details, see the [How to Install](https://chainforge.ai/docs/getting_started/).
## Run using Docker
You can use our [Dockerfile](/Dockerfile) to run `ChainForge` locally using `Docker Desktop`:
- Build the `Dockerfile`:
```shell
docker build -t chainforge .
```
- Run the image:
```shell
docker run -p 8000:8000 chainforge
```
Now you can open the browser of your choice and open `http://127.0.0.1:8000`.
# Supported providers
- OpenAI
- Anthropic
- Google (Gemini, PaLM2)
- DeepSeek
- HuggingFace (Inference and Endpoints)
- Together.ai
- [Ollama](https://github.com/jmorganca/ollama) (locally-hosted models)
- Microsoft Azure OpenAI Endpoints
- [AlephAlpha](https://app.aleph-alpha.com/)
@ -117,7 +135,7 @@ For more specific details, see our [documentation](https://chainforge.ai/docs/no
# Development
ChainForge was created by [Ian Arawjo](http://ianarawjo.com/index.html), a postdoctoral scholar in Harvard HCI's [Glassman Lab](http://glassmanlab.seas.harvard.edu/) with support from the Harvard HCI community. Collaborators include PhD students [Priyan Vaithilingam](https://priyan.info) and [Chelse Swoopes](https://seas.harvard.edu/person/chelse-swoopes), Harvard undergraduate [Sean Yang](https://shawsean.com), and faculty members [Elena Glassman](http://glassmanlab.seas.harvard.edu/glassman.html) and [Martin Wattenberg](https://www.bewitched.com/about.html).
ChainForge was created by [Ian Arawjo](http://ianarawjo.com/index.html), a postdoctoral scholar in Harvard HCI's [Glassman Lab](http://glassmanlab.seas.harvard.edu/) with support from the Harvard HCI community. Collaborators include PhD students [Priyan Vaithilingam](https://priyan.info) and [Chelse Swoopes](https://seas.harvard.edu/person/chelse-swoopes), Harvard undergraduate [Sean Yang](https://shawsean.com), and faculty members [Elena Glassman](http://glassmanlab.seas.harvard.edu/glassman.html) and [Martin Wattenberg](https://www.bewitched.com/about.html). Additional collaborators include UC Berkeley PhD student Shreya Shankar and Université de Montréal undergraduate Cassandre Hamel.
This work was partially funded by the NSF grants IIS-2107391, IIS-2040880, and IIS-1955699. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation.
@ -140,17 +158,15 @@ We welcome open-source collaborators. If you want to report a bug or request a f
# Cite Us
If you use ChainForge for research purposes, or build upon the source code, we ask that you cite our [arXiv pre-print](https://arxiv.org/abs/2309.09128) in any related publications.
The BibTeX you can use is:
If you use ChainForge for research purposes, whether by building upon the source code or investigating LLM behavior using the tool, we ask that you cite our [CHI research paper](https://dl.acm.org/doi/full/10.1145/3613904.3642016) in any related publications. The BibTeX you can use is:
```bibtex
@misc{arawjo2023chainforge,
title={ChainForge: A Visual Toolkit for Prompt Engineering and LLM Hypothesis Testing},
author={Ian Arawjo and Chelse Swoopes and Priyan Vaithilingam and Martin Wattenberg and Elena Glassman},
year={2023},
eprint={2309.09128},
archivePrefix={arXiv},
primaryClass={cs.HC}
@inproceedings{arawjo2024chainforge,
title={ChainForge: A Visual Toolkit for Prompt Engineering and LLM Hypothesis Testing},
author={Arawjo, Ian and Swoopes, Chelse and Vaithilingam, Priyan and Wattenberg, Martin and Glassman, Elena L},
booktitle={Proceedings of the CHI Conference on Human Factors in Computing Systems},
pages={1--18},
year={2024}
}
```

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -3,11 +3,13 @@ from dataclasses import dataclass
from enum import Enum
from typing import List
from statistics import mean, median, stdev
from datetime import datetime
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from chainforge.providers.dalai import call_dalai
from chainforge.providers import ProviderRegistry
import requests as py_requests
from platformdirs import user_data_dir
""" =================
SETUP AND GLOBALS
@ -26,6 +28,7 @@ app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
cors = CORS(app, resources={r"/*": {"origins": "*"}})
# The cache and examples files base directories
FLOWS_DIR = user_data_dir("chainforge") # platform-agnostic local storage that persists outside the package install location
CACHE_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache')
EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'examples')
@ -460,6 +463,7 @@ def fetchOpenAIEval():
def fetchEnvironAPIKeys():
keymap = {
'OPENAI_API_KEY': 'OpenAI',
'OPENAI_BASE_URL': 'OpenAI_BaseURL',
'ANTHROPIC_API_KEY': 'Anthropic',
'PALM_API_KEY': 'Google',
'HUGGINGFACE_API_KEY': 'HuggingFace',
@ -469,7 +473,9 @@ def fetchEnvironAPIKeys():
'AWS_ACCESS_KEY_ID': 'AWS_Access_Key_ID',
'AWS_SECRET_ACCESS_KEY': 'AWS_Secret_Access_Key',
'AWS_REGION': 'AWS_Region',
'AWS_SESSION_TOKEN': 'AWS_Session_Token'
'AWS_SESSION_TOKEN': 'AWS_Session_Token',
'TOGETHER_API_KEY': 'Together',
'DEEPSEEK_API_KEY': 'DeepSeek',
}
d = { alias: os.environ.get(key) for key, alias in keymap.items() }
ret = jsonify(d)
@ -506,7 +512,7 @@ def makeFetchCall():
ret.headers.add('Access-Control-Allow-Origin', '*')
return ret
else:
err_msg = "API request to Anthropic failed"
err_msg = "API request failed"
ret = response.json()
if "error" in ret and "message" in ret["error"]:
err_msg += ": " + ret["error"]["message"]
@ -718,6 +724,109 @@ async def callCustomProvider():
# Return the response
return jsonify({'response': response})
"""
LOCALLY SAVED FLOWS
"""
@app.route('/api/flows', methods=['GET'])
def get_flows():
"""Return a list of all saved flows. If the directory does not exist, try to create it."""
os.makedirs(FLOWS_DIR, exist_ok=True) # Creates the directory if it doesn't exist
flows = [
{
"name": f,
"last_modified": datetime.fromtimestamp(os.path.getmtime(os.path.join(FLOWS_DIR, f))).isoformat()
}
for f in os.listdir(FLOWS_DIR)
if f.endswith('.cforge') and f != "__autosave.cforge" # ignore the special autosave file
]
# Sort the flow files by last modified date in descending order (most recent first)
flows.sort(key=lambda x: x["last_modified"], reverse=True)
return jsonify({
"flow_dir": FLOWS_DIR,
"flows": flows
})
@app.route('/api/flows/<filename>', methods=['GET'])
def get_flow(filename):
"""Return the content of a specific flow"""
if not filename.endswith('.cforge'):
filename += '.cforge'
try:
with open(os.path.join(FLOWS_DIR, filename), 'r') as f:
return jsonify(json.load(f))
except FileNotFoundError:
return jsonify({"error": "Flow not found"}), 404
@app.route('/api/flows/<filename>', methods=['DELETE'])
def delete_flow(filename):
"""Delete a flow"""
if not filename.endswith('.cforge'):
filename += '.cforge'
try:
os.remove(os.path.join(FLOWS_DIR, filename))
return jsonify({"message": f"Flow {filename} deleted successfully"})
except FileNotFoundError:
return jsonify({"error": "Flow not found"}), 404
@app.route('/api/flows/<filename>', methods=['PUT'])
def save_or_rename_flow(filename):
"""Save or rename a flow"""
data = request.json
if not filename.endswith('.cforge'):
filename += '.cforge'
if data.get('flow'):
# Save flow (overwriting any existing flow file with the same name)
flow_data = data.get('flow')
try:
filepath = os.path.join(FLOWS_DIR, filename)
with open(filepath, 'w') as f:
json.dump(flow_data, f)
return jsonify({"message": f"Flow '{filename}' saved!"})
except FileNotFoundError:
return jsonify({"error": f"Could not save flow '{filename}' to local filesystem. See terminal for more details."}), 404
elif data.get('newName'):
# Rename flow
new_name = data.get('newName')
if not new_name.endswith('.cforge'):
new_name += '.cforge'
try:
# Check for name clashes (if a flow already exists with the new name)
if os.path.isfile(os.path.join(FLOWS_DIR, new_name)):
raise Exception("A flow with that name already exists.")
os.rename(os.path.join(FLOWS_DIR, filename), os.path.join(FLOWS_DIR, new_name))
return jsonify({"message": f"Flow renamed from {filename} to {new_name}"})
except Exception as error:
return jsonify({"error": str(error)}), 404
@app.route('/api/getUniqueFlowFilename', methods=['PUT'])
def get_unique_flow_name():
"""Return a non-name-clashing filename to store in the local disk."""
data = request.json
filename = data.get("name")
try:
base, ext = os.path.splitext(filename)
if ext is None or len(ext) == 0:
ext = ".cforge"
unique_filename = base + ext
i = 1
# Find the first non-clashing filename of the form <filename>(i).cforge where i=1,2,3 etc
while os.path.isfile(os.path.join(FLOWS_DIR, unique_filename)):
unique_filename = f"{base}({i}){ext}"
i += 1
return jsonify(unique_filename.replace(".cforge", ""))
except Exception as e:
return jsonify({"error": str(e)}), 404
def run_server(host="", port=8000, cmd_args=None):
global HOSTNAME, PORT

View File

@ -16,13 +16,13 @@
"@emoji-mart/react": "^1.1.1",
"@fontsource/geist-mono": "^5.0.1",
"@google-ai/generativelanguage": "^0.2.0",
"@google/generative-ai": "^0.1.3",
"@google/generative-ai": "^0.21.0",
"@mantine/core": "^6.0.9",
"@mantine/dates": "^6.0.13",
"@mantine/dropzone": "^6.0.19",
"@mantine/form": "^6.0.11",
"@mantine/prism": "^6.0.15",
"@mirai73/bedrock-fm": "^0.4.3",
"@mirai73/bedrock-fm": "^0.4.10",
"@reactflow/background": "^11.2.0",
"@reactflow/controls": "^11.1.11",
"@reactflow/core": "^11.7.0",
@ -65,7 +65,7 @@
"lodash": "^4.17.21",
"lz-string": "^1.5.0",
"mantine-contextmenu": "^6.1.0",
"mantine-react-table": "^1.0.0-beta.8",
"mantine-react-table": "^1.3.4",
"markdown-it": "^13.0.1",
"mathjs": "^11.8.2",
"mdast-util-from-markdown": "^2.0.0",
@ -112,7 +112,7 @@
"@types/react-edit-text": "^5.0.4",
"@types/react-plotly.js": "^2.6.3",
"@types/styled-components": "^5.1.34",
"eslint": "^8.56.0",
"eslint": "<9.0.0",
"eslint-config-prettier": "^9.1.0",
"eslint-config-semistandard": "^17.0.0",
"eslint-config-standard": "^17.1.0",
@ -3785,9 +3785,9 @@
}
},
"node_modules/@google/generative-ai": {
"version": "0.1.3",
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.1.3.tgz",
"integrity": "sha512-Cm4uJX1sKarpm1mje/MiOIinM7zdUUrQp/5/qGPAgznbdd/B9zup5ehT6c1qGqycFcSopTA1J1HpqHS5kJR8hQ==",
"version": "0.21.0",
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.21.0.tgz",
"integrity": "sha512-7XhUbtnlkSEZK15kN3t+tzIMxsbKm/dSkKBFalj+20NvPKe1kBY7mR2P7vuijEn+f06z5+A8bVGKO0v39cr6Wg==",
"engines": {
"node": ">=18.0.0"
}
@ -4653,9 +4653,9 @@
}
},
"node_modules/@mirai73/bedrock-fm": {
"version": "0.4.3",
"resolved": "https://registry.npmjs.org/@mirai73/bedrock-fm/-/bedrock-fm-0.4.3.tgz",
"integrity": "sha512-+4ytmKfZFswTS5ajkah1O3CTNonPg3Ti7poYnlK1XS7At1xcHEx3tHHGNtTpVOPNeeALz9gawvbOif2GanigoA==",
"version": "0.4.10",
"resolved": "https://registry.npmjs.org/@mirai73/bedrock-fm/-/bedrock-fm-0.4.10.tgz",
"integrity": "sha512-j4Nx9RcrnGoue14MhR0LUzB8LdjfwIQ4FVkkytAMalMU43oO/LNcw4gJGLhrNOaTLRztICe24rIJCrfpnxa7jA==",
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.507.0"
}
@ -6191,11 +6191,11 @@
}
},
"node_modules/@tanstack/react-table": {
"version": "8.10.0",
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.10.0.tgz",
"integrity": "sha512-FNhKE3525hryvuWw90xRbP16qNiq7OLJkDZopOKcwyktErLi1ibJzAN9DFwA/gR1br9SK4StXZh9JPvp9izrrQ==",
"version": "8.10.6",
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.10.6.tgz",
"integrity": "sha512-D0VEfkIYnIKdy6SHiBNEaMc4SxO+MV7ojaPhRu8jP933/gbMi367+Wul2LxkdovJ5cq6awm0L1+jgxdS/unzIg==",
"dependencies": {
"@tanstack/table-core": "8.10.0"
"@tanstack/table-core": "8.10.6"
},
"engines": {
"node": ">=12"
@ -6210,11 +6210,11 @@
}
},
"node_modules/@tanstack/react-virtual": {
"version": "3.0.0-beta.60",
"resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.0.0-beta.60.tgz",
"integrity": "sha512-F0wL9+byp7lf/tH6U5LW0ZjBqs+hrMXJrj5xcIGcklI0pggvjzMNW9DdIBcyltPNr6hmHQ0wt8FDGe1n1ZAThA==",
"version": "3.0.0-beta.63",
"resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.0.0-beta.63.tgz",
"integrity": "sha512-n4aaZs3g9U2oZjFp8dAeT1C2g4rr/3lbCo2qWbD9NquajKnGx7R+EfLBAHJ6pVMmfsTMZ0XCBwkIs7U74R/s0A==",
"dependencies": {
"@tanstack/virtual-core": "3.0.0-beta.60"
"@tanstack/virtual-core": "3.0.0-beta.63"
},
"funding": {
"type": "github",
@ -6225,9 +6225,9 @@
}
},
"node_modules/@tanstack/table-core": {
"version": "8.10.0",
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.10.0.tgz",
"integrity": "sha512-e701yAJ18aGDP6mzVworlFAmQ+gi3Wtqx5mGZUe2BUv4W4D80dJxUfkHdtEGJ6GryAnlCCNTej7eaJiYmPhyYg==",
"version": "8.10.6",
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.10.6.tgz",
"integrity": "sha512-9t8brthhAmCBIjzk7fCDa/kPKoLQTtA31l9Ir76jYxciTlHU61r/6gYm69XF9cbg9n88gVL5y7rNpeJ2dc1AFA==",
"engines": {
"node": ">=12"
},
@ -6237,9 +6237,9 @@
}
},
"node_modules/@tanstack/virtual-core": {
"version": "3.0.0-beta.60",
"resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.0.0-beta.60.tgz",
"integrity": "sha512-QlCdhsV1+JIf0c0U6ge6SQmpwsyAT0oQaOSZk50AtEeAyQl9tQrd6qCHAslxQpgphrfe945abvKG8uYvw3hIGA==",
"version": "3.0.0-beta.63",
"resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.0.0-beta.63.tgz",
"integrity": "sha512-KhhfRYSoQpl0y+2axEw+PJZd/e/9p87PDpPompxcXnweNpt9ZHCT/HuNx7MKM9PVY/xzg9xJSWxwnSCrO+d6PQ==",
"funding": {
"type": "github",
"url": "https://github.com/sponsors/tannerlinsley"
@ -17692,13 +17692,13 @@
}
},
"node_modules/mantine-react-table": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/mantine-react-table/-/mantine-react-table-1.3.0.tgz",
"integrity": "sha512-ljAd9ZI7S89glI8OGbM5DsNHzInZPfigbeDXboR5gLGPuuseLVgezlVjqV1h8MUQkY06++d9ah9zsMLr2VNxlQ==",
"version": "1.3.4",
"resolved": "https://registry.npmjs.org/mantine-react-table/-/mantine-react-table-1.3.4.tgz",
"integrity": "sha512-rD0CaeC4RCU7k/ZKvfj5ijFFMd4clGpeg/EwMcogYFioZjj8aNfD78osTNNYr90AnOAFGnd7ZnderLK89+W1ZQ==",
"dependencies": {
"@tanstack/match-sorter-utils": "8.8.4",
"@tanstack/react-table": "8.10.0",
"@tanstack/react-virtual": "3.0.0-beta.60"
"@tanstack/react-table": "8.10.6",
"@tanstack/react-virtual": "3.0.0-beta.63"
},
"engines": {
"node": ">=14"
@ -17712,7 +17712,7 @@
"@mantine/core": "^6.0",
"@mantine/dates": "^6.0",
"@mantine/hooks": "^6.0",
"@tabler/icons-react": ">=2.23.0",
"@tabler/icons-react": ">=2.23",
"react": ">=18.0",
"react-dom": ">=18.0"
}

View File

@ -14,13 +14,13 @@
"@emoji-mart/react": "^1.1.1",
"@fontsource/geist-mono": "^5.0.1",
"@google-ai/generativelanguage": "^0.2.0",
"@google/generative-ai": "^0.1.3",
"@google/generative-ai": "^0.21.0",
"@mantine/core": "^6.0.9",
"@mantine/dates": "^6.0.13",
"@mantine/dropzone": "^6.0.19",
"@mantine/form": "^6.0.11",
"@mantine/prism": "^6.0.15",
"@mirai73/bedrock-fm": "^0.4.3",
"@mirai73/bedrock-fm": "^0.4.10",
"@reactflow/background": "^11.2.0",
"@reactflow/controls": "^11.1.11",
"@reactflow/core": "^11.7.0",
@ -63,7 +63,7 @@
"lodash": "^4.17.21",
"lz-string": "^1.5.0",
"mantine-contextmenu": "^6.1.0",
"mantine-react-table": "^1.0.0-beta.8",
"mantine-react-table": "^1.3.4",
"markdown-it": "^13.0.1",
"mathjs": "^11.8.2",
"mdast-util-from-markdown": "^2.0.0",
@ -138,7 +138,7 @@
"@types/react-edit-text": "^5.0.4",
"@types/react-plotly.js": "^2.6.3",
"@types/styled-components": "^5.1.34",
"eslint": "^8.56.0",
"eslint": "<9.0.0",
"eslint-config-prettier": "^9.1.0",
"eslint-config-semistandard": "^17.0.0",
"eslint-config-standard": "^17.1.0",

View File

@ -10,12 +10,17 @@ import {
Badge,
Textarea,
Alert,
Divider,
Tooltip,
} from "@mantine/core";
import {
autofill,
autofillTable,
generateColumn,
generateAndReplace,
AIError,
getAIFeaturesModels,
generateAndReplaceTable,
} from "./backend/ai";
import { IconSparkles, IconAlertCircle } from "@tabler/icons-react";
import { AlertModalContext } from "./AlertModal";
@ -30,7 +35,14 @@ import { queryLLM } from "./backend/backend";
import { splitText } from "./SplitNode";
import { escapeBraces } from "./backend/template";
import { cleanMetavarsFilterFunc } from "./backend/utils";
import { Dict, VarsContext } from "./backend/typing";
import {
Dict,
TabularDataColType,
TabularDataRowType,
VarsContext,
} from "./backend/typing";
import { v4 as uuidv4 } from "uuid";
import { StringLookup } from "./backend/cache";
const zeroGap = { gap: "0rem" };
const popoverShadow = "rgb(38, 57, 77) 0px 10px 30px -14px";
@ -74,7 +86,7 @@ export const buildGenEvalCodePrompt = (
specPrompt: string,
manyFuncs?: boolean,
onlyBooleanFuncs?: boolean,
) => `You are to generate ${manyFuncs ? "many different functions" : "one function"} to evaluate textual data, given a user-specified specification.
) => `You are to generate ${manyFuncs ? "many different functions" : "one function"} to evaluate textual data, given a user-specified specification.
The function${manyFuncs ? "s" : ""} will be mapped over an array of objects of type ResponseInfo.
${manyFuncs ? "Each" : "Your"} solution must contain a single function called 'evaluate' that takes a single object, 'r', of type ResponseInfo. A ResponseInfo is defined as:
@ -84,8 +96,8 @@ For instance, here is an evaluator that returns the length of a response:
\`\`\`${progLang === "javascript" ? INFO_EXAMPLE_JS : INFO_EXAMPLE_PY}\`\`\`
You can only write in ${progLang.charAt(0).toUpperCase() + progLang.substring(1)}.
You ${progLang === "javascript" ? 'CANNOT import any external packages, and always use "let" to define variables instead of "var".' : "can use imports if necessary. Do not include any type hints."}
You can only write in ${progLang.charAt(0).toUpperCase() + progLang.substring(1)}.
You ${progLang === "javascript" ? 'CANNOT import any external packages, and always use "let" to define variables instead of "var".' : "can use imports if necessary. Do not include any type hints."}
Your function${manyFuncs ? "s" : ""} can ONLY return ${onlyBooleanFuncs ? "boolean" : "boolean, numeric, or string"} values.
${context}
Here is the user's specification:
@ -226,6 +238,340 @@ export function AIPopover({
);
}
export interface AIGenReplaceTablePopoverProps {
// Values in the rows of the table's columns
values: TabularDataRowType[];
// Names of the table's columns
colValues: TabularDataColType[];
// Function to add new rows
onAddRows: (newRows: TabularDataRowType[]) => void;
// Function to replace the table
onReplaceTable: (
columns: TabularDataColType[],
rows: TabularDataRowType[],
) => void;
// Function to add new columns
onAddColumns: (
newColumns: TabularDataColType[],
rowValues?: string[], // Optional row values
) => void;
// Indicates if values are loading
areValuesLoading: boolean;
// Callback to set loading state
setValuesLoading: (isLoading: boolean) => void;
}
/**
* AI Popover UI for TablularData nodes
*/
export function AIGenReplaceTablePopover({
values,
colValues,
onAddRows,
onReplaceTable,
onAddColumns,
areValuesLoading,
setValuesLoading,
}: AIGenReplaceTablePopoverProps) {
// API keys and provider
const apiKeys = useStore((state) => state.apiKeys);
const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider);
// Alert context
const showAlert = useContext(AlertModalContext);
// Command Fill state
const [commandFillNumber, setCommandFillNumber] = useState<number>(5);
const [isCommandFillLoading, setIsCommandFillLoading] = useState(false);
const [didCommandFillError, setDidCommandFillError] = useState(false);
// Generate and Replace state
const [generateAndReplaceNumber, setGenerateAndReplaceNumber] = useState(5);
const [generateAndReplacePrompt, setGenerateAndReplacePrompt] = useState("");
const [didGenerateAndReplaceTableError, setDidGenerateAndReplaceTableError] =
useState(false);
// Generate Column state
const [isGenerateColumnLoading, setIsGenerateColumnLoading] = useState(false);
const [generateColumnPrompt, setGenerateColumnPrompt] = useState("");
const [didGenerateColumnError, setDidGenerateColumnError] = useState(false);
// Check if there are any non-empty rows
const nonEmptyRows = useMemo(
() =>
values.filter((row) =>
Object.values(row).some((val) => StringLookup.get(val)?.trim()),
).length,
[values],
);
// Check if there are enough rows to suggest autofilling
const enoughRowsForSuggestions = useMemo(
() => nonEmptyRows >= ROW_CONSTANTS.beginAutofilling,
[nonEmptyRows],
);
const showWarning = useMemo(
() => enoughRowsForSuggestions && nonEmptyRows < ROW_CONSTANTS.warnIfBelow,
[enoughRowsForSuggestions, nonEmptyRows],
);
const handleGenerateAndReplaceTable = async () => {
setDidGenerateAndReplaceTableError(false);
setValuesLoading(true);
try {
// Fetch the generated table
const generatedTable = await generateAndReplaceTable(
generateAndReplacePrompt,
generateAndReplaceNumber,
aiFeaturesProvider,
apiKeys,
);
const { cols, rows } = generatedTable;
// Transform the result into TabularDataNode format
const columns = cols.map((col, index) => ({
key: `col-${index}`,
header: col,
}));
const tabularRows = rows.map((row) => {
const rowData: TabularDataRowType = { __uid: uuidv4() };
cols.forEach((col, index) => {
rowData[`col-${index}`] = row.split(" | ")[index]?.trim() || "";
});
return rowData;
});
// Update state with the transformed columns and rows
onReplaceTable(columns, tabularRows);
console.log("Generated table:", { columns, tabularRows });
} catch (error) {
console.error("Error in generateAndReplaceTable:", error);
setDidGenerateAndReplaceTableError(true);
showAlert && showAlert("An error occurred. Please try again.");
} finally {
setValuesLoading(false);
}
};
const handleCommandFill = async () => {
setIsCommandFillLoading(true);
setDidCommandFillError(false);
try {
// Extract columns from the values, excluding the __uid column
const tableColumns = colValues.map((col) => col.key);
// Extract rows as strings, excluding the __uid column and handling empty rows
const tableRows = values
.slice(0, -1) // Remove the last empty row
.map((row) =>
tableColumns
.map((col) => StringLookup.get(row[col])?.trim() || "")
.join(" | "),
);
const tableInput = {
cols: tableColumns,
rows: tableRows,
};
// Fetch new rows from the autofillTable function
const result = await autofillTable(
tableInput,
commandFillNumber,
aiFeaturesProvider,
apiKeys,
);
// Transform result.rows into TabularDataNode format
const newRows = result.rows.map((row) => {
const newRow: TabularDataRowType = { __uid: uuidv4() };
row.split(" | ").forEach((cell, index) => {
newRow[`col-${index}`] = cell;
});
return newRow;
});
// Append the new rows to the existing rows
onAddRows(newRows);
} catch (error) {
console.error("Error generating rows:", error);
setDidCommandFillError(true);
showAlert && showAlert("Failed to generate new rows. Please try again.");
} finally {
setIsCommandFillLoading(false);
}
};
const handleGenerateColumn = async () => {
setDidGenerateColumnError(false);
setIsGenerateColumnLoading(true);
try {
// Extract columns from the values, excluding the __uid column
const tableColumns = colValues;
// Extract rows as strings, excluding the __uid column and handling empty rows
const lastRow = values[values.length - 1]; // Get the last row
const emptyLastRow = Object.values(lastRow).every((val) => !val); // Check if the last row is empty
const tableRows = values
.slice(0, emptyLastRow ? -1 : values.length)
.map((row) =>
tableColumns
.map((col) => StringLookup.get(row[col.key])?.trim() || "")
.join(" | "),
);
const tableInput = {
cols: tableColumns,
rows: tableRows,
};
// Fetch the generated column
const generatedColumn = await generateColumn(
tableInput,
generateColumnPrompt,
aiFeaturesProvider,
apiKeys,
);
const rowValues = generatedColumn.rows;
// Append the new column to the existing columns
onAddColumns(
[{ key: `col-${tableColumns.length}`, header: generatedColumn.col }], // set key to length of columns
rowValues,
);
} catch (error) {
console.error("Error generating column:", error);
setDidGenerateColumnError(true);
showAlert &&
showAlert("Failed to generate a new column. Please try again.");
} finally {
setIsGenerateColumnLoading(false);
}
};
const extendUI = (
<Stack>
{didCommandFillError && (
<Text size="xs" color="red">
Failed to generate rows. Please try again.
</Text>
)}
<div style={{ display: "flex", alignItems: "center", gap: "0.5rem" }}>
<NumberInput
label="Rows to add"
mt={5}
min={1}
max={10}
value={commandFillNumber}
onChange={(num) => setCommandFillNumber(num || 1)}
style={{ flex: 1 }}
/>
<Button
size="sm"
variant="light"
color="grape"
onClick={handleCommandFill}
disabled={!enoughRowsForSuggestions}
loading={isCommandFillLoading}
style={{ marginTop: "1.5rem", flex: 1 }}
>
Extend
</Button>
</div>
{showWarning && (
<Text size="xs" color="grape">
You may want to add more fields for better suggestions.
</Text>
)}
<Divider label="OR" labelPosition="center" />
{didGenerateColumnError && (
<Text size="xs" color="red">
Failed to generate column. Please try again.
</Text>
)}
<Textarea
label="Generate a column for..."
value={generateColumnPrompt}
onChange={(e) => setGenerateColumnPrompt(e.currentTarget.value)}
/>
<Tooltip
label="Can take awhile if you have many rows. Please be patient."
withArrow
position="bottom"
>
<Button
size="sm"
variant="light"
color="grape"
fullWidth
onClick={handleGenerateColumn}
disabled={!enoughRowsForSuggestions}
loading={isGenerateColumnLoading}
>
Add Column
</Button>
</Tooltip>
</Stack>
);
const replaceUI = (
<Stack>
{didGenerateAndReplaceTableError && (
<Text size="xs" color="red">
Failed to replace rows. Please try again.
</Text>
)}
<Textarea
label="Generate data for..."
value={generateAndReplacePrompt}
onChange={(e) => setGenerateAndReplacePrompt(e.currentTarget.value)}
/>
<NumberInput
label="Rows to generate"
min={1}
max={50}
value={generateAndReplaceNumber}
onChange={(num) => setGenerateAndReplaceNumber(num || 1)}
/>
<Button
size="sm"
variant="light"
color="grape"
fullWidth
onClick={handleGenerateAndReplaceTable}
loading={areValuesLoading}
>
Replace
</Button>
</Stack>
);
return (
<AIPopover>
<Tabs color="grape" defaultValue="replace">
<Tabs.List grow>
<Tabs.Tab value="replace">Replace</Tabs.Tab>
<Tabs.Tab value="extend">Extend</Tabs.Tab>
</Tabs.List>
<Tabs.Panel value="extend" pb="xs">
{extendUI}
</Tabs.Panel>
<Tabs.Panel value="replace" pb="xs">
{replaceUI}
</Tabs.Panel>
</Tabs>
</AIPopover>
);
}
export interface AIGenReplaceItemsPopoverProps {
// Strings for the Extend feature to use as a basis.
values: Dict<string> | string[];
@ -308,6 +654,7 @@ export function AIGenReplaceItemsPopover({
const handleGenerateAndReplace = () => {
setDidGenerateAndReplaceError(false);
setValuesLoading(true);
generateAndReplace(
generateAndReplacePrompt,
generateAndReplaceNumber,
@ -604,10 +951,10 @@ export function AIGenCodeEvaluatorPopover({
const template = `Edit the code below according to the following: ${editPrompt}
You ${progLang === "javascript" ? "CANNOT import any external packages." : "can use imports if necessary. Do not include any type hints."}
You ${progLang === "javascript" ? "CANNOT import any external packages." : "can use imports if necessary. Do not include any type hints."}
Functions should only return boolean, numeric, or string values. Present the edited code in a single block.
Code:
Code:
\`\`\`${progLang}
${currentEvalCode}
\`\`\``;

File diff suppressed because it is too large Load Diff

View File

@ -58,11 +58,10 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
// Remove the node, after user confirmation dialog
const handleRemoveNode = useCallback(() => {
// Open the 'are you sure' modal:
if (deleteConfirmModal && deleteConfirmModal.current)
deleteConfirmModal.current.trigger();
deleteConfirmModal?.current?.trigger();
}, [deleteConfirmModal]);
const handleOpenContextMenu = (e: Dict) => {
const handleOpenContextMenu = useCallback((e: Dict) => {
// Ignore all right-clicked elements that aren't children of the parent,
// and that aren't divs (for instance, textfields should still have normal right-click)
if (e.target?.localName !== "div") return;
@ -91,23 +90,22 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
},
});
setContextMenuOpened(true);
};
}, []);
// A BaseNode is just a div with "cfnode" as a class, and optional other className(s) for the specific node.
// It adds a context menu to all nodes upon right-click of the node itself (the div), to duplicate or delete the node.
return (
<div
className={classes}
onPointerDown={() => setContextMenuOpened(false)}
onContextMenu={handleOpenContextMenu}
style={style}
>
const areYouSureModal = useMemo(
() => (
<AreYouSureModal
ref={deleteConfirmModal}
title="Delete node"
message="Are you sure you want to delete this node? This action is irreversible."
onConfirm={() => removeNode(nodeId)}
/>
),
[removeNode, nodeId, deleteConfirmModal],
);
const contextMenu = useMemo(
() => (
<Menu
opened={contextMenuOpened}
withinPortal={true}
@ -132,6 +130,29 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
</Menu.Item>
</Menu.Dropdown>
</Menu>
),
[
handleDuplicateNode,
handleRemoveNode,
contextMenuExts,
children,
contextMenuStyle,
contextMenuOpened,
setContextMenuOpened,
],
);
// A BaseNode is just a div with "cfnode" as a class, and optional other className(s) for the specific node.
// It adds a context menu to all nodes upon right-click of the node itself (the div), to duplicate or delete the node.
return (
<div
className={classes}
onPointerDown={() => setContextMenuOpened(false)}
onContextMenu={handleOpenContextMenu}
style={style}
>
{areYouSureModal}
{contextMenu}
</div>
);
};

View File

@ -33,6 +33,7 @@ import "ace-builds/src-noconflict/theme-xcode";
import "ace-builds/src-noconflict/ext-language_tools";
import {
APP_IS_RUNNING_LOCALLY,
genDebounceFunc,
getVarsAndMetavars,
stripLLMDetailsFromResponses,
toStandardResponseFormat,
@ -52,6 +53,7 @@ import {
import { Status } from "./StatusIndicatorComponent";
import { executejs, executepy, grabResponses } from "./backend/backend";
import { AlertModalContext } from "./AlertModal";
import { StringLookup } from "./backend/cache";
// Whether we are running on localhost or not, and hence whether
// we have access to the Flask backend for, e.g., Python code evaluation.
@ -188,6 +190,7 @@ export interface CodeEvaluatorComponentProps {
onCodeEdit?: (code: string) => void;
onCodeChangedFromLastRun?: () => void;
onCodeEqualToLastRun?: () => void;
sandbox?: boolean;
}
/**
@ -206,6 +209,7 @@ export const CodeEvaluatorComponent = forwardRef<
onCodeEdit,
onCodeChangedFromLastRun,
onCodeEqualToLastRun,
sandbox,
},
ref,
) {
@ -215,6 +219,10 @@ export const CodeEvaluatorComponent = forwardRef<
false,
);
// Debounce helpers
const debounceTimeoutRef = useRef(null);
const debounce = genDebounceFunc(debounceTimeoutRef);
// Controlled handle when user edits code
const handleCodeEdit = (code: string) => {
if (codeTextOnLastRun !== false) {
@ -223,7 +231,9 @@ export const CodeEvaluatorComponent = forwardRef<
else if (!code_changed && onCodeEqualToLastRun) onCodeEqualToLastRun();
}
setCodeText(code);
if (onCodeEdit) onCodeEdit(code);
// Debounce to control number of re-renders to parent, when user is editing/typing:
if (onCodeEdit) debounce(() => onCodeEdit(code), 200)();
};
// Runs the code evaluator/processor over the inputs, returning the results as a Promise.
@ -233,6 +243,8 @@ export const CodeEvaluatorComponent = forwardRef<
script_paths?: string[],
runInSandbox?: boolean,
) => {
if (runInSandbox === undefined) runInSandbox = sandbox;
// Double-check that the code includes an 'evaluate' or 'process' function, whichever is needed:
const find_func_regex =
node_type === "evaluator"
@ -317,7 +329,7 @@ export const CodeEvaluatorComponent = forwardRef<
mode={progLang}
theme="xcode"
onChange={handleCodeEdit}
value={code}
value={codeText}
name={"aceeditor_" + id}
editorProps={{ $blockScrolling: true }}
width="100%"
@ -529,7 +541,12 @@ The Python interpeter in the browser is Pyodide. You may not be able to run some
resp_obj.responses.map((r) => {
// Carry over the response text, prompt, prompt fill history (vars), and llm data
const o: TemplateVarInfo = {
text: typeof r === "string" ? escapeBraces(r) : undefined,
text:
typeof r === "number"
? escapeBraces(StringLookup.get(r)!)
: typeof r === "string"
? escapeBraces(r)
: undefined,
image:
typeof r === "object" && r.t === "img" ? r.d : undefined,
prompt: resp_obj.prompt,
@ -539,6 +556,11 @@ The Python interpeter in the browser is Pyodide. You may not be able to run some
uid: resp_obj.uid,
};
o.text =
o.text !== undefined
? StringLookup.intern(o.text as string)
: undefined;
// Carry over any chat history
if (resp_obj.chat_history)
o.chat_history = resp_obj.chat_history;

View File

@ -1,5 +1,13 @@
import React, { forwardRef, useImperativeHandle } from "react";
import { SimpleGrid, Card, Modal, Text, Button, Tabs } from "@mantine/core";
import {
SimpleGrid,
Card,
Modal,
Text,
Button,
Tabs,
Stack,
} from "@mantine/core";
import { useDisclosure } from "@mantine/hooks";
import { IconChartDots3 } from "@tabler/icons-react";
import { Dict } from "./backend/typing";
@ -331,28 +339,38 @@ const ExampleFlowCard: React.FC<ExampleFlowCardProps> = ({
onSelect,
}) => {
return (
<Card shadow="sm" padding="lg" radius="md" withBorder>
<Text mb="xs" weight={500}>
{title}
</Text>
<Card
shadow="sm"
radius="md"
withBorder
style={{ padding: "16px 10px 16px 10px" }}
>
<Stack justify="space-between" spacing="sm" h={160}>
<div>
<Text mb="xs" weight={500} lh={1.1} align="center">
{title}
</Text>
<Text size="sm" color="dimmed" lh={1.3}>
{description}
</Text>
<Text size="sm" color="dimmed" lh={1.1} align="center">
{description}
</Text>
</div>
<Button
onClick={() => {
if (onSelect) onSelect(filename);
}}
variant="light"
color="blue"
fullWidth
size="sm"
mt="md"
radius="md"
>
{buttonText ?? "Try me"}
</Button>
<Button
onClick={() => {
if (onSelect) onSelect(filename);
}}
variant="light"
color="blue"
h={32}
mih={32}
fullWidth
size="sm"
radius="md"
>
{buttonText ?? "Try me"}
</Button>
</Stack>
</Card>
);
};
@ -414,39 +432,105 @@ const ExampleFlowsModal = forwardRef<
<Tabs.Panel value="examples" pt="xs">
<SimpleGrid cols={3} spacing="sm" verticalSpacing="sm">
<ExampleFlowCard
title="Compare length of responses across LLMs"
title="📑 Compare between prompt templates"
description="Compare between prompt templates using template chaining. Visualize response quality across models."
filename="compare-prompts"
onSelect={onSelect}
/>
<ExampleFlowCard
title="📊 Compare prompt across models"
description="A simple evaluation with a prompt template, some inputs, and three models to prompt. Visualizes variability in response length."
filename="basic-comparison"
onSelect={onSelect}
/>
<ExampleFlowCard
title="Robustness to prompt injection attacks"
description="Get a sense of different model's robustness against prompt injection attacks."
filename="prompt-injection-test"
title="🤖 Compare system prompts"
description="Compares response quality across different system prompts. Visualizes how well it sticks to the instructions to only print Racket code."
filename="comparing-system-msg"
onSelect={onSelect}
/>
<ExampleFlowCard
title="Chain prompts together"
title="📗 Testing knowledge of book beginnings"
description="Test whether different LLMs know the first sentences of famous books."
filename="book-beginnings"
onSelect={onSelect}
/>
<ExampleFlowCard
title="⛓️ Extract data with prompt chaining"
description="Chain one prompt into another to extract entities from a text response. Plots number of entities."
filename="chaining-prompts"
onSelect={onSelect}
/>
<ExampleFlowCard
title="Measure impact of system message on response"
description="Compares response quality across different ChatGPT system prompts. Visualizes how well it sticks to the instructions to only print Racket code."
filename="comparing-system-msg"
title="💬🙇 Estimate chat model sycophancy"
description="Estimate how sycophantic a chat model is: ask it for a well-known fact, then tell it it's wrong, and check whether it apologizes or changes its answer."
filename="chat-sycophancy"
onSelect={onSelect}
/>
<ExampleFlowCard
title="Ground truth evaluation for math problems"
description="Uses a tabular data node to evaluate LLM performance on basic math problems. Compares responses to expected answer and plots performance across LLMs."
title="🧪 Audit models for gender bias"
description="Asks an LLM to estimate the gender of a person, given a profession and salary."
filename="audit-bias"
onSelect={onSelect}
/>
<ExampleFlowCard
title="🛑 Red-teaming of stereotypes about nationalities"
description="Check for whether models refuse to generate stereotypes about people from different countries."
filename="red-team-stereotypes"
onSelect={onSelect}
/>
<ExampleFlowCard
title="🐦 Multi-evals of prompt to extract structured data from tweets"
description="Extracts named entities from a dataset of tweets, and double-checks the output against multiple eval criteria."
filename="tweet-multi-eval"
onSelect={onSelect}
/>
<ExampleFlowCard
title="🧮 Produce structured outputs"
description="Extract information from a dataset and output it in a structured JSON format using OpenAI's structured outputs feature."
filename="structured-outputs"
onSelect={onSelect}
/>
<ExampleFlowCard
title="🔨 Detect whether tool is triggered"
description="Basic example showing whether a given prompt triggered tool usage."
filename="basic-function-calls"
onSelect={onSelect}
/>
<ExampleFlowCard
title="📑 Compare output format"
description="Check whether asking for a different format (YAML, XML, JSON, etc.) changes the content."
filename="comparing-formats"
onSelect={onSelect}
/>
<ExampleFlowCard
title="🧑‍💻️ HumanEvals Python coding benchmark"
description="Run the HumanEvals Python coding benchmark to evaluate LLMs on Python code completion, entirely in your browser. A classic!"
filename="python-coding-eval"
onSelect={onSelect}
/>
<ExampleFlowCard
title="🗯 Check robustness to prompt injection attacks"
description="Get a sense of different model's robustness against prompt injection attacks."
filename="prompt-injection-test"
onSelect={onSelect}
/>
<ExampleFlowCard
title="🔢 Ground truth evaluation for math problems"
description="Uses a Tabular Data Node to evaluate LLM performance on basic math problems. Compares responses to expected answer and plots performance."
filename="basic-math"
onSelect={onSelect}
/>
<ExampleFlowCard
title="Detect whether OpenAI function call was triggered"
description="Basic example showing whether a given prompt triggered an OpenAI function call. Also shows difference between ChatGPT prior to function calls, and function call version."
filename="basic-function-calls"
title="🦟 Test knowledge of mosquitos"
description="Uses an LLM scorer to test whether LLMs know the difference between lifetimes of male and female mosquitos."
filename="mosquito-knowledge"
onSelect={onSelect}
/>
<ExampleFlowCard
title="🖼 Generate images of animals"
description="Shows images of a fox, sparrow, and a pig as a computer scientist and a gamer, using Dall-E2."
filename="animal-images"
onSelect={onSelect}
/>
</SimpleGrid>
@ -462,10 +546,10 @@ const ExampleFlowsModal = forwardRef<
>
OpenAI evals
</a>
{`benchmarking package. We currently load evals with a common system
{` benchmarking package. We currently load evals with a common system
message, a single 'turn' (prompt), and evaluation types of
'includes', 'match', and 'fuzzy match', and a reasonable number of
prompts. &nbsp;`}
prompts. `}
<i>
Warning: some evals include tables with 1000 prompts or more.{" "}
</i>

View File

@ -0,0 +1,312 @@
import React, { useState, useEffect, useContext } from "react";
import {
IconEdit,
IconTrash,
IconMenu2,
IconX,
IconCheck,
} from "@tabler/icons-react";
import axios from "axios";
import { AlertModalContext } from "./AlertModal";
import { Dict } from "./backend/typing";
import {
ActionIcon,
Box,
Drawer,
Group,
Stack,
TextInput,
Text,
Flex,
Divider,
ScrollArea,
} from "@mantine/core";
import { FLASK_BASE_URL } from "./backend/utils";
interface FlowFile {
name: string;
last_modified: string;
}
interface FlowSidebarProps {
/** The name of flow that's currently loaded in the front-end, if defined. */
currentFlow?: string;
onLoadFlow: (flowFile?: Dict<any>, flowName?: string) => void;
}
const FlowSidebar: React.FC<FlowSidebarProps> = ({
onLoadFlow,
currentFlow,
}) => {
const [isOpen, setIsOpen] = useState(false);
const [savedFlows, setSavedFlows] = useState<FlowFile[]>([]);
const [editName, setEditName] = useState<string | null>(null);
const [newEditName, setNewEditName] = useState<string>("newName");
// The name of the local directory where flows are stored
const [flowDir, setFlowDir] = useState<string | undefined>(undefined);
// For displaying alerts
const showAlert = useContext(AlertModalContext);
// Fetch saved flows from the Flask backend
const fetchSavedFlowList = async () => {
try {
const response = await axios.get(`${FLASK_BASE_URL}api/flows`);
const flows = response.data.flows as FlowFile[];
setFlowDir(response.data.flow_dir);
setSavedFlows(
flows.map((item) => ({
name: item.name.replace(".cforge", ""),
last_modified: new Date(item.last_modified).toLocaleString(),
})),
);
} catch (error) {
console.error("Error fetching saved flows:", error);
}
};
// Load a flow when clicked, and push it to the caller
const handleLoadFlow = async (filename: string) => {
try {
// Fetch the flow
const response = await axios.get(
`${FLASK_BASE_URL}api/flows/${filename}`,
);
// Push the flow to the ReactFlow UI. We also pass the filename
// so that the caller can use that info to save the right flow when the user presses save.
onLoadFlow(response.data, filename);
setIsOpen(false); // Close sidebar after loading
} catch (error) {
console.error(`Error loading flow ${filename}:`, error);
if (showAlert) showAlert(error as Error);
}
};
// Delete a flow
const handleDeleteFlow = async (
filename: string,
event: React.MouseEvent<HTMLButtonElement, MouseEvent>,
) => {
event.stopPropagation(); // Prevent triggering the parent click
if (window.confirm(`Are you sure you want to delete "${filename}"?`)) {
try {
await axios.delete(`${FLASK_BASE_URL}api/flows/${filename}`);
fetchSavedFlowList(); // Refresh the list
} catch (error) {
console.error(`Error deleting flow ${filename}:`, error);
if (showAlert) showAlert(error as Error);
}
}
};
// Start editing a flow name
const handleEditClick = (
flowFile: string,
event: React.MouseEvent<HTMLButtonElement, MouseEvent>,
) => {
event.stopPropagation(); // Prevent triggering the parent click
setEditName(flowFile);
setNewEditName(flowFile);
};
// Cancel editing
const handleCancelEdit = (
event: React.MouseEvent<HTMLButtonElement, MouseEvent>,
) => {
event.stopPropagation(); // Prevent triggering the parent click
setEditName(null);
};
// Save the edited flow name
const handleSaveEdit = async (
oldFilename: string,
newFilename: string,
event: React.MouseEvent<HTMLButtonElement, MouseEvent>,
) => {
event?.stopPropagation(); // Prevent triggering the parent click
if (newFilename && newFilename !== oldFilename) {
await axios
.put(`${FLASK_BASE_URL}api/flows/${oldFilename}`, {
newName: newFilename,
})
.then(() => {
onLoadFlow(undefined, newFilename); // Tell the parent that the filename has changed. This won't replace the flow.
fetchSavedFlowList(); // Refresh the list
})
.catch((error) => {
let msg: string;
if (error.response) {
msg = `404 Error: ${error.response.status === 404 ? error.response.data?.error ?? "Not Found" : error.response.data}`;
} else if (error.request) {
// Request was made but no response was received
msg = "No response received from server.";
} else {
// Something else happened in setting up the request
msg = `Unknown Error: ${error.message}`;
}
console.error(msg);
if (showAlert) showAlert(msg);
});
}
// No longer editing
setEditName(null);
setNewEditName("newName");
};
// Load flows when component mounts
useEffect(() => {
if (isOpen) {
fetchSavedFlowList();
}
}, [isOpen]);
return (
<div className="relative">
{/* <RenameValueModal title="Rename flow" label="Edit name" initialValue="" onSubmit={handleEditName} /> */}
{/* Toggle Button */}
<ActionIcon
variant="gradient"
size="1.625rem"
style={{
position: "absolute",
top: "10px",
left: "10px",
// left: isOpen ? "250px" : "10px",
// transition: "left 0.3s ease-in-out",
zIndex: 10,
}}
onClick={() => setIsOpen(!isOpen)}
>
{isOpen ? <IconX /> : <IconMenu2 />}
</ActionIcon>
{/* Sidebar */}
<Drawer
opened={isOpen}
onClose={() => setIsOpen(false)}
title="Saved Flows"
position="left"
size="250px" // Adjust sidebar width
padding="md"
withCloseButton={true}
scrollAreaComponent={ScrollArea.Autosize}
>
<Divider />
<Stack spacing="4px" mt="0px" mb="120px">
{savedFlows.length === 0 ? (
<Text color="dimmed">No saved flows found</Text>
) : (
savedFlows.map((flow) => (
<Box
key={flow.name}
p="6px"
sx={(theme) => ({
borderRadius: theme.radius.sm,
cursor: "pointer",
"&:hover": {
backgroundColor:
theme.colorScheme === "dark"
? theme.colors.dark[6]
: theme.colors.gray[0],
},
})}
onClick={() => {
if (editName !== flow.name) handleLoadFlow(flow.name);
}}
>
{editName === flow.name ? (
<Group spacing="xs">
<TextInput
value={newEditName}
onChange={(e) => setNewEditName(e.target.value)}
style={{ flex: 1 }}
autoFocus
/>
<ActionIcon
color="green"
onClick={(e) => handleSaveEdit(editName, newEditName, e)}
>
<IconCheck size={18} />
</ActionIcon>
<ActionIcon color="gray" onClick={handleCancelEdit}>
<IconX size={18} />
</ActionIcon>
</Group>
) : (
<>
<Flex
justify="space-between"
align="center"
gap="0px"
h="auto"
>
{currentFlow === flow.name ? (
<Box
ml="-15px"
mr="5px"
bg="green"
w="10px"
h="10px"
style={{ borderRadius: "50%" }}
></Box>
) : (
<></>
)}
<Text size="sm" mr="auto">
{flow.name}
</Text>
<Flex gap="0px">
<ActionIcon
color="blue"
onClick={(e) => handleEditClick(flow.name, e)}
>
<IconEdit size={18} />
</ActionIcon>
<ActionIcon
color="red"
onClick={(e) => handleDeleteFlow(flow.name, e)}
>
<IconTrash size={18} />
</ActionIcon>
</Flex>
</Flex>
<Text size="xs" color="gray">
{flow.last_modified}
</Text>
</>
)}
<Divider />
</Box>
))
)}
</Stack>
{/* Sticky footer */}
<div
style={{
position: "fixed",
bottom: 0,
background: "white",
padding: "10px",
borderTop: "1px solid #ddd",
}}
>
{flowDir ? (
<Text size="xs" color="gray">
Local flows are saved at: {flowDir}
</Text>
) : (
<></>
)}
</div>
</Drawer>
</div>
);
};
export default FlowSidebar;

View File

@ -271,7 +271,7 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
// Load image compression settings from cache (if possible)
const cachedImgCompr = (
StorageCache.loadFromLocalStorage("imageCompression") as Dict
StorageCache.loadFromLocalStorage("imageCompression", false) as Dict
)?.value as boolean | undefined;
if (typeof cachedImgCompr === "boolean") {
setImageCompression(cachedImgCompr);
@ -285,6 +285,7 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
const form = useForm({
initialValues: {
OpenAI: "",
OpenAI_BaseURL: "",
Anthropic: "",
Google: "",
Azure_OpenAI: "",
@ -369,6 +370,15 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
placeholder="Paste your OpenAI API key here"
{...form.getInputProps("OpenAI")}
/>
<br />
<TextInput
label="OpenAI Base URL"
description="Note: This is rarely changed."
placeholder="Paste a different base URL to use for OpenAI calls"
{...form.getInputProps("OpenAI_BaseURL")}
/>
<br />
<TextInput
label="HuggingFace API Key"
@ -390,12 +400,24 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
{...form.getInputProps("Google")}
/>
<br />
<TextInput
label="DeepSeek API Key"
placeholder="Paste your DeepSeek API key here"
{...form.getInputProps("DeepSeek")}
/>
<br />
<TextInput
label="Aleph Alpha API Key"
placeholder="Paste your Aleph Alpha API key here"
{...form.getInputProps("AlephAlpha")}
/>
<br />
<TextInput
label="Together API Key"
placeholder="Paste your Together API key here"
{...form.getInputProps("Together")}
/>
<br />
<Divider
my="xs"
@ -502,7 +524,7 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
/>
<Select
label="LLM Provider"
description="The LLM provider to use for generative AI features. Currently only supports OpenAI and Bedrock (Anthropic). OpenAI will query gpt-3.5 and gpt-4 models. Bedrock will query Claude-3 models. You must have set the relevant API keys to use the provider."
description="The LLM provider to use for generative AI features. Currently only supports OpenAI and Bedrock (Anthropic). OpenAI will query gpt-4o and gpt-4 models. Bedrock will query Claude-3 models. You must have set the relevant API keys to use the provider."
dropdownPosition="bottom"
withinPortal
defaultValue={getAIFeaturesModelProviders()[0]}

View File

@ -94,6 +94,7 @@ const InspectorNode: React.FC<InspectorNodeProps> = ({ data, id }) => {
>
<LLMResponseInspector
jsonResponses={jsonResponses ?? []}
isOpen={true}
wideFormat={false}
/>
</div>

View File

@ -27,7 +27,7 @@ import {
getVarsAndMetavars,
cleanMetavarsFilterFunc,
} from "./backend/utils";
import StorageCache from "./backend/cache";
import StorageCache, { StringLookup } from "./backend/cache";
import { ResponseBox } from "./ResponseBoxes";
import {
Dict,
@ -78,9 +78,9 @@ const displayJoinedTexts = (
return textInfos.map((info, idx) => {
const llm_name =
typeof info !== "string"
? typeof info.llm === "string"
? info.llm
: info.llm?.name
? typeof info.llm === "string" || typeof info.llm === "number"
? StringLookup.get(info.llm)
: StringLookup.get(info.llm?.name)
: "";
const ps = (
<pre className="small-response">
@ -267,7 +267,12 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
if (groupByVar !== "A") vars[varname] = var_val;
return {
text: joinTexts(
resp_objs.map((r) => (typeof r === "string" ? r : r.text ?? "")),
resp_objs.map(
(r) =>
(typeof r === "string" || typeof r === "number"
? StringLookup.get(r)
: StringLookup.get(r.text)) ?? "",
),
formatting,
),
fill_history: isMetavar ? {} : vars,
@ -284,7 +289,12 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
countNumLLMs(unspecGroup) > 1 ? undefined : unspecGroup[0].llm;
joined_texts.push({
text: joinTexts(
unspecGroup.map((u) => (typeof u === "string" ? u : u.text ?? "")),
unspecGroup.map(
(u) =>
(typeof u === "string" || typeof u === "number"
? StringLookup.get(u)
: StringLookup.get(u.text)) ?? "",
),
formatting,
),
fill_history: {},
@ -326,7 +336,10 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
let joined_texts: (TemplateVarInfo | string)[] = [];
const [groupedRespsByLLM, nonLLMRespGroup] = groupResponsesBy(
resp_objs,
(r) => (typeof r.llm === "string" ? r.llm : r.llm?.key),
(r) =>
typeof r.llm === "string" || typeof r.llm === "number"
? StringLookup.get(r.llm)
: r.llm?.key,
);
// eslint-disable-next-line
Object.entries(groupedRespsByLLM).forEach(([llm_key, resp_objs]) => {
@ -337,7 +350,7 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
if (nonLLMRespGroup.length > 0)
joined_texts.push(
joinTexts(
nonLLMRespGroup.map((t) => t.text ?? ""),
nonLLMRespGroup.map((t) => StringLookup.get(t.text) ?? ""),
formatting,
),
);
@ -353,7 +366,12 @@ const JoinNode: React.FC<JoinNodeProps> = ({ data, id }) => {
setDataPropsForNode(id, { fields: joined_texts });
} else {
let joined_texts: string | TemplateVarInfo = joinTexts(
resp_objs.map((r) => (typeof r === "string" ? r : r.text ?? "")),
resp_objs.map(
(r) =>
(typeof r === "string" || typeof r === "number"
? StringLookup.get(r)
: StringLookup.get(r.text)) ?? "",
),
formatting,
);

View File

@ -21,11 +21,27 @@ import LLMResponseInspectorModal, {
} from "./LLMResponseInspectorModal";
import InspectFooter from "./InspectFooter";
import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer";
import { stripLLMDetailsFromResponses } from "./backend/utils";
import {
extractSettingsVars,
genDebounceFunc,
stripLLMDetailsFromResponses,
} from "./backend/utils";
import { AlertModalContext } from "./AlertModal";
import { Dict, LLMResponse, LLMSpec, QueryProgress } from "./backend/typing";
import {
Dict,
LLMResponse,
LLMResponseData,
LLMSpec,
QueryProgress,
} from "./backend/typing";
import { Status } from "./StatusIndicatorComponent";
import { evalWithLLM, grabResponses } from "./backend/backend";
import { evalWithLLM, generatePrompts, grabResponses } from "./backend/backend";
import { UserForcedPrematureExit } from "./backend/errors";
import CancelTracker from "./backend/canceler";
import { PromptInfo, PromptListModal, PromptListPopover } from "./PromptNode";
import { useDisclosure } from "@mantine/hooks";
import { PromptTemplate } from "./backend/template";
import { StringLookup } from "./backend/cache";
// The default prompt shown in gray highlights to give people a good example of an evaluation prompt.
const PLACEHOLDER_PROMPT =
@ -57,7 +73,9 @@ const DEFAULT_LLM_ITEM = (() => {
const item = [initLLMProviders.find((i) => i.base_model === "gpt-4")].map(
(i) => ({
key: uuid(),
settings: getDefaultModelSettings((i as LLMSpec).base_model),
settings: getDefaultModelSettings(
StringLookup.get(i?.base_model) as string,
),
...i,
}),
)[0];
@ -69,12 +87,15 @@ export interface LLMEvaluatorComponentRef {
run: (
input_node_ids: string[],
onProgressChange?: (progress: QueryProgress) => void,
cancelId?: string | number,
) => Promise<LLMResponse[]>;
cancel: (cancelId: string | number, cancelProgress: () => void) => void;
serialize: () => {
prompt: string;
format: string;
grader?: LLMSpec;
};
getPromptTemplate: () => string;
}
export interface LLMEvaluatorComponentProps {
@ -116,11 +137,17 @@ export const LLMEvaluatorComponent = forwardRef<
);
const apiKeys = useStore((state) => state.apiKeys);
// Debounce helpers
const debounceTimeoutRef = useRef(null);
const debounce = genDebounceFunc(debounceTimeoutRef);
const handlePromptChange = useCallback(
(e: React.ChangeEvent<HTMLTextAreaElement>) => {
// Store prompt text
setPromptText(e.target.value);
if (onPromptEdit) onPromptEdit(e.target.value);
// Update the caller, but debounce to reduce the number of callbacks when user is typing
if (onPromptEdit) debounce(() => onPromptEdit(e.target.value), 200)();
},
[setPromptText, onPromptEdit],
);
@ -143,50 +170,83 @@ export const LLMEvaluatorComponent = forwardRef<
[setExpectedFormat, onFormatChange],
);
const getPromptTemplate = () => {
const formatting_instr = OUTPUT_FORMAT_PROMPTS[expectedFormat] ?? "";
return (
"You are evaluating text that will be pasted below. " +
promptText +
" " +
formatting_instr +
"\n```\n{__input}\n```"
);
};
// Runs the LLM evaluator over the inputs, returning the results in a Promise.
// Errors are raised as a rejected Promise.
const run = (
input_node_ids: string[],
onProgressChange?: (progress: QueryProgress) => void,
cancelId?: string | number,
) => {
// Create prompt template to wrap user-specified scorer prompt and input data
const formatting_instr = OUTPUT_FORMAT_PROMPTS[expectedFormat] ?? "";
const template =
"You are evaluating text that will be pasted below. " +
promptText +
" " +
formatting_instr +
"\n```\n{input}\n```";
// Keeping track of progress (unpacking the progress state since there's only a single LLM)
const template = getPromptTemplate();
const llm_key = llmScorers[0].key ?? "";
const _progress_listener = onProgressChange
? (progress_by_llm: Dict<QueryProgress>) =>
onProgressChange({
success: progress_by_llm[llm_key].success,
error: progress_by_llm[llm_key].error,
})
: undefined;
// Run LLM as evaluator
return evalWithLLM(
id ?? Date.now().toString(),
llmScorers[0],
template,
input_node_ids,
apiKeys ?? {},
_progress_listener,
).then(function (res) {
// Check if there's an error; if so, bubble it up to user and exit:
if (res.errors && res.errors.length > 0) throw new Error(res.errors[0]);
else if (res.responses === undefined)
throw new Error(
"Unknown error encountered when requesting evaluations: empty response returned.",
// Fetch info about the number of queries we'll need to make
return grabResponses(input_node_ids)
.then(function (resps) {
// Create progress listener
// Keeping track of progress (unpacking the progress state since there's only a single LLM)
const num_resps_required = resps.reduce(
(acc, resp_obj) => acc + resp_obj.responses.length,
0,
);
return onProgressChange
? (progress_by_llm: Dict<QueryProgress>) =>
// Debounce the progress bars UI update to ensure we don't re-render too often:
debounce(() => {
onProgressChange({
success:
(100 * progress_by_llm[llm_key].success) /
num_resps_required,
error:
(100 * progress_by_llm[llm_key].error) / num_resps_required,
});
}, 30)()
: undefined;
})
.then((progress_listener) => {
// Run LLM as evaluator
return evalWithLLM(
id ?? Date.now().toString(),
llmScorers[0],
template,
input_node_ids,
apiKeys ?? {},
progress_listener,
cancelId,
);
})
.then(function (res) {
// eslint-disable-next-line
debounce(() => {}, 1)(); // erase any pending debounces
// Success!
return res.responses;
});
// Check if there's an error; if so, bubble it up to user and exit:
if (res.errors && res.errors.length > 0) throw new Error(res.errors[0]);
else if (res.responses === undefined)
throw new Error(
"Unknown error encountered when requesting evaluations: empty response returned.",
);
// Success!
return res.responses;
});
};
const cancel = (cancelId: string | number, cancelProgress: () => void) => {
CancelTracker.add(cancelId);
// eslint-disable-next-line
debounce(cancelProgress, 1)(); // erase any pending debounces
};
// Export the current internal state as JSON
@ -199,7 +259,9 @@ export const LLMEvaluatorComponent = forwardRef<
// Define functions accessible from the parent component
useImperativeHandle(ref, () => ({
run,
cancel,
serialize,
getPromptTemplate,
}));
return (
@ -270,11 +332,20 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
const [status, setStatus] = useState<Status>(Status.NONE);
const showAlert = useContext(AlertModalContext);
// Cancelation of pending queries
const [cancelId, setCancelId] = useState(Date.now());
const refreshCancelId = () => setCancelId(Date.now());
const inspectModal = useRef<LLMResponseInspectorModalRef>(null);
// eslint-disable-next-line
const [uninspectedResponses, setUninspectedResponses] = useState(false);
const [showDrawer, setShowDrawer] = useState(false);
// For an info pop-up that shows all the prompts that will be sent off
// NOTE: This is the 'full' version of the PromptListPopover that activates on hover.
const [infoModalOpened, { open: openInfoModal, close: closeInfoModal }] =
useDisclosure(false);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
@ -287,6 +358,56 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
undefined,
);
// On hover over the 'info' button, to preview the prompts that will be sent out
const [promptPreviews, setPromptPreviews] = useState<PromptInfo[]>([]);
const handlePreviewHover = () => {
// Get the ids from the connected input nodes:
const input_node_ids = inputEdgesForNode(id).map((e) => e.source);
if (input_node_ids.length === 0) {
console.warn("No inputs for evaluator node.");
return;
}
const promptText = llmEvaluatorRef?.current?.getPromptTemplate();
if (!promptText) return;
// Pull input data
try {
grabResponses(input_node_ids)
.then(function (resp_objs) {
const inputs = resp_objs
.map((obj: LLMResponse) =>
obj.responses.map((r: LLMResponseData) => ({
text:
typeof r === "string" || typeof r === "number"
? r
: undefined,
image: typeof r === "object" && r.t === "img" ? r.d : undefined,
fill_history: obj.vars,
metavars: obj.metavars,
})),
)
.flat();
return generatePrompts(promptText, { __input: inputs });
})
.then(function (prompts) {
setPromptPreviews(
prompts.map(
(p: PromptTemplate) =>
new PromptInfo(
p.toString(),
extractSettingsVars(p.fill_history),
),
),
);
});
} catch (err) {
// soft fail
console.error(err);
setPromptPreviews([]);
}
};
const handleRunClick = useCallback(() => {
// Get the ids from the connected input nodes:
const input_node_ids = inputEdgesForNode(id).map((e) => e.source);
@ -299,47 +420,36 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
setProgress({ success: 2, error: 0 });
const handleError = (err: Error | string) => {
setStatus(Status.ERROR);
setProgress(undefined);
if (typeof err !== "string") console.error(err);
if (showAlert) showAlert(typeof err === "string" ? err : err?.message);
if (
err instanceof UserForcedPrematureExit ||
CancelTracker.has(cancelId)
) {
// Handle a premature cancelation
console.log("Canceled.");
setStatus(Status.NONE);
} else {
setStatus(Status.ERROR);
if (showAlert) showAlert(typeof err === "string" ? err : err?.message);
}
};
// Fetch info about the number of queries we'll need to make
grabResponses(input_node_ids)
.then(function (resps) {
// Create progress listener
const num_resps_required = resps.reduce(
(acc, resp_obj) => acc + resp_obj.responses.length,
0,
);
const onProgressChange = (prog: QueryProgress) => {
setProgress({
success: (100 * prog.success) / num_resps_required,
error: (100 * prog.error) / num_resps_required,
});
};
// Run LLM evaluator
llmEvaluatorRef?.current
?.run(input_node_ids, setProgress, cancelId)
.then(function (evald_resps) {
// Ping any vis + inspect nodes attached to this node to refresh their contents:
pingOutputNodes(id);
// Run LLM evaluator
llmEvaluatorRef?.current
?.run(input_node_ids, onProgressChange)
.then(function (evald_resps) {
// Ping any vis + inspect nodes attached to this node to refresh their contents:
pingOutputNodes(id);
console.log(evald_resps);
setLastResponses(evald_resps);
console.log(evald_resps);
setLastResponses(evald_resps);
if (!showDrawer) setUninspectedResponses(true);
if (!showDrawer) setUninspectedResponses(true);
setStatus(Status.READY);
setProgress(undefined);
})
.catch(handleError);
setStatus(Status.READY);
setProgress(undefined);
})
.catch(() => {
handleError("Error pulling input data for node: No input data found.");
});
.catch(handleError);
}, [
inputEdgesForNode,
llmEvaluatorRef,
@ -347,8 +457,15 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
setStatus,
showDrawer,
showAlert,
cancelId,
]);
const handleStopClick = useCallback(() => {
llmEvaluatorRef?.current?.cancel(cancelId, () => setProgress(undefined));
refreshCancelId();
setStatus(Status.NONE);
}, [cancelId, refreshCancelId]);
const showResponseInspector = useCallback(() => {
if (inspectModal && inspectModal.current && lastResponses) {
setUninspectedResponses(false);
@ -384,13 +501,28 @@ const LLMEvaluatorNode: React.FC<LLMEvaluatorNodeProps> = ({ data, id }) => {
nodeId={id}
icon={<IconRobot size="16px" />}
status={status}
isRunning={status === Status.LOADING}
handleRunClick={handleRunClick}
handleStopClick={handleStopClick}
runButtonTooltip="Run scorer over inputs"
customButtons={[
<PromptListPopover
key="prompt-previews"
promptInfos={promptPreviews}
onHover={handlePreviewHover}
onClick={openInfoModal}
/>,
]}
/>
<LLMResponseInspectorModal
ref={inspectModal}
jsonResponses={lastResponses}
/>
<PromptListModal
promptPreviews={promptPreviews}
infoModalOpened={infoModalOpened}
closeInfoModal={closeInfoModal}
/>
<div className="llm-scorer-container">
<LLMEvaluatorComponent

View File

@ -23,9 +23,12 @@ import { StrictModeDroppable } from "./StrictModeDroppable";
import ModelSettingsModal, {
ModelSettingsModalRef,
} from "./ModelSettingsModal";
import { getDefaultModelSettings } from "./ModelSettingSchemas";
import {
getDefaultModelFormData,
getDefaultModelSettings,
} from "./ModelSettingSchemas";
import useStore, { initLLMProviders, initLLMProviderMenu } from "./store";
import { Dict, JSONCompatible, LLMSpec } from "./backend/typing";
import { Dict, JSONCompatible, LLMGroup, LLMSpec } from "./backend/typing";
import { useContextMenu } from "mantine-contextmenu";
import { ContextMenuItemOptions } from "mantine-contextmenu/dist/types";
@ -134,6 +137,8 @@ export function LLMList({
if (item.base_model.startsWith("__custom"))
// Custom models must always have their base name, to avoid name collisions
updated_item.model = item.base_model + "/" + formData.model;
else if (item.base_model === "together")
updated_item.model = ("together/" + formData.model) as string;
else updated_item.model = formData.model as string;
}
if ("shortname" in formData) {
@ -366,22 +371,12 @@ export const LLMListContainer = forwardRef<
);
const handleSelectModel = useCallback(
(model: string) => {
// Get the item for that model
let item = AvailableLLMs.find((llm) => llm.base_model === model);
if (!item) {
// This should never trigger, but in case it does:
console.error(
`Could not find model named '${model}' in list of available LLMs.`,
);
return;
}
(item: LLMSpec) => {
// Give it a uid as a unique key (this is needed for the draggable list to support multiple same-model items; keys must be unique)
item = { key: uuid(), ...item };
// Generate the default settings for this model
item.settings = getDefaultModelSettings(model);
item.settings = getDefaultModelSettings(item.base_model);
// Repair names to ensure they are unique
const unique_name = ensureUniqueName(
@ -391,6 +386,11 @@ export const LLMListContainer = forwardRef<
item.name = unique_name;
item.formData = { shortname: unique_name };
// Together models have a substring "together/" that we need to strip:
if (item.base_model === "together")
item.formData.model = item.model.substring(9);
else item.formData.model = item.model;
let new_items: LLMSpec[] = [];
if (selectModelAction === "add" || selectModelAction === undefined) {
// Add model to the LLM list (regardless of it's present already or not).
@ -430,25 +430,34 @@ export const LLMListContainer = forwardRef<
);
const menuItems = useMemo(() => {
const res: ContextMenuItemOptions[] = [];
for (const item of initLLMProviderMenu) {
if (!("group" in item)) {
res.push({
key: item.model,
title: `${item.emoji} ${item.name}`,
onClick: () => handleSelectModel(item.base_model),
});
} else {
res.push({
const initModels: Set<string> = new Set<string>();
const convert = (item: LLMSpec | LLMGroup): ContextMenuItemOptions => {
if ("group" in item) {
return {
key: item.group,
title: `${item.emoji} ${item.group}`,
items: item.items.map((k) => ({
key: k.model,
title: `${k.emoji} ${k.name}`,
onClick: () => handleSelectModel(k.base_model),
})),
});
items: item.items.map(convert),
};
} else {
initModels.add(item.base_model);
return {
key: item.model,
title: `${item.emoji} ${item.name}`,
onClick: () => handleSelectModel(item),
};
}
};
const res = initLLMProviderMenu.map(convert);
for (const item of AvailableLLMs) {
if (initModels.has(item.base_model)) {
continue;
}
res.push({
key: item.base_model,
title: `${item.emoji} ${item.name}`,
onClick: () => handleSelectModel(item),
});
}
return res;
}, [AvailableLLMs, handleSelectModel]);

File diff suppressed because it is too large Load Diff

View File

@ -22,6 +22,7 @@ export default function LLMResponseInspectorDrawer({
>
<LLMResponseInspector
jsonResponses={jsonResponses}
isOpen={showDrawer}
wideFormat={false}
/>
</div>

View File

@ -80,6 +80,7 @@ const LLMResponseInspectorModal = forwardRef<
<Suspense fallback={<LoadingOverlay visible={true} />}>
<LLMResponseInspector
jsonResponses={props.jsonResponses}
isOpen={opened}
wideFormat={true}
/>
</Suspense>

File diff suppressed because it is too large Load Diff

View File

@ -4,14 +4,16 @@ import React, {
forwardRef,
useImperativeHandle,
useEffect,
useMemo,
} from "react";
import { Button, Modal, Popover } from "@mantine/core";
import { Button, Modal, Popover, Select } from "@mantine/core";
import { useDisclosure } from "@mantine/hooks";
import emojidata from "@emoji-mart/data";
import Picker from "@emoji-mart/react";
// react-jsonschema-form
import validator from "@rjsf/validator-ajv8";
import Form from "@rjsf/core";
import { WidgetProps } from "@rjsf/utils";
import {
ModelSettings,
getDefaultModelFormData,
@ -24,6 +26,46 @@ import {
ModelSettingsDict,
} from "./backend/typing";
// Custom UI widgets for react-jsonschema-form
const DatalistWidget = (props: WidgetProps) => {
const [data, setData] = useState(
(
props.options.enumOptions?.map((option, index) => ({
value: option.value,
label: option.value,
})) ?? []
).concat(
props.options.enumOptions?.find((o) => o.value === props.value)
? []
: { value: props.value, label: props.value },
),
);
return (
<Select
data={data}
defaultValue={props.value ?? ""}
onChange={(newVal) => props.onChange(newVal ?? "")}
size="sm"
placeholder="Select items"
nothingFound="Nothing found"
searchable
creatable
getCreateLabel={(query) => `+ Create ${query}`}
onCreate={(query) => {
const item = { value: query, label: query };
setData((current) => [...current, item]);
console.log(item);
return item;
}}
/>
);
};
const widgets = {
datalist: DatalistWidget,
};
export interface ModelSettingsModalRef {
trigger: () => void;
}
@ -84,9 +126,16 @@ const ModelSettingsModal = forwardRef<
setSchema(schema);
setUISchema(settingsSpec.uiSchema);
setBaseModelName(settingsSpec.fullName);
// If the user has already saved custom settings...
if (model.formData) {
setFormData(model.formData);
setInitShortname(model.formData.shortname as string | undefined);
// If the "custom_model" field is set, use that as the initial model name, overriding "model".
// if (string_exists(model.formData.custom_model))
// setInitModelName(model.formData.custom_model as string);
// else
setInitModelName(model.formData.model as string | undefined);
} else {
// Create settings from schema
@ -166,7 +215,7 @@ const ModelSettingsModal = forwardRef<
modelname in shortname_map
)
state.formData.shortname = shortname_map[modelname];
else state.formData.shortname = modelname;
else state.formData.shortname = modelname?.split("/").at(-1);
setInitShortname(shortname);
}
@ -244,8 +293,9 @@ const ModelSettingsModal = forwardRef<
<Form
schema={schema}
uiSchema={uiSchema}
widgets={widgets} // Custom UI widgets
formData={formData}
// @ts-expect-error This is literally the example code from react-json-schema; no idea why it wouldn't typecheck correctly.
// // @ts-expect-error This is literally the example code from react-json-schema; no idea why it wouldn't typecheck correctly.
validator={validator}
// @ts-expect-error Expect format is LLMSpec.
onChange={onFormDataChange}

View File

@ -0,0 +1,875 @@
import React, {
useState,
useCallback,
useEffect,
useMemo,
useRef,
useContext,
} from "react";
import { Handle, Position } from "reactflow";
import { v4 as uuid } from "uuid";
import {
TextInput,
Text,
Group,
ActionIcon,
Menu,
Card,
rem,
Collapse,
Button,
Alert,
Tooltip,
} from "@mantine/core";
import { useDisclosure } from "@mantine/hooks";
import {
IconAbacus,
IconBox,
IconChevronDown,
IconChevronRight,
IconDots,
IconPlus,
IconRobot,
IconSearch,
IconTerminal,
IconTrash,
} from "@tabler/icons-react";
import BaseNode from "./BaseNode";
import NodeLabel from "./NodeLabelComponent";
import InspectFooter from "./InspectFooter";
import LLMResponseInspectorModal, {
LLMResponseInspectorModalRef,
} from "./LLMResponseInspectorModal";
import useStore from "./store";
import {
APP_IS_RUNNING_LOCALLY,
batchResponsesByUID,
genDebounceFunc,
toStandardResponseFormat,
} from "./backend/utils";
import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer";
import {
CodeEvaluatorComponent,
CodeEvaluatorComponentRef,
} from "./CodeEvaluatorNode";
import { LLMEvaluatorComponent, LLMEvaluatorComponentRef } from "./LLMEvalNode";
import { GatheringResponsesRingProgress } from "./LLMItemButtonGroup";
import { Dict, LLMResponse, QueryProgress } from "./backend/typing";
import { AlertModalContext } from "./AlertModal";
import { Status } from "./StatusIndicatorComponent";
const IS_RUNNING_LOCALLY = APP_IS_RUNNING_LOCALLY();
const EVAL_TYPE_PRETTY_NAME = {
python: "Python",
javascript: "JavaScript",
llm: "LLM",
};
export interface EvaluatorContainerProps {
name: string;
type: string;
padding?: string | number;
onDelete: () => void;
onChangeTitle: (newTitle: string) => void;
progress?: QueryProgress;
customButton?: React.ReactNode;
children: React.ReactNode;
initiallyOpen?: boolean;
}
/** A wrapper for a single evaluator, that can be renamed */
const EvaluatorContainer: React.FC<EvaluatorContainerProps> = ({
name,
type: evalType,
padding,
onDelete,
onChangeTitle,
progress,
customButton,
children,
initiallyOpen,
}) => {
const [opened, { toggle }] = useDisclosure(initiallyOpen ?? false);
const _padding = useMemo(() => padding ?? "0px", [padding]);
const [title, setTitle] = useState(name ?? "Criteria");
const handleChangeTitle = (newTitle: string) => {
setTitle(newTitle);
if (onChangeTitle) onChangeTitle(newTitle);
};
return (
<Card
withBorder
// shadow="sm"
mb={4}
radius="md"
style={{ cursor: "default" }}
>
<Card.Section withBorder pl="8px">
<Group>
<Group spacing="0px">
<Button
onClick={toggle}
variant="subtle"
color="gray"
p="0px"
m="0px"
>
{opened ? (
<IconChevronDown size="14pt" />
) : (
<IconChevronRight size="14pt" />
)}
</Button>
<TextInput
value={title}
onChange={(e) => setTitle(e.target.value)}
onBlur={(e) => handleChangeTitle(e.target.value)}
placeholder="Criteria name"
variant="unstyled"
size="sm"
className="nodrag nowheel"
styles={{
input: {
padding: "0px",
height: "14pt",
minHeight: "0pt",
fontWeight: 500,
},
}}
/>
</Group>
<Group spacing="4px" ml="auto">
{customButton}
<Text color="#bbb" size="sm" mr="6px">
{evalType}
</Text>
{progress ? (
<GatheringResponsesRingProgress progress={progress} />
) : (
<></>
)}
{/* <Progress
radius="xl"
w={32}
size={14}
sections={[
{ value: 70, color: 'green', tooltip: '70% true' },
{ value: 30, color: 'red', tooltip: '30% false' },
]} /> */}
<Menu withinPortal position="right-start" shadow="sm">
<Menu.Target>
<ActionIcon variant="subtle" color="gray">
<IconDots style={{ width: rem(16), height: rem(16) }} />
</ActionIcon>
</Menu.Target>
<Menu.Dropdown>
{/* <Menu.Item icon={<IconSearch size="14px" />}>
Inspect scores
</Menu.Item>
<Menu.Item icon={<IconInfoCircle size="14px" />}>
Help / info
</Menu.Item> */}
<Menu.Item
icon={<IconTrash size="14px" />}
color="red"
onClick={onDelete}
>
Delete
</Menu.Item>
</Menu.Dropdown>
</Menu>
</Group>
</Group>
</Card.Section>
<Card.Section p={opened ? _padding : "0px"}>
<Collapse in={opened}>{children}</Collapse>
</Card.Section>
</Card>
);
};
export interface EvaluatorContainerDesc {
name: string; // the user's nickname for the evaluator, which displays as the title of the banner
uid: string; // a unique identifier for this evaluator, since name can change
type: "python" | "javascript" | "llm"; // the type of evaluator
state: Dict; // the internal state necessary for that specific evaluator component (e.g., a prompt for llm eval, or code for code eval)
progress?: QueryProgress;
justAdded?: boolean;
}
export interface MultiEvalNodeProps {
data: {
evaluators: EvaluatorContainerDesc[];
refresh: boolean;
title: string;
};
id: string;
}
/** A node that stores multiple evaluator functions (can be mix of LLM scorer prompts and arbitrary code.) */
const MultiEvalNode: React.FC<MultiEvalNodeProps> = ({ data, id }) => {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const pullInputData = useStore((state) => state.pullInputData);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
const bringNodeToFront = useStore((state) => state.bringNodeToFront);
const inputEdgesForNode = useStore((state) => state.inputEdgesForNode);
const flags = useStore((state) => state.flags);
const AI_SUPPORT_ENABLED = useMemo(() => {
return flags.aiSupport;
}, [flags]);
const [status, setStatus] = useState<Status>(Status.NONE);
// For displaying error messages to user
const showAlert = useContext(AlertModalContext);
const inspectModal = useRef<LLMResponseInspectorModalRef>(null);
// -- EvalGen access --
// const pickCriteriaModalRef = useRef(null);
// const onClickPickCriteria = () => {
// const inputs = handlePullInputs();
// pickCriteriaModalRef?.current?.trigger(inputs, (implementations: EvaluatorContainerDesc[]) => {
// // Returned if/when the Pick Criteria modal finishes generating implementations.
// console.warn(implementations);
// // Append the returned implementations to the end of the existing eval list
// setEvaluators((evs) => evs.concat(implementations));
// });
// };
const [uninspectedResponses, setUninspectedResponses] = useState(false);
const [lastResponses, setLastResponses] = useState<LLMResponse[]>([]);
const [lastRunSuccess, setLastRunSuccess] = useState(true);
const [showDrawer, setShowDrawer] = useState(false);
// Debounce helpers
const debounceTimeoutRef = useRef(null);
const debounce = genDebounceFunc(debounceTimeoutRef);
/** Store evaluators as array of JSON serialized state:
* { name: <string> // the user's nickname for the evaluator, which displays as the title of the banner
* type: 'python' | 'javascript' | 'llm' // the type of evaluator
* state: <dict> // the internal state necessary for that specific evaluator component (e.g., a prompt for llm eval, or code for code eval)
* }
*/
const [evaluators, setEvaluators] = useState(data.evaluators ?? []);
// Add an evaluator to the end of the list
const addEvaluator = useCallback(
(name: string, type: EvaluatorContainerDesc["type"], state: Dict) => {
setEvaluators(
evaluators.concat({ name, uid: uuid(), type, state, justAdded: true }),
);
},
[evaluators],
);
// Sync evaluator state to stored state of this node
useEffect(() => {
setDataPropsForNode(id, {
evaluators: evaluators.map((e) => ({ ...e, justAdded: undefined })),
});
}, [evaluators]);
// Generate UI for the evaluator state
const evaluatorComponentRefs = useRef<
{
type: "code" | "llm";
name: string;
ref: CodeEvaluatorComponentRef | LLMEvaluatorComponentRef | null;
}[]
>([]);
const updateEvalState = (
idx: number,
transformFunc: (e: EvaluatorContainerDesc) => void,
) => {
setStatus(Status.WARNING);
setEvaluators((es) =>
es.map((e, i) => {
if (idx === i) transformFunc(e);
return e;
}),
);
};
// const evaluatorComponents = useMemo(() => {
// // evaluatorComponentRefs.current = [];
// return evaluators.map((e, idx) => {
// let component: React.ReactNode;
// if (e.type === "python" || e.type === "javascript") {
// component = (
// <CodeEvaluatorComponent
// ref={(el) =>
// (evaluatorComponentRefs.current[idx] = {
// type: "code",
// name: e.name,
// ref: el,
// })
// }
// code={e.state?.code}
// progLang={e.type}
// type="evaluator"
// id={id}
// onCodeEdit={(code) =>
// updateEvalState(idx, (e) => (e.state.code = code))
// }
// showUserInstruction={false}
// />
// );
// } else if (e.type === "llm") {
// component = (
// <LLMEvaluatorComponent
// ref={(el) =>
// (evaluatorComponentRefs.current[idx] = {
// type: "llm",
// name: e.name,
// ref: el,
// })
// }
// prompt={e.state?.prompt}
// grader={e.state?.grader}
// format={e.state?.format}
// id={id}
// showUserInstruction={false}
// onPromptEdit={(prompt) =>
// updateEvalState(idx, (e) => (e.state.prompt = prompt))
// }
// onLLMGraderChange={(grader) =>
// updateEvalState(idx, (e) => (e.state.grader = grader))
// }
// onFormatChange={(format) =>
// updateEvalState(idx, (e) => (e.state.format = format))
// }
// />
// );
// } else {
// console.error(
// `Unknown evaluator type ${e.type} inside multi-evaluator node. Cannot display evaluator UI.`,
// );
// component = <Alert>Error: Unknown evaluator type {e.type}</Alert>;
// }
// return (
// <EvaluatorContainer
// name={e.name}
// key={`${e.name}-${idx}`}
// type={EVAL_TYPE_PRETTY_NAME[e.type]}
// progress={e.progress}
// onDelete={() => {
// delete evaluatorComponentRefs.current[idx];
// setEvaluators(evaluators.filter((_, i) => i !== idx));
// }}
// onChangeTitle={(newTitle) =>
// setEvaluators(
// evaluators.map((e, i) => {
// if (i === idx) e.name = newTitle;
// console.log(e);
// return e;
// }),
// )
// }
// padding={e.type === "llm" ? "8px" : undefined}
// >
// {component}
// </EvaluatorContainer>
// );
// });
// }, [evaluators, id]);
const handleError = useCallback(
(err: Error | string) => {
console.error(err);
setStatus(Status.ERROR);
showAlert && showAlert(err);
},
[showAlert, setStatus],
);
const handlePullInputs = useCallback(() => {
// Pull input data
try {
const pulled_inputs = pullInputData(["responseBatch"], id);
if (!pulled_inputs || !pulled_inputs.responseBatch) {
console.warn(`No inputs to the Multi-Evaluator node.`);
return [];
}
// Convert to standard response format (StandardLLMResponseFormat)
return pulled_inputs.responseBatch.map(toStandardResponseFormat);
} catch (err) {
handleError(err as Error);
return [];
}
}, [pullInputData, id, toStandardResponseFormat]);
const handleRunClick = useCallback(() => {
// Pull inputs to the node
const pulled_inputs = handlePullInputs();
if (!pulled_inputs || pulled_inputs.length === 0) return;
// Get the ids from the connected input nodes:
// TODO: Remove this dependency; have everything go through pull instead.
const input_node_ids = inputEdgesForNode(id).map((e) => e.source);
if (input_node_ids.length === 0) {
console.warn("No inputs to multi-evaluator node.");
return;
}
// Sanity check that there's evaluators in the multieval node
if (
!evaluatorComponentRefs.current ||
evaluatorComponentRefs.current.length === 0
) {
console.error("Cannot run multievals: No current evaluators found.");
return;
}
// Set status and created rejection callback
setStatus(Status.LOADING);
setLastResponses([]);
// Helper function to update progress ring on a single evaluator component
const updateProgressRing = (
evaluator_idx: number,
progress?: QueryProgress,
) => {
// Update the progress rings, debouncing to avoid too many rerenders
debounce(
(_idx, _progress) =>
setEvaluators((evs) => {
if (_idx >= evs.length) return evs;
evs[_idx].progress = _progress;
return [...evs];
}),
30,
)(evaluator_idx, progress);
};
// Run all evaluators here!
// TODO
const runPromises = evaluatorComponentRefs.current.map(
({ type, name, ref }, idx) => {
if (ref === null) return { type: "error", name, result: null };
// Start loading spinner status on running evaluators
updateProgressRing(idx, { success: 0, error: 0 });
// Run each evaluator
if (type === "code") {
// Run code evaluator
// TODO: Change runInSandbox to be user-controlled, for Python code evals (right now it is always sandboxed)
return (ref as CodeEvaluatorComponentRef)
.run(pulled_inputs, undefined)
.then((ret) => {
console.log("Code evaluator done!", ret);
updateProgressRing(idx, undefined);
if (ret.error !== undefined) throw new Error(ret.error);
return {
type: "code",
name,
result: ret.responses,
};
});
} else {
// Run LLM-based evaluator
// TODO: Add back live progress, e.g. (progress) => updateProgressRing(idx, progress)) but with appropriate mapping for progress.
return (ref as LLMEvaluatorComponentRef)
.run(input_node_ids, (progress) => {
updateProgressRing(idx, progress);
})
.then((ret) => {
console.log("LLM evaluator done!", ret);
updateProgressRing(idx, undefined);
return {
type: "llm",
name,
result: ret,
};
});
}
},
);
// When all evaluators finish...
Promise.allSettled(runPromises).then((settled) => {
if (settled.some((s) => s.status === "rejected")) {
setStatus(Status.ERROR);
setLastRunSuccess(false);
// @ts-expect-error Reason exists on rejected settled promises, but TS doesn't know it for some reason.
handleError(settled.find((s) => s.status === "rejected").reason);
return;
}
// Remove progress rings without errors
setEvaluators((evs) =>
evs.map((e) => {
if (e.progress && !e.progress.error) e.progress = undefined;
return e;
}),
);
// Ignore null refs
settled = settled.filter(
(s) => s.status === "fulfilled" && s.value.result !== null,
);
// Success -- set the responses for the inspector
// First we need to group up all response evals by UID, *within* each evaluator.
const evalResults = settled.map((s) => {
const v =
s.status === "fulfilled"
? s.value
: { type: "code", name: "Undefined", result: [] };
if (v.type === "llm") return v; // responses are already batched by uid
// If code evaluator, for some reason, in this version of CF the code eval has de-batched responses.
// We need to re-batch them by UID before returning, to correct this:
return {
type: v.type,
name: v.name,
result: batchResponsesByUID(v.result ?? []),
};
});
// Now we have a duplicates of each response object, one per evaluator run,
// with evaluation results per evaluator. They are not yet merged. We now need
// to merge the evaluation results within response objects with the same UIDs.
// It *should* be the case (invariant) that response objects with the same UID
// have exactly the same number of evaluation results (e.g. n=3 for num resps per prompt=3).
const merged_res_objs_by_uid: Dict<LLMResponse> = {};
// For each set of evaluation results...
evalResults.forEach(({ name, result }) => {
// For each response obj in the results...
result?.forEach((res_obj: LLMResponse) => {
// If it's not already in the merged dict, add it:
const uid = res_obj.uid;
if (
res_obj.eval_res !== undefined &&
!(uid in merged_res_objs_by_uid)
) {
// Transform evaluation results into dict form, indexed by "name" of the evaluator:
res_obj.eval_res.items = res_obj.eval_res.items.map((item) => {
if (typeof item === "object") item = item.toString();
return {
[name]: item,
};
});
res_obj.eval_res.dtype = "KeyValue_Mixed"; // "KeyValue_Mixed" enum;
merged_res_objs_by_uid[uid] = res_obj; // we don't make a copy, to save time
} else {
// It is already in the merged dict, so add the new eval results
// Sanity check that the lengths of eval result lists are equal across evaluators:
if (merged_res_objs_by_uid[uid].eval_res === undefined) return;
else if (
// @ts-expect-error We've already checked that eval_res is defined, yet TS throws an error anyway... skip it:
merged_res_objs_by_uid[uid].eval_res.items.length !==
res_obj.eval_res?.items?.length
) {
console.error(
`Critical error: Evaluation result lists for response ${uid} do not contain the same number of items per evaluator. Skipping...`,
);
return;
}
// Add the new evaluation result, keyed by evaluator name:
// @ts-expect-error We've already checked that eval_res is defined, yet TS throws an error anyway... skip it:
merged_res_objs_by_uid[uid].eval_res.items.forEach((item, idx) => {
if (typeof item === "object") {
let v = res_obj.eval_res?.items[idx];
if (typeof v === "object") v = v.toString();
item[name] = v ?? "undefined";
}
});
}
});
});
// We now have a dict of the form { uid: LLMResponse }
// We need return only the values of this dict:
setLastResponses(Object.values(merged_res_objs_by_uid));
setLastRunSuccess(true);
setStatus(Status.READY);
});
}, [
handlePullInputs,
pingOutputNodes,
status,
showDrawer,
evaluators,
evaluatorComponentRefs,
]);
const showResponseInspector = useCallback(() => {
if (inspectModal && inspectModal.current && lastResponses) {
setUninspectedResponses(false);
inspectModal.current.trigger();
}
}, [inspectModal, lastResponses]);
// Something changed upstream
useEffect(() => {
if (data.refresh && data.refresh === true) {
setDataPropsForNode(id, { refresh: false });
setStatus(Status.WARNING);
}
}, [data]);
return (
<BaseNode
classNames="evaluator-node"
nodeId={id}
style={{ backgroundColor: "#eee" }}
>
<NodeLabel
title={data.title || "Multi-Evaluator"}
nodeId={id}
icon={<IconAbacus size="16px" />}
status={status}
handleRunClick={handleRunClick}
runButtonTooltip="Run all evaluators over inputs"
/>
<LLMResponseInspectorModal
ref={inspectModal}
jsonResponses={lastResponses}
/>
{/* <PickCriteriaModal ref={pickCriteriaModalRef} /> */}
<iframe style={{ display: "none" }} id={`${id}-iframe`}></iframe>
{/* {evaluatorComponents} */}
{evaluators.map((e, idx) => (
<EvaluatorContainer
name={e.name}
key={`${e.name}-${idx}`}
type={EVAL_TYPE_PRETTY_NAME[e.type]}
initiallyOpen={e.justAdded}
progress={e.progress}
customButton={
e.state?.sandbox !== undefined ? (
<Tooltip
label={
e.state?.sandbox
? "Running in sandbox (pyodide)"
: "Running unsandboxed (local Python)"
}
withinPortal
withArrow
>
<button
onClick={() =>
updateEvalState(
idx,
(e) => (e.state.sandbox = !e.state.sandbox),
)
}
className="custom-button"
style={{ border: "none", padding: "0px", marginTop: "3px" }}
>
<IconBox
size="12pt"
color={e.state.sandbox ? "orange" : "#999"}
/>
</button>
</Tooltip>
) : undefined
}
onDelete={() => {
delete evaluatorComponentRefs.current[idx];
setEvaluators(evaluators.filter((_, i) => i !== idx));
}}
onChangeTitle={(newTitle) =>
setEvaluators((evs) =>
evs.map((e, i) => {
if (i === idx) e.name = newTitle;
console.log(e);
return e;
}),
)
}
padding={e.type === "llm" ? "8px" : undefined}
>
{e.type === "python" || e.type === "javascript" ? (
<CodeEvaluatorComponent
ref={(el) =>
(evaluatorComponentRefs.current[idx] = {
type: "code",
name: e.name,
ref: el,
})
}
code={e.state?.code}
progLang={e.type}
sandbox={e.state?.sandbox}
type="evaluator"
id={id}
onCodeEdit={(code) =>
updateEvalState(idx, (e) => (e.state.code = code))
}
showUserInstruction={false}
/>
) : e.type === "llm" ? (
<LLMEvaluatorComponent
ref={(el) =>
(evaluatorComponentRefs.current[idx] = {
type: "llm",
name: e.name,
ref: el,
})
}
prompt={e.state?.prompt}
grader={e.state?.grader}
format={e.state?.format}
id={`${id}-${e.uid}`}
showUserInstruction={false}
onPromptEdit={(prompt) =>
updateEvalState(idx, (e) => (e.state.prompt = prompt))
}
onLLMGraderChange={(grader) =>
updateEvalState(idx, (e) => (e.state.grader = grader))
}
onFormatChange={(format) =>
updateEvalState(idx, (e) => (e.state.format = format))
}
/>
) : (
<Alert>Error: Unknown evaluator type {e.type}</Alert>
)}
</EvaluatorContainer>
))}
<Handle
type="target"
position={Position.Left}
id="responseBatch"
className="grouped-handle"
style={{ top: "50%" }}
/>
{/* TO IMPLEMENT <Handle
type="source"
position={Position.Right}
id="output"
className="grouped-handle"
style={{ top: "50%" }}
/> */}
<div className="add-text-field-btn">
<Menu withinPortal position="right-start" shadow="sm">
<Menu.Target>
<Tooltip label="Add evaluator" position="left" withArrow>
<ActionIcon variant="outline" color="gray" size="sm">
<IconPlus size="12px" />
</ActionIcon>
</Tooltip>
</Menu.Target>
<Menu.Dropdown>
<Menu.Item
icon={<IconTerminal size="14px" />}
onClick={() =>
addEvaluator(
`Criteria ${evaluators.length + 1}`,
"javascript",
{
code: "function evaluate(r) {\n\treturn r.text.length;\n}",
},
)
}
>
JavaScript
</Menu.Item>
{IS_RUNNING_LOCALLY ? (
<Menu.Item
icon={<IconTerminal size="14px" />}
onClick={() =>
addEvaluator(`Criteria ${evaluators.length + 1}`, "python", {
code: "def evaluate(r):\n\treturn len(r.text)",
sandbox: true,
})
}
>
Python
</Menu.Item>
) : (
<></>
)}
<Menu.Item
icon={<IconRobot size="14px" />}
onClick={() =>
addEvaluator(`Criteria ${evaluators.length + 1}`, "llm", {
prompt: "",
format: "bin",
})
}
>
LLM
</Menu.Item>
{/* {AI_SUPPORT_ENABLED ? <Menu.Divider /> : <></>} */}
{/* {AI_SUPPORT_ENABLED ? (
<Menu.Item
icon={<IconSparkles size="14px" />}
onClick={onClickPickCriteria}
>
Let an AI decide!
</Menu.Item>
) : (
<></>
)} */}
</Menu.Dropdown>
</Menu>
</div>
{/* EvalGen {evaluators && evaluators.length === 0 ? (
<Flex justify="center" gap={12} mt="md">
<Tooltip
label="Let an AI help you generate criteria and implement evaluation functions."
multiline
position="bottom"
withArrow
>
<Button onClick={onClickPickCriteria} variant="outline" size="xs">
<IconSparkles size="11pt" />
&nbsp;Generate criteria
</Button>
</Tooltip> */}
{/* <Button disabled variant='gradient' gradient={{ from: 'teal', to: 'lime', deg: 105 }}><IconSparkles />&nbsp;Validate</Button> */}
{/* </Flex>
) : (
<></>
)} */}
{lastRunSuccess && lastResponses && lastResponses.length > 0 ? (
<InspectFooter
label={
<>
Inspect scores&nbsp;
<IconSearch size="12pt" />
</>
}
onClick={showResponseInspector}
showNotificationDot={uninspectedResponses}
isDrawerOpen={showDrawer}
showDrawerButton={true}
onDrawerClick={() => {
setShowDrawer(!showDrawer);
setUninspectedResponses(false);
bringNodeToFront(id);
}}
/>
) : (
<></>
)}
<LLMResponseInspectorDrawer
jsonResponses={lastResponses}
showDrawer={showDrawer}
/>
</BaseNode>
);
};
export default MultiEvalNode;

View File

@ -52,6 +52,7 @@ import {
QueryProgress,
LLMResponse,
TemplateVarInfo,
StringOrHash,
} from "./backend/typing";
import { AlertModalContext } from "./AlertModal";
import { Status } from "./StatusIndicatorComponent";
@ -62,6 +63,7 @@ import {
grabResponses,
queryLLM,
} from "./backend/backend";
import { StringLookup } from "./backend/cache";
const getUniqueLLMMetavarKey = (responses: LLMResponse[]) => {
const metakeys = new Set(
@ -81,7 +83,7 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => {
return chats_by_llm;
};
class PromptInfo {
export class PromptInfo {
prompt: string;
settings: Dict;
@ -118,7 +120,7 @@ export interface PromptListPopoverProps {
onClick: () => void;
}
const PromptListPopover: React.FC<PromptListPopoverProps> = ({
export const PromptListPopover: React.FC<PromptListPopoverProps> = ({
promptInfos,
onHover,
onClick,
@ -176,6 +178,39 @@ const PromptListPopover: React.FC<PromptListPopoverProps> = ({
);
};
export interface PromptListModalProps {
promptPreviews: PromptInfo[];
infoModalOpened: boolean;
closeInfoModal: () => void;
}
export const PromptListModal: React.FC<PromptListModalProps> = ({
promptPreviews,
infoModalOpened,
closeInfoModal,
}) => {
return (
<Modal
title={
"List of prompts that will be sent to LLMs (" +
promptPreviews.length +
" total)"
}
size="xl"
opened={infoModalOpened}
onClose={closeInfoModal}
styles={{
header: { backgroundColor: "#FFD700" },
root: { position: "relative", left: "-5%" },
}}
>
<Box m="lg" mt="xl">
{displayPromptInfos(promptPreviews, true)}
</Box>
</Modal>
);
};
export interface PromptNodeProps {
data: {
title: string;
@ -376,24 +411,36 @@ const PromptNode: React.FC<PromptNodeProps> = ({
[setTemplateVars, templateVars, pullInputData, id],
);
const handleInputChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
const value = event.target.value;
const handleInputChange = useCallback(
(event: React.ChangeEvent<HTMLTextAreaElement>) => {
const value = event.target.value;
const updateStatus =
promptTextOnLastRun !== null &&
status !== Status.WARNING &&
value !== promptTextOnLastRun;
// Store prompt text
setPromptText(value);
data.prompt = value;
// Store prompt text
data.prompt = value;
// Update status icon, if need be:
if (
promptTextOnLastRun !== null &&
status !== Status.WARNING &&
value !== promptTextOnLastRun
)
setStatus(Status.WARNING);
// Debounce the global state change to happen only after 500ms, as it forces a costly rerender:
debounce((_value, _updateStatus) => {
setPromptText(_value);
setDataPropsForNode(id, { prompt: _value });
refreshTemplateHooks(_value);
if (_updateStatus) setStatus(Status.WARNING);
}, 300)(value, updateStatus);
// Debounce refreshing the template hooks so we don't annoy the user
debounce((_value) => refreshTemplateHooks(_value), 500)(value);
};
// Debounce refreshing the template hooks so we don't annoy the user
// debounce((_value) => refreshTemplateHooks(_value), 500)(value);
},
[
promptTextOnLastRun,
status,
refreshTemplateHooks,
setDataPropsForNode,
debounceTimeoutRef,
],
);
// On initialization
useEffect(() => {
@ -432,7 +479,7 @@ const PromptNode: React.FC<PromptNodeProps> = ({
// Chat nodes only. Pulls input data attached to the 'past conversations' handle.
// Returns a tuple (past_chat_llms, __past_chats), where both are undefined if nothing is connected.
const pullInputChats = () => {
const pullInputChats = useCallback(() => {
const pulled_data = pullInputData(["__past_chats"], id);
if (!("__past_chats" in pulled_data)) return [undefined, undefined];
@ -454,6 +501,7 @@ const PromptNode: React.FC<PromptNodeProps> = ({
// Add to unique LLMs list, if necessary
if (
typeof info?.llm !== "string" &&
typeof info?.llm !== "number" &&
info?.llm?.name !== undefined &&
!llm_names.has(info.llm.name)
) {
@ -464,8 +512,8 @@ const PromptNode: React.FC<PromptNodeProps> = ({
// Create revised chat_history on the TemplateVarInfo object,
// with the prompt and text of the pulled data as the 2nd-to-last, and last, messages:
const last_messages = [
{ role: "user", content: info.prompt ?? "" },
{ role: "assistant", content: info.text ?? "" },
{ role: "user", content: StringLookup.get(info.prompt) ?? "" },
{ role: "assistant", content: StringLookup.get(info.text) ?? "" },
];
let updated_chat_hist =
info.chat_history !== undefined
@ -475,6 +523,7 @@ const PromptNode: React.FC<PromptNodeProps> = ({
// Append any present system message retroactively as the first message in the chat history:
if (
typeof info?.llm !== "string" &&
typeof info?.llm !== "number" &&
typeof info?.llm?.settings?.system_msg === "string" &&
updated_chat_hist[0].role !== "system"
)
@ -487,7 +536,10 @@ const PromptNode: React.FC<PromptNodeProps> = ({
messages: updated_chat_hist,
fill_history: info.fill_history ?? {},
metavars: info.metavars ?? {},
llm: typeof info?.llm === "string" ? info.llm : info?.llm?.name,
llm:
typeof info?.llm === "string" || typeof info?.llm === "number"
? StringLookup.get(info.llm) ?? "(LLM lookup failed)"
: StringLookup.get(info?.llm?.name),
uid: uuid(),
};
},
@ -495,36 +547,46 @@ const PromptNode: React.FC<PromptNodeProps> = ({
// Returns [list of LLM specs, list of ChatHistoryInfo]
return [past_chat_llms, past_chats];
};
}, [id, pullInputData]);
// Ask the backend how many responses it needs to collect, given the input data:
const fetchResponseCounts = (
prompt: string,
vars: Dict,
llms: (string | Dict)[],
chat_histories?:
| (ChatHistoryInfo | undefined)[]
| Dict<(ChatHistoryInfo | undefined)[]>,
) => {
return countQueries(
prompt,
vars,
llms,
const fetchResponseCounts = useCallback(
(
prompt: string,
vars: Dict,
llms: (StringOrHash | LLMSpec)[],
chat_histories?:
| (ChatHistoryInfo | undefined)[]
| Dict<(ChatHistoryInfo | undefined)[]>,
) => {
return countQueries(
prompt,
vars,
llms,
numGenerations,
chat_histories,
id,
node_type !== "chat" ? showContToggle && contWithPriorLLMs : undefined,
).then(function (results) {
return [results.counts, results.total_num_responses] as [
Dict<Dict<number>>,
Dict<number>,
];
});
},
[
countQueries,
numGenerations,
chat_histories,
showContToggle,
contWithPriorLLMs,
id,
node_type !== "chat" ? showContToggle && contWithPriorLLMs : undefined,
).then(function (results) {
return [results.counts, results.total_num_responses] as [
Dict<Dict<number>>,
Dict<number>,
];
});
};
node_type,
],
);
// On hover over the 'info' button, to preview the prompts that will be sent out
const [promptPreviews, setPromptPreviews] = useState<PromptInfo[]>([]);
const handlePreviewHover = () => {
const handlePreviewHover = useCallback(() => {
// Pull input data and prompt
try {
const pulled_vars = pullInputData(templateVars, id);
@ -545,10 +607,18 @@ const PromptNode: React.FC<PromptNodeProps> = ({
console.error(err);
setPromptPreviews([]);
}
};
}, [
pullInputData,
templateVars,
id,
updateShowContToggle,
generatePrompts,
promptText,
pullInputChats,
]);
// On hover over the 'Run' button, request how many responses are required and update the tooltip. Soft fails.
const handleRunHover = () => {
const handleRunHover = useCallback(() => {
// Check if the PromptNode is not already waiting for a response...
if (status === "loading") {
setRunTooltip("Fetching responses...");
@ -679,9 +749,17 @@ const PromptNode: React.FC<PromptNodeProps> = ({
console.error(err); // soft fail
setRunTooltip("Could not reach backend server.");
});
};
}, [
status,
llmItemsCurrState,
pullInputChats,
contWithPriorLLMs,
pullInputData,
fetchResponseCounts,
promptText,
]);
const handleRunClick = () => {
const handleRunClick = useCallback(() => {
// Go through all template hooks (if any) and check they're connected:
const is_fully_connected = templateVars.every((varname) => {
// Check that some edge has, as its target, this node and its template hook:
@ -912,7 +990,12 @@ Soft failing by replacing undefined with empty strings.`,
resp_obj.responses.map((r) => {
// Carry over the response text, prompt, prompt fill history (vars), and llm nickname:
const o: TemplateVarInfo = {
text: typeof r === "string" ? escapeBraces(r) : undefined,
text:
typeof r === "number"
? escapeBraces(StringLookup.get(r)!)
: typeof r === "string"
? escapeBraces(r)
: undefined,
image:
typeof r === "object" && r.t === "img" ? r.d : undefined,
prompt: resp_obj.prompt,
@ -923,6 +1006,11 @@ Soft failing by replacing undefined with empty strings.`,
uid: resp_obj.uid,
};
o.text =
o.text !== undefined
? StringLookup.intern(o.text as string)
: undefined;
// Carry over any metavars
o.metavars = resp_obj.metavars ?? {};
@ -935,9 +1023,11 @@ Soft failing by replacing undefined with empty strings.`,
// Add a meta var to keep track of which LLM produced this response
o.metavars[llm_metavar_key] =
typeof resp_obj.llm === "string"
? resp_obj.llm
typeof resp_obj.llm === "string" ||
typeof resp_obj.llm === "number"
? StringLookup.get(resp_obj.llm) ?? "(LLM lookup failed)"
: resp_obj.llm.name;
return o;
}),
)
@ -1006,7 +1096,31 @@ Soft failing by replacing undefined with empty strings.`,
.then(open_progress_listener)
.then(query_llms)
.catch(rejected);
};
}, [
templateVars,
triggerAlert,
pullInputChats,
pullInputData,
updateShowContToggle,
llmItemsCurrState,
contWithPriorLLMs,
showAlert,
fetchResponseCounts,
numGenerations,
promptText,
apiKeys,
showContToggle,
cancelId,
refreshCancelId,
node_type,
id,
setDataPropsForNode,
llmListContainer,
responsesWillChange,
showDrawer,
pingOutputNodes,
debounceTimeoutRef,
]);
const handleStopClick = useCallback(() => {
CancelTracker.add(cancelId);
@ -1024,7 +1138,7 @@ Soft failing by replacing undefined with empty strings.`,
setStatus(Status.NONE);
setContChatToggleDisabled(false);
llmListContainer?.current?.resetLLMItemsProgress();
}, [cancelId, refreshCancelId]);
}, [cancelId, refreshCancelId, debounceTimeoutRef]);
const handleNumGenChange = useCallback(
(event: React.ChangeEvent<HTMLInputElement>) => {
@ -1124,24 +1238,11 @@ Soft failing by replacing undefined with empty strings.`,
ref={inspectModal}
jsonResponses={jsonResponses ?? []}
/>
<Modal
title={
"List of prompts that will be sent to LLMs (" +
promptPreviews.length +
" total)"
}
size="xl"
opened={infoModalOpened}
onClose={closeInfoModal}
styles={{
header: { backgroundColor: "#FFD700" },
root: { position: "relative", left: "-5%" },
}}
>
<Box m="lg" mt="xl">
{displayPromptInfos(promptPreviews, true)}
</Box>
</Modal>
<PromptListModal
promptPreviews={promptPreviews}
infoModalOpened={infoModalOpened}
closeInfoModal={closeInfoModal}
/>
{node_type === "chat" ? (
<div ref={setRef}>

View File

@ -1,13 +1,15 @@
import React, { Suspense, useMemo, lazy } from "react";
import { Collapse, Flex } from "@mantine/core";
import { Collapse, Flex, Stack } from "@mantine/core";
import { useDisclosure } from "@mantine/hooks";
import { truncStr } from "./backend/utils";
import { llmResponseDataToString, truncStr } from "./backend/utils";
import {
Dict,
EvaluationScore,
LLMResponse,
LLMResponseData,
StringOrHash,
} from "./backend/typing";
import { StringLookup } from "./backend/cache";
// Lazy load the response toolbars
const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar"));
@ -15,19 +17,53 @@ const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar"));
/* HELPER FUNCTIONS */
const SUCCESS_EVAL_SCORES = new Set(["true", "yes"]);
const FAILURE_EVAL_SCORES = new Set(["false", "no"]);
const getEvalResultStr = (
eval_item: string[] | Dict | string | number | boolean,
) => {
/**
* Returns an array of JSX elements, and the searchable text underpinning them,
* that represents a concrete version of the Evaluation Scores passed in.
* @param eval_item The evaluation result to visualize.
* @param hide_prefix Whether to hide 'score: ' or '{key}: ' prefixes when printing.
* @param onlyString Whether to only return string values.
* @returns An array [JSX.Element, string] where the latter is a string representation of the eval score, to enable search
*/
export const getEvalResultStr = (
eval_item: EvaluationScore,
hide_prefix: boolean,
onlyString?: boolean,
): [JSX.Element | string, string] => {
if (Array.isArray(eval_item)) {
return "scores: " + eval_item.join(", ");
const items_str = (hide_prefix ? "" : "scores: ") + eval_item.join(", ");
return [items_str, items_str];
} else if (typeof eval_item === "object") {
const strs = Object.keys(eval_item).map((key) => {
let val = eval_item[key];
if (typeof val === "number" && val.toString().indexOf(".") > -1)
val = val.toFixed(4); // truncate floats to 4 decimal places
return `${key}: ${val}`;
});
return strs.join(", ");
const strs: [JSX.Element | string, string][] = Object.keys(eval_item).map(
(key, j) => {
const innerKey = `${key}-${j}`;
let val = eval_item[key];
if (typeof val === "number" && val.toString().indexOf(".") > -1)
val = val.toFixed(4); // truncate floats to 4 decimal places
const [recurs_res, recurs_str] = getEvalResultStr(val, true);
if (onlyString) return [`${key}: ${recurs_str}`, recurs_str];
else
return [
<div key={innerKey}>
<span key={0}>{key}: </span>
<span key={1}>{recurs_res}</span>
</div>,
recurs_str,
];
},
);
const joined_strs = strs.map((s) => s[1]).join("\n");
if (onlyString) {
return [joined_strs, joined_strs];
} else
return [
<Stack key={1} spacing={0}>
{strs.map((s, i) => (
<span key={i}>s</span>
))}
</Stack>,
joined_strs,
];
} else {
const eval_str = eval_item.toString().trim().toLowerCase();
const color = SUCCESS_EVAL_SCORES.has(eval_str)
@ -35,12 +71,15 @@ const getEvalResultStr = (
: FAILURE_EVAL_SCORES.has(eval_str)
? "red"
: "black";
return (
<>
<span style={{ color: "gray" }}>{"score: "}</span>
<span style={{ color }}>{eval_str}</span>
</>
);
if (onlyString) return [eval_str, eval_str];
else
return [
<>
{!hide_prefix && <span style={{ color: "gray" }}>{"score: "}</span>}
<span style={{ color }}>{eval_str}</span>
</>,
eval_str,
];
}
};
@ -107,7 +146,7 @@ export const ResponseGroup: React.FC<ResponseGroupProps> = ({
*/
interface ResponseBoxProps {
children: React.ReactNode; // For components, HTML elements, text, etc.
vars?: Dict<string>;
vars?: Dict<StringOrHash>;
truncLenForVars?: number;
llmName?: string;
boxColor?: string;
@ -125,7 +164,10 @@ export const ResponseBox: React.FC<ResponseBoxProps> = ({
const var_tags = useMemo(() => {
if (vars === undefined) return [];
return Object.entries(vars).map(([varname, val]) => {
const v = truncStr(val.trim(), truncLenForVars ?? 18);
const v = truncStr(
(StringLookup.get(val) ?? "").trim(),
truncLenForVars ?? 18,
);
return (
<div key={varname} className="response-var-inline">
<span className="response-var-name">{varname}&nbsp;=&nbsp;</span>
@ -164,10 +206,12 @@ export const genResponseTextsDisplay = (
onlyShowScores?: boolean,
llmName?: string,
wideFormat?: boolean,
hideEvalScores?: boolean,
): React.ReactNode[] | React.ReactNode => {
if (!res_obj) return <></>;
const eval_res_items = res_obj.eval_res ? res_obj.eval_res.items : null;
const eval_res_items =
!hideEvalScores && res_obj.eval_res ? res_obj.eval_res.items : null;
// Bucket responses that have the same text, and sort by the
// number of same responses so that the top div is the most prevalent response.
@ -183,16 +227,15 @@ export const genResponseTextsDisplay = (
const resp_str_to_eval_res: Dict<EvaluationScore> = {};
if (eval_res_items)
responses.forEach((r, idx) => {
resp_str_to_eval_res[typeof r === "string" ? r : r.d] =
eval_res_items[idx];
resp_str_to_eval_res[llmResponseDataToString(r)] = eval_res_items[idx];
});
const same_resp_text_counts = countResponsesBy(responses, (r) =>
typeof r === "string" ? r : r.d,
llmResponseDataToString(r),
);
const resp_special_type_map: Dict<string> = {};
responses.forEach((r) => {
const key = typeof r === "string" ? r : r.d;
const key = llmResponseDataToString(r);
if (typeof r === "object") resp_special_type_map[key] = r.t;
});
const same_resp_keys = Object.keys(same_resp_text_counts).sort(
@ -232,6 +275,7 @@ export const genResponseTextsDisplay = (
uid={res_obj.uid}
innerIdxs={origIdxs}
wideFormat={wideFormat}
responseData={r}
/>
</Suspense>
{llmName !== undefined &&
@ -251,7 +295,7 @@ export const genResponseTextsDisplay = (
)}
{eval_res_items ? (
<p className="small-response-metrics">
{getEvalResultStr(resp_str_to_eval_res[r])}
{getEvalResultStr(resp_str_to_eval_res[r], true)[0]}
</p>
) : (
<></>

View File

@ -5,11 +5,17 @@ import React, {
useMemo,
useState,
} from "react";
import { Button, Flex, Popover, Stack, Textarea } from "@mantine/core";
import { IconMessage2, IconThumbDown, IconThumbUp } from "@tabler/icons-react";
import { Button, Flex, Popover, Stack, Textarea, Tooltip } from "@mantine/core";
import {
IconCopy,
IconMessage2,
IconThumbDown,
IconThumbUp,
} from "@tabler/icons-react";
import StorageCache from "./backend/cache";
import useStore from "./store";
import { deepcopy } from "./backend/utils";
import { LLMResponseData } from "./backend/typing";
type RatingDict = Record<number, boolean | string | undefined>;
@ -63,14 +69,14 @@ export interface ResponseRatingToolbarProps {
uid: string;
wideFormat?: boolean;
innerIdxs: number[];
onUpdateResponses?: () => void;
responseData?: string;
}
const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
uid,
wideFormat,
innerIdxs,
onUpdateResponses,
responseData,
}) => {
// The cache keys storing the ratings for this response object
const gradeKey = getRatingKeyForResponse(uid, "grade");
@ -108,6 +114,9 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
const [noteText, setNoteText] = useState("");
const [notePopoverOpened, setNotePopoverOpened] = useState(false);
// Text state
const [copied, setCopied] = useState(false);
// Override the text in the internal textarea whenever upstream annotation changes.
useEffect(() => {
setNoteText(note !== undefined ? note.toString() : "");
@ -133,7 +142,6 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
new_grades[idx] = grade;
});
setRating(uid, "grade", new_grades);
if (onUpdateResponses) onUpdateResponses();
};
const onAnnotate = (label?: string) => {
@ -145,7 +153,6 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
new_notes[idx] = label;
});
setRating(uid, "note", new_notes);
if (onUpdateResponses) onUpdateResponses();
};
const handleSaveAnnotation = useCallback(() => {
@ -175,6 +182,33 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({
>
<IconThumbDown size={size} />
</ToolbarButton>
<Tooltip
label={copied ? "Copied!" : "Copy"}
withArrow
arrowPosition="center"
>
<ToolbarButton
selected={copied}
onClick={() => {
if (responseData) {
navigator.clipboard
.writeText(responseData)
.then(() => {
console.log("Text copied to clipboard");
setCopied(() => true);
setTimeout(() => {
setCopied(() => false);
}, 1000);
})
.catch((err) => {
console.error("Failed to copy text: ", err);
});
}
}}
>
<IconCopy size={size} />
</ToolbarButton>
</Tooltip>
<Popover
opened={notePopoverOpened}
onChange={setNotePopoverOpened}

View File

@ -28,7 +28,7 @@ import {
} from "./backend/utils";
import { fromMarkdown } from "mdast-util-from-markdown";
import StorageCache from "./backend/cache";
import StorageCache, { StringLookup } from "./backend/cache";
import { ResponseBox } from "./ResponseBoxes";
import { Root, RootContent } from "mdast";
import { Dict, TemplateVarInfo } from "./backend/typing";
@ -130,11 +130,13 @@ const displaySplitTexts = (
} else {
const llm_color =
typeof info.llm === "object" && "name" in info.llm
? color_for_llm(info.llm?.name)
? color_for_llm(
StringLookup.get(info.llm?.name) ?? "(string lookup failed)",
)
: "#ddd";
const llm_name =
typeof info.llm === "object" && "name" in info.llm
? info.llm?.name
? StringLookup.get(info.llm?.name)
: "";
return (
<ResponseBox
@ -289,7 +291,11 @@ const SplitNode: React.FC<SplitNodeProps> = ({ data, id }) => {
.map((resp_obj: TemplateVarInfo | string) => {
if (typeof resp_obj === "string")
return splitText(resp_obj, formatting, true);
const texts = splitText(resp_obj?.text ?? "", formatting, true);
const texts = splitText(
StringLookup.get(resp_obj?.text) ?? "",
formatting,
true,
);
if (texts !== undefined && texts.length >= 1)
return texts.map(
(t: string) =>

View File

@ -24,6 +24,9 @@ import useStore from "./store";
import { sampleRandomElements } from "./backend/utils";
import { Dict, TabularDataRowType, TabularDataColType } from "./backend/typing";
import { Position } from "reactflow";
import { AIGenReplaceTablePopover } from "./AiPopover";
import { parseTableData } from "./backend/tableUtils";
import { StringLookup } from "./backend/cache";
const defaultRows: TabularDataRowType[] = [
{
@ -304,43 +307,8 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
// NOTE: JSON objects should be in row format, with keys
// as the header names. The internal keys of the columns will use uids to be unique.
const importJSONList = (jsonl: unknown) => {
if (!Array.isArray(jsonl)) {
throw new Error(
"Imported tabular data is not in array format: " +
(jsonl !== undefined ? (jsonl as object).toString() : ""),
);
}
// Extract unique column names
const headers = new Set<string>();
jsonl.forEach((o) => Object.keys(o).forEach((key) => headers.add(key)));
// Create new columns with unique ids c0, c1 etc
const cols = Array.from(headers).map((h, idx) => ({
header: h,
key: `c${idx.toString()}`,
}));
// Construct a lookup table from header name to our new key uid
const col_key_lookup: Dict<string> = {};
cols.forEach((c) => {
col_key_lookup[c.header] = c.key;
});
// Construct the table rows by swapping the header names for our new columm keys
const rows = jsonl.map((o) => {
const row: TabularDataRowType = { __uid: uuidv4() };
Object.keys(o).forEach((header) => {
const raw_val = o[header];
const val =
typeof raw_val === "object" ? JSON.stringify(raw_val) : raw_val;
row[col_key_lookup[header]] = val.toString();
});
return row;
});
// Save the new columns and rows
setTableColumns(cols);
const { columns, rows } = parseTableData(jsonl as any[]);
setTableColumns(columns);
setTableData(rows);
pingOutputNodes(id);
};
@ -499,6 +467,141 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
[ref],
);
const [isLoading, setIsLoading] = useState(false);
const [rowValues, setRowValues] = useState<string[]>(
tableData.map((row) => StringLookup.get(row.value) ?? ""),
);
// Function to add new columns to the right of the existing columns (with optional row values)
const addColumns = (
newColumns: TabularDataColType[],
rowValues?: string[], // If values are passed, they will be used to populate the new columns
) => {
setTableColumns((prevColumns) => {
// Filter out columns that already exist
const filteredNewColumns = newColumns.filter(
(col) =>
!prevColumns.some((existingCol) => existingCol.key === col.key),
);
// If no genuinely new columns, return previous columns
if (filteredNewColumns.length === 0) return prevColumns;
const updatedColumns = [...prevColumns, ...filteredNewColumns];
setTableData((prevData) => {
let updatedRows: TabularDataRowType[] = [];
if (prevData.length > 0) {
// Update the existing rows with the new column values
updatedRows = prevData.map((row, rowIndex) => {
const updatedRow = { ...row };
// Set the value for each new column
filteredNewColumns.forEach((col) => {
// Only set the value if it's not already set
if (updatedRow[col.key] === undefined) {
updatedRow[col.key] =
rowValues && rowValues[rowIndex] !== undefined
? rowValues[rowIndex]
: "";
}
});
return updatedRow;
});
} else if (rowValues && rowValues.length > 0) {
// If no rows exist, create rows using rowValues
updatedRows = rowValues.map((value) => {
const newRow: TabularDataRowType = { __uid: uuidv4() };
filteredNewColumns.forEach((col) => {
newRow[col.key] = value || "";
});
return newRow;
});
} else {
// If no rows and no rowValues, create a single blank row
const blankRow: TabularDataRowType = { __uid: uuidv4() };
filteredNewColumns.forEach((col) => {
blankRow[col.key] = "";
});
updatedRows.push(blankRow);
}
return updatedRows; // Update table rows
});
return updatedColumns; // Update table columns
});
};
// Function to add multiple rows to the table
const addMultipleRows = (newRows: TabularDataRowType[]) => {
setTableData((prev) => {
// Remove the last row of the current table data as it is a blank row (if table is not empty)
let newTableData = prev;
if (prev.length > 0) {
const lastRow = prev[prev.length - 1]; // Get the last row
const emptyLastRow = Object.values(lastRow).every((val) => !val); // Check if the last row is empty
if (emptyLastRow) newTableData = prev.slice(0, -1); // Remove the last row if it is empty
}
// Add the new rows to the table
const addedRows = newRows.map((value) => {
const newRow: TabularDataRowType = { __uid: uuidv4() };
// Map to correct column keys
tableColumns.forEach((col, index) => {
newRow[col.key] = value[`col-${index}`] || ""; // If (false, empty, null, etc...), default to empty string
});
return newRow;
});
// Return the updated table data with the new rows
return [...newTableData, ...addedRows];
});
};
// Function to replace the entire table (columns and rows)
const replaceTable = (
columns: TabularDataColType[],
rows: TabularDataRowType[],
) => {
// Validate columns
if (!Array.isArray(columns) || columns.length === 0) {
console.error("Invalid columns provided for table replacement.");
return;
}
// Validate rows
if (!Array.isArray(rows)) {
console.error("Invalid rows provided for table replacement.");
return;
}
// Replace columns
const updatedColumns = columns.map((col, idx) => ({
header: col.header,
key: col.key || `c${idx}`, // Ensure each column has a uid
}));
// Replace rows
const updatedRows = rows.map((row) => {
const newRow: TabularDataRowType = { __uid: uuidv4() };
updatedColumns.forEach((column) => {
// Map row data to columns, default to empty strings for missing values
newRow[column.key] = row[column.key] || "";
});
return newRow;
});
setTableColumns(updatedColumns); // Replace table columns
setTableData(updatedRows); // Replace table rows
setRowValues(updatedRows.map((row) => JSON.stringify(row))); // Update row values
};
return (
<BaseNode
classNames="tabular-data-node"
@ -511,6 +614,16 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
nodeId={id}
icon={"🗂️"}
customButtons={[
<AIGenReplaceTablePopover
key="ai-popover"
values={tableData}
colValues={tableColumns}
onAddRows={addMultipleRows}
onAddColumns={addColumns}
onReplaceTable={replaceTable}
areValuesLoading={isLoading}
setValuesLoading={setIsLoading}
/>,
<Tooltip
key={0}
label="Accepts xlsx, jsonl, and csv files with a header row"
@ -525,7 +638,6 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
</Tooltip>,
]}
/>
<RenameValueModal
ref={renameColumnModal}
initialValue={

View File

@ -7,7 +7,7 @@ import React, {
MouseEventHandler,
} from "react";
import { Handle, Node, Position } from "reactflow";
import { Textarea, Tooltip, Skeleton } from "@mantine/core";
import { Textarea, Tooltip, Skeleton, ScrollArea } from "@mantine/core";
import {
IconTextPlus,
IconEye,
@ -74,6 +74,15 @@ const TextFieldsNode: React.FC<TextFieldsNodeProps> = ({ data, id }) => {
data.fields_visibility || {},
);
// For when textfields exceed the TextFields Node max height,
// when we add a new field, this gives us a way to scroll to the bottom. Better UX.
const viewport = useRef<HTMLDivElement>(null);
const scrollToBottom = () =>
viewport?.current?.scrollTo({
top: viewport.current.scrollHeight,
behavior: "smooth",
});
// Whether the text fields should be in a loading state
const [isLoading, setIsLoading] = useState(false);
@ -163,6 +172,10 @@ const TextFieldsNode: React.FC<TextFieldsNodeProps> = ({ data, id }) => {
setDataPropsForNode(id, { fields: new_fields });
pingOutputNodes(id);
setTimeout(() => {
scrollToBottom();
}, 10);
// Cycle suggestions when new field is created
// aiSuggestionsManager.cycleSuggestions();
@ -493,7 +506,11 @@ const TextFieldsNode: React.FC<TextFieldsNodeProps> = ({ data, id }) => {
}
/>
<Skeleton visible={isLoading}>
<div ref={setRef}>{textFields}</div>
<div ref={setRef} className="nodrag nowheel">
<ScrollArea.Autosize mah={580} type="hover" viewportRef={viewport}>
{textFields}
</ScrollArea.Autosize>
</div>
</Skeleton>
<Handle
type="source"

View File

@ -13,9 +13,11 @@ import {
EvaluationScore,
JSONCompatible,
LLMResponse,
LLMResponseData,
} from "./backend/typing";
import { Status } from "./StatusIndicatorComponent";
import { grabResponses } from "./backend/backend";
import { StringLookup } from "./backend/cache";
// Helper funcs
const splitAndAddBreaks = (s: string, chunkSize: number) => {
@ -223,8 +225,9 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
const get_llm = (resp_obj: LLMResponse) => {
if (selectedLLMGroup === "LLM")
return typeof resp_obj.llm === "string"
? resp_obj.llm
return typeof resp_obj.llm === "string" ||
typeof resp_obj.llm === "number"
? StringLookup.get(resp_obj.llm) ?? "(LLM lookup failed)"
: resp_obj.llm?.name;
else return resp_obj.metavars[selectedLLMGroup] as string;
};
@ -332,7 +335,7 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
? resp_obj.metavars[varname.slice("__meta_".length)]
: resp_obj.vars[varname];
if (v === undefined && empty_str_if_undefined) return "";
return v;
return StringLookup.get(v) ?? "";
};
const get_var_and_trim = (
@ -345,6 +348,11 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
else return v;
};
const castData = (v: LLMResponseData) =>
typeof v === "string" || typeof v === "number"
? StringLookup.get(v) ?? "(unknown lookup error)"
: v.d;
const get_items = (eval_res_obj?: EvaluationResults) => {
if (eval_res_obj === undefined) return [];
if (typeof_eval_res.includes("KeyValue"))
@ -478,9 +486,7 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
if (resp_to_x(r) !== name) return;
x_items = x_items.concat(get_items(r.eval_res));
text_items = text_items.concat(
createHoverTexts(
r.responses.map((v) => (typeof v === "string" ? v : v.d)),
),
createHoverTexts(r.responses.map(castData)),
);
});
}
@ -573,11 +579,7 @@ const VisNode: React.FC<VisNodeProps> = ({ data, id }) => {
if (resp_to_x(r) !== name) return;
x_items = x_items.concat(get_items(r.eval_res)).flat();
text_items = text_items
.concat(
createHoverTexts(
r.responses.map((v) => (typeof v === "string" ? v : v.d)),
),
)
.concat(createHoverTexts(r.responses.map(castData)))
.flat();
y_items = y_items
.concat(

View File

@ -10,7 +10,7 @@ import {
ResponseInfo,
grabResponses,
} from "../backend";
import { LLMResponse, Dict } from "../typing";
import { LLMResponse, Dict, StringOrHash, LLMSpec } from "../typing";
import StorageCache from "../cache";
test("count queries required", async () => {
@ -22,7 +22,10 @@ test("count queries required", async () => {
};
// Double-check the queries required (not loading from cache)
const test_count_queries = async (llms: Array<string | Dict>, n: number) => {
const test_count_queries = async (
llms: Array<StringOrHash | LLMSpec>,
n: number,
) => {
const { counts, total_num_responses } = await countQueries(
prompt,
vars,

View File

@ -22,7 +22,7 @@ test("saving and loading cache data from localStorage", () => {
expect(d).toBeUndefined();
// Load cache from localStorage
StorageCache.loadFromLocalStorage("test");
StorageCache.loadFromLocalStorage("test", false);
// Verify stored data:
d = StorageCache.get("hello");

View File

@ -2,11 +2,11 @@
* @jest-environment node
*/
import { PromptPipeline } from "../query";
import { LLM, NativeLLM } from "../models";
import { LLM, LLMProvider, NativeLLM } from "../models";
import { expect, test } from "@jest/globals";
import { LLMResponseError, RawLLMResponseObject } from "../typing";
async function prompt_model(model: LLM): Promise<void> {
async function prompt_model(model: LLM, provider: LLMProvider): Promise<void> {
const pipeline = new PromptPipeline(
"What is the oldest {thing} in the world? Keep your answer brief.",
model.toString(),
@ -15,6 +15,7 @@ async function prompt_model(model: LLM): Promise<void> {
for await (const response of pipeline.gen_responses(
{ thing: ["bar", "tree", "book"] },
model,
provider,
1,
1.0,
)) {
@ -35,6 +36,7 @@ async function prompt_model(model: LLM): Promise<void> {
for await (const response of pipeline.gen_responses(
{ thing: ["bar", "tree", "book"] },
model,
provider,
2,
1.0,
)) {
@ -54,7 +56,6 @@ async function prompt_model(model: LLM): Promise<void> {
`Prompt: ${prompt}\nResponses: ${JSON.stringify(resp_obj.responses)}`,
);
expect(resp_obj.responses).toHaveLength(2);
expect(resp_obj.raw_response).toHaveLength(2); // these should've been merged
});
expect(Object.keys(cache)).toHaveLength(3); // still expect 3 prompts
@ -63,6 +64,7 @@ async function prompt_model(model: LLM): Promise<void> {
for await (const response of pipeline.gen_responses(
{ thing: ["bar", "tree", "book"] },
model,
provider,
2,
1.0,
)) {
@ -79,20 +81,19 @@ async function prompt_model(model: LLM): Promise<void> {
Object.entries(cache).forEach(([prompt, response]) => {
const resp_obj = Array.isArray(response) ? response[0] : response;
expect(resp_obj.responses).toHaveLength(2);
expect(resp_obj.raw_response).toHaveLength(2); // these should've been merged
});
expect(Object.keys(cache)).toHaveLength(3); // still expect 3 prompts
}
test("basic prompt pipeline with chatgpt", async () => {
// Setup a simple pipeline with a prompt template, 1 variable and 3 input values
await prompt_model(NativeLLM.OpenAI_ChatGPT);
await prompt_model(NativeLLM.OpenAI_ChatGPT, LLMProvider.OpenAI);
}, 20000);
test("basic prompt pipeline with anthropic", async () => {
await prompt_model(NativeLLM.Claude_v1);
await prompt_model(NativeLLM.Claude_v1, LLMProvider.Anthropic);
}, 40000);
test("basic prompt pipeline with google palm2", async () => {
await prompt_model(NativeLLM.PaLM2_Chat_Bison);
await prompt_model(NativeLLM.PaLM2_Chat_Bison, LLMProvider.Google);
}, 40000);

View File

@ -9,7 +9,7 @@ import {
extract_responses,
merge_response_objs,
} from "../utils";
import { NativeLLM } from "../models";
import { LLMProvider, NativeLLM } from "../models";
import { expect, test } from "@jest/globals";
import { RawLLMResponseObject } from "../typing";
@ -17,7 +17,6 @@ test("merge response objects", () => {
// Merging two response objects
const A: RawLLMResponseObject = {
responses: ["x", "y", "z"],
raw_response: ["x", "y", "z"],
prompt: "this is a test",
query: {},
llm: NativeLLM.OpenAI_ChatGPT,
@ -27,7 +26,6 @@ test("merge response objects", () => {
};
const B: RawLLMResponseObject = {
responses: ["a", "b", "c"],
raw_response: { B: "B" },
prompt: "this is a test 2",
query: {},
llm: NativeLLM.OpenAI_ChatGPT,
@ -40,7 +38,6 @@ test("merge response objects", () => {
expect(JSON.stringify(C.responses)).toBe(
JSON.stringify(["x", "y", "z", "a", "b", "c"]),
);
expect(C.raw_response).toHaveLength(4);
expect(Object.keys(C.vars)).toHaveLength(2);
expect(Object.keys(C.vars)).toContain("varB1");
expect(Object.keys(C.metavars)).toHaveLength(1);
@ -68,7 +65,11 @@ test("openai chat completions", async () => {
expect(query).toHaveProperty("temperature");
// Extract responses, check their type
const resps = extract_responses(response, NativeLLM.OpenAI_ChatGPT);
const resps = extract_responses(
response,
NativeLLM.OpenAI_ChatGPT,
LLMProvider.OpenAI,
);
expect(resps).toHaveLength(2);
expect(typeof resps[0]).toBe("string");
}, 20000);
@ -86,7 +87,11 @@ test("openai text completions", async () => {
expect(query).toHaveProperty("n");
// Extract responses, check their type
const resps = extract_responses(response, NativeLLM.OpenAI_Davinci003);
const resps = extract_responses(
response,
NativeLLM.OpenAI_Davinci003,
LLMProvider.OpenAI,
);
expect(resps).toHaveLength(2);
expect(typeof resps[0]).toBe("string");
}, 20000);
@ -104,7 +109,11 @@ test("anthropic models", async () => {
expect(query).toHaveProperty("max_tokens_to_sample");
// Extract responses, check their type
const resps = extract_responses(response, NativeLLM.Claude_v1);
const resps = extract_responses(
response,
NativeLLM.Claude_v1,
LLMProvider.Anthropic,
);
expect(resps).toHaveLength(1);
expect(typeof resps[0]).toBe("string");
}, 20000);
@ -121,7 +130,11 @@ test("google palm2 models", async () => {
expect(query).toHaveProperty("candidateCount");
// Extract responses, check their type
let resps = extract_responses(response, NativeLLM.PaLM2_Chat_Bison);
let resps = extract_responses(
response,
NativeLLM.PaLM2_Chat_Bison,
LLMProvider.Google,
);
expect(resps).toHaveLength(3);
expect(typeof resps[0]).toBe("string");
console.log(JSON.stringify(resps));
@ -137,7 +150,11 @@ test("google palm2 models", async () => {
expect(query).toHaveProperty("maxOutputTokens");
// Extract responses, check their type
resps = extract_responses(response, NativeLLM.PaLM2_Chat_Bison);
resps = extract_responses(
response,
NativeLLM.PaLM2_Chat_Bison,
LLMProvider.Google,
);
expect(resps).toHaveLength(3);
expect(typeof resps[0]).toBe("string");
console.log(JSON.stringify(resps));
@ -153,7 +170,11 @@ test("aleph alpha model", async () => {
expect(response).toHaveLength(3);
// Extract responses, check their type
let resps = extract_responses(response, NativeLLM.Aleph_Alpha_Luminous_Base);
let resps = extract_responses(
response,
NativeLLM.Aleph_Alpha_Luminous_Base,
LLMProvider.Aleph_Alpha,
);
expect(resps).toHaveLength(3);
expect(typeof resps[0]).toBe("string");
console.log(JSON.stringify(resps));
@ -169,7 +190,11 @@ test("aleph alpha model", async () => {
expect(response).toHaveLength(3);
// Extract responses, check their type
resps = extract_responses(response, NativeLLM.Aleph_Alpha_Luminous_Base);
resps = extract_responses(
response,
NativeLLM.Aleph_Alpha_Luminous_Base,
LLMProvider.Aleph_Alpha,
);
expect(resps).toHaveLength(3);
expect(typeof resps[0]).toBe("string");
console.log(JSON.stringify(resps));

View File

@ -7,8 +7,9 @@ import {
escapeBraces,
containsSameTemplateVariables,
} from "./template";
import { ChatHistoryInfo, Dict } from "./typing";
import { ChatHistoryInfo, Dict, TabularDataColType } from "./typing";
import { fromMarkdown } from "mdast-util-from-markdown";
import { llmResponseDataToString, sampleRandomElements } from "./utils";
export class AIError extends Error {
constructor(message: string) {
@ -24,7 +25,7 @@ export type Row = string;
const AIFeaturesLLMs = [
{
provider: "OpenAI",
small: { value: "gpt-3.5-turbo", label: "OpenAI GPT3.5" },
small: { value: "gpt-4o", label: "OpenAI GPT4o" },
large: { value: "gpt-4", label: "OpenAI GPT4" },
},
{
@ -111,6 +112,24 @@ function autofillSystemMessage(
return `Here is a list of commands or items. Say what the pattern seems to be in a single sentence. Then, generate ${n} more commands or items following the pattern, as an unordered markdown list. ${templateVariables && templateVariables.length > 0 ? templateVariableMessage(templateVariables) : ""}`;
}
/**
* Generate the system message used for autofillingTables.
* @param n number of rows to generate
* @param templateVariables list of template variables to use
*/
function autofillTableSystemMessage(n: number): string {
return `Here is a table. Generate ${n} more commands or items following the pattern. You must format your response as a markdown table with labeled columns and a divider with only the next ${n} generated commands or items of the table.`;
}
/**
* Generate the system message used for generate column.
* @param templateVariables list of template variables to use
* @param prompt description or pattern for the column content
*/
function generateColumnSystemMessage(): string {
return `You are a helpful assistant. Given partial row data and a prompt for a missing field, produce only the new field's value. No extra formatting or explanations, just the value itself.`;
}
/**
* Generate the system message used for generate and replace (GAR).
*/
@ -122,6 +141,17 @@ function GARSystemMessage(
return `Generate a list of exactly ${n} items. Format your response as an unordered markdown list using "-". Do not ever repeat anything.${creative ? "Be unconventional with your outputs." : ""} ${generatePrompts ? "Your outputs should be commands that can be given to an AI chat assistant." : ""} If the user has specified items or inputs to their command, generate a template in Jinja format, with single braces {} around the masked variables.`;
}
/**
* Generate the system message used for generate and replace table (GART).
* @param n number of rows to generate
* @param creative whether the output should be diverse
* @param generatePrompts whether the output should be commands
* @returns the system message
*/
function GARTSystemMessage(n: number, generatePrompts?: boolean): string {
return `Generate a table with exactly ${n} rows. Format your response as a markdown table using. Do not ever repeat anything. ${generatePrompts ? "Your outputs should be commands that can be given to an AI chat assistant." : ""} If the user has specified items or inputs to their command, generate a template in Jinja format, with single braces {} around the masked variables.`;
}
/**
* Returns a string representing the given rows as a markdown list
* @param rows to encode
@ -130,6 +160,19 @@ function encode(rows: Row[]): string {
return escapeBraces(rows.map((row) => `- ${row}`).join("\n"));
}
/**
* Returns a string representing the given rows and columns as a markdown table
* @param cols to encode as headers
* @param rows to encode as table rows
* @returns a string representing the table in markdown format
*/
function encodeTable(cols: string[], rows: Row[]): string {
const header = `| ${cols.join(" | ")} |`;
const divider = `| ${cols.map(() => "---").join(" | ")} |`;
const body = rows.map((row) => `| ${row} |`).join("\n");
return escapeBraces(`${header}\n${divider}\n${body}`);
}
/**
* Returns a list of items that appears in the given markdown text. Throws an AIError if the string is not in markdown list format.
* @param mdText raw text to decode (in markdown format)
@ -162,6 +205,69 @@ function decode(mdText: string): Row[] {
return result;
}
/**
* Returns an object containing the columns and rows of the table decoded from the given markdown text. Throws an AIError if the string is not in "markdown table format".
* @param mdText markdown text to decode
* @returns an object containing the columns and rows of the table
*/
function decodeTable(mdText: string): { cols: string[]; rows: Row[] } {
// Remove code block markers and trim the text
const mdTextCleaned = mdText
.replace(/```markdown/g, "")
.replace(/```/g, "")
.trim();
// Split into lines and clean up whitespace
const lines = mdTextCleaned.split("\n").map((line) => line.trim());
// If lines have less than 1 line, throw an error
if (lines.length < 1) {
throw new AIError(`Invalid table format: ${mdText}`);
}
let cols: string[];
let dataLines: string[];
// Check if a proper header exists
if (/^(\|\s*-+\s*)+\|$/.test(lines[1])) {
// If valid header and divider exist
cols = lines[0]
.split("|")
.map((col) => col.trim())
.filter((col) => col.length > 0);
dataLines = lines.slice(2); // Skip header and divider lines
} else {
// If no valid header/divider, generate default column names
const firstRowCells = lines[0]
.split("|")
.map((cell) => cell.trim())
.filter((cell) => cell.length > 0);
// Generate default column names (col_1, col_2, ...)
cols = firstRowCells.map((_, idx) => `col_${idx + 1}`);
dataLines = lines; // Treat all lines as data rows
}
// Parse the rows
const rows = lines.slice(2).map((line) => {
const cells = line
.split("|")
.map((cell) => cell.trim())
.slice(1, -1); // Remove leading/trailing "|" splits
if (cells.length !== cols.length) {
throw new AIError(`Row column mismatch: ${line}`);
}
return cells.join(" | ");
});
// Validate the parsed content
if (cols.length === 0 || rows.length === 0) {
throw new AIError(`Failed to decode output: ${mdText}`);
}
return { cols, rows };
}
/**
* Uses an LLM to interpret the pattern from the given rows as return new rows following the pattern.
* @param input rows for the autofilling system
@ -208,7 +314,7 @@ export async function autofill(
if (result.errors && Object.keys(result.errors).length > 0)
throw new Error(Object.values(result.errors)[0].toString());
const output = result.responses[0].responses[0] as string;
const output = llmResponseDataToString(result.responses[0].responses[0]);
console.log("LLM said: ", output);
@ -221,6 +327,208 @@ export async function autofill(
return new_items.slice(0, n);
}
/**
* Uses an LLM to interpret the pattern from the given table (columns and rows) and generate new rows following the pattern.
* @param input Object containing the columns and rows of the input table.
* @param n Number of new rows to generate.
* @param provider The LLM provider to use.
* @param apiKeys API keys required for the LLM query.
* @returns A promise resolving to an object containing updated columns and rows.
*/
export async function autofillTable(
input: { cols: string[]; rows: Row[] },
n: number,
provider: string,
apiKeys: Dict,
): Promise<{ cols: string[]; rows: Row[] }> {
// Get a random sample of the table rows, if there are more than 30 (as an estimate):
// TODO: This is a temporary solution to avoid sending large tables to the LLM. In future, check the number of characters too.
const sampleRows =
input.rows.length > 30 ? sampleRandomElements(input.rows, 30) : input.rows;
// Hash the arguments to get a unique id
const id = JSON.stringify([input.cols, sampleRows, n]);
// Encode the input table to a markdown table
const encoded = encodeTable(input.cols, sampleRows);
const history: ChatHistoryInfo[] = [
{
messages: [
{
role: "system",
content: autofillTableSystemMessage(n),
},
],
fill_history: {},
},
];
try {
// Query the LLM
const result = await queryLLM(
id,
getAIFeaturesModels(provider).small,
1,
encoded,
{},
history,
apiKeys,
true,
);
if (result.errors && Object.keys(result.errors).length > 0)
throw new Error(Object.values(result.errors)[0].toString());
// Extract the output from the LLM response
const output = llmResponseDataToString(result.responses[0].responses[0]);
console.log("LLM said: ", output);
const newRows = decodeTable(output).rows;
// Return the updated table with "n" number of rows
return {
cols: input.cols,
rows: newRows, // Return the new rows generated by the LLM
};
} catch (error) {
console.error("Error in autofillTable:", error);
throw new AIError(
`Failed to autofill table. Details: ${(error as Error).message || error}`,
);
}
}
// Queries the model for a single rows missing field:
async function fillMissingFieldForRow(
existingRowData: Record<string, string>, // Key-value pairs for the row
prompt: string, // The user prompt describing what the missing field should be
provider: string,
apiKeys: Dict,
): Promise<string> {
// Generate a user prompt for the LLM pass over existing row data in list format
// const userPrompt = `You are given partial data for a row of a table. Here is the data:
// ${Object.entries(existingRowData)
// .map(([key, val]) => `- ${key}: ${val}`)
// .join("\n")}
// This is the requirement of the new column: "${prompt}". Produce an appropriate value for the item. Respond with just the new field's value, and nothing else.`;
const userPrompt = `Fill in the last piece of information. Respond with just the missing information, nothing else.
${Object.entries(existingRowData)
.map(([key, val]) => `${key}: ${val}`)
.join("\n")}
${prompt}: ?`;
const history: ChatHistoryInfo[] = [
{
messages: [
{
role: "system",
content: generateColumnSystemMessage(),
},
],
fill_history: {},
},
];
const id = JSON.stringify([existingRowData, prompt]);
const result = await queryLLM(
id,
getAIFeaturesModels(provider).small,
1,
userPrompt,
{},
history,
apiKeys,
true,
);
console.log("LLM said: ", result.responses[0].responses[0]);
// Handle any errors in the response
if (result.errors && Object.keys(result.errors).length > 0) {
throw new AIError(Object.values(result.errors)[0].toString());
}
const output = llmResponseDataToString(result.responses[0].responses[0]);
return output.trim();
}
/**
* Uses an LLM to generate one new column with data based on the pattern explained in `prompt`.
* @param prompt Description or pattern for the column content.
* @param provider The LLM provider to use (e.g., OpenAI, Bedrock).
* @param apiKeys API keys required for the LLM query.
* @returns A promise resolving to an array of strings (column values).
*/
export async function generateColumn(
tableData: { cols: TabularDataColType[]; rows: string[] },
prompt: string,
provider: string,
apiKeys: Dict,
): Promise<{ col: string; rows: string[] }> {
// If the length of the prompt is less than 20 characters, use the prompt
// Else, use the LLM to generate an appropriate column name for the prompt
let colName: string;
if (prompt.length <= 20) {
colName = prompt;
} else {
const result = await queryLLM(
JSON.stringify([prompt]),
getAIFeaturesModels(provider).small,
1,
`You produce column names for a table. The column names must be short, less than 20 characters, and in natural language, like "Column Name." Return only the column name. Generate an appropriate column name for the prompt: "${prompt}"`,
{},
[],
apiKeys,
true,
);
colName = llmResponseDataToString(result.responses[0].responses[0]).replace(
"_",
" ",
);
}
// Remove any leading/trailing whitespace from the column name as well as any double quotes
colName = colName.trim().replace(/"/g, "");
// Parse the existing table into mark down row objects
const columnNames = tableData.cols.map((col) => col.header);
const parsedRows = tableData.rows.map((rowStr) => {
// Remove leading/trailing "|" along with any whitespace
const cells = rowStr
.replace(/^\|/, "")
.replace(/\|$/, "")
.split("|")
.map((cell) => cell.trim());
const rowData: Record<string, string> = {};
columnNames.forEach((colName, index) => {
rowData[colName] = cells[index] || "";
});
return rowData;
});
const newColumnValues: string[] = [];
for (const rowData of parsedRows) {
// For each row, we request a new field from the LLM:
const newValue = await fillMissingFieldForRow(
rowData,
prompt,
provider,
apiKeys,
);
newColumnValues.push(newValue);
}
// Return the new column name and values
return {
col: colName,
rows: newColumnValues,
};
}
/**
* Uses an LLM to generate `n` new rows based on the pattern explained in `prompt`.
* @param prompt
@ -268,8 +576,79 @@ export async function generateAndReplace(
if (result.errors && Object.keys(result.errors).length > 0)
throw new Error(Object.values(result.errors)[0].toString());
console.log("LLM said: ", result.responses[0].responses[0]);
const resp = llmResponseDataToString(result.responses[0].responses[0]);
console.log("LLM said: ", resp);
const new_items = decode(result.responses[0].responses[0] as string);
const new_items = decode(resp);
return new_items.slice(0, n);
}
/**
* Uses an LLM to generate a table with `n` rows based on the pattern explained in `prompt`.
* @param prompt Description or pattern for the table content.
* @param n Number of rows to generate.
* @param provider The LLM provider to use.
* @param apiKeys API keys required for the LLM query.
* @returns A promise resolving to an object containing the columns and rows of the generated table.
*/
export async function generateAndReplaceTable(
prompt: string,
n: number,
provider: string,
apiKeys: Dict,
): Promise<{ cols: string[]; rows: Row[] }> {
// Hash the arguments to get a unique id
const id = JSON.stringify([prompt, n]);
// Determine if the prompt includes the word "prompt"
const generatePrompts = prompt.toLowerCase().includes("prompt");
const history: ChatHistoryInfo[] = [
{
messages: [
{
role: "system",
content: GARTSystemMessage(n, generatePrompts),
},
],
fill_history: {},
},
];
const input = `Generate a table with data of ${escapeBraces(prompt)}`;
try {
// Query the LLM
const result = await queryLLM(
id,
getAIFeaturesModels(provider).small,
1,
input,
{},
history,
apiKeys,
true,
);
if (result.errors && Object.keys(result.errors).length > 0)
throw new Error(Object.values(result.errors)[0].toString());
console.log("LLM result: ", result);
console.log("LLM said: ", result.responses[0].responses[0]);
const { cols: new_cols, rows: new_rows } = decodeTable(
llmResponseDataToString(result.responses[0].responses[0]),
);
// Return the generated table with "n" number of rows
return {
cols: new_cols,
rows: new_rows.slice(0, n),
};
} catch (error) {
console.error("Error in generateAndReplaceTable:", error);
throw new AIError(
`Failed to generate and replace table. Details: ${(error as Error).message || error}`,
);
}
}

View File

@ -1,4 +1,5 @@
import MarkdownIt from "markdown-it";
import axios from "axios";
import { v4 as uuid } from "uuid";
import {
Dict,
@ -12,11 +13,12 @@ import {
EvaluationScore,
LLMSpec,
EvaluatedResponsesResults,
TemplateVarInfo,
CustomLLMProviderSpec,
LLMResponseData,
PromptVarType,
StringOrHash,
} from "./typing";
import { LLM, getEnumName } from "./models";
import { LLM, LLMProvider, getEnumName, getProvider } from "./models";
import {
APP_IS_RUNNING_LOCALLY,
set_api_keys,
@ -26,17 +28,20 @@ import {
areEqualVarsDicts,
repairCachedResponses,
deepcopy,
llmResponseDataToString,
} from "./utils";
import StorageCache from "./cache";
import StorageCache, { StringLookup } from "./cache";
import { PromptPipeline } from "./query";
import {
PromptPermutationGenerator,
PromptTemplate,
cleanEscapedBraces,
escapeBraces,
} from "./template";
import { UserForcedPrematureExit } from "./errors";
import CancelTracker from "./canceler";
import { execPy } from "./pyodide/exec-py";
import { baseModelToProvider } from "../ModelSettingSchemas";
// """ =================
// SETUP AND GLOBALS
@ -214,19 +219,30 @@ function gen_unique_cache_filename(
return `${cache_id}_${idx}.json`;
}
function extract_llm_nickname(llm_spec: Dict | string) {
function extract_llm_nickname(llm_spec: StringOrHash | LLMSpec): string {
if (typeof llm_spec === "object" && llm_spec.name !== undefined)
return llm_spec.name;
else return llm_spec;
else
return (
StringLookup.get(llm_spec as StringOrHash) ?? "(string lookup failed)"
);
}
function extract_llm_name(llm_spec: Dict | string): string {
if (typeof llm_spec === "string") return llm_spec;
function extract_llm_name(llm_spec: StringOrHash | LLMSpec): string {
if (typeof llm_spec === "string" || typeof llm_spec === "number")
return StringLookup.get(llm_spec) ?? "(string lookup failed)";
else return llm_spec.model;
}
function extract_llm_key(llm_spec: Dict | string): string {
if (typeof llm_spec === "string") return llm_spec;
function extract_llm_provider(llm_spec: StringOrHash | LLMSpec): LLMProvider {
if (typeof llm_spec === "string" || typeof llm_spec === "number")
return getProvider(StringLookup.get(llm_spec) ?? "") ?? LLMProvider.Custom;
else return baseModelToProvider(llm_spec.base_model);
}
function extract_llm_key(llm_spec: StringOrHash | LLMSpec): string {
if (typeof llm_spec === "string" || typeof llm_spec === "number")
return StringLookup.get(llm_spec) ?? "(string lookup failed)";
else if (llm_spec.key !== undefined) return llm_spec.key;
else
throw new Error(
@ -236,7 +252,7 @@ function extract_llm_key(llm_spec: Dict | string): string {
);
}
function extract_llm_params(llm_spec: Dict | string): Dict {
function extract_llm_params(llm_spec: StringOrHash | LLMSpec): Dict {
if (typeof llm_spec === "object" && llm_spec.settings !== undefined)
return llm_spec.settings;
else return {};
@ -249,8 +265,10 @@ function filterVarsByLLM(vars: PromptVarsDict, llm_key: string): Dict {
_vars[key] = vs.filter(
(v) =>
typeof v === "string" ||
typeof v === "number" ||
v?.llm === undefined ||
typeof v.llm === "string" ||
typeof v.llm === "number" ||
v.llm.key === llm_key,
);
});
@ -293,8 +311,8 @@ function isLooselyEqual(value1: any, value2: any): boolean {
* determines whether the response query used the same parameters.
*/
function matching_settings(
cache_llm_spec: Dict | string,
llm_spec: Dict | string,
cache_llm_spec: StringOrHash | LLMSpec,
llm_spec: StringOrHash | LLMSpec,
): boolean {
if (extract_llm_name(cache_llm_spec) !== extract_llm_name(llm_spec))
return false;
@ -397,7 +415,7 @@ async function run_over_responses(
const evald_resps: Promise<LLMResponse>[] = responses.map(
async (_resp_obj: LLMResponse) => {
// Deep clone the response object
const resp_obj = JSON.parse(JSON.stringify(_resp_obj));
const resp_obj: LLMResponse = JSON.parse(JSON.stringify(_resp_obj));
// Whether the processor function is async or not
const async_processor =
@ -406,12 +424,12 @@ async function run_over_responses(
// Map the processor func over every individual response text in each response object
const res = resp_obj.responses;
const llm_name = extract_llm_nickname(resp_obj.llm);
let processed = res.map((r: string) => {
let processed = res.map((r: LLMResponseData) => {
const r_info = new ResponseInfo(
cleanEscapedBraces(r),
resp_obj.prompt,
resp_obj.vars,
resp_obj.metavars || {},
cleanEscapedBraces(llmResponseDataToString(r)),
StringLookup.get(resp_obj.prompt) ?? "",
StringLookup.concretizeDict(resp_obj.vars),
StringLookup.concretizeDict(resp_obj.metavars) || {},
llm_name,
);
@ -451,7 +469,8 @@ async function run_over_responses(
// Store items with summary of mean, median, etc
resp_obj.eval_res = {
items: processed,
dtype: getEnumName(MetricType, eval_res_type),
dtype: (getEnumName(MetricType, eval_res_type) ??
"Unknown") as keyof typeof MetricType,
};
} else if (
[MetricType.Unknown, MetricType.Empty].includes(eval_res_type)
@ -463,7 +482,8 @@ async function run_over_responses(
// Categorical, KeyValue, etc, we just store the items:
resp_obj.eval_res = {
items: processed,
dtype: getEnumName(MetricType, eval_res_type),
dtype: (getEnumName(MetricType, eval_res_type) ??
"Unknown") as keyof typeof MetricType,
};
}
}
@ -488,7 +508,7 @@ async function run_over_responses(
*/
export async function generatePrompts(
root_prompt: string,
vars: Dict<(TemplateVarInfo | string)[]>,
vars: Dict<PromptVarType[]>,
): Promise<PromptTemplate[]> {
const gen_prompts = new PromptPermutationGenerator(root_prompt);
const all_prompt_permutations = Array.from(
@ -513,7 +533,7 @@ export async function generatePrompts(
export async function countQueries(
prompt: string,
vars: PromptVarsDict,
llms: Array<Dict | string>,
llms: Array<StringOrHash | LLMSpec>,
n: number,
chat_histories?:
| (ChatHistoryInfo | undefined)[]
@ -587,7 +607,9 @@ export async function countQueries(
found_cache = true;
// Load the cache file
const cache_llm_responses = load_from_cache(cache_filename);
const cache_llm_responses: Dict<
RawLLMResponseObject[] | RawLLMResponseObject
> = load_from_cache(cache_filename);
// Iterate through all prompt permutations and check if how many responses there are in the cache with that prompt
_all_prompt_perms.forEach((prompt) => {
@ -676,6 +698,40 @@ export async function fetchEnvironAPIKeys(): Promise<Dict<string>> {
}).then((res) => res.json());
}
export async function saveFlowToLocalFilesystem(
flowJSON: Dict,
filename: string,
): Promise<void> {
try {
await axios.put(`${FLASK_BASE_URL}api/flows/${filename}`, {
flow: flowJSON,
});
} catch (error) {
throw new Error(
`Error saving flow with name ${filename}: ${(error as Error).toString()}`,
);
}
}
export async function ensureUniqueFlowFilename(
filename: string,
): Promise<string> {
try {
const response = await axios.put(
`${FLASK_BASE_URL}api/getUniqueFlowFilename`,
{
name: filename,
},
);
return response.data as string;
} catch (error) {
console.error(
`Error contact Flask to ensure unique filename for imported flow. Defaulting to passed filename (warning: risk this overrides an existing flow.) Error: ${(error as Error).toString()}`,
);
return filename;
}
}
/**
* Queries LLM(s) with root prompt template `prompt` and prompt input variables `vars`, `n` times per prompt.
* Soft-fails if API calls fail, and collects the errors in `errors` property of the return object.
@ -762,7 +818,7 @@ export async function queryLLM(
// Create a new cache JSON object
cache = { cache_files: {}, responses_last_run: [] };
const prev_filenames: Array<string> = [];
llms.forEach((llm_spec: string | Dict) => {
llms.forEach((llm_spec) => {
const fname = gen_unique_cache_filename(id, prev_filenames);
llm_to_cache_filename[extract_llm_key(llm_spec)] = fname;
cache.cache_files[fname] = llm_spec;
@ -796,9 +852,12 @@ export async function queryLLM(
const responses: { [key: string]: Array<RawLLMResponseObject> } = {};
const all_errors: Dict<string[]> = {};
const num_generations = n ?? 1;
async function query(llm_spec: string | Dict): Promise<LLMPrompterResults> {
async function query(
llm_spec: StringOrHash | LLMSpec,
): Promise<LLMPrompterResults> {
// Get LLM model name and any params
const llm_str = extract_llm_name(llm_spec);
const llm_provider = extract_llm_provider(llm_spec);
const llm_nickname = extract_llm_nickname(llm_spec);
const llm_params = extract_llm_params(llm_spec);
const llm_key = extract_llm_key(llm_spec);
@ -838,6 +897,7 @@ export async function queryLLM(
for await (const response of prompter.gen_responses(
_vars,
llm_str as LLM,
llm_provider,
num_generations,
temperature,
llm_params,
@ -1214,11 +1274,12 @@ export async function executepy(
*/
export async function evalWithLLM(
id: string,
llm: string | LLMSpec,
llm: LLMSpec,
root_prompt: string,
response_ids: string | string[],
api_keys?: Dict,
progress_listener?: (progress: { [key: symbol]: any }) => void,
cancel_id?: string | number,
): Promise<{ responses?: LLMResponse[]; errors: string[] }> {
// Check format of response_ids
if (!Array.isArray(response_ids)) response_ids = [response_ids];
@ -1238,17 +1299,27 @@ export async function evalWithLLM(
const resp_objs = (load_cache_responses(fname) as LLMResponse[]).map((r) =>
JSON.parse(JSON.stringify(r)),
) as LLMResponse[];
if (resp_objs.length === 0) continue;
console.log(resp_objs);
// We need to keep track of the index of each response in the response object.
// We can generate var dicts with metadata to store the indices:
const inputs = resp_objs
.map((obj, __i) =>
obj.responses.map((r: LLMResponseData, __j: number) => ({
text: typeof r === "string" ? r : undefined,
text:
typeof r === "string" || typeof r === "number"
? escapeBraces(StringLookup.get(r) ?? "(string lookup failed)")
: undefined,
image: typeof r === "object" && r.t === "img" ? r.d : undefined,
fill_history: obj.vars,
metavars: { ...obj.metavars, __i, __j },
metavars: {
...obj.metavars,
__i: __i.toString(),
__j: __j.toString(),
},
})),
)
.flat();
@ -1259,11 +1330,13 @@ export async function evalWithLLM(
[llm],
1,
root_prompt,
{ input: inputs },
{ __input: inputs },
undefined,
undefined,
undefined,
progress_listener,
false,
cancel_id,
);
const err_vals: string[] = Object.values(errors).flat();
@ -1272,15 +1345,17 @@ export async function evalWithLLM(
// Now we need to apply each response as an eval_res (a score) back to each response object,
// using the aforementioned mapping metadata:
responses.forEach((r: LLMResponse) => {
const resp_obj = resp_objs[r.metavars.__i];
const __i = parseInt(StringLookup.get(r.metavars.__i) ?? "");
const __j = parseInt(StringLookup.get(r.metavars.__j) ?? "");
const resp_obj = resp_objs[__i];
if (resp_obj.eval_res !== undefined)
resp_obj.eval_res.items[r.metavars.__j] = r.responses[0];
resp_obj.eval_res.items[__j] = llmResponseDataToString(r.responses[0]);
else {
resp_obj.eval_res = {
items: [],
dtype: "Categorical",
};
resp_obj.eval_res.items[r.metavars.__j] = r.responses[0];
resp_obj.eval_res.items[__j] = llmResponseDataToString(r.responses[0]);
}
});
@ -1408,8 +1483,8 @@ export async function exportCache(ids: string[]): Promise<Dict<Dict>> {
}
// Bundle up specific other state in StorageCache, which
// includes things like human ratings for responses:
const cache_state = StorageCache.getAllMatching((key) =>
key.startsWith("r."),
const cache_state = StorageCache.getAllMatching(
(key) => key.startsWith("r.") || key === "__s",
);
return { ...cache_files, ...cache_state };
}
@ -1434,6 +1509,9 @@ export async function importCache(files: {
Object.entries(files).forEach(([filename, data]) => {
StorageCache.store(filename, data);
});
// Load StringLookup table from cache
StringLookup.restoreFrom(StorageCache.get("__s"));
} catch (err) {
throw new Error("Error importing from cache:" + (err as Error).message);
}
@ -1475,9 +1553,9 @@ export async function fetchExampleFlow(evalname: string): Promise<Dict> {
// App is not running locally, but hosted on a site.
// If this is the case, attempt to fetch the example flow from a relative site path:
return fetch(`examples/${evalname}.cforge`)
.then((response) => response.json())
.then((res) => ({ data: res }));
return fetch(`examples/${evalname}.cforge`).then((response) =>
response.json(),
);
}
/**
@ -1517,9 +1595,9 @@ export async function fetchOpenAIEval(evalname: string): Promise<Dict> {
// App is not running locally, but hosted on a site.
// If this is the case, attempt to fetch the example flow from relative path on the site:
// > ALT: `https://raw.githubusercontent.com/ianarawjo/ChainForge/main/chainforge/oaievals/${_name}.cforge`
return fetch(`oaievals/${evalname}.cforge`)
.then((response) => response.json())
.then((res) => ({ data: res }));
return fetch(`oaievals/${evalname}.cforge`).then((response) =>
response.json(),
);
}
/**

View File

@ -72,6 +72,7 @@ export default class StorageCache {
*/
public static clear(key?: string): void {
StorageCache.getInstance().clearCache(key);
if (key === undefined) StringLookup.restoreFrom([]);
}
/**
@ -117,11 +118,12 @@ export default class StorageCache {
* Performs lz-string decompression from UTF16 encoding.
*
* @param localStorageKey The key that will be used in localStorage (default='chainforge')
* @param replaceStorageCacheWithLoadedData Whether the data in the StorageCache should be saved with the loaded data. Erases all current memory. Only set this to true if you are replacing the ChainForge flow state entirely.
* @returns Loaded data if succeeded, undefined if failure (e.g., key not found).
*/
public static loadFromLocalStorage(
localStorageKey = "chainforge",
setStorageCacheData = true,
replaceStorageCacheWithLoadedData = false,
): JSONCompatible | undefined {
const compressed = localStorage.getItem(localStorageKey);
if (!compressed) {
@ -132,7 +134,12 @@ export default class StorageCache {
}
try {
const data = JSON.parse(LZString.decompressFromUTF16(compressed));
if (setStorageCacheData) StorageCache.getInstance().data = data;
if (replaceStorageCacheWithLoadedData) {
// Replaces the current cache data with the loaded data
StorageCache.getInstance().data = data;
// Restores the current StringLookup table with the contents of the loaded data, if the __s key is present.
StringLookup.restoreFrom(data.__s);
}
console.log("loaded", data);
return data;
} catch (error) {
@ -141,3 +148,171 @@ export default class StorageCache {
}
}
}
/** Global string intern table for efficient storage of repeated strings */
export class StringLookup {
// eslint-disable-next-line no-use-before-define
private static instance: StringLookup;
private stringToIndex: Map<string, number> = new Map();
private indexToString: string[] = [];
/** Gets the string intern lookup table. Initializes it if the singleton instance does not yet exist. */
public static getInstance(): StringLookup {
if (!StringLookup.instance) {
StringLookup.instance = new StringLookup();
}
return StringLookup.instance;
}
/** Adds a string to the table and returns its index */
public static intern(str: string): number {
const s = StringLookup.getInstance();
if (s.stringToIndex.has(str)) {
return s.stringToIndex.get(str)!; // Return existing index
}
// Add new string to the table
const index = s.indexToString.length;
s.indexToString.push(str);
s.stringToIndex.set(str, index);
// Save to cache
StorageCache.store("__s", s.indexToString);
return index;
}
// Overloaded signatures
// This tells TypeScript that a number or string will always produce a string or undefined,
// whereas any other type T will return the same type.
public static get(index: number | string | undefined): string | undefined;
public static get<T>(index: T): T;
/**
* Retrieves the string in the lookup table, given its index.
* - **Note**: This function soft fails: if index is not a number, returns index unchanged.
*/
public static get<T>(index: T | number): T | string {
if (typeof index !== "number") return index;
const s = StringLookup.getInstance();
return s.indexToString[index]; // O(1) lookup
}
/**
* Transforms a Dict by interning all strings encountered, up to 1 level of depth,
* and returning the modified Dict with the strings as hash indexes instead.
*
* NOTE: This ignores recursing into any key "llm" that has a dict component.
*/
public static internDict(
d: Dict,
inplace?: boolean,
depth = 1,
ignoreKey = ["llm", "uid", "eval_res"],
): Dict {
const newDict = inplace ? d : ({} as Dict);
const entries = Object.entries(d);
for (const [key, value] of entries) {
if (ignoreKey.includes(key)) {
// Keep the ignored key the same
if (!inplace) newDict[key] = value;
continue;
}
if (typeof value === "string") {
newDict[key] = StringLookup.intern(value);
} else if (
Array.isArray(value) &&
value.every((v) => typeof v === "string")
) {
newDict[key] = value.map((v) => StringLookup.intern(v));
} else if (depth > 0 && typeof value === "object" && value !== null) {
newDict[key] = StringLookup.internDict(
value as Dict,
inplace,
depth - 1,
);
} else {
if (!inplace) newDict[key] = value;
}
}
return newDict as Map<string, unknown>;
}
/**
* Treats all numberic values in the dictionary as hashes, and maps them to strings.
* Leaves the rest of the dict unchanged. (Only operates 1 level deep.)
* @param d The dictionary to operate over
*/
public static concretizeDict<T>(
d: Dict<T | number>,
inplace = false,
depth = 1,
ignoreKey = ["llm", "uid", "eval_res"],
): Dict<T | string> {
const newDict = inplace ? d : ({} as Dict);
const entries = Object.entries(d);
for (const [key, value] of entries) {
const ignore = ignoreKey.includes(key);
if (!ignore && typeof value === "number")
newDict[key] = StringLookup.get(value);
else if (
!ignore &&
Array.isArray(value) &&
value.every((v) => typeof v === "number")
)
newDict[key] = value.map((v) => StringLookup.get(v));
else if (
!ignore &&
depth > 0 &&
typeof value === "object" &&
value !== null
) {
newDict[key] = StringLookup.concretizeDict(
value as Dict<unknown>,
false,
0,
);
} else if (!inplace) newDict[key] = value;
}
return newDict;
}
public static restoreFrom(savedIndexToString?: string[]): void {
const s = StringLookup.getInstance();
s.stringToIndex = new Map<string, number>();
if (savedIndexToString === undefined || savedIndexToString.length === 0) {
// Reset
s.indexToString = [];
return;
} else if (!Array.isArray(savedIndexToString)) {
// Reset, but warn user
console.error(
"String lookup table could not be loaded: data.__s is not an array.",
);
s.indexToString = [];
return;
}
// Recreate from the index array
s.indexToString = savedIndexToString;
savedIndexToString.forEach((v, i) => {
s.stringToIndex.set(v, i);
});
}
/** Serializes interned strings and their mappings */
public static toJSON() {
const s = StringLookup.getInstance();
return s.indexToString;
}
/** Restores from JSON */
static fromJSON(data: { dictionary: string[] }) {
const table = new StringLookup();
table.indexToString = data.dictionary;
table.stringToIndex = new Map(data.dictionary.map((str, i) => [str, i]));
StringLookup.instance = table;
}
}

View File

@ -19,6 +19,13 @@ export enum NativeLLM {
OpenAI_GPT4_1106_Prev = "gpt-4-1106-preview",
OpenAI_GPT4_0125_Prev = "gpt-4-0125-preview",
OpenAI_GPT4_Turbo_Prev = "gpt-4-turbo-preview",
OpenAI_GPT4_Turbo = "gpt-4-turbo",
OpenAI_GPT4_Turbo_0409 = "gpt-4-turbo-2024-04-09",
OpenAI_GPT4_O = "gpt-4o",
OpenAI_GPT4_O_Mini = "gpt-4o-mini",
OpenAI_GPT4_O_0513 = "gpt-4o-2024-05-13",
OpenAI_GPT4_O_0806 = "gpt-4o-2024-08-06",
OpenAO_ChatGPT4_O = "chatgpt-4o-latest",
OpenAI_GPT4_32k = "gpt-4-32k",
OpenAI_GPT4_32k_0314 = "gpt-4-32k-0314",
OpenAI_GPT4_32k_0613 = "gpt-4-32k-0613",
@ -44,8 +51,12 @@ export enum NativeLLM {
Dalai_Llama_65B = "llama.65B",
// Anthropic
Claude_v3_opus_latest = "claude-3-opus-latest",
Claude_v3_opus = "claude-3-opus-20240229",
Claude_v3_sonnet = "claude-3-sonnet-20240229",
Claude_v3_5_sonnet_latest = "claude-3-5-sonnet-latest",
Claude_v3_5_sonnet = "claude-3-5-sonnet-20240620",
Claude_v3_5_haiku_latest = "claude-3-5-haiku-latest",
Claude_v3_haiku = "claude-3-haiku-20240307",
Claude_v2_1 = "claude-2.1",
Claude_v2 = "claude-2",
@ -63,6 +74,15 @@ export enum NativeLLM {
PaLM2_Text_Bison = "text-bison-001", // it's really models/text-bison-001, but that's confusing
PaLM2_Chat_Bison = "chat-bison-001",
GEMINI_PRO = "gemini-pro",
GEMINI_v2_flash = "gemini-2.0-flash-exp",
GEMINI_v1_5_flash = "gemini-1.5-flash",
GEMINI_v1_5_flash_8B = "gemini-1.5-flash-8b",
GEMINI_v1_5_pro = "gemini-1.5-pro",
GEMINI_v1_pro = "gemini-1.0-pro",
// DeepSeek
DeepSeek_Chat = "deepseek-chat",
DeepSeek_Reasoner = "deepseek-reasoner",
// Aleph Alpha
Aleph_Alpha_Luminous_Extended = "luminous-extended",
@ -81,18 +101,17 @@ export enum NativeLLM {
HF_DIALOGPT_LARGE = "microsoft/DialoGPT-large", // chat model
HF_GPT2 = "gpt2",
HF_BLOOM_560M = "bigscience/bloom-560m",
// HF_GPTJ_6B = "EleutherAI/gpt-j-6b",
// HF_LLAMA_7B = "decapoda-research/llama-7b-hf",
// A special flag for a user-defined HuggingFace model endpoint.
// The actual model name will be passed as a param to the LLM call function.
HF_OTHER = "Other (HuggingFace)",
Ollama = "ollama",
Bedrock_Claude_2_1 = "anthropic.claude-v2:1",
Bedrock_Claude_2 = "anthropic.claude-v2",
Bedrock_Claude_3_Sonnet = "anthropic.claude-3-sonnet-20240229-v1:0",
Bedrock_Claude_3_Haiku = "anthropic.claude-3-haiku-20240307-v1:0",
Bedrock_Claude_3_Opus = "anthropic.claude-3-opus-20240229-v1:0",
Bedrock_Claude_Instant_1 = "anthropic.claude-instant-v1",
Bedrock_Jurassic_Ultra = "ai21.j2-ultra",
Bedrock_Jurassic_Mid = "ai21.j2-mid",
@ -103,8 +122,94 @@ export enum NativeLLM {
Bedrock_Command_Text_Light = "cohere.command-light-text-v14",
Bedrock_Meta_LLama2Chat_13b = "meta.llama2-13b-chat-v1",
Bedrock_Meta_LLama2Chat_70b = "meta.llama2-70b-chat-v1",
Bedrock_Meta_LLama3Instruct_8b = "meta.llama3-8b-instruct-v1:0",
Bedrock_Meta_LLama3Instruct_70b = "meta.llama3-70b-instruct-v1:0",
Bedrock_Mistral_Mistral = "mistral.mistral-7b-instruct-v0:2",
Bedrock_Mistral_Mistral_Large = "mistral.mistral-large-2402-v1:0",
Bedrock_Mistral_Mixtral = "mistral.mixtral-8x7b-instruct-v0:1",
// Together.ai
Together_ZeroOneAI_01ai_Yi_Chat_34B = "together/zero-one-ai/Yi-34B-Chat",
Together_AllenAI_OLMo_Instruct_7B = "together/allenai/OLMo-7B-Instruct",
Together_AllenAI_OLMo_Twin2T_7B = "together/allenai/OLMo-7B-Twin-2T",
Together_AllenAI_OLMo_7B = "together/allenai/OLMo-7B",
Together_Austism_Chronos_Hermes_13B = "together/Austism/chronos-hermes-13b",
Together_cognitivecomputations_Dolphin_2_5_Mixtral_8x7b = "together/cognitivecomputations/dolphin-2.5-mixtral-8x7b",
Together_databricks_DBRX_Instruct = "together/databricks/dbrx-instruct",
Together_DeepSeek_Deepseek_Coder_Instruct_33B = "together/deepseek-ai/deepseek-coder-33b-instruct",
Together_DeepSeek_DeepSeek_LLM_Chat_67B = "together/deepseek-ai/deepseek-llm-67b-chat",
Together_garagebAInd_Platypus2_Instruct_70B = "together/garage-bAInd/Platypus2-70B-instruct",
Together_Google_Gemma_Instruct_2B = "together/google/gemma-2b-it",
Together_Google_Gemma_Instruct_7B = "together/google/gemma-7b-it",
Together_Gryphe_MythoMaxL2_13B = "together/Gryphe/MythoMax-L2-13b",
Together_LMSys_Vicuna_v1_5_13B = "together/lmsys/vicuna-13b-v1.5",
Together_LMSys_Vicuna_v1_5_7B = "together/lmsys/vicuna-7b-v1.5",
Together_Meta_Code_Llama_Instruct_13B = "together/codellama/CodeLlama-13b-Instruct-hf",
Together_Meta_Code_Llama_Instruct_34B = "together/codellama/CodeLlama-34b-Instruct-hf",
Together_Meta_Code_Llama_Instruct_70B = "together/codellama/CodeLlama-70b-Instruct-hf",
Together_Meta_Code_Llama_Instruct_7B = "together/codellama/CodeLlama-7b-Instruct-hf",
Together_Meta_LLaMA2_Chat_70B = "together/meta-llama/Llama-2-70b-chat-hf",
Together_Meta_LLaMA2_Chat_13B = "together/meta-llama/Llama-2-13b-chat-hf",
Together_Meta_LLaMA2_Chat_7B = "together/meta-llama/Llama-2-7b-chat-hf",
Together_Meta_LLaMA3_Chat_8B = "together/meta-llama/Llama-3-8b-chat-hf",
Together_Meta_LLaMA3_Chat_70B = "together/meta-llama/Llama-3-70b-chat-hf",
Together_Meta_LLaMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo",
Together_Meta_LLaMA3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
Together_Meta_LLaMA3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
Together_Meta_LLaMA3_1_405B = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
Together_Meta_LLaMA3_8B = "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
Together_Meta_LLaMA3_70B = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
Together_Meta_LLaMA3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo",
Together_Meta_LLaMA3_8B_Lite = "meta-llama/Meta-Llama-3-8B-Instruct-Lite",
Together_Meta_LLaMA3_70B_Lite = "meta-llama/Meta-Llama-3-70B-Instruct-Lite",
Together_Nvidia_LLaMA3_1_Nemotron_70B = "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
Together_Qwen_Qwen2_5_Coder_32B = "Qwen/Qwen2.5-Coder-32B-Instruct",
Together_Qwen_QwQ_32B_Preview = "Qwen/QwQ-32B-Preview",
Together_Microsoft_WizardLM_2_8x22B = "microsoft/WizardLM-2-8x22B",
Together_Google_Gemma2_27B = "google/gemma-2-27b-it",
Together_Google_Gemma2_9B = "google/gemma-2-9b-it",
Together_DeepSeek_3 = "deepseek-ai/DeepSeek-V3",
Together_DeepSeek_R1 = "deepseek-ai/DeepSeek-R1",
Together_mistralai_Mistral_7B_Instruct_v0_3 = "mistralai/Mistral-7B-Instruct-v0.3",
Together_Qwen_Qwen2_5_7B_Turbo = "Qwen/Qwen2.5-7B-Instruct-Turbo",
Together_Qwen_Qwen2_5_72B_Turbo = "Qwen/Qwen2.5-72B-Instruct-Turbo",
Together_Qwen_Qwen2_5_72B = "Qwen/Qwen2-72B-Instruct",
Together_Qwen_Qwen2_VL_72B = "Qwen/Qwen2-VL-72B-Instruct",
Together_Qwen_Qwen2_5_32B_Coder = "Qwen/Qwen2.5-Coder-32B-Instruct",
Together_mistralai_Mistral_7B_Instruct = "together/mistralai/Mistral-7B-Instruct-v0.1",
Together_mistralai_Mistral_7B_Instruct_v0_2 = "together/mistralai/Mistral-7B-Instruct-v0.2",
Together_mistralai_Mixtral8x7B_Instruct_46_7B = "together/mistralai/Mixtral-8x7B-Instruct-v0.1",
Together_mistralai_Mixtral8x22B_Instruct_141B = "together/mistralai/Mixtral-8x22B-Instruct-v0.1",
Together_NousResearch_Nous_Capybara_v1_9_7B = "together/NousResearch/Nous-Capybara-7B-V1p9",
Together_NousResearch_Nous_Hermes_2__Mistral_DPO_7B = "together/NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
Together_NousResearch_Nous_Hermes_2__Mixtral_8x7BDPO_46_7B = "together/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
Together_NousResearch_Nous_Hermes_2__Mixtral_8x7BSFT_46_7B = "together/NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
Together_NousResearch_Nous_Hermes_LLaMA2_7B = "together/NousResearch/Nous-Hermes-llama-2-7b",
Together_NousResearch_Nous_Hermes_Llama2_13B = "together/NousResearch/Nous-Hermes-Llama2-13b",
Together_NousResearch_Nous_Hermes2_Yi_34B = "together/NousResearch/Nous-Hermes-2-Yi-34B",
Together_OpenChat_OpenChat_3_5_7B = "together/openchat/openchat-3.5-1210",
Together_OpenOrca_OpenOrca_Mistral_7B_8K = "together/Open-Orca/Mistral-7B-OpenOrca",
Together_Qwen_Qwen_1_5_Chat_0_5B = "together/Qwen/Qwen1.5-0.5B-Chat",
Together_Qwen_Qwen_1_5_Chat_1_8B = "together/Qwen/Qwen1.5-1.8B-Chat",
Together_Qwen_Qwen_1_5_Chat_4B = "together/Qwen/Qwen1.5-4B-Chat",
Together_Qwen_Qwen_1_5_Chat_7B = "together/Qwen/Qwen1.5-7B-Chat",
Together_Qwen_Qwen_1_5_Chat_14B = "together/Qwen/Qwen1.5-14B-Chat",
Together_Qwen_Qwen_1_5_Chat_32B = "together/Qwen/Qwen1.5-32B-Chat",
Together_Qwen_Qwen_1_5_Chat_72B = "together/Qwen/Qwen1.5-72B-Chat",
Together_Qwen_Qwen_1_5_Chat_110B = "together/Qwen/Qwen1.5-110B-Chat",
Together_SnorkelAI_Snorkel_Mistral_PairRM_DPO_7B = "together/snorkelai/Snorkel-Mistral-PairRM-DPO",
Together_Snowflake_Snowflake_Arctic_Instruct = "together/Snowflake/snowflake-arctic-instruct",
Together_Stanford_Alpaca_7B = "together/togethercomputer/alpaca-7b",
Together_Teknium_OpenHermes2Mistral_7B = "together/teknium/OpenHermes-2-Mistral-7B",
Together_Teknium_OpenHermes2_5Mistral_7B = "together/teknium/OpenHermes-2p5-Mistral-7B",
Together_LLaMA27B32KInstruct_7B = "together/togethercomputer/Llama-2-7B-32K-Instruct",
Together_RedPajamaINCITE_Chat_3B = "together/togethercomputer/RedPajama-INCITE-Chat-3B-v1",
Together_RedPajamaINCITE_Chat_7B = "together/togethercomputer/RedPajama-INCITE-7B-Chat",
Together_StripedHyena_Nous_7B = "together/togethercomputer/StripedHyena-Nous-7B",
Together_Undi95_ReMM_SLERP_L2_13B = "together/Undi95/ReMM-SLERP-L2-13B",
Together_Undi95_Toppy_M_7B = "together/Undi95/Toppy-M-7B",
Together_WizardLM_WizardLM_v1_2_13B = "together/WizardLM/WizardLM-13B-V1.2",
Together_upstage_Upstage_SOLAR_Instruct_v1_11B = "together/upstage/SOLAR-10.7B-Instruct-v1.0",
}
export type LLM = string | NativeLLM;
@ -131,6 +236,8 @@ export enum LLMProvider {
Aleph_Alpha = "alephalpha",
Ollama = "ollama",
Bedrock = "bedrock",
Together = "together",
DeepSeek = "deepseek",
Custom = "__custom",
}
@ -151,6 +258,8 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
else if (llm_name?.startsWith("Aleph_Alpha")) return LLMProvider.Aleph_Alpha;
else if (llm_name?.startsWith("Ollama")) return LLMProvider.Ollama;
else if (llm_name?.startsWith("Bedrock")) return LLMProvider.Bedrock;
else if (llm_name?.startsWith("Together")) return LLMProvider.Together;
else if (llm_name?.startsWith("DeepSeek")) return LLMProvider.DeepSeek;
else if (llm.toString().startsWith("__custom/")) return LLMProvider.Custom;
return undefined;
@ -186,24 +295,32 @@ export const RATE_LIMIT_BY_MODEL: { [key in LLM]?: number } = {
[NativeLLM.Bedrock_Jurassic_Ultra]: 25,
[NativeLLM.Bedrock_Titan_Light]: 800,
[NativeLLM.Bedrock_Titan_Express]: 400, // 400 RPM
[NativeLLM.Bedrock_Claude_2]: 100, // 100 RPM
[NativeLLM.Bedrock_Claude_2_1]: 100, // 100 RPM
[NativeLLM.Bedrock_Claude_3_Haiku]: 100, // 100 RPM
[NativeLLM.Bedrock_Claude_2]: 500, // 500 RPM
[NativeLLM.Bedrock_Claude_2_1]: 500, // 500 RPM
[NativeLLM.Bedrock_Claude_3_Haiku]: 1000, // 1000 RPM
[NativeLLM.Bedrock_Claude_3_Sonnet]: 100, // 100 RPM
[NativeLLM.Bedrock_Claude_3_Opus]: 50, // 50 RPM
[NativeLLM.Bedrock_Claude_Instant_1]: 1000, // 1000 RPM
[NativeLLM.Bedrock_Command_Text]: 400, // 400 RPM
[NativeLLM.Bedrock_Command_Text_Light]: 800, // 800 RPM
[NativeLLM.Bedrock_Meta_LLama2Chat_70b]: 400, // 400 RPM
[NativeLLM.Bedrock_Meta_LLama2Chat_13b]: 800, // 800 RPM
[NativeLLM.Bedrock_Meta_LLama3Instruct_8b]: 400, // 400 RPM
[NativeLLM.Bedrock_Meta_LLama3Instruct_70b]: 800, // 800 RPM
[NativeLLM.Bedrock_Mistral_Mixtral]: 400, // 400 RPM
[NativeLLM.Bedrock_Mistral_Mistral_Large]: 400, // 400 RPM
[NativeLLM.Bedrock_Mistral_Mistral]: 800, // 800 RPM
};
export const RATE_LIMIT_BY_PROVIDER: { [key in LLMProvider]?: number } = {
[LLMProvider.Anthropic]: 25, // Tier 1 pricing limit is 50 per minute, across all models; we halve this, to be safe.
[LLMProvider.Together]: 30, // Paid tier limit is 60 per minute, across all models; we halve this, to be safe.
[LLMProvider.Google]: 1000, // RPM for Google Gemini models 1.5 is quite generous; at base it is 1000 RPM. If you are using the free version it's 15 RPM, but we can expect most CF users to be using paid (and anyway you can just re-run prompt node until satisfied).
[LLMProvider.DeepSeek]: 1000, // DeepSeek does not constrain users atm but they might in the future. To be safe we are limiting it to 1000 queries per minute.
};
// Max concurrent requests. Add to this to further constrain the rate limiter.
export const MAX_CONCURRENT: { [key in LLM]?: number } = {};
export const MAX_CONCURRENT: { [key in string | NativeLLM]?: number } = {};
const DEFAULT_RATE_LIMIT = 100; // RPM for any models not listed above
@ -229,14 +346,14 @@ export class RateLimiter {
}
/** Get the Bottleneck limiter for the given model. If it doesn't already exist, instantiates it dynamically. */
private getLimiter(model: LLM): Bottleneck {
private getLimiter(model: LLM, provider: LLMProvider): Bottleneck {
// Find if there's an existing limiter for this model
if (!(model in this.limiters)) {
// If there isn't, make one:
// Find the RPM. First search if the model is present in predefined rate limits; then search for pre-defined RLs by provider; then set to default.
const rpm =
RATE_LIMIT_BY_MODEL[model] ??
RATE_LIMIT_BY_PROVIDER[getProvider(model) ?? LLMProvider.Custom] ??
RATE_LIMIT_BY_PROVIDER[provider ?? LLMProvider.Custom] ??
DEFAULT_RATE_LIMIT;
this.limiters[model] = new Bottleneck({
reservoir: rpm, // max requests per minute
@ -257,13 +374,14 @@ export class RateLimiter {
*/
public static throttle<T>(
model: LLM,
provider: LLMProvider,
func: () => PromiseLike<T>,
should_cancel?: () => boolean,
): Promise<T> {
// Rate limit per model, and abort if the API request takes 3 minutes or more.
return this.getInstance()
.getLimiter(model)
.schedule({ expiration: 180000 }, () => {
.getLimiter(model, provider)
.schedule({}, () => {
if (should_cancel && should_cancel())
throw new UserForcedPrematureExit();
return func();

View File

@ -1,6 +1,6 @@
import { v4 as uuid } from "uuid";
import { PromptTemplate, PromptPermutationGenerator } from "./template";
import { LLM, NativeLLM, RateLimiter } from "./models";
import { LLM, LLMProvider, NativeLLM, RateLimiter } from "./models";
import {
Dict,
LLMResponseError,
@ -21,7 +21,7 @@ import {
repairCachedResponses,
compressBase64Image,
} from "./utils";
import StorageCache from "./cache";
import StorageCache, { StringLookup } from "./cache";
import { UserForcedPrematureExit } from "./errors";
import { typecastSettingsDict } from "../ModelSettingSchemas";
@ -75,6 +75,7 @@ export class PromptPipeline {
private async collect_LLM_response(
result: _IntermediateLLMResponseType,
llm: LLM,
provider: LLMProvider,
cached_responses: Dict,
): Promise<RawLLMResponseObject | LLMResponseError> {
const {
@ -97,7 +98,7 @@ export class PromptPipeline {
const metavars = prompt.metavars;
// Extract and format the responses into `LLMResponseData`
const extracted_resps = extract_responses(response, llm);
const extracted_resps = extract_responses(response, llm, provider);
// Detect any images and downrez them if the user has approved of automatic compression.
// This saves a lot of performance and storage. We also need to disable storing the raw response here, to save space.
@ -130,7 +131,6 @@ export class PromptPipeline {
query: query ?? {},
uid: uuid(),
responses: extracted_resps,
raw_response: contains_imgs ? {} : response ?? {}, // don't double-store images
llm,
vars: mergeDicts(info, chat_history?.fill_history) ?? {},
metavars: mergeDicts(metavars, chat_history?.metavars) ?? {},
@ -140,6 +140,9 @@ export class PromptPipeline {
if (chat_history !== undefined)
resp_obj.chat_history = chat_history.messages;
// Hash strings present in the response object, to improve performance
StringLookup.internDict(resp_obj, true);
// Merge the response obj with the past one, if necessary
if (past_resp_obj)
resp_obj = merge_response_objs(
@ -149,14 +152,14 @@ export class PromptPipeline {
// Save the current state of cache'd responses to a JSON file
// NOTE: We do this to save money --in case something breaks between calls, can ensure we got the data!
if (!(resp_obj.prompt in cached_responses))
cached_responses[resp_obj.prompt] = [];
else if (!Array.isArray(cached_responses[resp_obj.prompt]))
cached_responses[resp_obj.prompt] = [cached_responses[resp_obj.prompt]];
const prompt_str = prompt.toString();
if (!(prompt_str in cached_responses)) cached_responses[prompt_str] = [];
else if (!Array.isArray(cached_responses[prompt_str]))
cached_responses[prompt_str] = [cached_responses[prompt_str]];
if (past_resp_obj_cache_idx !== undefined && past_resp_obj_cache_idx > -1)
cached_responses[resp_obj.prompt][past_resp_obj_cache_idx] = resp_obj;
else cached_responses[resp_obj.prompt].push(resp_obj);
cached_responses[prompt_str][past_resp_obj_cache_idx] = resp_obj;
else cached_responses[prompt_str].push(resp_obj);
this._cache_responses(cached_responses);
@ -183,6 +186,7 @@ export class PromptPipeline {
* @param vars The 'vars' dict to fill variables in the root prompt template. For instance, for 'Who is {person}?', vars might = { person: ['TJ', 'MJ', 'AD'] }.
* @param llm The specific LLM model to call. See the LLM enum for supported models.
* @param provider The specific LLM provider to call. See the LLMProvider enum for supported providers.
* @param n How many generations per prompt sent to the LLM.
* @param temperature The temperature to use when querying the LLM.
* @param llm_params Optional. The model-specific settings to pass into the LLM API call. Varies by LLM.
@ -195,6 +199,7 @@ export class PromptPipeline {
async *gen_responses(
vars: Dict,
llm: LLM,
provider: LLMProvider,
n = 1,
temperature = 1.0,
llm_params?: Dict,
@ -284,11 +289,10 @@ export class PromptPipeline {
if (cached_resp && extracted_resps.length >= n) {
// console.log(` - Found cache'd response for prompt ${prompt_str}. Using...`);
const resp: RawLLMResponseObject = {
prompt: prompt_str,
prompt: cached_resp.prompt,
query: cached_resp.query,
uid: cached_resp.uid ?? uuid(),
responses: extracted_resps.slice(0, n),
raw_response: cached_resp.raw_response,
llm: cached_resp.llm || NativeLLM.OpenAI_ChatGPT,
// We want to use the new info, since 'vars' could have changed even though
// the prompt text is the same (e.g., "this is a tool -> this is a {x} where x='tool'")
@ -297,6 +301,7 @@ export class PromptPipeline {
};
if (chat_history !== undefined)
resp.chat_history = chat_history.messages;
yield resp;
continue;
}
@ -305,6 +310,7 @@ export class PromptPipeline {
tasks.push(
this._prompt_llm(
llm,
provider,
prompt,
n,
temperature,
@ -319,7 +325,9 @@ export class PromptPipeline {
},
chat_history,
should_cancel,
).then((result) => this.collect_LLM_response(result, llm, responses)),
).then((result) =>
this.collect_LLM_response(result, llm, provider, responses),
),
);
}
}
@ -358,6 +366,7 @@ export class PromptPipeline {
async _prompt_llm(
llm: LLM,
provider: LLMProvider,
prompt: PromptTemplate,
n = 1,
temperature = 1.0,
@ -393,9 +402,11 @@ export class PromptPipeline {
// It's not perfect, but it's simpler than throttling at the call-specific level.
[query, response] = await RateLimiter.throttle(
llm,
provider,
() =>
call_llm(
llm,
provider,
prompt.toString(),
n,
temperature,

View File

@ -0,0 +1,51 @@
import { v4 as uuidv4 } from "uuid";
import { Dict, TabularDataRowType, TabularDataColType } from "./typing";
/*
This file contains utility functions for parsing raw table data
into a format for TabularDataNode
*/
export function parseTableData(rawTableData: any[]): {
columns: TabularDataColType[];
rows: TabularDataRowType[];
} {
if (!Array.isArray(rawTableData)) {
throw new Error(
"Table data is not in array format: " +
(rawTableData !== undefined && rawTableData !== null
? String(rawTableData)
: ""),
);
}
// Extract unique column names
const headers = new Set<string>();
rawTableData.forEach((row) =>
Object.keys(row).forEach((key) => headers.add(key)),
);
// Create columns with unique IDs
const columns = Array.from(headers).map((header, idx) => ({
header,
key: `c${idx}`,
}));
// Create a lookup table for column keys
const columnKeyLookup: Dict<string> = {};
columns.forEach((col) => {
columnKeyLookup[col.header] = col.key;
});
// Map rows to the new column keys
const rows = rawTableData.map((row) => {
const parsedRow: TabularDataRowType = { __uid: uuidv4() };
Object.keys(row).forEach((header) => {
const rawValue = row[header];
const value =
typeof rawValue === "object" ? JSON.stringify(rawValue) : rawValue;
parsedRow[columnKeyLookup[header]] = value?.toString() ?? "";
});
return parsedRow;
});
return { columns, rows };
}

View File

@ -1,4 +1,12 @@
import { StringLookup } from "./cache";
import { isEqual } from "./setUtils";
import {
Dict,
PromptVarsDict,
PromptVarType,
StringOrHash,
TemplateVarInfo,
} from "./typing";
function len(o: object | string | Array<any>): number {
// Acts akin to Python's builtin 'len' method
@ -193,8 +201,8 @@ export class PromptTemplate {
print(partial_prompt)
*/
template: string;
fill_history: { [key: string]: any };
metavars: { [key: string]: any };
fill_history: Dict<StringOrHash>;
metavars: Dict<StringOrHash>;
constructor(templateStr: string) {
/**
@ -234,7 +242,8 @@ export class PromptTemplate {
has_unfilled_settings_var(varname: string): boolean {
return Object.entries(this.fill_history).some(
([key, val]) =>
key.startsWith("=") && new StringTemplate(val).has_vars([varname]),
key.startsWith("=") &&
new StringTemplate(StringLookup.get(val) ?? "").has_vars([varname]),
);
}
@ -257,33 +266,47 @@ export class PromptTemplate {
"PL": "Python"
});
*/
fill(paramDict: { [key: string]: any }): PromptTemplate {
fill(paramDict: Dict<PromptVarType>): PromptTemplate {
// Check for special 'past fill history' format:
let past_fill_history = {};
let past_metavars = {};
let past_fill_history: Dict<StringOrHash> = {};
let past_metavars: Dict<StringOrHash> = {};
const some_key = Object.keys(paramDict).pop();
const some_val = some_key ? paramDict[some_key] : undefined;
if (len(paramDict) > 0 && isDict(some_val)) {
// Transfer over the fill history and metavars
Object.values(paramDict).forEach((obj) => {
if ("fill_history" in obj)
past_fill_history = { ...obj.fill_history, ...past_fill_history };
if ("metavars" in obj)
past_metavars = { ...obj.metavars, ...past_metavars };
if ("fill_history" in (obj as TemplateVarInfo))
past_fill_history = {
...(obj as TemplateVarInfo).fill_history,
...past_fill_history,
};
if ("metavars" in (obj as TemplateVarInfo))
past_metavars = {
...(obj as TemplateVarInfo).metavars,
...past_metavars,
};
});
past_fill_history = StringLookup.concretizeDict(past_fill_history);
past_metavars = StringLookup.concretizeDict(past_metavars);
// Recreate the param dict from just the 'text' property of the fill object
const newParamDict: { [key: string]: any } = {};
const newParamDict: Dict<StringOrHash> = {};
Object.entries(paramDict).forEach(([param, obj]) => {
newParamDict[param] = obj.text;
newParamDict[param] = (obj as TemplateVarInfo).text as StringOrHash;
});
paramDict = newParamDict;
}
// Concretize the params
paramDict = StringLookup.concretizeDict(paramDict) as Dict<
string | TemplateVarInfo
>;
// For 'settings' template vars of form {=system_msg}, we use the same logic of storing param
// values as before -- the only difference is that, when it comes to the actual substitution of
// the string, we *don't fill the template with anything* --it vanishes.
let params_wo_settings = paramDict;
let params_wo_settings = paramDict as Dict<string>;
// To improve performance, we first check if there's a settings var present at all before deep cloning:
if (Object.keys(paramDict).some((key) => key?.charAt(0) === "=")) {
// A settings var is present; deep clone the param dict and replace it with the empty string:
@ -305,9 +328,9 @@ export class PromptTemplate {
// Perform the fill inside any and all 'settings' template vars
Object.entries(filled_pt.fill_history).forEach(([key, val]) => {
if (!key.startsWith("=")) return;
filled_pt.fill_history[key] = new StringTemplate(val).safe_substitute(
params_wo_settings,
);
filled_pt.fill_history[key] = new StringTemplate(
StringLookup.get(val) ?? "",
).safe_substitute(params_wo_settings);
});
// Append any past history passed as vars:
@ -325,7 +348,7 @@ export class PromptTemplate {
});
// Add the new fill history using the passed parameters that we just filled in
Object.entries(paramDict).forEach(([key, val]) => {
Object.entries(paramDict as Dict<string>).forEach(([key, val]) => {
if (key in filled_pt.fill_history)
console.log(
`Warning: PromptTemplate already has fill history for key ${key}.`,
@ -341,11 +364,11 @@ export class PromptTemplate {
* Modifies the prompt template in place.
* @param fill_history A fill history dict.
*/
fill_special_vars(fill_history: { [key: string]: any }): void {
fill_special_vars(fill_history: Dict<StringOrHash>): void {
// Special variables {#...} denotes filling a variable from a matching var in fill_history or metavars.
// Find any special variables:
const unfilled_vars = new StringTemplate(this.template).get_vars();
const special_vars_to_fill: { [key: string]: string } = {};
const special_vars_to_fill: Dict<StringOrHash> = {};
for (const v of unfilled_vars) {
if (v.length > 0 && v[0] === "#") {
// special template variables must begin with #
@ -360,7 +383,7 @@ export class PromptTemplate {
// Fill any special variables, using the fill history of the template in question:
if (Object.keys(special_vars_to_fill).length > 0)
this.template = new StringTemplate(this.template).safe_substitute(
special_vars_to_fill,
StringLookup.concretizeDict(special_vars_to_fill),
);
}
}
@ -389,7 +412,7 @@ export class PromptPermutationGenerator {
*_gen_perm(
template: PromptTemplate,
params_to_fill: Array<string>,
paramDict: { [key: string]: any },
paramDict: PromptVarsDict,
): Generator<PromptTemplate, boolean, undefined> {
if (len(params_to_fill) === 0) return true;
@ -417,24 +440,31 @@ export class PromptPermutationGenerator {
val.forEach((v) => {
if (param === undefined) return;
const param_fill_dict: { [key: string]: any } = {};
const param_fill_dict: Dict<PromptVarType> = {};
param_fill_dict[param] = v;
/* If this var has an "associate_id", then it wants to "carry with"
values of other prompt parameters with the same id.
We have to find any parameters with values of the same id,
and fill them in alongside the initial parameter v: */
if (isDict(v) && "associate_id" in v) {
const v_associate_id = v.associate_id;
if (isDict(v) && "associate_id" in (v as object)) {
const v_associate_id = (v as TemplateVarInfo).associate_id;
params_left.forEach((other_param) => {
if (
(template.has_var(other_param) ||
template.has_unfilled_settings_var(other_param)) &&
Array.isArray(paramDict[other_param])
) {
for (let i = 0; i < paramDict[other_param].length; i++) {
const ov = paramDict[other_param][i];
if (isDict(ov) && ov.associate_id === v_associate_id) {
for (
let i = 0;
i < (paramDict[other_param] as PromptVarType[]).length;
i++
) {
const ov = (paramDict[other_param] as PromptVarType[])[i];
if (
isDict(ov) &&
(ov as TemplateVarInfo).associate_id === v_associate_id
) {
// This is a match. We should add the val to our param_fill_dict:
param_fill_dict[other_param] = ov;
break;
@ -447,8 +477,8 @@ export class PromptPermutationGenerator {
// Fill the template with the param values and append it to the list
new_prompt_temps.push(template.fill(param_fill_dict));
});
} else if (typeof val === "string") {
const sub_dict: { [key: string]: any } = {};
} else if (typeof val === "string" || typeof val === "number") {
const sub_dict: Dict<StringOrHash> = {};
sub_dict[param] = val;
new_prompt_temps = [template.fill(sub_dict)];
} else
@ -470,9 +500,9 @@ export class PromptPermutationGenerator {
}
// Generator class method to yield permutations of a root prompt template
*generate(paramDict: {
[key: string]: any;
}): Generator<PromptTemplate, boolean, undefined> {
*generate(
paramDict: PromptVarsDict,
): Generator<PromptTemplate, boolean, undefined> {
const template =
typeof this.template === "string"
? new PromptTemplate(this.template)

View File

@ -13,6 +13,9 @@ export interface Dict<T = any> {
[key: string]: T;
}
/** A string or a number representing the index to a hash table (`StringLookup`). */
export type StringOrHash = string | number;
// Function types
export type Func<T = void> = (...args: any[]) => T;
@ -52,7 +55,7 @@ export interface PaLMChatContext {
export interface GeminiChatMessage {
role: string;
parts: string;
parts: [{ text: string }];
}
export interface GeminiChatContext {
@ -68,8 +71,8 @@ export interface HuggingFaceChatHistory {
// Chat history with 'carried' variable metadata
export interface ChatHistoryInfo {
messages: ChatHistory;
fill_history: Dict;
metavars?: Dict;
fill_history: Dict<StringOrHash>;
metavars?: Dict<StringOrHash>;
llm?: string;
}
@ -122,6 +125,12 @@ export type LLMSpec = {
progress?: QueryProgress; // only used for front-end to display progress collecting responses for this LLM
};
export type LLMGroup = {
group: string;
emoji: string;
items: LLMSpec[] | LLMGroup[];
};
/** A spec for a user-defined custom LLM provider */
export type CustomLLMProviderSpec = {
name: string;
@ -158,7 +167,7 @@ export type LLMResponseData =
t: "img"; // type
d: string; // payload
}
| string;
| StringOrHash;
export function isImageResponseData(
r: LLMResponseData,
@ -171,13 +180,13 @@ export interface BaseLLMResponseObject {
/** A unique ID to refer to this response */
uid: ResponseUID;
/** The concrete prompt that led to this response. */
prompt: string;
prompt: StringOrHash;
/** The variables fed into the prompt. */
vars: Dict;
vars: Dict<StringOrHash>;
/** Any associated metavariables. */
metavars: Dict;
metavars: Dict<StringOrHash>;
/** The LLM to query (usually a dict of settings) */
llm: string | LLMSpec;
llm: StringOrHash | LLMSpec;
/** Optional: The chat history to pass the LLM */
chat_history?: ChatHistory;
}
@ -187,7 +196,8 @@ export interface RawLLMResponseObject extends BaseLLMResponseObject {
// A snapshot of the exact query (payload) sent to the LLM's API
query: Dict;
// The raw JSON response from the LLM
raw_response: Dict;
// NOTE: This is now deprecated since it wastes precious storage space.
// raw_response: Dict;
// Extracted responses (1 or more) from raw_response
responses: LLMResponseData[];
// Token lengths (if given)
@ -199,6 +209,7 @@ export type EvaluationScore =
| number
| string
| Dict<boolean | number | string>;
export type EvaluationResults = {
items: EvaluationScore[];
dtype:
@ -233,19 +244,19 @@ export type EvaluatedResponsesResults = {
/** The outputs of prompt nodes, text fields or other data passed internally in the front-end and to the PromptTemplate backend.
* Used to populate prompt templates and carry variables/metavariables along the chain. */
export interface TemplateVarInfo {
text?: string;
image?: string; // base-64 encoding
fill_history?: Dict<string>;
metavars?: Dict<string>;
associate_id?: string;
prompt?: string;
text?: StringOrHash;
image?: StringOrHash; // base-64 encoding
fill_history?: Dict<StringOrHash>;
metavars?: Dict<StringOrHash>;
associate_id?: StringOrHash;
prompt?: StringOrHash;
uid?: ResponseUID;
llm?: string | LLMSpec;
llm?: StringOrHash | LLMSpec;
chat_history?: ChatHistory;
}
export type LLMResponsesByVarDict = Dict<
(BaseLLMResponseObject | LLMResponse | TemplateVarInfo | string)[]
(BaseLLMResponseObject | LLMResponse | TemplateVarInfo | StringOrHash)[]
>;
export type VarsContext = {
@ -253,12 +264,12 @@ export type VarsContext = {
metavars: string[];
};
export type PromptVarType = string | TemplateVarInfo;
export type PromptVarType = StringOrHash | TemplateVarInfo;
export type PromptVarsDict = {
[key: string]: PromptVarType[];
[key: string]: PromptVarType[] | StringOrHash;
};
export type TabularDataRowType = Dict<string>;
export type TabularDataRowType = Dict<StringOrHash>;
export type TabularDataColType = {
key: string;
header: string;

View File

@ -25,6 +25,9 @@ import {
LLMSpec,
EvaluationScore,
LLMResponseData,
isImageResponseData,
StringOrHash,
PromptVarsDict,
} from "./typing";
import { v4 as uuid } from "uuid";
import { StringTemplate } from "./template";
@ -46,9 +49,9 @@ import {
fromModelId,
ChatMessage as BedrockChatMessage,
} from "@mirai73/bedrock-fm";
import { Models } from "@mirai73/bedrock-fm/lib/bedrock";
import StorageCache from "./cache";
import StorageCache, { StringLookup } from "./cache";
import Compressor from "compressorjs";
// import { Models } from "@mirai73/bedrock-fm/lib/bedrock";
const ANTHROPIC_HUMAN_PROMPT = "\n\nHuman:";
const ANTHROPIC_AI_PROMPT = "\n\nAssistant:";
@ -143,6 +146,7 @@ function get_environ(key: string): string | undefined {
}
let OPENAI_API_KEY = get_environ("OPENAI_API_KEY");
let OPENAI_BASE_URL = get_environ("OPENAI_BASE_URL");
let ANTHROPIC_API_KEY = get_environ("ANTHROPIC_API_KEY");
let GOOGLE_PALM_API_KEY = get_environ("PALM_API_KEY");
let AZURE_OPENAI_KEY = get_environ("AZURE_OPENAI_KEY");
@ -153,6 +157,8 @@ let AWS_ACCESS_KEY_ID = get_environ("AWS_ACCESS_KEY_ID");
let AWS_SECRET_ACCESS_KEY = get_environ("AWS_SECRET_ACCESS_KEY");
let AWS_SESSION_TOKEN = get_environ("AWS_SESSION_TOKEN");
let AWS_REGION = get_environ("AWS_REGION");
let TOGETHER_API_KEY = get_environ("TOGETHER_API_KEY");
let DEEPSEEK_API_KEY = get_environ("DEEPSEEK_API_KEY");
/**
* Sets the local API keys for the revelant LLM API(s).
@ -160,12 +166,15 @@ let AWS_REGION = get_environ("AWS_REGION");
export function set_api_keys(api_keys: Dict<string>): void {
function key_is_present(name: string): boolean {
return (
name in api_keys &&
api_keys[name] !== undefined &&
api_keys[name].trim().length > 0
(name in api_keys &&
api_keys[name] &&
api_keys[name].trim().length > 0) ||
name === "OpenAI_BaseURL"
);
}
if (key_is_present("OpenAI")) OPENAI_API_KEY = api_keys.OpenAI;
if (key_is_present("OpenAI_BaseURL"))
OPENAI_BASE_URL = api_keys.OpenAI_BaseURL;
if (key_is_present("HuggingFace")) HUGGINGFACE_API_KEY = api_keys.HuggingFace;
if (key_is_present("Anthropic")) ANTHROPIC_API_KEY = api_keys.Anthropic;
if (key_is_present("Google")) GOOGLE_PALM_API_KEY = api_keys.Google;
@ -181,6 +190,8 @@ export function set_api_keys(api_keys: Dict<string>): void {
if (key_is_present("AWS_Session_Token"))
AWS_SESSION_TOKEN = api_keys.AWS_Session_Token;
if (key_is_present("AWS_Region")) AWS_REGION = api_keys.AWS_Region;
if (key_is_present("Together")) TOGETHER_API_KEY = api_keys.Together;
if (key_is_present("DeepSeek")) DEEPSEEK_API_KEY = api_keys.DeepSeek;
}
export function get_azure_openai_api_keys(): [
@ -200,12 +211,16 @@ function construct_openai_chat_history(
prompt: string,
chat_history?: ChatHistory,
system_msg?: string,
system_role_name?: string,
): ChatHistory {
const sys_role_name = system_role_name ?? "system";
const prompt_msg: ChatMessage = { role: "user", content: prompt };
const sys_msg: ChatMessage[] =
system_msg !== undefined ? [{ role: "system", content: system_msg }] : [];
system_msg !== undefined
? [{ role: sys_role_name, content: system_msg }]
: [];
if (chat_history !== undefined && chat_history.length > 0) {
if (chat_history[0].role === "system") {
if (chat_history[0].role === sys_role_name) {
// In this case, the system_msg is ignored because the prior history already contains one.
return chat_history.concat([prompt_msg]);
} else {
@ -227,6 +242,8 @@ export async function call_chatgpt(
temperature = 1.0,
params?: Dict,
should_cancel?: () => boolean,
BASE_URL?: string,
API_KEY?: string,
): Promise<[Dict, Dict]> {
if (!OPENAI_API_KEY)
throw new Error(
@ -234,7 +251,8 @@ export async function call_chatgpt(
);
const configuration = new OpenAIConfig({
apiKey: OPENAI_API_KEY,
apiKey: API_KEY ?? OPENAI_API_KEY,
basePath: BASE_URL ?? OPENAI_BASE_URL ?? undefined,
});
// Since we are running client-side, we need to remove the user-agent header:
@ -243,6 +261,8 @@ export async function call_chatgpt(
const openai = new OpenAIApi(configuration);
const modelname: string = model.toString();
// Remove empty params
if (
params?.stop !== undefined &&
(!Array.isArray(params.stop) || params.stop.length === 0)
@ -260,18 +280,35 @@ export async function call_chatgpt(
params.function_call.trim().length === 0)
)
delete params.function_call;
if (
params?.tools !== undefined &&
(!Array.isArray(params.tools) || params.tools.length === 0)
)
delete params?.tools;
if (
params?.tool_choice !== undefined &&
(!(typeof params.tool_choice === "string") ||
params.tool_choice.trim().length === 0)
)
delete params.tool_choice;
if (params?.tools === undefined && params?.parallel_tool_calls !== undefined)
delete params?.parallel_tool_calls;
console.log(`Querying OpenAI model '${model}' with prompt '${prompt}'...`);
if (!BASE_URL)
console.log(`Querying OpenAI model '${model}' with prompt '${prompt}'...`);
// Determine the system message and whether there's chat history to continue:
const chat_history: ChatHistory | undefined = params?.chat_history;
const system_msg: string =
params?.system_msg !== undefined
? params.system_msg
: "You are a helpful assistant.";
let system_msg: string | undefined =
params?.system_msg !== undefined ? params.system_msg : undefined;
delete params?.system_msg;
delete params?.chat_history;
// The o1 and later OpenAI models, for whatever reason, block the system message from being sent in the chat history.
// The official API states that "developer" works, but it doesn't for some models, so until OpenAI fixes this
// and fully supports system messages, we have to block them from being sent in the chat history.
if (model.startsWith("o")) system_msg = undefined;
const query: Dict = {
model: modelname,
n,
@ -316,6 +353,36 @@ export async function call_chatgpt(
return [query, response];
}
/**
* Calls DeepSeek models via DeepSeek's API.
*/
export async function call_deepseek(
prompt: string,
model: LLM,
n = 1,
temperature = 1.0,
params?: Dict,
should_cancel?: () => boolean,
): Promise<[Dict, Dict]> {
if (!DEEPSEEK_API_KEY)
throw new Error(
"Could not find a DeepSeek API key. Double-check that your API key is set in Settings or in your local environment.",
);
console.log(`Querying DeepSeek model '${model}' with prompt '${prompt}'...`);
return await call_chatgpt(
prompt,
model,
n,
temperature,
params,
should_cancel,
"https://api.deepseek.com",
DEEPSEEK_API_KEY,
);
}
/**
* Calls OpenAI Image models via OpenAI's API.
@returns raw query and response JSON dicts.
@ -533,18 +600,37 @@ export async function call_anthropic(
// Required non-standard params
const max_tokens_to_sample = params?.max_tokens_to_sample ?? 1024;
const stop_sequences = params?.stop_sequences ?? [ANTHROPIC_HUMAN_PROMPT];
const system_msg = params?.system_msg;
let system_msg = params?.system_msg;
delete params?.custom_prompt_wrapper;
delete params?.max_tokens_to_sample;
delete params?.system_msg;
// Tool usage -- remove tool params before passing, if they are empty
if (
params?.tools !== undefined &&
(!Array.isArray(params.tools) || params.tools.length === 0)
)
delete params?.tools;
if (
params?.tool_choice !== undefined &&
(!(typeof params.tool_choice === "string") ||
params.tool_choice.trim().length === 0)
)
delete params.tool_choice;
if (params?.tools === undefined) delete params?.parallel_tool_calls;
else {
if (params?.tool_choice === undefined) params.tool_choice = { type: "any" };
params.tool_choice.disable_parallel_tool_use = !params.parallel_tool_calls;
delete params?.parallel_tool_calls;
}
// Detect whether to use old text completions or new messaging API
const use_messages_api = is_newer_anthropic_model(model);
// Carry chat history
// :: See https://docs.anthropic.com/claude/docs/human-and-assistant-formatting#use-human-and-assistant-to-put-words-in-claudes-mouth
const chat_history: ChatHistory | undefined = params?.chat_history;
let chat_history: ChatHistory | undefined = params?.chat_history;
if (chat_history !== undefined) {
// FOR OLD TEXT COMPLETIONS API ONLY: Carry chat history by prepending it to the prompt
if (!use_messages_api) {
@ -558,6 +644,13 @@ export async function call_anthropic(
anthr_chat_context += " " + chat_msg.content;
}
wrapped_prompt = anthr_chat_context + wrapped_prompt; // prepend the chat context
} else {
// The new messages API doesn't allow a first "system" message inside chat history, like OpenAI does.
// We need to detect a "system" message and eject it:
if (chat_history.some((m) => m.role === "system")) {
system_msg = chat_history.filter((m) => m.role === "system")[0].content;
chat_history = chat_history.filter((m) => m.role !== "system");
}
}
// For newer models Claude 2.1 and Claude 3, we carry chat history directly below; no need to do anything else.
@ -651,26 +744,27 @@ export async function call_google_ai(
params?: Dict,
should_cancel?: () => boolean,
): Promise<[Dict, Dict]> {
switch (model) {
case NativeLLM.GEMINI_PRO:
return call_google_gemini(
prompt,
model,
n,
temperature,
params,
should_cancel,
);
default:
return call_google_palm(
prompt,
model,
n,
temperature,
params,
should_cancel,
);
}
if (
model === NativeLLM.PaLM2_Chat_Bison ||
model === NativeLLM.PaLM2_Text_Bison
)
return call_google_palm(
prompt,
model,
n,
temperature,
params,
should_cancel,
);
else
return call_google_gemini(
prompt,
model,
n,
temperature,
params,
should_cancel,
);
}
/**
@ -812,18 +906,21 @@ export async function call_google_gemini(
"Could not find an API key for Google Gemini models. Double-check that your API key is set in Settings or in your local environment.",
);
// calling the correct model client
model = NativeLLM.GEMINI_PRO;
const genAI = new GoogleGenerativeAI(GOOGLE_PALM_API_KEY);
const gemini_model = genAI.getGenerativeModel({ model: model.toString() });
// removing chat for now. by default chat is supported
// Required non-standard params
const max_output_tokens = params?.max_output_tokens || 1000;
const chat_history = params?.chat_history;
const chat_history: ChatHistory = params?.chat_history;
const system_msg = params?.system_msg;
delete params?.chat_history;
delete params?.system_msg;
const genAI = new GoogleGenerativeAI(GOOGLE_PALM_API_KEY);
const gemini_model = genAI.getGenerativeModel({
model: model.toString(),
systemInstruction:
typeof system_msg === "string" && chat_history === undefined
? system_msg
: undefined,
});
const query: Dict = {
model: `models/${model}`,
@ -875,10 +972,20 @@ export async function call_google_gemini(
for (const chat_msg of chat_history) {
if (chat_msg.role === "system") {
// Carry the system message over as PaLM's chat 'context':
gemini_messages.push({ role: "model", parts: chat_msg.content });
gemini_messages.push({
role: "model",
parts: [{ text: chat_msg.content }],
});
} else if (chat_msg.role === "user") {
gemini_messages.push({ role: "user", parts: chat_msg.content });
} else gemini_messages.push({ role: "model", parts: chat_msg.content });
gemini_messages.push({
role: "user",
parts: [{ text: chat_msg.content }],
});
} else
gemini_messages.push({
role: "model",
parts: [{ text: chat_msg.content }],
});
}
gemini_chat_context.history = gemini_messages;
}
@ -1156,6 +1263,7 @@ export async function call_ollama_provider(
const model_type: string = params?.model_type ?? "text";
const system_msg: string = params?.system_msg ?? "";
const chat_history: ChatHistory | undefined = params?.chat_history;
const format: Dict | string | undefined = params?.format;
// Cleanup
for (const name of [
@ -1164,6 +1272,7 @@ export async function call_ollama_provider(
"model_type",
"system_msg",
"chat_history",
"format",
])
if (params && name in params) delete params[name];
@ -1172,8 +1281,10 @@ export async function call_ollama_provider(
const query: Dict = {
model: ollama_model,
stream: false,
temperature,
...params, // 'the rest' of the settings, passed from the front-end settings
options: {
temperature,
...params, // 'the rest' of the settings, passed from the front-end settings
},
};
// If the model type is explicitly or implicitly set to "chat", pass chat history instead:
@ -1195,6 +1306,17 @@ export async function call_ollama_provider(
`Calling Ollama API at ${url} for model '${ollama_model}' with prompt '${prompt}' n=${n} times. Please be patient...`,
);
// If there are structured outputs specified, convert to an object:
if (typeof format === "string" && format.trim().length > 0) {
try {
query.format = JSON.parse(format);
} catch (err) {
throw Error(
"Cannot parse structured output format into JSON: JSON schema is incorrectly structured.",
);
}
}
// Call Ollama API
const resps: Response[] = [];
for (let i = 0; i < n; i++) {
@ -1263,25 +1385,25 @@ export async function call_bedrock(
params?: Dict,
should_cancel?: () => boolean,
): Promise<[Dict, Dict]> {
if (!AWS_ACCESS_KEY_ID && !AWS_SESSION_TOKEN && !AWS_REGION) {
if (
!AWS_ACCESS_KEY_ID ||
!AWS_SECRET_ACCESS_KEY ||
!AWS_SESSION_TOKEN ||
!AWS_REGION
) {
throw new Error(
"Could not find credentials value for the Bedrock API. Double-check that your API key is set in Settings or in your local environment.",
"Could not find credentials value for the Bedrock API. Double-check that your AWS Credentials are set in Settings or in your local environment.",
);
}
const modelName: string = model.toString();
let stopWords = [];
if (
!(
params?.stop_sequences !== undefined &&
(!Array.isArray(params.stop_sequences) ||
params.stop_sequences.length === 0)
)
params?.stop_sequences !== undefined &&
Array.isArray(params.stop_sequences && params.stop_sequences.length > 0)
) {
stopWords = params?.stop_sequences ?? [];
stopWords = params?.stop_sequences;
}
const bedrockConfig = {
credentials: {
accessKeyId: AWS_ACCESS_KEY_ID,
@ -1291,17 +1413,15 @@ export async function call_bedrock(
region: AWS_REGION,
};
delete params?.stop;
delete params?.stop_sequences;
const query: Dict = {
stopSequences: stopWords,
temperature,
topP: params?.top_p ?? 1.0,
maxTokenCount: params?.max_tokens_to_sample ?? 512,
};
const fm = fromModelId(modelName as Models, {
region: bedrockConfig.region ?? "us-west-2",
const fm = fromModelId(modelName, {
region: bedrockConfig.region,
credentials: bedrockConfig.credentials,
...query,
});
@ -1315,32 +1435,134 @@ export async function call_bedrock(
// Grab the response
let response: string;
if (modelName.startsWith("anthropic")) {
if (
modelName.startsWith("anthropic") ||
modelName.startsWith("mistral") ||
modelName.startsWith("meta")
) {
const chat_history: ChatHistory = construct_openai_chat_history(
prompt,
params?.chat_history,
params?.system_msg,
);
response = (
await fm.chat(to_bedrock_chat_history(chat_history), { ...params })
await fm.chat(to_bedrock_chat_history(chat_history), {
modelArgs: { ...(params as Map<string, any>) },
})
).message;
} else {
response = await fm.generate(prompt, { ...params });
response = await fm.generate(prompt, {
modelArgs: { ...(params as Map<string, any>) },
});
}
responses.push(response);
}
} catch (error: any) {
console.error("Error", error);
throw new Error(
error?.response?.data?.error?.message ??
error?.message ??
error.toString(),
);
throw new Error(error?.message ?? error.toString());
}
return [query, responses];
}
/**
* Calls Together.ai text + chat models via Together's API.
@returns raw query and response JSON dicts.
*/
export async function call_together(
prompt: string,
model: LLM,
n = 1,
temperature = 1.0,
params?: Dict,
should_cancel?: () => boolean,
): Promise<[Dict, Dict]> {
if (!TOGETHER_API_KEY)
throw new Error(
"Could not find an Together API key. Double-check that your API key is set in Settings or in your local environment.",
);
const togetherBaseUrl = "https://api.together.xyz/v1";
// Together.ai uses OpenAI's API, so we can use the OpenAI API client to make the call:
const configuration = new OpenAIConfig({
apiKey: TOGETHER_API_KEY,
basePath: togetherBaseUrl,
});
// Since we are running client-side, we need to remove the user-agent header:
delete configuration.baseOptions.headers["User-Agent"];
const together = new OpenAIApi(configuration);
// Strip the "together/" prefix:
const modelname: string = model.toString().substring(9);
if (
params?.stop !== undefined &&
(!Array.isArray(params.stop) || params.stop.length === 0)
)
delete params.stop;
if (params?.seed && params.seed.toString().length === 0) delete params?.seed;
if (
params?.functions !== undefined &&
(!Array.isArray(params.functions) || params.functions.length === 0)
)
delete params?.functions;
if (
params?.function_call !== undefined &&
(!(typeof params.function_call === "string") ||
params.function_call.trim().length === 0)
)
delete params.function_call;
console.log(
`Querying Together model '${modelname}' with prompt '${prompt}'...`,
);
// Determine the system message and whether there's chat history to continue:
const chat_history: ChatHistory | undefined = params?.chat_history;
const system_msg: string =
params?.system_msg !== undefined
? params.system_msg
: "You are a helpful assistant.";
delete params?.system_msg;
delete params?.chat_history;
const query: Dict = {
model: modelname,
n,
temperature,
...params, // 'the rest' of the settings, passed from the front-end settings
};
// Create call to chat model
const together_call: any = together.createChatCompletion.bind(together);
// Carry over chat history, if present:
query.messages = construct_openai_chat_history(
prompt,
chat_history,
system_msg,
);
// Try to call Together
let response: Dict = {};
try {
const completion = await together_call(query);
response = completion.data;
} catch (error: any) {
if (error?.response) {
throw new Error(error.response.data?.error?.message);
// throw new Error(error.response.status);
} else {
console.log(error?.message || error);
throw new Error(error?.message || error);
}
}
return [query, response];
}
async function call_custom_provider(
prompt: string,
model: LLM,
@ -1397,6 +1619,7 @@ async function call_custom_provider(
*/
export async function call_llm(
llm: LLM,
provider: LLMProvider,
prompt: string,
n: number,
temperature: number,
@ -1405,7 +1628,7 @@ export async function call_llm(
): Promise<[Dict, Dict]> {
// Get the correct API call for the given LLM:
let call_api: LLMAPICall | undefined;
const llm_provider: LLMProvider | undefined = getProvider(llm);
const llm_provider: LLMProvider | undefined = provider ?? getProvider(llm); // backwards compatibility if there's no explicit provider
if (llm_provider === undefined)
throw new Error(`Language model ${llm} is not supported.`);
@ -1425,6 +1648,8 @@ export async function call_llm(
else if (llm_provider === LLMProvider.Ollama) call_api = call_ollama_provider;
else if (llm_provider === LLMProvider.Custom) call_api = call_custom_provider;
else if (llm_provider === LLMProvider.Bedrock) call_api = call_bedrock;
else if (llm_provider === LLMProvider.Together) call_api = call_together;
else if (llm_provider === LLMProvider.DeepSeek) call_api = call_deepseek;
if (call_api === undefined)
throw new Error(
`Adapter for Language model ${llm} and ${llm_provider} not found`,
@ -1445,8 +1670,27 @@ function _extract_openai_chat_choice_content(choice: Dict): string {
) {
const func = choice.message.function_call;
return "[[FUNCTION]] " + func.name + func.arguments.toString();
} else if (
choice.finish_reason === "tool_calls" ||
("tool_calls" in choice.message && choice.message.tool_calls.length > 0)
) {
const tools = choice.message.tool_calls;
return (
"[[TOOLS]] " +
tools
.map((t: Dict) => t.function.name + " " + t.function.arguments)
.join("\n\n")
);
} else {
return choice.message.content;
// Extract the content. Note that structured outputs in OpenAI's API as of late 2024
// can sometimes output a response to a "refusal" key, which is annoying. We check for that here:
if (
"refusal" in choice.message &&
typeof choice.message.refusal === "string"
)
return choice.message.refusal;
// General chat outputs
else return choice.message.content;
}
}
@ -1498,10 +1742,12 @@ function _extract_google_ai_responses(
llm: LLM | string,
): Array<string> {
switch (llm) {
case NativeLLM.GEMINI_PRO:
return _extract_gemini_responses(response as Array<Dict>);
default:
case NativeLLM.PaLM2_Chat_Bison:
return _extract_palm_responses(response);
case NativeLLM.PaLM2_Text_Bison:
return _extract_palm_responses(response);
default:
return _extract_gemini_responses(response as Array<Dict>);
}
}
@ -1528,7 +1774,28 @@ function _extract_gemini_responses(completions: Array<Dict>): Array<string> {
function _extract_anthropic_chat_responses(
response: Array<Dict>,
): Array<string> {
return response.map((r: Dict) => r.content[0].text.trim());
return response.map((r: Dict) =>
r.content
.map((c: Dict) => {
// Regular text response
if (c?.type === "text") return c.text.trim();
// Anthropic tool usage
else if (c?.type === "tool_use")
return (
"[[TOOLS]] " +
JSON.stringify({
name: c.name,
input: c.input,
})
);
// Unknown type of message
else
throw Error(
`Unknown type '${c?.type}' of message found in Anthropic response. If this is a new type, raise an Issue on the ChainForge Github.`,
);
})
.join("\n\n"),
);
}
/**
@ -1570,8 +1837,10 @@ function _extract_ollama_responses(
export function extract_responses(
response: Array<string | Dict> | Dict,
llm: LLM | string,
provider: LLMProvider,
): Array<LLMResponseData> {
const llm_provider: LLMProvider | undefined = getProvider(llm as LLM);
const llm_provider: LLMProvider | undefined =
provider ?? getProvider(llm as LLM);
const llm_name = llm.toString().toLowerCase();
switch (llm_provider) {
case LLMProvider.OpenAI:
@ -1600,13 +1869,18 @@ export function extract_responses(
return _extract_ollama_responses(response as Dict[]);
case LLMProvider.Bedrock:
return response as Array<string>;
case LLMProvider.Together:
return _extract_openai_responses(response as Dict[]);
case LLMProvider.DeepSeek:
return _extract_openai_responses(response as Dict[]);
default:
if (
Array.isArray(response) &&
response.length > 0 &&
typeof response[0] === "string"
(typeof response[0] === "string" ||
(typeof response[0] === "object" && isImageResponseData(response[0])))
)
return response as string[];
return response as LLMResponseData[];
else
throw new Error(
`No method defined to extract responses for LLM ${llm}.`,
@ -1631,18 +1905,13 @@ export function merge_response_objs(
else if (!resp_obj_A && resp_obj_B) return resp_obj_B;
resp_obj_A = resp_obj_A as RawLLMResponseObject; // required by typescript
resp_obj_B = resp_obj_B as RawLLMResponseObject;
let raw_resp_A = resp_obj_A.raw_response;
let raw_resp_B = resp_obj_B.raw_response;
if (!Array.isArray(raw_resp_A)) raw_resp_A = [raw_resp_A];
if (!Array.isArray(raw_resp_B)) raw_resp_B = [raw_resp_B];
const res: RawLLMResponseObject = {
responses: resp_obj_A.responses.concat(resp_obj_B.responses),
raw_response: raw_resp_A.concat(raw_resp_B),
prompt: resp_obj_B.prompt,
query: resp_obj_B.query,
llm: resp_obj_B.llm,
vars: resp_obj_B.vars,
metavars: resp_obj_B.metavars,
vars: resp_obj_B.vars ?? (resp_obj_B as any).info ?? {}, // backwards compatibility---vars used to be 'info'
metavars: resp_obj_B.metavars ?? {},
uid: resp_obj_B.uid,
};
if (resp_obj_B.chat_history !== undefined)
@ -1693,7 +1962,7 @@ export const transformDict = (
*
* Returns empty dict {} if no settings vars found.
*/
export const extractSettingsVars = (vars?: Dict) => {
export const extractSettingsVars = (vars?: PromptVarsDict) => {
if (
vars !== undefined &&
Object.keys(vars).some((k) => k.charAt(0) === "=")
@ -1710,8 +1979,8 @@ export const extractSettingsVars = (vars?: Dict) => {
* Given two info vars dicts, detects whether any + all vars (keys) match values.
*/
export const areEqualVarsDicts = (
A: Dict | undefined,
B: Dict | undefined,
A: PromptVarsDict | undefined,
B: PromptVarsDict | undefined,
): boolean => {
if (A === undefined || B === undefined) {
if (A === undefined && B === undefined) return true;
@ -1792,12 +2061,15 @@ export const stripLLMDetailsFromResponses = (
): LLMResponse[] =>
resps.map((r) => ({
...r,
llm: typeof r?.llm === "string" ? r?.llm : r?.llm?.name ?? "undefined",
llm:
(typeof r?.llm === "string" || typeof r?.llm === "number"
? StringLookup.get(r?.llm)
: r?.llm?.name) ?? "undefined",
}));
// NOTE: The typing is purposefully general since we are trying to cast to an expected format.
export const toStandardResponseFormat = (r: Dict | string) => {
if (typeof r === "string")
if (typeof r === "string" || typeof r === "number")
return {
vars: {},
metavars: {},
@ -1812,7 +2084,7 @@ export const toStandardResponseFormat = (r: Dict | string) => {
uid: r?.uid ?? r?.batch_id ?? uuid(),
llm: r?.llm ?? undefined,
prompt: r?.prompt ?? "",
responses: [typeof r === "string" ? r : r?.text],
responses: [typeof r === "string" || typeof r === "number" ? r : r?.text],
tokens: r?.raw_response?.usage ?? {},
};
if (r?.eval_res !== undefined) resp_obj.eval_res = r.eval_res;
@ -1838,8 +2110,10 @@ export const tagMetadataWithLLM = (input_data: LLMResponsesByVarDict) => {
if (
!r ||
typeof r === "string" ||
typeof r === "number" ||
!r?.llm ||
typeof r.llm === "string" ||
typeof r.llm === "number" ||
!r.llm.key
)
return r;
@ -1853,20 +2127,21 @@ export const tagMetadataWithLLM = (input_data: LLMResponsesByVarDict) => {
export const extractLLMLookup = (
input_data: Dict<
(string | TemplateVarInfo | BaseLLMResponseObject | LLMResponse)[]
(StringOrHash | TemplateVarInfo | BaseLLMResponseObject | LLMResponse)[]
>,
) => {
const llm_lookup: Dict<string | LLMSpec> = {};
const llm_lookup: Dict<StringOrHash | LLMSpec> = {};
Object.values(input_data).forEach((resp_objs) => {
resp_objs.forEach((r) => {
const llm_name =
typeof r === "string"
typeof r === "string" || typeof r === "number"
? undefined
: !r.llm || typeof r.llm === "string"
? r.llm
: !r.llm || typeof r.llm === "string" || typeof r.llm === "number"
? StringLookup.get(r.llm)
: r.llm.key;
if (
typeof r === "string" ||
typeof r === "number" ||
!r.llm ||
!llm_name ||
llm_name in llm_lookup
@ -1946,6 +2221,13 @@ export const batchResponsesByUID = (
.concat(unspecified_id_group);
};
export function llmResponseDataToString(data: LLMResponseData): string {
if (typeof data === "string") return data;
else if (typeof data === "number")
return StringLookup.get(data) ?? "(string lookup failed)";
else return data.d;
}
/**
* Naive method to sample N items at random from an array.
* @param arr an array of items

View File

@ -19,13 +19,16 @@ import {
import { DuplicateVariableNameError } from "./backend/errors";
import {
Dict,
LLMGroup,
LLMSpec,
PromptVarType,
PromptVarsDict,
TemplateVarInfo,
TabularDataColType,
TabularDataRowType,
} from "./backend/typing";
import { TogetherChatSettings } from "./ModelSettingSchemas";
import { NativeLLM } from "./backend/models";
import { StringLookup } from "./backend/cache";
// Initial project settings
const initialAPIKeys = {};
@ -87,10 +90,7 @@ const refreshableOutputNodeTypes = new Set([
"split",
]);
export const initLLMProviderMenu: (
| LLMSpec
| { group: string; emoji: string; items: LLMSpec[] }
)[] = [
export const initLLMProviderMenu: (LLMSpec | LLMGroup)[] = [
{
group: "OpenAI",
emoji: "🤖",
@ -109,6 +109,20 @@ export const initLLMProviderMenu: (
base_model: "gpt-4",
temp: 1.0,
},
{
name: "GPT4o",
emoji: "👄",
model: "gpt-4o",
base_model: "gpt-4",
temp: 1.0,
},
{
name: "GPT4o-mini",
emoji: "👄",
model: "gpt-4o-mini",
base_model: "gpt-4",
temp: 1.0,
},
{
name: "Dall-E",
emoji: "🖼",
@ -119,18 +133,85 @@ export const initLLMProviderMenu: (
],
},
{
name: "Claude",
group: "Claude",
emoji: "📚",
model: "claude-2",
base_model: "claude-v1",
temp: 0.5,
items: [
{
name: "Claude 3.5 Sonnet",
emoji: "📚",
model: "claude-3-5-sonnet-latest",
base_model: "claude-v1",
temp: 0.5,
},
{
name: "Claude 3.5 Haiku",
emoji: "📗",
model: "claude-3-5-haiku-latest",
base_model: "claude-v1",
temp: 0.5,
},
{
name: "Claude 3 Opus",
emoji: "📙",
model: "claude-3-opus-latest",
base_model: "claude-v1",
temp: 0.5,
},
{
name: "Claude 2",
emoji: "📓",
model: "claude-2",
base_model: "claude-v1",
temp: 0.5,
},
],
},
{
name: "Gemini",
group: "Gemini",
emoji: "♊",
model: "gemini-pro",
base_model: "palm2-bison",
temp: 0.7,
items: [
{
name: "Gemini 1.5",
emoji: "♊",
model: "gemini-1.5-pro",
base_model: "palm2-bison",
temp: 0.7,
},
{
name: "Gemini 1.5 Flash",
emoji: "📸",
model: "gemini-1.5-flash",
base_model: "palm2-bison",
temp: 0.7,
},
{
name: "Gemini 1.5 Flash 8B",
emoji: "⚡️",
model: "gemini-1.5-flash-8b",
base_model: "palm2-bison",
temp: 0.7,
},
],
},
{
group: "DeepSeek",
emoji: "🐋",
items: [
{
name: "DeepSeek Chat",
emoji: "🐋",
model: "deepseek-chat",
base_model: "deepseek",
temp: 1.0,
}, // The base_model designates what settings form will be used, and must be unique.
{
name: "DeepSeek Reasoner",
emoji: "🐳",
model: "deepseek-reasoner",
base_model: "deepseek",
temp: 1.0,
},
],
},
{
group: "HuggingFace",
@ -173,56 +254,98 @@ export const initLLMProviderMenu: (
{
name: "Anthropic Claude",
emoji: "👨‍🏫",
model: "anthropic.claude-v2:1",
model: NativeLLM.Bedrock_Claude_3_Haiku,
base_model: "br.anthropic.claude",
temp: 0.9,
},
{
name: "AI21 Jurassic 2",
emoji: "🦖",
model: "ai21.j2-ultra",
model: NativeLLM.Bedrock_Jurassic_Ultra,
base_model: "br.ai21.j2",
temp: 0.9,
},
{
name: "Amazon Titan",
emoji: "🏛️",
model: "amazon.titan-tg1-large",
model: NativeLLM.Bedrock_Titan_Large,
base_model: "br.amazon.titan",
temp: 0.9,
},
{
name: "Cohere Command Text 14",
emoji: "📚",
model: "cohere.command-text-v14",
model: NativeLLM.Bedrock_Command_Text,
base_model: "br.cohere.command",
temp: 0.9,
},
{
name: "Mistral Mistral",
emoji: "💨",
model: "mistral.mistral-7b-instruct-v0:2",
model: NativeLLM.Bedrock_Mistral_Mistral,
base_model: "br.mistral.mistral",
temp: 0.9,
},
{
name: "Mistral Mixtral",
emoji: "🌪️",
model: "mistral.mixtral-8x7b-instruct-v0:1",
model: NativeLLM.Bedrock_Mistral_Mixtral,
base_model: "br.mistral.mixtral",
temp: 0.9,
},
{
name: "Meta Llama2 Chat",
emoji: "🦙",
model: "meta.llama2-13b-chat-v1",
model: NativeLLM.Bedrock_Meta_LLama2Chat_13b,
base_model: "br.meta.llama2",
temp: 0.9,
},
{
name: "Meta Llama3 Instruct",
emoji: "🦙",
model: NativeLLM.Bedrock_Meta_LLama3Instruct_8b,
base_model: "br.meta.llama3",
temp: 0.9,
},
],
},
];
const togetherModels = TogetherChatSettings.schema.properties.model
.enum as string[];
const togetherGroups = () => {
const groupNames: string[] = [];
const groups: { [key: string]: LLMGroup } = {};
togetherModels.forEach((model) => {
const [groupName, modelName] = model.split("/");
const spec: LLMSpec = {
name: modelName,
emoji: "🤝",
model: "together/" + model,
base_model: "together",
temp: 0.9,
};
if (groupName in groups) {
(groups[groupName].items as LLMSpec[]).push(spec);
} else {
groups[groupName] = {
group: groupName,
emoji: "🤝",
items: [spec],
};
groupNames.push(groupName);
}
});
return groupNames.map((name) => groups[name]);
};
console.log(togetherGroups());
const togetherLLMProviderMenu: LLMGroup = {
group: "Together",
emoji: "🤝",
items: togetherGroups(),
};
initLLMProviderMenu.push(togetherLLMProviderMenu);
if (APP_IS_RUNNING_LOCALLY()) {
initLLMProviderMenu.push({
name: "Ollama",
@ -235,9 +358,20 @@ if (APP_IS_RUNNING_LOCALLY()) {
// initLLMProviders.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
// -------------------------
}
export const initLLMProviders = initLLMProviderMenu
.map((item) => ("group" in item && "items" in item ? item.items : item))
.flat();
function flattenLLMGroup(group: LLMGroup): LLMSpec[] {
return group.items.flatMap((item) =>
"group" in item && "items" in item ? flattenLLMGroup(item) : item,
);
}
function flattenLLMProviders(providers: (LLMSpec | LLMGroup)[]): LLMSpec[] {
return providers.flatMap((item) =>
"group" in item && "items" in item ? flattenLLMGroup(item) : item,
);
}
export const initLLMProviders = flattenLLMProviders(initLLMProviderMenu);
export interface StoreHandles {
// Nodes and edges
@ -337,7 +471,9 @@ const useStore = create<StoreHandles>((set, get) => ({
// Filter out any empty or incorrectly formatted API key values:
const new_keys = transformDict(
apiKeys,
(key) => typeof apiKeys[key] === "string" && apiKeys[key].length > 0,
(key) =>
(typeof apiKeys[key] === "string" && apiKeys[key].length > 0) ||
key === "OpenAI_BaseURL",
);
// Only update API keys present in the new array; don't delete existing ones:
set({ apiKeys: { ...get().apiKeys, ...new_keys } });
@ -488,7 +624,9 @@ const useStore = create<StoreHandles>((set, get) => ({
(key) =>
key === "__uid" ||
!row[key] ||
(typeof row[key] === "string" && row[key].trim() === ""),
((typeof row[key] === "string" ||
typeof row[key] === "number") &&
StringLookup.get(row[key])?.trim() === ""),
)
)
return undefined;
@ -496,14 +634,20 @@ const useStore = create<StoreHandles>((set, get) => ({
const row_excluding_col: Dict<string> = {};
row_keys.forEach((key) => {
if (key !== src_col.key && key !== "__uid")
row_excluding_col[col_header_lookup[key]] =
row[key].toString();
row_excluding_col[col_header_lookup[key]] = (
StringLookup.get(row[key]) ?? "(string lookup failed)"
).toString();
});
return {
// We escape any braces in the source text before they're passed downstream.
// This is a special property of tabular data nodes: we don't want their text to be treated as prompt templates.
text: escapeBraces(
src_col.key in row ? row[src_col.key].toString() : "",
src_col.key in row
? (
StringLookup.get(row[src_col.key]) ??
"(string lookup failed)"
).toString()
: "",
),
metavars: row_excluding_col,
associate_id: row.__uid, // this is used by the backend to 'carry' certain values together
@ -572,7 +716,7 @@ const useStore = create<StoreHandles>((set, get) => ({
const store_data = (
_texts: PromptVarType[],
_varname: string,
_data: PromptVarsDict,
_data: Dict<PromptVarType[]>,
) => {
if (_varname in _data) _data[_varname] = _data[_varname].concat(_texts);
else _data[_varname] = _texts;

View File

@ -4,6 +4,9 @@
.monofont {
font-family: var(--monofont);
}
.linebreaks {
white-space: pre-wrap;
}
.text-fields-node {
background-color: #fff;
@ -390,7 +393,7 @@ g.ytick text {
padding-bottom: 20px;
min-width: 160px;
border-right: 1px solid #eee;
padding-left: 8px !important;
padding-left: 0px !important;
padding-right: 0px !important;
}
.inspect-responses-drawer {
@ -646,17 +649,18 @@ g.ytick text {
cursor: text;
}
.small-response-metrics {
font-size: 10pt;
font-size: 9pt;
font-family: -apple-system, "Segoe UI", "Roboto", "Oxygen", "Ubuntu",
"Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif;
font-weight: 500;
text-align: center;
border-top-left-radius: 20px;
border-top-right-radius: 20px;
padding: 0px 2px 1px 0px;
padding: 0px 2px 2px 0px;
margin: 8px 20% -6px 20%;
background-color: rgba(255, 255, 255, 0.3);
/* background-color: rgba(255, 255, 255, 0.3); */
color: #333;
white-space: pre-wrap;
}
.num-same-responses {
position: relative;

View File

@ -5,6 +5,5 @@ requests
openai
dalaipy==2.0.2
urllib3==1.26.6
anthropic
google-generativeai
mistune>=2.0
mistune>=2.0
platformdirs

View File

@ -5,14 +5,14 @@ def readme():
return f.read()
setup(
name='chainforge',
version='0.3.1.0',
name="chainforge",
version="0.3.4.3",
packages=find_packages(),
author="Ian Arawjo",
description="A Visual Programming Environment for Prompt Engineering",
long_description=readme(),
long_description_content_type='text/markdown',
keywords='prompt engineering LLM response evaluation',
long_description_content_type="text/markdown",
keywords="prompt engineering LLM response evaluation",
license="MIT",
url="https://github.com/ianarawjo/ChainForge/",
install_requires=[
@ -21,28 +21,29 @@ setup(
"flask[async]",
"flask_cors",
"requests",
"platformdirs",
"urllib3==1.26.6",
"openai",
"anthropic",
"google-generativeai",
"dalaipy>=2.0.2",
"mistune>=2.0", # for LLM response markdown parsing
],
entry_points={
'console_scripts': [
'chainforge = chainforge.app:main',
"console_scripts": [
"chainforge = chainforge.app:main",
],
},
classifiers=[
# Package classifiers
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
python_requires=">=3.8",
include_package_data=True,
)
)