Add Amazon Bedrock models to main (#251)

* Adding support for Amazon Bedrock models (#247)

* Create global setting for GenAI features provider, to support Bedrock (Anthropic) models as an alternative

* Reformats dropdown in PromptNode to use Mantine ContextMenu with a nested menu, to save space. 

* Remove build folder from git

* Fix context menu to close on click-off. Refactor context menu array code.

* Ensure context menu is positioned below the Add+ button, like a proper dropdown. 

* Toggle context menu off when clicking btn again.

---------

Co-authored-by: Massimiliano Angelino <angmas@amazon.com>
This commit is contained in:
ianarawjo 2024-03-30 17:59:17 -04:00 committed by GitHub
parent ad84cfdecc
commit eb51d1cee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 2595 additions and 441 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
chainforge/cache
chainforge/examples/oaievals/
chainforge/react-server/node_modules
chainforge/react-server/build
# == Below was generated by https://www.toptal.com/developers/gitignore/api/python ==
# Edit at https://www.toptal.com/developers/gitignore?templates=python

View File

@ -1,14 +1,16 @@
# ⛓️🛠️ ChainForge
**An open-source visual programming environment for battle-testing prompts to LLMs.**
<img width="1517" alt="banner" src="https://github.com/ianarawjo/ChainForge/assets/5251713/570879ef-ef8a-4e00-b37c-b49bc3c1a370">
ChainForge is a data flow prompt engineering environment for analyzing and evaluating LLM responses. It is geared towards early-stage, quick-and-dirty exploration of prompts, chat responses, and response quality that goes beyond ad-hoc chatting with individual LLMs. With ChainForge, you can:
- **Query multiple LLMs at once** to test prompt ideas and variations quickly and effectively.
- **Compare response quality across prompt permutations, across models, and across model settings** to choose the best prompt and model for your use case.
- **Setup evaluation metrics** (scoring function) and immediately visualize results across prompts, prompt parameters, models, and model settings.
- **Hold multiple conversations at once across template parameters and chat models.** Template not just prompts, but follow-up chat messages, and inspect and evaluate outputs at each turn of a chat conversation.
ChainForge is a data flow prompt engineering environment for analyzing and evaluating LLM responses. It is geared towards early-stage, quick-and-dirty exploration of prompts, chat responses, and response quality that goes beyond ad-hoc chatting with individual LLMs. With ChainForge, you can:
- **Query multiple LLMs at once** to test prompt ideas and variations quickly and effectively.
- **Compare response quality across prompt permutations, across models, and across model settings** to choose the best prompt and model for your use case.
- **Setup evaluation metrics** (scoring function) and immediately visualize results across prompts, prompt parameters, models, and model settings.
- **Hold multiple conversations at once across template parameters and chat models.** Template not just prompts, but follow-up chat messages, and inspect and evaluate outputs at each turn of a chat conversation.
[Read the docs to learn more.](https://chainforge.ai/docs/) ChainForge comes with a number of example evaluation flows to give you a sense of what's possible, including 188 example flows generated from benchmarks in OpenAI evals.
**This is an open beta of Chainforge.** We support model providers OpenAI, HuggingFace, Anthropic, Google PaLM2, Azure OpenAI endpoints, and [Dalai](https://github.com/cocktailpeanut/dalai)-hosted models Alpaca and Llama. You can change the exact model and individual model settings. Visualization nodes support numeric and boolean evaluation metrics. Try it and let us know what you think! :)
@ -16,12 +18,13 @@ ChainForge is a data flow prompt engineering environment for analyzing and evalu
ChainForge is built on [ReactFlow](https://reactflow.dev) and [Flask](https://flask.palletsprojects.com/en/2.3.x/).
# Table of Contents
- [Documentation](https://chainforge.ai/docs/)
- [Installation](#installation)
- [Example Experiments](#example-experiments)
- [Share with Others](#share-with-others)
- [Features](#features) (see the [docs](https://chainforge.ai/docs/nodes/) for more comprehensive info)
- [Development and How to Cite](#development)
- [Documentation](https://chainforge.ai/docs/)
- [Installation](#installation)
- [Example Experiments](#example-experiments)
- [Share with Others](#share-with-others)
- [Features](#features) (see the [docs](https://chainforge.ai/docs/nodes/) for more comprehensive info)
- [Development and How to Cite](#development)
# Installation
@ -41,7 +44,7 @@ chainforge serve
Open [localhost:8000](http://localhost:8000/) in a Google Chrome, Firefox, Microsoft Edge, or Brave browser.
You can set your API keys by clicking the Settings icon in the top-right corner. If you prefer to not worry about this everytime you open ChainForge, we recommend that save your OpenAI, Anthropic, and/or Google PaLM API keys to your local environment. For more details, see the [How to Install](https://chainforge.ai/docs/getting_started/).
You can set your API keys by clicking the Settings icon in the top-right corner. If you prefer to not worry about this everytime you open ChainForge, we recommend that save your OpenAI, Anthropic, Google PaLM API keys and/or Amazon AWS credentials to your local environment. For more details, see the [How to Install](https://chainforge.ai/docs/getting_started/).
# Supported providers
@ -52,6 +55,7 @@ You can set your API keys by clicking the Settings icon in the top-right corner.
- [Ollama](https://github.com/jmorganca/ollama) (locally-hosted models)
- Microsoft Azure OpenAI Endpoints
- [AlephAlpha](https://app.aleph-alpha.com/)
- Foundation models via Amazon Bedrock on-demand inference, including Anthropic Claude 3
- ...and any other provider through [custom provider scripts](https://chainforge.ai/docs/custom_providers/)!
# Example experiments
@ -67,15 +71,15 @@ You can also conduct **ground truth evaluations** using Tabular Data nodes. For
# Compare responses across models and prompts
Compare across models and prompt variables with an interactive response inspector, including a formatted table and exportable data:
Compare across models and prompt variables with an interactive response inspector, including a formatted table and exportable data:
<img width="1460" alt="Screen Shot 2023-07-19 at 5 03 55 PM" src="https://github.com/ianarawjo/ChainForge/assets/5251713/6aca2bd7-7820-4256-9e8b-3a87795f3e50">
Here's [a tutorial to get started comparing across prompt templates](https://chainforge.ai/docs/compare_prompts/).
Here's [a tutorial to get started comparing across prompt templates](https://chainforge.ai/docs/compare_prompts/).
# Share with others
The web version of ChainForge (https://chainforge.ai/play/) includes a Share button.
The web version of ChainForge (https://chainforge.ai/play/) includes a Share button.
Simply click Share to generate a unique link for your flow and copy it to your clipboard:
@ -93,6 +97,7 @@ For finer details about the features of specific nodes, check out the [List of N
# Features
A key goal of ChainForge is facilitating **comparison** and **evaluation** of prompts and models. Basic features are:
- **Prompt permutations**: Setup a prompt template and feed it variations of input variables. ChainForge will prompt all selected LLMs with all possible permutations of the input prompt, so that you can get a better sense of prompt quality. You can also chain prompt templates at arbitrary depth (e.g., to compare templates).
- **Chat turns**: Go beyond prompts and template follow-up chat messages, just like prompts. You can test how the wording of the user's query might change an LLM's output, or compare quality of later responses across multiple chat models (or the same chat model with different settings!).
- **Model settings**: Change the settings of supported models, and compare across settings. For instance, you can measure the impact of a system message on ChatGPT by adding several ChatGPT models, changing individual settings, and nicknaming each one. ChainForge will send out queries to each version of the model.
@ -100,14 +105,15 @@ A key goal of ChainForge is facilitating **comparison** and **evaluation** of pr
- **Visualization nodes**: Visualize evaluation results on plots like grouped box-and-whisker (for numeric metrics) and histograms (for boolean metrics). Currently we only support numeric and boolean metrics. We aim to provide users more control and options for plotting in the future.
Taken together, these features let you easily:
- **Compare across prompts and prompt parameters**: Choose the best set of prompts that maximizes your eval target metrics (e.g., lowest code error rate). Or, see how changing parameters in a prompt template affects the quality of responses.
- **Compare across models**: Compare responses for every prompt across models and different model settings.
- **Compare across prompts and prompt parameters**: Choose the best set of prompts that maximizes your eval target metrics (e.g., lowest code error rate). Or, see how changing parameters in a prompt template affects the quality of responses.
- **Compare across models**: Compare responses for every prompt across models and different model settings.
We've also found that some users simply want to use ChainForge to make tons of parametrized queries to LLMs (e.g., chaining prompt templates into prompt templates), possibly score them, and then output the results to a spreadsheet (Excel `xlsx`). To do this, attach an Inspect node to the output of a Prompt node and click `Export Data`.
For more specific details, see our [documentation](https://chainforge.ai/docs/nodes/).
----------------------------------
---
# Development
@ -120,6 +126,7 @@ We provide ongoing releases of this tool in the hopes that others find it useful
## Inspiration and Links
ChainForge is meant to be general-purpose, and is not developed for a specific API or LLM back-end. Our ultimate goal is integration into other tools for the systematic evaluation and auditing of LLMs. We hope to help others who are developing prompt-analysis flows in LLMs, or otherwise auditing LLM outputs. This project was inspired by own our use case, but also shares some comraderie with two related (closed-source) research projects, both led by [Sherry Wu](https://www.cs.cmu.edu/~sherryw/):
- "PromptChainer: Chaining Large Language Model Prompts through Visual Programming" (Wu et al., CHI 22 LBW) [Video](https://www.youtube.com/watch?v=p6MA8q19uo0)
- "AI Chains: Transparent and Controllable Human-AI Interaction by Chaining Large Language Model Prompts" (Wu et al., CHI 22)
@ -129,7 +136,8 @@ Unlike these projects, we are focusing on supporting evaluation across prompts,
We welcome open-source collaborators. If you want to report a bug or request a feature, open an [Issue](https://github.com/ianarawjo/ChainForge/issues). We also encourage users to implement the requested feature / bug fix and submit a Pull Request.
------------------
---
# Cite Us
If you use ChainForge for research purposes, or build upon the source code, we ask that you cite our [arXiv pre-print](https://arxiv.org/abs/2309.09128) in any related publications.
@ -137,7 +145,7 @@ The BibTeX you can use is:
```bibtex
@misc{arawjo2023chainforge,
title={ChainForge: A Visual Toolkit for Prompt Engineering and LLM Hypothesis Testing},
title={ChainForge: A Visual Toolkit for Prompt Engineering and LLM Hypothesis Testing},
author={Ian Arawjo and Chelse Swoopes and Priyan Vaithilingam and Martin Wattenberg and Elena Glassman},
year={2023},
eprint={2309.09128},

View File

@ -1 +0,0 @@
3.10.6/envs/chainforge

View File

@ -465,7 +465,11 @@ def fetchEnvironAPIKeys():
'HUGGINGFACE_API_KEY': 'HuggingFace',
'AZURE_OPENAI_KEY': 'Azure_OpenAI',
'AZURE_OPENAI_ENDPOINT': 'Azure_OpenAI_Endpoint',
'ALEPH_ALPHA_API_KEY': 'AlephAlpha'
'ALEPH_ALPHA_API_KEY': 'AlephAlpha',
'AWS_ACCESS_KEY_ID': 'AWS_Access_Key_ID',
'AWS_SECRET_ACCESS_KEY': 'AWS_Secret_Access_Key',
'AWS_REGION': 'AWS_Region',
'AWS_SESSION_TOKEN': 'AWS_Session_Token'
}
d = { alias: os.environ.get(key) for key, alias in keymap.items() }
ret = jsonify(d)

View File

@ -1,17 +0,0 @@
{
"files": {
"main.css": "/static/css/main.85149714.css",
"main.js": "/static/js/main.d09374ca.js",
"static/js/477.650352c6.chunk.js": "/static/js/477.650352c6.chunk.js",
"static/js/787.4c72bb55.chunk.js": "/static/js/787.4c72bb55.chunk.js",
"index.html": "/index.html",
"main.85149714.css.map": "/static/css/main.85149714.css.map",
"main.d09374ca.js.map": "/static/js/main.d09374ca.js.map",
"477.650352c6.chunk.js.map": "/static/js/477.650352c6.chunk.js.map",
"787.4c72bb55.chunk.js.map": "/static/js/787.4c72bb55.chunk.js.map"
},
"entrypoints": [
"static/css/main.85149714.css",
"static/js/main.d09374ca.js"
]
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

View File

@ -1 +0,0 @@
<!doctype html><html lang="en"><head><meta charset="utf-8"/><script async src="https://www.googletagmanager.com/gtag/js?id=G-RN3FDBLMCR"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","G-RN3FDBLMCR")</script><link rel="icon" href="/favicon.ico"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="A visual programming environment for prompt engineering"/><link rel="apple-touch-icon" href="/logo192.png"/><link rel="manifest" href="/manifest.json"/><title>ChainForge</title><script defer="defer" src="/static/js/main.d09374ca.js"></script><link href="/static/css/main.85149714.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

View File

@ -1,25 +0,0 @@
{
"short_name": "ChainForge",
"name": "ChainForge",
"icons": [
{
"src": "favicon.ico",
"sizes": "64x64 32x32 24x24 16x16",
"type": "image/x-icon"
},
{
"src": "logo192.png",
"type": "image/png",
"sizes": "192x192"
},
{
"src": "logo512.png",
"type": "image/png",
"sizes": "512x512"
}
],
"start_url": ".",
"display": "standalone",
"theme_color": "#000000",
"background_color": "#ffffff"
}

View File

@ -1,3 +0,0 @@
# https://www.robotstxt.org/robotstxt.html
User-agent: *
Disallow:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1 +0,0 @@
/*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */

File diff suppressed because one or more lines are too long

View File

@ -1,2 +0,0 @@
"use strict";(self.webpackChunkchain_forge=self.webpackChunkchain_forge||[]).push([[787],{787:function(e,n,t){t.r(n),t.d(n,{getCLS:function(){return y},getFCP:function(){return g},getFID:function(){return C},getLCP:function(){return P},getTTFB:function(){return D}});var i,r,a,o,u=function(e,n){return{name:e,value:void 0===n?-1:n,delta:0,entries:[],id:"v2-".concat(Date.now(),"-").concat(Math.floor(8999999999999*Math.random())+1e12)}},c=function(e,n){try{if(PerformanceObserver.supportedEntryTypes.includes(e)){if("first-input"===e&&!("PerformanceEventTiming"in self))return;var t=new PerformanceObserver((function(e){return e.getEntries().map(n)}));return t.observe({type:e,buffered:!0}),t}}catch(e){}},f=function(e,n){var t=function t(i){"pagehide"!==i.type&&"hidden"!==document.visibilityState||(e(i),n&&(removeEventListener("visibilitychange",t,!0),removeEventListener("pagehide",t,!0)))};addEventListener("visibilitychange",t,!0),addEventListener("pagehide",t,!0)},s=function(e){addEventListener("pageshow",(function(n){n.persisted&&e(n)}),!0)},m=function(e,n,t){var i;return function(r){n.value>=0&&(r||t)&&(n.delta=n.value-(i||0),(n.delta||void 0===i)&&(i=n.value,e(n)))}},v=-1,p=function(){return"hidden"===document.visibilityState?0:1/0},d=function(){f((function(e){var n=e.timeStamp;v=n}),!0)},l=function(){return v<0&&(v=p(),d(),s((function(){setTimeout((function(){v=p(),d()}),0)}))),{get firstHiddenTime(){return v}}},g=function(e,n){var t,i=l(),r=u("FCP"),a=function(e){"first-contentful-paint"===e.name&&(f&&f.disconnect(),e.startTime<i.firstHiddenTime&&(r.value=e.startTime,r.entries.push(e),t(!0)))},o=window.performance&&performance.getEntriesByName&&performance.getEntriesByName("first-contentful-paint")[0],f=o?null:c("paint",a);(o||f)&&(t=m(e,r,n),o&&a(o),s((function(i){r=u("FCP"),t=m(e,r,n),requestAnimationFrame((function(){requestAnimationFrame((function(){r.value=performance.now()-i.timeStamp,t(!0)}))}))})))},h=!1,T=-1,y=function(e,n){h||(g((function(e){T=e.value})),h=!0);var t,i=function(n){T>-1&&e(n)},r=u("CLS",0),a=0,o=[],v=function(e){if(!e.hadRecentInput){var n=o[0],i=o[o.length-1];a&&e.startTime-i.startTime<1e3&&e.startTime-n.startTime<5e3?(a+=e.value,o.push(e)):(a=e.value,o=[e]),a>r.value&&(r.value=a,r.entries=o,t())}},p=c("layout-shift",v);p&&(t=m(i,r,n),f((function(){p.takeRecords().map(v),t(!0)})),s((function(){a=0,T=-1,r=u("CLS",0),t=m(i,r,n)})))},E={passive:!0,capture:!0},w=new Date,L=function(e,n){i||(i=n,r=e,a=new Date,F(removeEventListener),S())},S=function(){if(r>=0&&r<a-w){var e={entryType:"first-input",name:i.type,target:i.target,cancelable:i.cancelable,startTime:i.timeStamp,processingStart:i.timeStamp+r};o.forEach((function(n){n(e)})),o=[]}},b=function(e){if(e.cancelable){var n=(e.timeStamp>1e12?new Date:performance.now())-e.timeStamp;"pointerdown"==e.type?function(e,n){var t=function(){L(e,n),r()},i=function(){r()},r=function(){removeEventListener("pointerup",t,E),removeEventListener("pointercancel",i,E)};addEventListener("pointerup",t,E),addEventListener("pointercancel",i,E)}(n,e):L(n,e)}},F=function(e){["mousedown","keydown","touchstart","pointerdown"].forEach((function(n){return e(n,b,E)}))},C=function(e,n){var t,a=l(),v=u("FID"),p=function(e){e.startTime<a.firstHiddenTime&&(v.value=e.processingStart-e.startTime,v.entries.push(e),t(!0))},d=c("first-input",p);t=m(e,v,n),d&&f((function(){d.takeRecords().map(p),d.disconnect()}),!0),d&&s((function(){var a;v=u("FID"),t=m(e,v,n),o=[],r=-1,i=null,F(addEventListener),a=p,o.push(a),S()}))},k={},P=function(e,n){var t,i=l(),r=u("LCP"),a=function(e){var n=e.startTime;n<i.firstHiddenTime&&(r.value=n,r.entries.push(e),t())},o=c("largest-contentful-paint",a);if(o){t=m(e,r,n);var v=function(){k[r.id]||(o.takeRecords().map(a),o.disconnect(),k[r.id]=!0,t(!0))};["keydown","click"].forEach((function(e){addEventListener(e,v,{once:!0,capture:!0})})),f(v,!0),s((function(i){r=u("LCP"),t=m(e,r,n),requestAnimationFrame((function(){requestAnimationFrame((function(){r.value=performance.now()-i.timeStamp,k[r.id]=!0,t(!0)}))}))}))}},D=function(e){var n,t=u("TTFB");n=function(){try{var n=performance.getEntriesByType("navigation")[0]||function(){var e=performance.timing,n={entryType:"navigation",startTime:0};for(var t in e)"navigationStart"!==t&&"toJSON"!==t&&(n[t]=Math.max(e[t]-e.navigationStart,0));return n}();if(t.value=t.delta=n.responseStart,t.value<0||t.value>performance.now())return;t.entries=[n],e(t)}catch(e){}},"complete"===document.readyState?setTimeout(n,0):addEventListener("load",(function(){return setTimeout(n,0)}))}}}]);
//# sourceMappingURL=787.4c72bb55.chunk.js.map

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,241 +0,0 @@
/*
* @copyright 2016 Sean Connelly (@voidqk), http://syntheti.cc
* @license MIT
* @preserve Project Home: https://github.com/voidqk/polybooljs
*/
/*
object-assign
(c) Sindre Sorhus
@license MIT
*/
/*
* based on code from:
*
* @license RequireJS text 0.25.0 Copyright (c) 2010-2011, The Dojo Foundation All Rights Reserved.
* Available via the MIT or new BSD license.
* see: http://github.com/jrburke/requirejs for details
*/
/* @license
Papa Parse
v5.4.1
https://github.com/mholt/PapaParse
License: MIT
*/
/*!
Copyright (c) 2018 Jed Watson.
Licensed under the MIT License (MIT), see
http://jedwatson.github.io/classnames
*/
/*!
* The buffer module from node.js, for the browser.
*
* @author Feross Aboukhadijeh <https://feross.org>
* @license MIT
*/
/*!
* Determine if an object is a Buffer
*
* @author Feross Aboukhadijeh <https://feross.org>
* @license MIT
*/
/*!
* pad-left <https://github.com/jonschlinkert/pad-left>
*
* Copyright (c) 2014-2015, Jon Schlinkert.
* Licensed under the MIT license.
*/
/*!
* repeat-string <https://github.com/jonschlinkert/repeat-string>
*
* Copyright (c) 2014-2015, Jon Schlinkert.
* Licensed under the MIT License.
*/
/*!
* The buffer module from node.js, for the browser.
*
* @author Feross Aboukhadijeh <feross@feross.org> <http://feross.org>
* @license MIT
*/
/*!
* The buffer module from node.js, for the browser.
*
* @author Feross Aboukhadijeh <https://feross.org>
* @license MIT
*/
/*!
* The buffer module from node.js, for the browser.
*
* @author Feross Aboukhadijeh <https://feross.org>
* @license MIT
*/
/*! @license
==========================================================================
SproutCore -- JavaScript Application Framework
copyright 2006-2009, Sprout Systems Inc., Apple Inc. and contributors.
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
SproutCore and the SproutCore logo are trademarks of Sprout Systems, Inc.
For more information about SproutCore, visit http://www.sproutcore.com
==========================================================================
@license */
/*! Native Promise Only
v0.8.1 (c) Kyle Simpson
MIT License: http://getify.mit-license.org
*/
/*! ieee754. BSD-3-Clause License. Feross Aboukhadijeh <https://feross.org/opensource> */
/*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
/*! sheetjs (C) 2013-present SheetJS -- http://sheetjs.com */
/**
* @license
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @license
* Lodash <https://lodash.com/>
* Copyright OpenJS Foundation and other contributors <https://openjsf.org/>
* Released under MIT license <https://lodash.com/license>
* Based on Underscore.js 1.8.3 <http://underscorejs.org/LICENSE>
* Copyright Jeremy Ashkenas, DocumentCloud and Investigative Reporters & Editors
*/
/**
* @license React
* react-dom.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* react-is.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* react-jsx-runtime.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* react.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* scheduler.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* use-sync-external-store-shim.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* use-sync-external-store-shim/with-selector.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* Prism: Lightweight, robust, elegant syntax highlighting
*
* @license MIT <https://opensource.org/licenses/MIT>
* @author Lea Verou <https://lea.verou.me>
* @namespace
* @public
*/
/** @license React v16.13.1
* react-is.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/** @license React v17.0.2
* react-is.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/** @license URI.js v4.4.1 (c) 2011 Gary Court. License: http://github.com/garycourt/uri-js */

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -20,6 +20,7 @@
"@mantine/dropzone": "^6.0.19",
"@mantine/form": "^6.0.11",
"@mantine/prism": "^6.0.15",
"@mirai73/bedrock-fm": "^0.4.3",
"@reactflow/background": "^11.2.0",
"@reactflow/controls": "^11.1.11",
"@reactflow/core": "^11.7.0",
@ -27,7 +28,7 @@
"@rjsf/core": "^5.7.3",
"@rjsf/utils": "^5.7.3",
"@rjsf/validator-ajv8": "^5.7.3",
"@tabler/icons-react": "^2.39.0",
"@tabler/icons-react": "^2.47.0",
"@testing-library/jest-dom": "^5.16.5",
"@testing-library/react": "^13.4.0",
"@testing-library/user-event": "^13.5.0",
@ -59,7 +60,7 @@
"https-browserify": "^1.0.0",
"lodash": "^4.17.21",
"lz-string": "^1.5.0",
"mantine-contextmenu": "^1.2.15",
"mantine-contextmenu": "^6.1.0",
"mantine-react-table": "^1.0.0-beta.8",
"markdown-it": "^13.0.1",
"mathjs": "^11.8.2",

View File

@ -11,7 +11,12 @@ import {
Textarea,
Alert,
} from "@mantine/core";
import { autofill, generateAndReplace, AIError } from "./backend/ai";
import {
autofill,
generateAndReplace,
AIError,
getAIFeaturesModels,
} from "./backend/ai";
import { IconSparkles, IconAlertCircle } from "@tabler/icons-react";
import AlertModal from "./AlertModal";
import useStore from "./store";
@ -124,16 +129,34 @@ export function AIPopover({
}) {
// API keys
const apiKeys = useStore((state) => state.apiKeys);
const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider);
// To check for OpenAI API key
const noOpenAIKeyMessage = useMemo(() => {
if (apiKeys && apiKeys.OpenAI) return undefined;
else
// To check for provider selection and credentials/api keys
const invalidAIFeaturesSetup = useMemo(() => {
if (!aiFeaturesProvider) {
return (
<Alert
variant="light"
color="grape"
title="No OpenAI API key detected."
title="No provider selected"
mt="xs"
maw={200}
fz="xs"
icon={<IconAlertCircle />}
>
You need to select a model in the settings to use this feature
</Alert>
);
} else if (
apiKeys &&
aiFeaturesProvider.toLowerCase().includes("openai") &&
!apiKeys.OpenAI
) {
return (
<Alert
variant="light"
color="grape"
title="No OpenAI API key detected"
mt="xs"
maw={200}
fz="xs"
@ -143,7 +166,32 @@ export function AIPopover({
support features.
</Alert>
);
}, [apiKeys]);
} else if (
apiKeys &&
aiFeaturesProvider.toLowerCase().includes("bedrock") &&
!(
apiKeys.AWS_Access_Key_ID &&
apiKeys.AWS_Secret_Access_Key &&
apiKeys.AWS_Session_Token
)
) {
return (
<Alert
variant="light"
color="grape"
title="No AWS Credentials detected"
mt="xs"
maw={200}
fz="xs"
icon={<IconAlertCircle />}
>
You must set temporary AWS Credentials before you can use generative
AI support features.
</Alert>
);
}
return undefined;
}, [apiKeys, aiFeaturesProvider]);
return (
<Popover
@ -165,9 +213,9 @@ export function AIPopover({
variant="light"
leftSection={<IconSparkles size={10} stroke={3} />}
>
Generative AI
Generative AI ({aiFeaturesProvider ?? "None"})
</Badge>
{noOpenAIKeyMessage || children}
{invalidAIFeaturesSetup || children}
</Stack>
</Popover.Dropdown>
</Popover>
@ -192,6 +240,8 @@ export function AIGenReplaceItemsPopover({
// API keys
const apiKeys = useStore((state) => state.apiKeys);
const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider);
// Alerts
const alertModal = useRef(null);
@ -228,7 +278,12 @@ export function AIGenReplaceItemsPopover({
const handleCommandFill = () => {
setIsCommandFillLoading(true);
setDidCommandFillError(false);
autofill(Object.values(values), commandFillNumber, apiKeys)
autofill(
Object.values(values),
commandFillNumber,
aiFeaturesProvider,
apiKeys,
)
.then(onAddValues)
.catch((e) => {
if (e instanceof AIError) {
@ -248,6 +303,7 @@ export function AIGenReplaceItemsPopover({
generateAndReplacePrompt,
generateAndReplaceNumber,
generateAndReplaceIsUnconventional,
aiFeaturesProvider,
apiKeys,
)
.then(onReplaceValues)
@ -425,6 +481,7 @@ export function AIGenCodeEvaluatorPopover({
}) {
// API keys
const apiKeys = useStore((state) => state.apiKeys);
const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider);
// State
const [replacePrompt, setReplacePrompt] = useState("");
@ -464,7 +521,7 @@ export function AIGenCodeEvaluatorPopover({
queryLLM(
replacePrompt,
"gpt-4",
getAIFeaturesModels(aiFeaturesProvider).large,
1,
escapeBraces(template),
{},
@ -538,7 +595,7 @@ ${currentEvalCode}
queryLLM(
editPrompt,
"gpt-4",
getAIFeaturesModels(aiFeaturesProvider).large,
1,
escapeBraces(template),
{},

View File

@ -11,6 +11,7 @@ import {
Tooltip,
} from "@mantine/core";
import { useClipboard } from "@mantine/hooks";
import { useContextMenu } from "mantine-contextmenu";
import {
IconSettings,
IconTextPlus,
@ -22,7 +23,6 @@ import {
IconArrowMerge,
IconArrowsSplit,
IconForms,
IconAbacus,
} from "@tabler/icons-react";
import RemoveEdge from "./RemoveEdge";
import TextFieldsNode from "./TextFieldsNode"; // Import a custom node
@ -231,6 +231,10 @@ const App = () => {
message: "Are you sure?",
});
// For Mantine Context Menu forced closing
// (for some reason the menu doesn't close automatically upon click-off)
const { hideContextMenu } = useContextMenu();
// For displaying error messages to user
const alertModal = useRef(null);
@ -1035,6 +1039,7 @@ const App = () => {
<div
id="cf-root-container"
style={{ display: "flex", height: "100vh" }}
onPointerDown={hideContextMenu}
>
<div
style={{ height: "100%", backgroundColor: "#eee", flexGrow: "1" }}

View File

@ -21,6 +21,7 @@ import {
Badge,
Card,
Switch,
Select,
} from "@mantine/core";
import { useDisclosure } from "@mantine/hooks";
import { useForm } from "@mantine/form";
@ -35,6 +36,7 @@ import useStore from "./store";
import { APP_IS_RUNNING_LOCALLY } from "./backend/utils";
import fetch_from_backend from "./fetch_from_backend";
import { setCustomProviders } from "./ModelSettingSchemas";
import { getAIFeaturesModelProviders } from "./backend/ai";
const _LINK_STYLE = { color: "#1E90FF", textDecoration: "none" };
@ -153,6 +155,10 @@ const GlobalSettingsModal = forwardRef(
const nodes = useStore((state) => state.nodes);
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const alertModal = props?.alertModal;
const setAIFeaturesProvider = useStore(
(state) => state.setAIFeaturesProvider,
);
const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider);
const [aiSupportActive, setAISupportActive] = useState(
getFlag("aiSupport"),
@ -257,6 +263,11 @@ const GlobalSettingsModal = forwardRef(
Azure_OpenAI_Endpoint: "",
HuggingFace: "",
AlephAlpha: "",
AWS_Access_Key_ID: "",
AWS_Secret_Access_Key: "",
AWS_Session_Token: "",
AWS_Region: "us-east-1",
AmazonBedrock: JSON.stringify({ credentials: {}, region: "us-east-1" }),
},
validate: {
@ -287,7 +298,7 @@ const GlobalSettingsModal = forwardRef(
closeOnClickOutside={false}
style={{ position: "relative", left: "-5%" }}
>
<Box maw={400} mx="auto">
<Box maw={600} mx="auto">
<Tabs defaultValue="api-keys">
<Tabs.List>
<Tabs.Tab value="api-keys">API Keys</Tabs.Tab>
@ -355,6 +366,44 @@ const GlobalSettingsModal = forwardRef(
{...form.getInputProps("AlephAlpha")}
/>
<br />
<Divider
my="xs"
label="Amazon Web Services"
labelPosition="center"
/>
<TextInput
description={
"AWS credentials are used to access the AWS API. You must use" +
"temporary credentials and associated to an IAM role with the" +
"right permission."
}
label="AWS Access Key ID"
placeholder="Paste your AWS Access Key ID here"
{...form.getInputProps("AWS_Access_Key_ID")}
style={{ marginBottom: "8pt" }}
/>
<TextInput
label="AWS Secret Access Key"
placeholder="Paste your AWS Secret Access Key here"
{...form.getInputProps("AWS_Secret_Access_Key")}
style={{ marginBottom: "8pt" }}
/>
<TextInput
label="AWS Session Token"
placeholder="Paste your AWS Session Token here"
{...form.getInputProps("AWS_Session_Token")}
style={{ marginBottom: "8pt" }}
/>
<TextInput
label="AWS Region"
placeholder="Paste your AWS Region here"
{...form.getInputProps("AWS_Region")}
/>
<br />
<Divider
my="xs"
label="Microsoft Azure"
@ -406,7 +455,7 @@ const GlobalSettingsModal = forwardRef(
<Switch
label="AI Support Features"
size="sm"
description="Adds purple sparkly AI buttons to nodes. Must have OpenAI API key access to use."
description="Adds purple sparkly AI buttons to nodes. These buttons allow you to generate in-context data or code."
checked={aiSupportActive}
onChange={handleAISupportChecked}
/>
@ -421,6 +470,16 @@ const GlobalSettingsModal = forwardRef(
checked={aiAutocompleteActive}
onChange={handleAIAutocompleteChecked}
/>
<Select
label="LLM Provider"
description="The LLM provider to use for generative AI features. Currently only supports OpenAI and Bedrock (Anthropic). OpenAI will query gpt-3.5 and gpt-4 models. Bedrock will query Claude-3 models. You must have set the relevant API keys to use the provider."
dropdownPosition="bottom"
withinPortal
defaultValue={getAIFeaturesModelProviders()[0]}
data={getAIFeaturesModelProviders()}
value={aiFeaturesProvider}
onChange={setAIFeaturesProvider}
></Select>
</Group>
) : (
<></>

View File

@ -9,20 +9,20 @@ import React, {
useMemo,
} from "react";
import { DragDropContext, Draggable } from "react-beautiful-dnd";
import { Menu } from "@mantine/core";
import { v4 as uuid } from "uuid";
import LLMListItem, { LLMListItemClone } from "./LLMListItem";
import { StrictModeDroppable } from "./StrictModeDroppable";
import ModelSettingsModal from "./ModelSettingsModal";
import { getDefaultModelSettings } from "./ModelSettingSchemas";
import useStore, { initLLMProviders } from "./store";
import useStore, { initLLMProviderMenu, initLLMProviders } from "./store";
import { useContextMenu } from "mantine-contextmenu";
// The LLM(s) to include by default on a PromptNode whenever one is created.
// Defaults to ChatGPT (GPT3.5) when running locally, and HF-hosted falcon-7b for online version since it's free.
const DEFAULT_INIT_LLMS = [initLLMProviders[0]];
// Helper funcs
// Ensure that a name is 'unique'; if not, return an amended version with a count tacked on (e.g. "GPT-4 (2)")
/** Ensure that a name is 'unique'; if not, return an amended version with a count tacked on (e.g. "GPT-4 (2)") */
const ensureUniqueName = (_name, _prev_names) => {
// Strip whitespace around names
const prev_names = _prev_names.map((n) => n.trim());
@ -41,6 +41,18 @@ const ensureUniqueName = (_name, _prev_names) => {
return new_name;
};
/** Get position CSS style below and left-aligned to the input element */
const getPositionCSSStyle = (elem) => {
const rect = elem.getBoundingClientRect();
return {
style: {
position: "absolute",
left: `${rect.left}px`,
top: `${rect.bottom}px`,
},
};
};
export function LLMList({ llms, onItemsChange, hideTrashIcon }) {
const [items, setItems] = useState(llms);
const settingsModal = useRef(null);
@ -220,7 +232,8 @@ export const LLMListContainer = forwardRef(function LLMListContainer(
) {
// All available LLM providers, for the dropdown list
const AvailableLLMs = useStore((state) => state.AvailableLLMs);
const { showContextMenu, hideContextMenu, isContextMenuVisible } =
useContextMenu();
// For some reason, when the AvailableLLMs list is updated in the store/, it is not
// immediately updated here. I've tried all kinds of things, but cannot seem to fix this problem.
// We must force a re-render of the component:
@ -355,34 +368,71 @@ export const LLMListContainer = forwardRef(function LLMListContainer(
[bgColor],
);
const menuItems = useMemo(() => {
const res = [];
for (const item of initLLMProviderMenu) {
if (!("group" in item)) {
res.push({
key: item.model,
title: `${item.emoji} ${item.name}`,
onClick: () => handleSelectModel(item.base_model),
});
} else {
res.push({
key: item.group,
title: `${item.emoji} ${item.group}`,
items: item.items.map((k) => ({
key: k.model,
title: `${k.emoji} ${k.name}`,
onClick: () => handleSelectModel(k.base_model),
})),
});
}
}
return res;
}, [AvailableLLMs, handleSelectModel]);
// Mantine ContextMenu does not fix the position of the menu
// to be below the clicked button, so we must do it ourselves.
const addBtnRef = useRef(null);
const [wasContextMenuToggled, setWasContextMenuToggled] = useState(false);
return (
<div className="llm-list-container nowheel" style={_bgStyle}>
<div className="llm-list-backdrop" style={_bgStyle}>
{description || "Models to query:"}
<div className="add-llm-model-btn nodrag">
<Menu
transitionProps={{ transition: "pop-top-left" }}
position="bottom-start"
width={220}
withinPortal={true}
<button
ref={addBtnRef}
style={_bgStyle}
onPointerDownCapture={() => {
setWasContextMenuToggled(
isContextMenuVisible && wasContextMenuToggled,
);
}}
onClick={(evt) => {
if (wasContextMenuToggled) {
setWasContextMenuToggled(false);
return; // abort
}
// This is a hack ---without hiding, the context menu position is not always updated.
// This is the case even if hideContextMenu() was triggered elsewhere.
hideContextMenu();
// Now show the context menu below the button:
showContextMenu(
menuItems,
addBtnRef?.current
? getPositionCSSStyle(addBtnRef.current)
: undefined,
)(evt);
// Save whether the context menu was open, before
// onPointerDown in App.tsx could auto-close the menu.
setWasContextMenuToggled(true);
}}
>
<Menu.Target>
<button style={_bgStyle}>
{modelSelectButtonText || "Add +"}
</button>
</Menu.Target>
<Menu.Dropdown>
{AvailableLLMs.map((item) => (
<Menu.Item
key={item.model}
onClick={() => handleSelectModel(item.base_model)}
icon={item.emoji}
>
{item.name}
</Menu.Item>
))}
</Menu.Dropdown>
</Menu>
{modelSelectButtonText ?? "Add +"}
</button>
</div>
</div>
<div className="nodrag">

View File

@ -1243,6 +1243,691 @@ const OllamaSettings = {
},
};
const BedrockClaudeSettings = {
fullName: "Claude (Anthropic) via Amazon Bedrock",
schema: {
type: "object",
required: ["shortname"],
properties: {
shortname: {
type: "string",
title: "Nickname",
description:
"Unique identifier to appear in ChainForge. Keep it short.",
default: "Claude",
},
model: {
type: "string",
title: "Model Version",
description:
"Select a version of Claude to query. For more details on the differences, see the Anthropic API documentation.",
enum: [
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-v2:1",
"anthropic.claude-v2",
"anthropic.claude-instant-v1",
],
default: "anthropic.claude-3-haiku-20240307-v1:0",
shortname_map: {
"anthropic.claude-3-sonnet-20240229-v1:0": "claude-3-sonnet",
"anthropic.claude-3-haiku-20240307-v1:0": "claude-3-haiku",
},
},
system_msg: {
type: "string",
title: "system_msg",
description: "A system message to use with the model",
default: "",
},
temperature: {
type: "number",
title: "temperature",
description:
"Amount of randomness injected into the response. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and temp closer to 1 for creative and generative tasks.",
default: 1,
minimum: 0,
maximum: 1,
multipleOf: 0.01,
},
max_tokens_to_sample: {
type: "integer",
title: "max_tokens_to_sample",
description:
"A maximum number of tokens to generate before stopping. Lower this if you want shorter responses. By default, ChainForge uses the value 1024, although the Anthropic API does not specify a default value.",
default: 1024,
minimum: 1,
},
custom_prompt_wrapper: {
type: "string",
title: "Prompt Wrapper (ChainForge)",
description:
// eslint-disable-next-line no-template-curly-in-string
'Anthropic models expect prompts in the form "\\n\\nHuman: ${prompt}\\n\\nAssistant:". ChainForge wraps all prompts in this template by default. If you wish to' +
// eslint-disable-next-line no-template-curly-in-string
"explore custom prompt wrappers that deviate, write a Python template here with a single variable, ${prompt}, where the actual prompt text should go. Otherwise, leave this field blank. (Note that you should enter newlines as newlines, not escape codes like \\n.)",
default: "",
},
stop_sequences: {
type: "string",
title: "stop_sequences",
description:
'Anthropic models stop on "\\n\\nHuman:", and may include additional built-in stop sequences in the future. By providing the stop_sequences parameter, you may include additional strings that will cause the model to stop generating.\nEnclose stop sequences in double-quotes "" and use whitespace to separate them.',
default: '"\n\nHuman:"',
},
top_k: {
type: "integer",
title: "top_k",
description:
'Only sample from the top K options for each subsequent token. Used to remove "long tail" low probability responses. Defaults to -1, which disables it.',
minimum: 1,
default: 1,
},
top_p: {
type: "number",
title: "top_p",
description:
"Does nucleus sampling, in which we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. Defaults to -1, which disables it. Note that you should either alter temperature or top_p, but not both.",
default: 0.9,
minimum: 0.001,
maximum: 1,
multipleOf: 0.001,
},
},
},
uiSchema: {
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
shortname: {
"ui:autofocus": true,
},
model: {
"ui:help":
"Defaults to claude-2. Note that Anthropic models in particular are subject to change. Model names prior to Claude 2, including 100k context window, are no longer listed on the Anthropic site, so they may or may not work.",
},
temperature: {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range",
},
max_tokens_to_sample: {
"ui:help": "Defaults to 1024.",
},
top_k: {
"ui:help": "Defaults to -1 (none).",
},
top_p: {
"ui:help": "Defaults to -1 (none).",
},
stop_sequences: {
"ui:widget": "textarea",
"ui:help": 'Defaults to one stop sequence, "\\n\\nHuman: "',
},
custom_prompt_wrapper: {
"ui:widget": "textarea",
"ui:help":
'Defaults to Anthropic\'s internal wrapper "\\n\\nHuman: {prompt}\\n\\nAssistant".',
},
},
postprocessors: {
stop_sequences: (str) => {
if (str.trim().length === 0) return ["\n\nHuman:"];
return str
.match(/"((?:[^"\\]|\\.)*)"/g)
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
},
},
};
const BedrockJurassic2Settings = {
fullName: "Jurassic-2 (Ai21) via Amazon Bedrock",
schema: {
type: "object",
required: ["shortname"],
properties: {
shortname: {
type: "string",
title: "Nickname",
description:
"Unique identifier to appear in ChainForge. Keep it short.",
default: "Jurassic2",
},
model: {
type: "string",
title: "Model Version",
description:
"Select a version of Jurassic 2 to query. For more details on the differences, see the AI21 API documentation.",
enum: ["ai21.j2-ultra", "ai21.j2-mid"],
default: "ai21.j2-ultra",
},
temperature: {
type: "number",
title: "temperature",
description:
"Amount of randomness injected into the response. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and temp closer to 1 for creative and generative tasks.",
default: 1,
minimum: 0,
maximum: 1,
multipleOf: 0.01,
},
maxTokens: {
type: "integer",
title: "maxTokens",
description:
"The maximum number of tokens to generate for each response.",
default: 1024,
minimum: 1,
},
minTokens: {
type: "integer",
title: "maxTokens",
description:
"The minimum number of tokens to generate for each response.",
default: 1,
minimum: 1,
},
numResults: {
type: "integer",
title: "numResults",
description: "The number of responses to generate for a given prompt.",
default: 1,
minimum: 1,
},
stopSequences: {
type: "string",
title: "stopSequences",
description:
'Enclose stop sequences in double-quotes "" and use whitespace to separate them.',
default: "",
},
topKReturn: {
type: "integer",
title: "topKReturn",
description:
"The number of top-scoring tokens to consider for each generation step.",
minimum: 0,
default: 0,
},
topP: {
type: "number",
title: "topP",
description:
"Does nucleus sampling, in which we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. Defaults to -1, which disables it. Note that you should either alter temperature or top_p, but not both.",
default: 1,
minimum: 0.01,
maximum: 1,
multipleOf: 0.001,
},
},
},
postprocessors: {
stopSequences: (str) => {
if (str.trim().length === 0) return [];
return str
.match(/"((?:[^"\\]|\\.)*)"/g)
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
},
},
uiSchema: {
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
shortname: {
"ui:autofocus": true,
},
model: {
"ui:help": "Defaults to Jurassic 2 Ultra. ",
},
temperature: {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range",
},
maxTokens: {
"ui:help": "Defaults to 1024.",
},
minTokens: {
"ui:help": "Defaults to 1.",
},
topKReturn: {
"ui:help": "Defaults to 0.",
},
topP: {
"ui:help": "Defaults to 1.",
},
stopSequences: {
"ui:widget": "textarea",
"ui:help": "Defaults to no sequence",
},
},
};
const BedrockTitanSettings = {
fullName: "Titan (Amazon) via Amazon Bedrock",
schema: {
type: "object",
required: ["shortname"],
properties: {
shortname: {
type: "string",
title: "Nickname",
description:
"Unique identifier to appear in ChainForge. Keep it short.",
default: "Titan",
},
model: {
type: "string",
title: "Model Version",
description:
"Select a version of Amazon Titan to query. For more details on the differences, see the Amazon Titan API documentation.",
enum: [
"amazon.titan-tg1-large",
"amazon.titan-text-lite-v1",
"amazon.titan-text-express-v1",
],
default: "amazon.titan-tg1-large",
},
temperature: {
type: "number",
title: "temperature",
description:
"Amount of randomness injected into the response. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and temp closer to 1 for creative and generative tasks.",
default: 1,
minimum: 0,
maximum: 1,
multipleOf: 0.01,
},
maxTokenCount: {
type: "integer",
title: "maxTokens",
description:
"The maximum number of tokens to generate for each response.",
default: 1024,
minimum: 1,
},
stopSequences: {
type: "string",
title: "stopSequences",
description:
'Enclose stop sequences in double-quotes "" and use whitespace to separate them.',
default: "",
},
topP: {
type: "number",
title: "topP",
description:
"Does nucleus sampling, in which we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. Defaults to -1, which disables it. Note that you should either alter temperature or top_p, but not both.",
default: 1,
minimum: 0.01,
maximum: 1,
multipleOf: 0.001,
},
},
},
postprocessors: {
stopSequences: (str) => {
if (str.trim().length === 0) return [];
return str
.match(/"((?:[^"\\]|\\.)*)"/g)
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
},
},
uiSchema: {
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
shortname: {
"ui:autofocus": true,
},
model: {
"ui:help": "Defaults to Titan Large",
},
temperature: {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range",
},
maxTokenCount: {
"ui:help": "Defaults to 1024.",
},
topP: {
"ui:help": "Defaults to 1.",
},
stopSequences: {
"ui:widget": "textarea",
"ui:help": "Defaults to no sequence",
},
},
};
const BedrockCommandTextSettings = {
fullName: "Command Text (Cohere) via Amazon Bedrock",
schema: {
type: "object",
required: ["shortname"],
properties: {
shortname: {
type: "string",
title: "Nickname",
description:
"Unique identifier to appear in ChainForge. Keep it short.",
default: "CommandText",
},
model: {
type: "string",
title: "Model Version",
description:
"Select a version of Command Cohere to query. For more details on the differences, see the Cohere API documentation.",
enum: ["cohere.command-text-v14", "cohere.command-light-text-v14"],
default: "cohere.command-text-v14",
},
temperature: {
type: "number",
title: "temperature",
description:
"Amount of randomness injected into the response. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and temp closer to 1 for creative and generative tasks.",
default: 1,
minimum: 0,
maximum: 1,
multipleOf: 0.01,
},
max_tokens: {
type: "integer",
title: "max_tokens",
description:
"The maximum number of tokens to generate for each response.",
default: 1024,
minimum: 1,
},
num_generations: {
type: "integer",
title: "num_generations",
description: "The number of responses to generate for a given prompt.",
default: 1,
minimum: 1,
},
stop_sequences: {
type: "string",
title: "stop_sequences",
description:
'Enclose stop sequences in double-quotes "" and use whitespace to separate them.',
default: "",
},
k: {
type: "integer",
title: "k",
description:
"The number of top-scoring tokens to consider for each generation step.",
minimum: 0,
default: 0,
},
p: {
type: "number",
title: "p",
description:
"Does nucleus sampling, in which we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. Defaults to -1, which disables it. Note that you should either alter temperature or top_p, but not both.",
default: 1,
minimum: 0.01,
maximum: 1,
multipleOf: 0.001,
},
},
},
postprocessors: {
stop_sequences: (str) => {
if (str.trim().length === 0) return [];
return str
.match(/"((?:[^"\\]|\\.)*)"/g)
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
},
},
uiSchema: {
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
shortname: {
"ui:autofocus": true,
},
model: {
"ui:help": "Defaults to Command Text",
},
temperature: {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range",
},
max_tokens: {
"ui:help": "Defaults to 1024.",
},
num_generations: {
"ui:help": "Defaults to 1.",
},
k: {
"ui:help": "Defaults to 0.",
},
p: {
"ui:help": "Defaults to 1.",
},
stop_sequences: {
"ui:widget": "textarea",
"ui:help": "Defaults to no sequence",
},
},
};
const MistralSettings = {
fullName: "Mistral models via Amazon Bedrock",
schema: {
type: "object",
required: ["shortname"],
properties: {
shortname: {
type: "string",
title: "Nickname",
description:
"Unique identifier to appear in ChainForge. Keep it short.",
default: "Mistral",
},
model: {
type: "string",
title: "Model Version",
description:
"Select a version of Mistral model to query. For more details on the differences, see the Mistral API documentation.",
enum: ["mistral.mistral-7b-instruct-v0:2"],
default: "mistral.mistral-7b-instruct-v0:2",
},
temperature: {
type: "number",
title: "temperature",
description:
"Amount of randomness injected into the response. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and temp closer to 1 for creative and generative tasks.",
default: 1,
minimum: 0,
maximum: 1,
multipleOf: 0.01,
},
max_tokens: {
type: "integer",
title: "max_tokens",
description:
"The maximum number of tokens to generate for each response.",
default: 1024,
minimum: 1,
},
stop: {
type: "string",
title: "stop",
description:
'Enclose stop sequences in double-quotes "" and use whitespace to separate them.',
default: "",
},
top_k: {
type: "integer",
title: "top_k",
description:
"The number of top-scoring tokens to consider for each generation step.",
minimum: 0,
default: 0,
},
top_p: {
type: "number",
title: "top_p",
description:
"Does nucleus sampling, in which we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. Defaults to -1, which disables it. Note that you should either alter temperature or top_p, but not both.",
default: 1,
minimum: 0.01,
maximum: 1,
multipleOf: 0.001,
},
},
},
postprocessors: {
stop_sequences: (str) => {
if (str.trim().length === 0) return [];
return str
.match(/"((?:[^"\\]|\\.)*)"/g)
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
},
},
uiSchema: {
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
shortname: {
"ui:autofocus": true,
},
model: {
"ui:help": "Defaults to Mistral",
},
temperature: {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range",
},
max_tokens: {
"ui:help": "Defaults to 1024.",
},
num_generations: {
"ui:help": "Defaults to 1.",
},
k: {
"ui:help": "Defaults to 0.",
},
p: {
"ui:help": "Defaults to 1.",
},
stop_sequences: {
"ui:widget": "textarea",
"ui:help": "Defaults to no sequence",
},
},
};
const MixtralSettings = { ...MistralSettings };
MixtralSettings.schema.properties = {
...MixtralSettings.schema.properties,
...{
model: {
type: "string",
title: "Model Version",
description:
"Select a version of Mistral model to query. For more details on the differences, see the Mixtral API documentation.",
enum: ["mistral.mixtral-8x7b-instruct-v0:1"],
default: "mistral.mixtral-8x7b-instruct-v0:1",
},
shortname: {
type: "string",
title: "Nickname",
description: "Unique identifier to appear in ChainForge. Keep it short.",
default: "Mixtral",
},
},
};
MixtralSettings.uiSchema.model = { "ui:help": "Defaults to Mixtral" };
const MetaLlama2ChatSettings = {
fullName: "Llama2Chat (Meta) via Amazon Bedrock",
schema: {
type: "object",
required: ["shortname"],
properties: {
shortname: {
type: "string",
title: "Nickname",
description:
"Unique identifier to appear in ChainForge. Keep it short.",
default: "LlamaChat",
},
model: {
type: "string",
title: "Model Version",
description:
"Select a version of Command Cohere to query. For more details on the differences, see the Cohere API documentation.",
enum: ["meta.llama2-13b-chat-v1", "meta.llama2-70b-chat-v1"],
default: "meta.llama2-13b-chat-v1",
},
temperature: {
type: "number",
title: "temperature",
description:
"Amount of randomness injected into the response. Ranges from 0 to 1. Use temp closer to 0 for analytical / multiple choice, and temp closer to 1 for creative and generative tasks.",
default: 1,
minimum: 0,
maximum: 1,
multipleOf: 0.01,
},
max_gen_len: {
type: "integer",
title: "max_gen_len",
description:
"The maximum number of tokens to generate for each response.",
default: 1024,
minimum: 1,
},
top_p: {
type: "number",
title: "top_p",
description:
"Does nucleus sampling, in which we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. Defaults to -1, which disables it. Note that you should either alter temperature or top_p, but not both.",
default: 1,
minimum: 0.01,
maximum: 1,
multipleOf: 0.001,
},
},
},
postprocessors: {
stop_sequences: (str) => {
if (str.trim().length === 0) return [];
return str
.match(/"((?:[^"\\]|\\.)*)"/g)
.map((s) => s.substring(1, s.length - 1)); // split on double-quotes but exclude escaped double-quotes inside the group
},
},
uiSchema: {
"ui:submitButtonOptions": UI_SUBMIT_BUTTON_SPEC,
shortname: {
"ui:autofocus": true,
},
model: {
"ui:help": "Defaults to LlamaChat 13B",
},
temperature: {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range",
},
max_tokens: {
"ui:help": "Defaults to 1024.",
},
num_generations: {
"ui:help": "Defaults to 1.",
},
k: {
"ui:help": "Defaults to 0.",
},
p: {
"ui:help": "Defaults to 1.",
},
stop_sequences: {
"ui:widget": "textarea",
"ui:help": "Defaults to no sequence",
},
},
};
// A lookup table indexed by base_model.
export const ModelSettings = {
"gpt-3.5-turbo": ChatGPTSettings,
@ -1254,6 +1939,13 @@ export const ModelSettings = {
hf: HuggingFaceTextInferenceSettings,
"luminous-base": AlephAlphaLuminousSettings,
ollama: OllamaSettings,
"br.anthropic.claude": BedrockClaudeSettings,
"br.ai21.j2": BedrockJurassic2Settings,
"br.amazon.titan": BedrockTitanSettings,
"br.cohere.command": BedrockCommandTextSettings,
"br.mistral.mistral": MistralSettings,
"br.mistral.mixtral": MixtralSettings,
"br.meta.llama2": MetaLlama2ChatSettings,
};
export function getSettingsSchemaForLLM(llm_name) {
@ -1273,7 +1965,9 @@ export function getSettingsSchemaForLLM(llm_name) {
if (llm_provider === LLMProvider.Custom) return ModelSettings[llm_name];
else if (llm_provider in provider_to_settings_schema)
return provider_to_settings_schema[llm_provider];
else {
else if (llm_provider === LLMProvider.Bedrock) {
return ModelSettings[llm_name.split("-")[0]];
} else {
console.error(`Could not find provider for llm ${llm_name}`);
return {};
}
@ -1410,7 +2104,8 @@ export const setCustomProvider = (
rate_limit > 0
) {
if (rate_limit >= 60)
RATE_LIMITS[base_model] = [Math.trunc(rate_limit / 60), 1]; // for instance, 300 rpm means 5 every second
RATE_LIMITS[base_model] = [Math.trunc(rate_limit / 60), 1];
// for instance, 300 rpm means 5 every second
else RATE_LIMITS[base_model] = [1, Math.trunc(60 / rate_limit)]; // for instance, 10 rpm means 1 every 6 seconds
}

View File

@ -35,6 +35,7 @@ const TextFieldsNode = ({ data, id }) => {
const setDataPropsForNode = useStore((state) => state.setDataPropsForNode);
const pingOutputNodes = useStore((state) => state.pingOutputNodes);
const apiKeys = useStore((state) => state.apiKeys);
const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider);
const flags = useStore((state) => state.flags);
const [textfieldsValues, setTextfieldsValues] = useState(data.fields || {});
@ -51,6 +52,7 @@ const TextFieldsNode = ({ data, id }) => {
undefined,
// When suggestions are refreshed, throw out existing placeholders.
() => setPlaceholders({}),
() => aiFeaturesProvider,
() => apiKeys,
),
);

View File

@ -1,10 +1,20 @@
import { autofill, generateAndReplace } from "../ai";
describe("autofill", () => {
const apiKeys = {
OpenAI: process.env.OPENAI_API_KEY,
AWS_Access_Key_ID: process.env.AWS_ACCESS_KEY_ID,
AWS_Secret_Access_Key: process.env.AWS_SECRET_ACCESS_KEY,
AWS_Session_Token: process.env.AWS_SESSION_TOKEN,
};
describe("autofill-openai", () => {
if (!apiKeys.OpenAI) {
return;
}
it("should return an array of n rows", async () => {
const input = ["1", "2", "3", "4", "5"];
const n = 3;
const result = await autofill(input, n);
const result = await autofill(input, n, "OpenAI", apiKeys);
expect(result).toHaveLength(n);
result.forEach((row) => {
expect(typeof row).toBe("string");
@ -12,11 +22,56 @@ describe("autofill", () => {
});
});
describe("generateAndReplace", () => {
describe("generateAndReplace-openai", () => {
if (!apiKeys.OpenAI) {
return;
}
it("should return an array of n rows", async () => {
const prompt = "animals";
const n = 3;
const result = await generateAndReplace(prompt, n);
const result = await generateAndReplace(
prompt,
n,
false,
"OpenAI",
apiKeys,
);
expect(result).toHaveLength(n);
result.forEach((row) => {
expect(typeof row).toBe("string");
});
});
});
describe("autofill-bedrock-anthropic", () => {
if (!apiKeys.AWS_Access_Key_ID) {
return;
}
it("should return an array of n rows", async () => {
const input = ["1", "2", "3", "4", "5"];
const n = 3;
const result = await autofill(input, n, "Bedrock", apiKeys);
expect(result).toHaveLength(n);
result.forEach((row) => {
expect(typeof row).toBe("string");
});
});
});
describe("generateAndReplace-bedrock-anthropic", () => {
if (!apiKeys.AWS_Access_Key_ID) {
return;
}
it("should return an array of n rows", async () => {
const prompt = "animals";
const n = 3;
const result = await generateAndReplace(
prompt,
n,
false,
"Bedrock",
apiKeys,
);
expect(result).toHaveLength(n);
result.forEach((row) => {
expect(typeof row).toBe("string");

View File

@ -20,8 +20,40 @@ export class AIError extends Error {
// Input and outputs of autofill are both rows of strings.
export type Row = string;
// LLM to use for AI features.
const LLM = "gpt-3.5-turbo";
// The list of LLMs models that can be used with AI features
const AIFeaturesLLMs = [
{
provider: "OpenAI",
small: { value: "gpt-3.5-turbo", label: "OpenAI GPT3.5" },
large: { value: "gpt-4", label: "OpenAI GPT4" },
},
{
provider: "Bedrock",
small: {
value: "anthropic.claude-3-haiku-20240307-v1:0",
label: "Claude 3 Haiku",
},
large: {
value: "anthropic.claude-3-sonnet-20240229-v1:0",
label: "Claude 3 Sonnet",
},
},
];
export function getAIFeaturesModelProviders() {
return AIFeaturesLLMs.map((m) => m.provider);
}
export function getAIFeaturesModels(
provider: string,
): { small: string; large: string } | undefined {
const models = AIFeaturesLLMs.filter((m) => m.provider === provider);
if (models.length === 0) return undefined;
return {
small: models[0].small.value,
large: models[0].large.value,
};
}
/**
* Flattens markdown AST to text
@ -137,6 +169,7 @@ function decode(mdText: string): Row[] {
export async function autofill(
input: Row[],
n: number,
provider: string,
apiKeys?: Dict,
): Promise<Row[]> {
// hash the arguments to get a unique id
@ -162,7 +195,7 @@ export async function autofill(
const result = await queryLLM(
/* id= */ id,
/* llm= */ LLM,
/* llm= */ getAIFeaturesModels(provider).small,
/* n= */ 1,
/* prompt= */ encoded,
/* vars= */ {},
@ -196,8 +229,9 @@ export async function autofill(
export async function generateAndReplace(
prompt: string,
n: number,
creative?: boolean,
apiKeys?: Dict,
creative: boolean,
provider: string,
apiKeys: Dict,
): Promise<Row[]> {
// hash the arguments to get a unique id
const id = JSON.stringify([prompt, n]);
@ -221,7 +255,7 @@ export async function generateAndReplace(
const result = await queryLLM(
/* id= */ id,
/* llm= */ LLM,
/* llm= */ getAIFeaturesModels(provider).small,
/* n= */ 1,
/* prompt= */ input,
/* vars= */ {},

View File

@ -50,17 +50,21 @@ class AISuggestionsManager {
onSuggestionsRefreshed?: (suggestions: Row[]) => void;
// Fetches API keys from front-end
getAPIKeys?: () => Dict;
// Fetches the model provider from front-end
getModelProvider?: () => string;
// Whether the suggestions are loading.
isLoading = false;
constructor(
onSuggestionsChanged?: (suggestions: Row[]) => void,
onSuggestionsRefreshed?: (suggestions: Row[]) => void,
getModelProvider?: () => string,
getAPIKeys?: () => Dict,
) {
this.onSuggestionsChanged = onSuggestionsChanged;
this.onSuggestionsRefreshed = onSuggestionsRefreshed;
this.getAPIKeys = getAPIKeys;
this.getModelProvider = getModelProvider;
}
/**
@ -105,6 +109,7 @@ class AISuggestionsManager {
autofill(
this.base,
NUM_SUGGESTIONS_TO_CACHE,
this.getModelProvider(),
this.getAPIKeys ? this.getAPIKeys() : undefined,
)
// Update suggestions.

View File

@ -171,8 +171,8 @@ function get_cache_keys_related_to_id(
);
else return include_basefile ? [base_file] : [];
}
// eslint-disable-next-line
async function setAPIKeys(api_keys: StringDict): Promise<void> {
export async function setAPIKeys(api_keys: StringDict): Promise<void> {
if (api_keys !== undefined) set_api_keys(api_keys);
}
@ -223,7 +223,9 @@ function extract_llm_key(llm_spec: Dict | string): string {
else if (llm_spec.key !== undefined) return llm_spec.key;
else
throw new Error(
`Could not find a key property on spec ${JSON.stringify(llm_spec)} for LLM`,
`Could not find a key property on spec ${JSON.stringify(
llm_spec,
)} for LLM`,
);
}
@ -782,7 +784,7 @@ export async function queryLLM(
// For each LLM, generate and cache responses:
const responses: { [key: string]: Array<LLMResponseObject> } = {};
const all_errors = {};
const num_generations = n !== undefined ? n : 1;
const num_generations = n ?? 1;
async function query(llm_spec: string | Dict): Promise<LLMPrompterResults> {
// Get LLM model name and any params
const llm_str = extract_llm_name(llm_spec);
@ -935,7 +937,6 @@ export async function queryLLM(
errors: all_errors,
};
}
/**
* A convenience function for a simpler call to queryLLM.
* This is queryLLM with "no_cache" turned on, no variables, and n=1 responses per prompt.
@ -1054,6 +1055,7 @@ export async function executejs(
// Run the user-defined 'evaluate' function over the responses:
// NOTE: 'evaluate' here was defined dynamically from 'eval' above. We've already checked that it exists.
processed_resps = await run_over_responses(
iframe ? process_func : code,
responses,
@ -1221,8 +1223,7 @@ export async function evalWithLLM(
// Load all responses with the given ID:
let all_evald_responses: StandardizedLLMResponse[] = [];
let all_errors: string[] = [];
for (let i = 0; i < response_ids.length; i++) {
const cache_id = response_ids[i];
for (const cache_id of response_ids) {
const fname = `${cache_id}.json`;
if (!StorageCache.has(fname))
return { error: `Did not find cache file for id ${cache_id}` };
@ -1333,8 +1334,7 @@ export async function evalWithLLM(
export async function grabResponses(responses: Array<string>): Promise<Dict> {
// Grab all responses with the given ID:
let grabbed_resps: Dict[] = [];
for (let i = 0; i < responses.length; i++) {
const cache_id = responses[i];
for (const cache_id of responses) {
const storageKey = `${cache_id}.json`;
if (!StorageCache.has(storageKey))
return { error: `Did not find cache data for id ${cache_id}` };
@ -1362,8 +1362,7 @@ export async function grabResponses(responses: Array<string>): Promise<Dict> {
export async function exportCache(ids: string[]) {
// For each id, extract relevant cache file data
const cache_files = {};
for (let i = 0; i < ids.length; i++) {
const cache_id = ids[i];
for (const cache_id of ids) {
const cache_keys = get_cache_keys_related_to_id(cache_id);
if (cache_keys.length === 0) {
console.warn(

View File

@ -82,6 +82,23 @@ export enum NativeLLM {
// The actual model name will be passed as a param to the LLM call function.
HF_OTHER = "Other (HuggingFace)",
Ollama = "ollama",
Bedrock_Claude_2_1 = "anthropic.claude-v2:1",
Bedrock_Claude_2 = "anthropic.claude-v2",
Bedrock_Claude_3_Sonnet = "anthropic.claude-3-sonnet-20240229-v1:0",
Bedrock_Claude_3_Haiku = "anthropic.claude-3-haiku-20240307-v1:0",
Bedrock_Claude_Instant_1 = "anthropic.claude-instant-v1",
Bedrock_Jurassic_Ultra = "ai21.j2-ultra",
Bedrock_Jurassic_Mid = "ai21.j2-mid",
Bedrock_Titan_Light = "amazon.titan-text-lite-v1",
Bedrock_Titan_Large = "amazon.titan-tg1-large",
Bedrock_Titan_Express = "amazon.titan-text-express-v1",
Bedrock_Command_Text = "cohere.command-text-v14",
Bedrock_Command_Text_Light = "cohere.command-light-text-v14",
Bedrock_Meta_LLama2Chat_13b = "meta.llama2-13b-chat-v1",
Bedrock_Meta_LLama2Chat_70b = "meta.llama2-70b-chat-v1",
Bedrock_Mistral_Mistral = "mistral.mistral-7b-instruct-v0:2",
Bedrock_Mistral_Mixtral = "mistral.mixtral-8x7b-instruct-v0:1",
}
export type LLM = string | NativeLLM;
@ -98,6 +115,7 @@ export enum LLMProvider {
HuggingFace = "hf",
Aleph_Alpha = "alephalpha",
Ollama = "ollama",
Bedrock = "bedrock",
Custom = "__custom",
}
@ -117,6 +135,7 @@ export function getProvider(llm: LLM): LLMProvider | undefined {
else if (llm.toString().startsWith("claude")) return LLMProvider.Anthropic;
else if (llm_name?.startsWith("Aleph_Alpha")) return LLMProvider.Aleph_Alpha;
else if (llm_name?.startsWith("Ollama")) return LLMProvider.Ollama;
else if (llm_name?.startsWith("Bedrock")) return LLMProvider.Bedrock;
else if (llm.toString().startsWith("__custom/")) return LLMProvider.Custom;
return undefined;
@ -143,6 +162,20 @@ export const RATE_LIMITS: { [key in LLM]?: [number, number] } = {
[NativeLLM.Azure_OpenAI]: [30, 10],
[NativeLLM.PaLM2_Text_Bison]: [4, 10], // max 30 requests per minute; so do 4 per batch, 10 seconds between (conservative)
[NativeLLM.PaLM2_Chat_Bison]: [4, 10],
[NativeLLM.Bedrock_Jurassic_Mid]: [20, 5],
[NativeLLM.Bedrock_Jurassic_Ultra]: [5, 5],
[NativeLLM.Bedrock_Titan_Light]: [40, 5],
[NativeLLM.Bedrock_Titan_Express]: [20, 5], // 400 RPM
[NativeLLM.Bedrock_Claude_2]: [20, 15], // 100 RPM
[NativeLLM.Bedrock_Claude_2_1]: [20, 15], // 100 RPM
[NativeLLM.Bedrock_Claude_3_Haiku]: [20, 5], // 100 RPM
[NativeLLM.Bedrock_Claude_3_Sonnet]: [20, 15], // 100 RPM
[NativeLLM.Bedrock_Command_Text]: [20, 5], // 400 RPM
[NativeLLM.Bedrock_Command_Text_Light]: [40, 5], // 800 RPM
[NativeLLM.Bedrock_Meta_LLama2Chat_70b]: [20, 5], // 400 RPM
[NativeLLM.Bedrock_Meta_LLama2Chat_13b]: [40, 5], // 800 RPM
[NativeLLM.Bedrock_Mistral_Mixtral]: [20, 5], // 400 RPM
[NativeLLM.Bedrock_Mistral_Mistral]: [40, 5], // 800 RPM
};
/** Equivalent to a Python enum's .name property */

View File

@ -394,7 +394,6 @@ export class PromptPipeline {
try {
// When/if we emerge from sleep, check if this process has been canceled in the meantime:
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
// Call the LLM, returning when the Promise returns (if it does!)
[query, response] = await call_llm(
llm,

View File

@ -17,7 +17,6 @@ import {
GeminiChatContext,
GeminiChatMessage,
} from "./typing";
import { env as process_env } from "process";
import { v4 as uuid } from "uuid";
import { StringTemplate } from "./template";
@ -29,6 +28,11 @@ import {
} from "@azure/openai";
import { GoogleGenerativeAI } from "@google/generative-ai";
import { UserForcedPrematureExit } from "./errors";
import {
fromModelId,
ChatMessage as BedrockChatMessage,
} from "@mirai73/bedrock-fm";
import { Models } from "@mirai73/bedrock-fm/lib/bedrock";
import StorageCache from "./cache";
const ANTHROPIC_HUMAN_PROMPT = "\n\nHuman:";
@ -115,18 +119,14 @@ async function route_fetch(
}
}
// import { DiscussServiceClient, TextServiceClient } from "@google-ai/generativelanguage";
// import { GoogleAuth } from "google-auth-library";
function get_environ(key: string): string | undefined {
if (key in process_env) return process_env[key];
return undefined;
}
function appendEndSlashIfMissing(path: string) {
return path + (path[path.length - 1] === "/" ? "" : "/");
}
function get_environ(key: string): string | undefined {
return process.env[key];
}
let OPENAI_API_KEY = get_environ("OPENAI_API_KEY");
let ANTHROPIC_API_KEY = get_environ("ANTHROPIC_API_KEY");
let GOOGLE_PALM_API_KEY = get_environ("PALM_API_KEY");
@ -134,6 +134,10 @@ let AZURE_OPENAI_KEY = get_environ("AZURE_OPENAI_KEY");
let AZURE_OPENAI_ENDPOINT = get_environ("AZURE_OPENAI_ENDPOINT");
let HUGGINGFACE_API_KEY = get_environ("HUGGINGFACE_API_KEY");
let ALEPH_ALPHA_API_KEY = get_environ("ALEPH_ALPHA_API_KEY");
let AWS_ACCESS_KEY_ID = get_environ("AWS_ACCESS_KEY_ID");
let AWS_SECRET_ACCESS_KEY = get_environ("AWS_SECRET_ACCESS_KEY");
let AWS_SESSION_TOKEN = get_environ("AWS_SESSION_TOKEN");
let AWS_REGION = get_environ("AWS_REGION");
/**
* Sets the local API keys for the revelant LLM API(s).
@ -155,6 +159,13 @@ export function set_api_keys(api_keys: StringDict): void {
AZURE_OPENAI_ENDPOINT = api_keys.Azure_OpenAI_Endpoint;
if (key_is_present("AlephAlpha")) ALEPH_ALPHA_API_KEY = api_keys.AlephAlpha;
// Soft fail for non-present keys
if (key_is_present("AWS_Access_Key_ID"))
AWS_ACCESS_KEY_ID = api_keys.AWS_Access_Key_ID;
if (key_is_present("AWS_Secret_Access_Key"))
AWS_SECRET_ACCESS_KEY = api_keys.AWS_Secret_Access_Key;
if (key_is_present("AWS_Session_Token"))
AWS_SESSION_TOKEN = api_keys.AWS_Session_Token;
if (key_is_present("AWS_Region")) AWS_REGION = api_keys.AWS_Region;
}
export function get_azure_openai_api_keys(): [
@ -172,8 +183,8 @@ export function get_azure_openai_api_keys(): [
*/
function construct_openai_chat_history(
prompt: string,
chat_history: ChatHistory | undefined,
system_msg: string | undefined,
chat_history?: ChatHistory,
system_msg?: string,
): ChatHistory {
const prompt_msg: ChatMessage = { role: "user", content: prompt };
const sys_msg: ChatMessage[] =
@ -450,7 +461,7 @@ export async function call_anthropic(
// Carry chat history
// :: See https://docs.anthropic.com/claude/docs/human-and-assistant-formatting#use-human-and-assistant-to-put-words-in-claudes-mouth
const chat_history: ChatHistory | undefined = params.chat_history;
const chat_history: ChatHistory | undefined = params?.chat_history;
if (chat_history !== undefined) {
// FOR OLD TEXT COMPLETIONS API ONLY: Carry chat history by prepending it to the prompt
if (!use_messages_api) {
@ -467,7 +478,7 @@ export async function call_anthropic(
}
// For newer models Claude 2.1 and Claude 3, we carry chat history directly below; no need to do anything else.
delete params.chat_history;
delete params?.chat_history;
}
// Format query
@ -687,7 +698,9 @@ export async function call_google_palm(
// We need to detect this and fill the response with the safety reasoning:
if (completion.filters && completion.filters.length > 0) {
// Request was blocked. Output why in the response text, repairing the candidate dict to mock up 'n' responses
const block_error_msg = `[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: ${JSON.stringify(completion.filters)}`;
const block_error_msg = `[[BLOCKED_REQUEST]] Request was blocked because it triggered safety filters: ${JSON.stringify(
completion.filters,
)}`;
completion.candidates = new Array(n).fill({
author: "1",
content: block_error_msg,
@ -924,9 +937,11 @@ export async function call_huggingface(
// Inference Endpoints for text completion models has the same call,
// except the endpoint is an entire URL. Detect this:
const url =
using_custom_model_endpoint && params?.custom_model?.startsWith("https:")
? params?.custom_model
: `https://api-inference.huggingface.co/models/${using_custom_model_endpoint ? params?.custom_model?.trim() : model}`;
using_custom_model_endpoint && params?.custom_model.startsWith("https:")
? params.custom_model
: `https://api-inference.huggingface.co/models/${
using_custom_model_endpoint ? params?.custom_model.trim() : model
}`;
const responses: Array<Dict> = [];
while (responses.length < n) {
@ -1048,6 +1063,11 @@ export async function call_ollama_provider(
params?: Dict,
should_cancel?: () => boolean,
): Promise<[Dict, Dict]> {
if (!params?.ollama_url)
throw Error(
"Could not find a base URL for Ollama model. Double-check that your base URL is set in the model settings.",
);
let url: string = appendEndSlashIfMissing(params?.ollama_url);
const ollama_model: string = params?.ollamaModel.toString();
const model_type: string = params?.model_type ?? "text";
@ -1125,6 +1145,118 @@ export async function call_ollama_provider(
return [query, responses];
}
/** Convert OpenAI chat history to Bedrock format */
function to_bedrock_chat_history(
chat_history: ChatHistory,
): BedrockChatMessage[] {
const role_map = {
assistant: "ai",
user: "human",
};
// Transform the ChatMessage format in the chat_history array to what is expected by Bedrock
return chat_history.map((msg) =>
transformDict(
msg,
undefined,
(key) => (key === "content" ? "message" : key),
(key, val) => {
if (key === "role") return role_map[val] ?? val;
},
),
) as BedrockChatMessage[];
}
/**
* Calls Bedrock models via Bedrock's API.
@returns raw query and response JSON dicts.
*/
export async function call_bedrock(
prompt: string,
model: LLM,
n = 1,
temperature = 1.0,
params?: Dict,
should_cancel?: () => boolean,
): Promise<[Dict, Dict]> {
if (!AWS_ACCESS_KEY_ID && !AWS_SESSION_TOKEN && !AWS_REGION) {
throw new Error(
"Could not find credentials value for the Bedrock API. Double-check that your API key is set in Settings or in your local environment.",
);
}
const modelName: string = model.toString();
let stopWords = [];
if (
!(
params?.stop_sequences !== undefined &&
(!Array.isArray(params.stop_sequences) ||
params.stop_sequences.length === 0)
)
) {
stopWords = params?.stop_sequences ?? [];
}
const bedrockConfig = {
credentials: {
accessKeyId: AWS_ACCESS_KEY_ID,
secretAccessKey: AWS_SECRET_ACCESS_KEY,
sessionToken: AWS_SESSION_TOKEN,
},
region: AWS_REGION,
};
delete params?.stop;
const query: Dict = {
stopSequences: stopWords,
temperature,
topP: params?.top_p ?? 1.0,
maxTokenCount: params?.max_tokens_to_sample ?? 512,
};
const fm = fromModelId(modelName as Models, {
region: bedrockConfig.region ?? "us-west-2",
credentials: bedrockConfig.credentials,
...query,
});
const responses: string[] = [];
try {
// Collect n responses, one at a time
while (responses.length < n) {
// Abort if the user canceled
if (should_cancel && should_cancel()) throw new UserForcedPrematureExit();
// Grab the response
let response: string;
if (modelName.startsWith("anthropic")) {
const chat_history: ChatHistory = construct_openai_chat_history(
prompt,
params?.chat_history,
params?.system_msg,
);
response = (
await fm.chat(to_bedrock_chat_history(chat_history), { ...params })
).message;
} else {
response = await fm.generate(prompt, { ...params });
}
responses.push(response);
}
} catch (error: any) {
console.error("Error", error);
throw new Error(
error?.response?.data?.error?.message ??
error?.message ??
error.toString(),
);
}
return [query, responses];
}
async function call_custom_provider(
prompt: string,
model: LLM,
@ -1205,10 +1337,11 @@ export async function call_llm(
else if (llm_provider === LLMProvider.Aleph_Alpha) call_api = call_alephalpha;
else if (llm_provider === LLMProvider.Ollama) call_api = call_ollama_provider;
else if (llm_provider === LLMProvider.Custom) call_api = call_custom_provider;
else if (llm_provider === LLMProvider.Bedrock) call_api = call_bedrock;
if (call_api === undefined)
throw new Error(`Could not find an API hook for model ${llm}.`);
throw new Error(
`Adapter for Language model ${llm} and ${llm_provider} not found`,
);
return call_api(prompt, llm, n, temperature, params, should_cancel);
}
@ -1358,6 +1491,8 @@ export function extract_responses(
return _extract_alephalpha_responses(response);
case LLMProvider.Ollama:
return _extract_ollama_responses(response as Dict[]);
case LLMProvider.Bedrock:
return response as Array<string>;
default:
if (
Array.isArray(response) &&

View File

@ -3,11 +3,14 @@ import ReactDOM from "react-dom/client";
import "./index.css";
import App from "./App";
import reportWebVitals from "./reportWebVitals";
import { ContextMenuProvider } from "mantine-contextmenu";
const root = ReactDOM.createRoot(document.getElementById("root"));
root.render(
<React.StrictMode>
<App />
<ContextMenuProvider>
<App />
</ContextMenuProvider>
</React.StrictMode>,
);

View File

@ -68,15 +68,27 @@ const refreshableOutputNodeTypes = new Set([
"split",
]);
export const initLLMProviders = [
export const initLLMProviderMenu = [
{
name: "GPT3.5",
group: "OpenAI",
emoji: "🤖",
model: "gpt-3.5-turbo",
base_model: "gpt-3.5-turbo",
temp: 1.0,
}, // The base_model designates what settings form will be used, and must be unique.
{ name: "GPT4", emoji: "🥵", model: "gpt-4", base_model: "gpt-4", temp: 1.0 },
items: [
{
name: "GPT3.5",
emoji: "🤖",
model: "gpt-3.5-turbo",
base_model: "gpt-3.5-turbo",
temp: 1.0,
}, // The base_model designates what settings form will be used, and must be unique.
{
name: "GPT4",
emoji: "🥵",
model: "gpt-4",
base_model: "gpt-4",
temp: 1.0,
},
],
},
{
name: "Claude",
emoji: "📚",
@ -92,11 +104,24 @@ export const initLLMProviders = [
temp: 0.7,
},
{
name: "HuggingFace",
group: "HuggingFace",
emoji: "🤗",
model: "tiiuae/falcon-7b-instruct",
base_model: "hf",
temp: 1.0,
items: [
{
name: "Mistral.7B",
emoji: "🤗",
model: "mistralai/Mistral-7B-Instruct-v0.1",
base_model: "hf",
temp: 1.0,
},
{
name: "Falcon.7B",
emoji: "🤗",
model: "tiiuae/falcon-7b-instruct",
base_model: "hf",
temp: 1.0,
},
],
},
{
name: "Aleph Alpha",
@ -112,19 +137,78 @@ export const initLLMProviders = [
base_model: "azure-openai",
temp: 1.0,
},
{
group: "Bedrock",
emoji: "🪨",
items: [
{
name: "Anthropic Claude",
emoji: "👨‍🏫",
model: "anthropic.claude-v2:1",
base_model: "br.anthropic.claude",
temp: 0.9,
},
{
name: "AI21 Jurassic 2",
emoji: "🦖",
model: "ai21.j2-ultra",
base_model: "br.ai21.j2",
temp: 0.9,
},
{
name: "Amazon Titan",
emoji: "🏛️",
model: "amazon.titan-tg1-large",
base_model: "br.amazon.titan",
temp: 0.9,
},
{
name: "Cohere Command Text 14",
emoji: "📚",
model: "cohere.command-text-v14",
base_model: "br.cohere.command",
temp: 0.9,
},
{
name: "Mistral Mistral",
emoji: "💨",
model: "mistral.mistral-7b-instruct-v0:2",
base_model: "br.mistral.mistral",
temp: 0.9,
},
{
name: "Mistral Mixtral",
emoji: "🌪️",
model: "mistral.mixtral-8x7b-instruct-v0:1",
base_model: "br.mistral.mixtral",
temp: 0.9,
},
{
name: "Meta Llama2 Chat",
emoji: "🦙",
model: "meta.llama2-13b-chat-v1",
base_model: "br.meta.llama2",
temp: 0.9,
},
],
},
];
if (APP_IS_RUNNING_LOCALLY()) {
initLLMProviders.push({
initLLMProviderMenu.push({
name: "Ollama",
emoji: "🦙",
model: "ollama",
base_model: "ollama",
provider: null,
temp: 1.0,
});
// -- Deprecated provider --
// initLLMProviders.push({ name: "Dalai (Alpaca.7B)", emoji: "🦙", model: "alpaca.7B", base_model: "dalai", temp: 0.5 });
// -------------------------
}
export const initLLMProviders = initLLMProviderMenu
.map((item) => (item.group !== undefined ? item.items : item))
.flat();
// A global store of variables, used for maintaining state
// across ChainForge and ReactFlow components.
@ -138,6 +222,11 @@ const useStore = create((set, get) => ({
set({ AvailableLLMs: llmProviderList });
},
aiFeaturesProvider: "OpenAI",
setAIFeaturesProvider: (llmProvider) => {
set({ aiFeaturesProvider: llmProvider });
},
// Keeping track of LLM API keys
apiKeys: initialAPIKeys,
setAPIKeys: (apiKeys) => {
@ -319,7 +408,6 @@ const useStore = create((set, get) => ({
} else {
// Get the data related to that handle:
if ("fields" in src_node.data) {
console.log(src_node.data);
if (Array.isArray(src_node.data.fields)) return src_node.data.fields;
else {
// We have to filter over a special 'fields_visibility' prop, which

View File

@ -6,7 +6,7 @@ def readme():
setup(
name='chainforge',
version='0.3.0.6',
version='0.3.0.7',
packages=find_packages(),
author="Ian Arawjo",
description="A Visual Programming Environment for Prompt Engineering",