mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 08:16:37 +00:00
Fix bugs. Change OpenAI small model for GenAI features to GPT-4o.
This commit is contained in:
parent
29278d5f3d
commit
4718fe55ba
@ -11,6 +11,7 @@ import {
|
||||
Textarea,
|
||||
Alert,
|
||||
Divider,
|
||||
Tooltip,
|
||||
} from "@mantine/core";
|
||||
import {
|
||||
autofill,
|
||||
@ -239,6 +240,8 @@ export function AIPopover({
|
||||
export interface AIGenReplaceTablePopoverProps {
|
||||
// Values in the rows of the table's columns
|
||||
values: TabularDataRowType[];
|
||||
// Names of the table's columns
|
||||
colValues: TabularDataColType[];
|
||||
// Function to add new rows
|
||||
onAddRows: (newRows: TabularDataRowType[]) => void;
|
||||
// Function to replace the table
|
||||
@ -262,6 +265,7 @@ export interface AIGenReplaceTablePopoverProps {
|
||||
*/
|
||||
export function AIGenReplaceTablePopover({
|
||||
values,
|
||||
colValues,
|
||||
onAddRows,
|
||||
onReplaceTable,
|
||||
onAddColumns,
|
||||
@ -276,12 +280,12 @@ export function AIGenReplaceTablePopover({
|
||||
const showAlert = useContext(AlertModalContext);
|
||||
|
||||
// Command Fill state
|
||||
const [commandFillNumber, setCommandFillNumber] = useState<number>(3);
|
||||
const [commandFillNumber, setCommandFillNumber] = useState<number>(5);
|
||||
const [isCommandFillLoading, setIsCommandFillLoading] = useState(false);
|
||||
const [didCommandFillError, setDidCommandFillError] = useState(false);
|
||||
|
||||
// Generate and Replace state
|
||||
const [generateAndReplaceNumber, setGenerateAndReplaceNumber] = useState(3);
|
||||
const [generateAndReplaceNumber, setGenerateAndReplaceNumber] = useState(5);
|
||||
const [generateAndReplacePrompt, setGenerateAndReplacePrompt] = useState("");
|
||||
const [genDiverseOutputs, setGenDiverseOutputs] = useState(false);
|
||||
const [didGenerateAndReplaceTableError, setDidGenerateAndReplaceTableError] =
|
||||
@ -360,9 +364,7 @@ export function AIGenReplaceTablePopover({
|
||||
|
||||
try {
|
||||
// Extract columns from the values, excluding the __uid column
|
||||
const tableColumns = Object.keys(values[0] || {}).filter(
|
||||
(col) => col !== "__uid",
|
||||
);
|
||||
const tableColumns = colValues.map((col) => col.key);
|
||||
|
||||
// Extract rows as strings, excluding the __uid column and handling empty rows
|
||||
const tableRows = values
|
||||
@ -410,9 +412,7 @@ export function AIGenReplaceTablePopover({
|
||||
|
||||
try {
|
||||
// Extract columns from the values, excluding the __uid column
|
||||
const tableColumns = Object.keys(values[0] || {}).filter(
|
||||
(col) => col !== "__uid",
|
||||
);
|
||||
const tableColumns = colValues;
|
||||
|
||||
// Extract rows as strings, excluding the __uid column and handling empty rows
|
||||
const lastRow = values[values.length - 1]; // Get the last row
|
||||
@ -420,13 +420,14 @@ export function AIGenReplaceTablePopover({
|
||||
const tableRows = values
|
||||
.slice(0, emptyLastRow ? -1 : values.length)
|
||||
.map((row) =>
|
||||
tableColumns.map((col) => row[col]?.trim() || "").join(" | "),
|
||||
tableColumns.map((col) => row[col.key]?.trim() || "").join(" | "),
|
||||
);
|
||||
|
||||
const tableInput = {
|
||||
cols: tableColumns,
|
||||
rows: tableRows,
|
||||
};
|
||||
|
||||
// Fetch the generated column
|
||||
const generatedColumn = await generateColumn(
|
||||
tableInput,
|
||||
@ -497,17 +498,23 @@ export function AIGenReplaceTablePopover({
|
||||
value={generateColumnPrompt}
|
||||
onChange={(e) => setGenerateColumnPrompt(e.currentTarget.value)}
|
||||
/>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="light"
|
||||
color="grape"
|
||||
fullWidth
|
||||
onClick={handleGenerateColumn}
|
||||
disabled={!enoughRowsForSuggestions}
|
||||
loading={isGenerateColumnLoading}
|
||||
<Tooltip
|
||||
label="Can take awhile if you have many rows. Please be patient."
|
||||
withArrow
|
||||
position="bottom"
|
||||
>
|
||||
Add Column
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="light"
|
||||
color="grape"
|
||||
fullWidth
|
||||
onClick={handleGenerateColumn}
|
||||
disabled={!enoughRowsForSuggestions}
|
||||
loading={isGenerateColumnLoading}
|
||||
>
|
||||
Add Column
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</Stack>
|
||||
);
|
||||
|
||||
@ -526,7 +533,7 @@ export function AIGenReplaceTablePopover({
|
||||
<NumberInput
|
||||
label="Rows to generate"
|
||||
min={1}
|
||||
max={10}
|
||||
max={50}
|
||||
value={generateAndReplaceNumber}
|
||||
onChange={(num) => setGenerateAndReplaceNumber(num || 1)}
|
||||
/>
|
||||
|
@ -518,7 +518,7 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
|
||||
/>
|
||||
<Select
|
||||
label="LLM Provider"
|
||||
description="The LLM provider to use for generative AI features. Currently only supports OpenAI and Bedrock (Anthropic). OpenAI will query gpt-3.5 and gpt-4 models. Bedrock will query Claude-3 models. You must have set the relevant API keys to use the provider."
|
||||
description="The LLM provider to use for generative AI features. Currently only supports OpenAI and Bedrock (Anthropic). OpenAI will query gpt-4o and gpt-4 models. Bedrock will query Claude-3 models. You must have set the relevant API keys to use the provider."
|
||||
dropdownPosition="bottom"
|
||||
withinPortal
|
||||
defaultValue={getAIFeaturesModelProviders()[0]}
|
||||
|
@ -25,7 +25,6 @@ import { sampleRandomElements } from "./backend/utils";
|
||||
import { Dict, TabularDataRowType, TabularDataColType } from "./backend/typing";
|
||||
import { Position } from "reactflow";
|
||||
import { AIGenReplaceTablePopover } from "./AiPopover";
|
||||
import AISuggestionsManager from "./backend/aiSuggestionsManager";
|
||||
import { parseTableData } from "./backend/tableUtils";
|
||||
|
||||
const defaultRows: TabularDataRowType[] = [
|
||||
@ -476,12 +475,13 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
|
||||
// Function to add new columns to the right of the existing columns (with optional row values)
|
||||
const addColumns = (
|
||||
newColumns: TabularDataColType[],
|
||||
rowValues?: string[] // If values are passed, they will be used to populate the new columns
|
||||
rowValues?: string[], // If values are passed, they will be used to populate the new columns
|
||||
) => {
|
||||
setTableColumns((prevColumns) => {
|
||||
// Filter out columns that already exist
|
||||
const filteredNewColumns = newColumns.filter(
|
||||
(col) => !prevColumns.some((existingCol) => existingCol.key === col.key)
|
||||
(col) =>
|
||||
!prevColumns.some((existingCol) => existingCol.key === col.key),
|
||||
);
|
||||
|
||||
// If no genuinely new columns, return previous columns
|
||||
@ -538,7 +538,12 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
|
||||
const addMultipleRows = (newRows: TabularDataRowType[]) => {
|
||||
setTableData((prev) => {
|
||||
// Remove the last row of the current table data as it is a blank row (if table is not empty)
|
||||
const newTableData = prev.length > 0 ? prev.slice(0, -1) : [];
|
||||
let newTableData = prev;
|
||||
if (prev.length > 0) {
|
||||
const lastRow = prev[prev.length - 1]; // Get the last row
|
||||
const emptyLastRow = Object.values(lastRow).every((val) => !val); // Check if the last row is empty
|
||||
if (emptyLastRow) newTableData = prev.slice(0, -1); // Remove the last row if it is empty
|
||||
}
|
||||
|
||||
// Add the new rows to the table
|
||||
const addedRows = newRows.map((value) => {
|
||||
@ -611,6 +616,7 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
|
||||
<AIGenReplaceTablePopover
|
||||
key="ai-popover"
|
||||
values={tableData}
|
||||
colValues={tableColumns}
|
||||
onAddRows={addMultipleRows}
|
||||
onAddColumns={addColumns}
|
||||
onReplaceTable={replaceTable}
|
||||
|
@ -7,8 +7,9 @@ import {
|
||||
escapeBraces,
|
||||
containsSameTemplateVariables,
|
||||
} from "./template";
|
||||
import { ChatHistoryInfo, Dict } from "./typing";
|
||||
import { ChatHistoryInfo, Dict, TabularDataColType } from "./typing";
|
||||
import { fromMarkdown } from "mdast-util-from-markdown";
|
||||
import { sampleRandomElements } from "./utils";
|
||||
|
||||
export class AIError extends Error {
|
||||
constructor(message: string) {
|
||||
@ -24,7 +25,7 @@ export type Row = string;
|
||||
const AIFeaturesLLMs = [
|
||||
{
|
||||
provider: "OpenAI",
|
||||
small: { value: "gpt-3.5-turbo", label: "OpenAI GPT3.5" },
|
||||
small: { value: "gpt-4o", label: "OpenAI GPT4o" },
|
||||
large: { value: "gpt-4", label: "OpenAI GPT4" },
|
||||
},
|
||||
{
|
||||
@ -116,11 +117,8 @@ function autofillSystemMessage(
|
||||
* @param n number of rows to generate
|
||||
* @param templateVariables list of template variables to use
|
||||
*/
|
||||
function autofillTableSystemMessage(
|
||||
n: number,
|
||||
templateVariables?: string[],
|
||||
): string {
|
||||
return `Here is a table. Generate ${n} more commands or items following the pattern. You must format your response as a markdown table with labeled columns and a divider with only the next ${n} generated commands or items of the table. ${templateVariables && templateVariables.length > 0 ? templateVariableMessage(templateVariables) : ""}`;
|
||||
function autofillTableSystemMessage(n: number): string {
|
||||
return `Here is a table. Generate ${n} more commands or items following the pattern. You must format your response as a markdown table with labeled columns and a divider with only the next ${n} generated commands or items of the table.`;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -347,26 +345,23 @@ export async function autofillTable(
|
||||
provider: string,
|
||||
apiKeys: Dict,
|
||||
): Promise<{ cols: string[]; rows: Row[] }> {
|
||||
// Get a random sample of the table rows, if there are more than 30 (as an estimate):
|
||||
// TODO: This is a temporary solution to avoid sending large tables to the LLM. In future, check the number of characters too.
|
||||
const sampleRows =
|
||||
input.rows.length > 30 ? sampleRandomElements(input.rows, 30) : input.rows;
|
||||
|
||||
// Hash the arguments to get a unique id
|
||||
const id = JSON.stringify([input.cols, input.rows, n]);
|
||||
const id = JSON.stringify([input.cols, sampleRows, n]);
|
||||
|
||||
// Encode the input table to a markdown table
|
||||
const encoded = encodeTable(input.cols, input.rows);
|
||||
|
||||
// Extract template variables from the columns and rows
|
||||
const templateVariables = [
|
||||
...new Set([
|
||||
...new StringTemplate(input.rows.join("\n")).get_vars(),
|
||||
...new StringTemplate(input.cols.join("\n")).get_vars(),
|
||||
]),
|
||||
];
|
||||
const encoded = encodeTable(input.cols, sampleRows);
|
||||
|
||||
const history: ChatHistoryInfo[] = [
|
||||
{
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: autofillTableSystemMessage(n, templateVariables),
|
||||
content: autofillTableSystemMessage(n),
|
||||
},
|
||||
],
|
||||
fill_history: {},
|
||||
@ -415,17 +410,18 @@ async function fillMissingFieldForRow(
|
||||
apiKeys: Dict,
|
||||
): Promise<string> {
|
||||
// Generate a user prompt for the LLM pass over existing row data in list format
|
||||
const userPrompt = `
|
||||
You are given partial data for a row in a table. Here is the data:
|
||||
${Object.entries(existingRowData)
|
||||
.map(([key, val]) => `- ${key}: ${val}`)
|
||||
.join("\n")}
|
||||
// const userPrompt = `You are given partial data for a row of a table. Here is the data:
|
||||
// ${Object.entries(existingRowData)
|
||||
// .map(([key, val]) => `- ${key}: ${val}`)
|
||||
// .join("\n")}
|
||||
|
||||
This is the requirement of the new column:"${prompt}", produce a single appropriate value for the item.
|
||||
// This is the requirement of the new column: "${prompt}". Produce an appropriate value for the item. Respond with just the new field's value, and nothing else.`;
|
||||
|
||||
|
||||
Respond with just the new field’s value, and nothing else.
|
||||
`;
|
||||
const userPrompt = `Fill in the last piece of information. Respond with just the missing information, nothing else.
|
||||
${Object.entries(existingRowData)
|
||||
.map(([key, val]) => `${key}: ${val}`)
|
||||
.join("\n")}
|
||||
${prompt}: ?`;
|
||||
|
||||
const history: ChatHistoryInfo[] = [
|
||||
{
|
||||
@ -452,6 +448,8 @@ async function fillMissingFieldForRow(
|
||||
true,
|
||||
);
|
||||
|
||||
console.log("LLM said: ", result.responses[0].responses[0]);
|
||||
|
||||
// Handle any errors in the response
|
||||
if (result.errors && Object.keys(result.errors).length > 0) {
|
||||
throw new AIError(Object.values(result.errors)[0].toString());
|
||||
@ -469,12 +467,11 @@ async function fillMissingFieldForRow(
|
||||
* @returns A promise resolving to an array of strings (column values).
|
||||
*/
|
||||
export async function generateColumn(
|
||||
tableData: { cols: string[]; rows: string[] },
|
||||
tableData: { cols: TabularDataColType[]; rows: string[] },
|
||||
prompt: string,
|
||||
provider: string,
|
||||
apiKeys: Dict,
|
||||
): Promise<{ col: string; rows: string[] }> {
|
||||
|
||||
// If the length of the prompt is less than 20 characters, use the prompt
|
||||
// Else, use the LLM to generate an appropriate column name for the prompt
|
||||
let colName: string;
|
||||
@ -485,23 +482,27 @@ export async function generateColumn(
|
||||
JSON.stringify([prompt]),
|
||||
getAIFeaturesModels(provider).small,
|
||||
1,
|
||||
`Generate an appropriate column name for the prompt: "${prompt}"`,
|
||||
`You produce column names for a table. The column names must be short, less than 20 characters, and in natural language, like "Column Name." Return only the column name. Generate an appropriate column name for the prompt: "${prompt}"`,
|
||||
{},
|
||||
[],
|
||||
[],
|
||||
apiKeys,
|
||||
true,
|
||||
);
|
||||
colName = result.responses[0].responses[0] as string;
|
||||
colName = (result.responses[0].responses[0] as string).replace("_", " ");
|
||||
}
|
||||
|
||||
// Remove any leading/trailing whitespace from the column name as well as any double quotes
|
||||
colName = colName.trim().replace(/"/g, "");
|
||||
|
||||
// Parse the existing table into mark down row objects
|
||||
const columnNames = tableData.cols;
|
||||
const columnNames = tableData.cols.map((col) => col.header);
|
||||
const parsedRows = tableData.rows.map((rowStr) => {
|
||||
// Remove leading/trailing "|" along with any whitespace
|
||||
const cells = rowStr.replace(/^\|/, "").replace(/\|$/, "").split("|").map((cell) => cell.trim());
|
||||
const cells = rowStr
|
||||
.replace(/^\|/, "")
|
||||
.replace(/\|$/, "")
|
||||
.split("|")
|
||||
.map((cell) => cell.trim());
|
||||
const rowData: Record<string, string> = {};
|
||||
columnNames.forEach((colName, index) => {
|
||||
rowData[colName] = cells[index] || "";
|
||||
@ -517,7 +518,7 @@ export async function generateColumn(
|
||||
rowData,
|
||||
prompt,
|
||||
provider,
|
||||
apiKeys
|
||||
apiKeys,
|
||||
);
|
||||
newColumnValues.push(newValue);
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ export function parseTableData(rawTableData: any[]): {
|
||||
columns: TabularDataColType[];
|
||||
rows: TabularDataRowType[];
|
||||
} {
|
||||
|
||||
if (!Array.isArray(rawTableData)) {
|
||||
throw new Error(
|
||||
"Table data is not in array format: " +
|
||||
|
Loading…
x
Reference in New Issue
Block a user