Fix bugs. Change OpenAI small model for GenAI features to GPT-4o.

This commit is contained in:
Ian Arawjo 2024-12-19 15:20:32 -05:00
parent 29278d5f3d
commit 4718fe55ba
5 changed files with 74 additions and 61 deletions

View File

@ -11,6 +11,7 @@ import {
Textarea,
Alert,
Divider,
Tooltip,
} from "@mantine/core";
import {
autofill,
@ -239,6 +240,8 @@ export function AIPopover({
export interface AIGenReplaceTablePopoverProps {
// Values in the rows of the table's columns
values: TabularDataRowType[];
// Names of the table's columns
colValues: TabularDataColType[];
// Function to add new rows
onAddRows: (newRows: TabularDataRowType[]) => void;
// Function to replace the table
@ -262,6 +265,7 @@ export interface AIGenReplaceTablePopoverProps {
*/
export function AIGenReplaceTablePopover({
values,
colValues,
onAddRows,
onReplaceTable,
onAddColumns,
@ -276,12 +280,12 @@ export function AIGenReplaceTablePopover({
const showAlert = useContext(AlertModalContext);
// Command Fill state
const [commandFillNumber, setCommandFillNumber] = useState<number>(3);
const [commandFillNumber, setCommandFillNumber] = useState<number>(5);
const [isCommandFillLoading, setIsCommandFillLoading] = useState(false);
const [didCommandFillError, setDidCommandFillError] = useState(false);
// Generate and Replace state
const [generateAndReplaceNumber, setGenerateAndReplaceNumber] = useState(3);
const [generateAndReplaceNumber, setGenerateAndReplaceNumber] = useState(5);
const [generateAndReplacePrompt, setGenerateAndReplacePrompt] = useState("");
const [genDiverseOutputs, setGenDiverseOutputs] = useState(false);
const [didGenerateAndReplaceTableError, setDidGenerateAndReplaceTableError] =
@ -360,9 +364,7 @@ export function AIGenReplaceTablePopover({
try {
// Extract columns from the values, excluding the __uid column
const tableColumns = Object.keys(values[0] || {}).filter(
(col) => col !== "__uid",
);
const tableColumns = colValues.map((col) => col.key);
// Extract rows as strings, excluding the __uid column and handling empty rows
const tableRows = values
@ -410,9 +412,7 @@ export function AIGenReplaceTablePopover({
try {
// Extract columns from the values, excluding the __uid column
const tableColumns = Object.keys(values[0] || {}).filter(
(col) => col !== "__uid",
);
const tableColumns = colValues;
// Extract rows as strings, excluding the __uid column and handling empty rows
const lastRow = values[values.length - 1]; // Get the last row
@ -420,13 +420,14 @@ export function AIGenReplaceTablePopover({
const tableRows = values
.slice(0, emptyLastRow ? -1 : values.length)
.map((row) =>
tableColumns.map((col) => row[col]?.trim() || "").join(" | "),
tableColumns.map((col) => row[col.key]?.trim() || "").join(" | "),
);
const tableInput = {
cols: tableColumns,
rows: tableRows,
};
// Fetch the generated column
const generatedColumn = await generateColumn(
tableInput,
@ -497,17 +498,23 @@ export function AIGenReplaceTablePopover({
value={generateColumnPrompt}
onChange={(e) => setGenerateColumnPrompt(e.currentTarget.value)}
/>
<Button
size="sm"
variant="light"
color="grape"
fullWidth
onClick={handleGenerateColumn}
disabled={!enoughRowsForSuggestions}
loading={isGenerateColumnLoading}
<Tooltip
label="Can take awhile if you have many rows. Please be patient."
withArrow
position="bottom"
>
Add Column
</Button>
<Button
size="sm"
variant="light"
color="grape"
fullWidth
onClick={handleGenerateColumn}
disabled={!enoughRowsForSuggestions}
loading={isGenerateColumnLoading}
>
Add Column
</Button>
</Tooltip>
</Stack>
);
@ -526,7 +533,7 @@ export function AIGenReplaceTablePopover({
<NumberInput
label="Rows to generate"
min={1}
max={10}
max={50}
value={generateAndReplaceNumber}
onChange={(num) => setGenerateAndReplaceNumber(num || 1)}
/>

View File

@ -518,7 +518,7 @@ const GlobalSettingsModal = forwardRef<GlobalSettingsModalRef, object>(
/>
<Select
label="LLM Provider"
description="The LLM provider to use for generative AI features. Currently only supports OpenAI and Bedrock (Anthropic). OpenAI will query gpt-3.5 and gpt-4 models. Bedrock will query Claude-3 models. You must have set the relevant API keys to use the provider."
description="The LLM provider to use for generative AI features. Currently only supports OpenAI and Bedrock (Anthropic). OpenAI will query gpt-4o and gpt-4 models. Bedrock will query Claude-3 models. You must have set the relevant API keys to use the provider."
dropdownPosition="bottom"
withinPortal
defaultValue={getAIFeaturesModelProviders()[0]}

View File

@ -25,7 +25,6 @@ import { sampleRandomElements } from "./backend/utils";
import { Dict, TabularDataRowType, TabularDataColType } from "./backend/typing";
import { Position } from "reactflow";
import { AIGenReplaceTablePopover } from "./AiPopover";
import AISuggestionsManager from "./backend/aiSuggestionsManager";
import { parseTableData } from "./backend/tableUtils";
const defaultRows: TabularDataRowType[] = [
@ -476,12 +475,13 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
// Function to add new columns to the right of the existing columns (with optional row values)
const addColumns = (
newColumns: TabularDataColType[],
rowValues?: string[] // If values are passed, they will be used to populate the new columns
rowValues?: string[], // If values are passed, they will be used to populate the new columns
) => {
setTableColumns((prevColumns) => {
// Filter out columns that already exist
const filteredNewColumns = newColumns.filter(
(col) => !prevColumns.some((existingCol) => existingCol.key === col.key)
(col) =>
!prevColumns.some((existingCol) => existingCol.key === col.key),
);
// If no genuinely new columns, return previous columns
@ -538,7 +538,12 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
const addMultipleRows = (newRows: TabularDataRowType[]) => {
setTableData((prev) => {
// Remove the last row of the current table data as it is a blank row (if table is not empty)
const newTableData = prev.length > 0 ? prev.slice(0, -1) : [];
let newTableData = prev;
if (prev.length > 0) {
const lastRow = prev[prev.length - 1]; // Get the last row
const emptyLastRow = Object.values(lastRow).every((val) => !val); // Check if the last row is empty
if (emptyLastRow) newTableData = prev.slice(0, -1); // Remove the last row if it is empty
}
// Add the new rows to the table
const addedRows = newRows.map((value) => {
@ -611,6 +616,7 @@ const TabularDataNode: React.FC<TabularDataNodeProps> = ({ data, id }) => {
<AIGenReplaceTablePopover
key="ai-popover"
values={tableData}
colValues={tableColumns}
onAddRows={addMultipleRows}
onAddColumns={addColumns}
onReplaceTable={replaceTable}

View File

@ -7,8 +7,9 @@ import {
escapeBraces,
containsSameTemplateVariables,
} from "./template";
import { ChatHistoryInfo, Dict } from "./typing";
import { ChatHistoryInfo, Dict, TabularDataColType } from "./typing";
import { fromMarkdown } from "mdast-util-from-markdown";
import { sampleRandomElements } from "./utils";
export class AIError extends Error {
constructor(message: string) {
@ -24,7 +25,7 @@ export type Row = string;
const AIFeaturesLLMs = [
{
provider: "OpenAI",
small: { value: "gpt-3.5-turbo", label: "OpenAI GPT3.5" },
small: { value: "gpt-4o", label: "OpenAI GPT4o" },
large: { value: "gpt-4", label: "OpenAI GPT4" },
},
{
@ -116,11 +117,8 @@ function autofillSystemMessage(
* @param n number of rows to generate
* @param templateVariables list of template variables to use
*/
function autofillTableSystemMessage(
n: number,
templateVariables?: string[],
): string {
return `Here is a table. Generate ${n} more commands or items following the pattern. You must format your response as a markdown table with labeled columns and a divider with only the next ${n} generated commands or items of the table. ${templateVariables && templateVariables.length > 0 ? templateVariableMessage(templateVariables) : ""}`;
function autofillTableSystemMessage(n: number): string {
return `Here is a table. Generate ${n} more commands or items following the pattern. You must format your response as a markdown table with labeled columns and a divider with only the next ${n} generated commands or items of the table.`;
}
/**
@ -347,26 +345,23 @@ export async function autofillTable(
provider: string,
apiKeys: Dict,
): Promise<{ cols: string[]; rows: Row[] }> {
// Get a random sample of the table rows, if there are more than 30 (as an estimate):
// TODO: This is a temporary solution to avoid sending large tables to the LLM. In future, check the number of characters too.
const sampleRows =
input.rows.length > 30 ? sampleRandomElements(input.rows, 30) : input.rows;
// Hash the arguments to get a unique id
const id = JSON.stringify([input.cols, input.rows, n]);
const id = JSON.stringify([input.cols, sampleRows, n]);
// Encode the input table to a markdown table
const encoded = encodeTable(input.cols, input.rows);
// Extract template variables from the columns and rows
const templateVariables = [
...new Set([
...new StringTemplate(input.rows.join("\n")).get_vars(),
...new StringTemplate(input.cols.join("\n")).get_vars(),
]),
];
const encoded = encodeTable(input.cols, sampleRows);
const history: ChatHistoryInfo[] = [
{
messages: [
{
role: "system",
content: autofillTableSystemMessage(n, templateVariables),
content: autofillTableSystemMessage(n),
},
],
fill_history: {},
@ -415,17 +410,18 @@ async function fillMissingFieldForRow(
apiKeys: Dict,
): Promise<string> {
// Generate a user prompt for the LLM pass over existing row data in list format
const userPrompt = `
You are given partial data for a row in a table. Here is the data:
${Object.entries(existingRowData)
.map(([key, val]) => `- ${key}: ${val}`)
.join("\n")}
// const userPrompt = `You are given partial data for a row of a table. Here is the data:
// ${Object.entries(existingRowData)
// .map(([key, val]) => `- ${key}: ${val}`)
// .join("\n")}
This is the requirement of the new column:"${prompt}", produce a single appropriate value for the item.
// This is the requirement of the new column: "${prompt}". Produce an appropriate value for the item. Respond with just the new field's value, and nothing else.`;
Respond with just the new fields value, and nothing else.
`;
const userPrompt = `Fill in the last piece of information. Respond with just the missing information, nothing else.
${Object.entries(existingRowData)
.map(([key, val]) => `${key}: ${val}`)
.join("\n")}
${prompt}: ?`;
const history: ChatHistoryInfo[] = [
{
@ -452,6 +448,8 @@ async function fillMissingFieldForRow(
true,
);
console.log("LLM said: ", result.responses[0].responses[0]);
// Handle any errors in the response
if (result.errors && Object.keys(result.errors).length > 0) {
throw new AIError(Object.values(result.errors)[0].toString());
@ -469,12 +467,11 @@ async function fillMissingFieldForRow(
* @returns A promise resolving to an array of strings (column values).
*/
export async function generateColumn(
tableData: { cols: string[]; rows: string[] },
tableData: { cols: TabularDataColType[]; rows: string[] },
prompt: string,
provider: string,
apiKeys: Dict,
): Promise<{ col: string; rows: string[] }> {
// If the length of the prompt is less than 20 characters, use the prompt
// Else, use the LLM to generate an appropriate column name for the prompt
let colName: string;
@ -485,23 +482,27 @@ export async function generateColumn(
JSON.stringify([prompt]),
getAIFeaturesModels(provider).small,
1,
`Generate an appropriate column name for the prompt: "${prompt}"`,
`You produce column names for a table. The column names must be short, less than 20 characters, and in natural language, like "Column Name." Return only the column name. Generate an appropriate column name for the prompt: "${prompt}"`,
{},
[],
[],
apiKeys,
true,
);
colName = result.responses[0].responses[0] as string;
colName = (result.responses[0].responses[0] as string).replace("_", " ");
}
// Remove any leading/trailing whitespace from the column name as well as any double quotes
colName = colName.trim().replace(/"/g, "");
// Parse the existing table into mark down row objects
const columnNames = tableData.cols;
const columnNames = tableData.cols.map((col) => col.header);
const parsedRows = tableData.rows.map((rowStr) => {
// Remove leading/trailing "|" along with any whitespace
const cells = rowStr.replace(/^\|/, "").replace(/\|$/, "").split("|").map((cell) => cell.trim());
const cells = rowStr
.replace(/^\|/, "")
.replace(/\|$/, "")
.split("|")
.map((cell) => cell.trim());
const rowData: Record<string, string> = {};
columnNames.forEach((colName, index) => {
rowData[colName] = cells[index] || "";
@ -517,7 +518,7 @@ export async function generateColumn(
rowData,
prompt,
provider,
apiKeys
apiKeys,
);
newColumnValues.push(newValue);
}

View File

@ -8,7 +8,6 @@ export function parseTableData(rawTableData: any[]): {
columns: TabularDataColType[];
rows: TabularDataRowType[];
} {
if (!Array.isArray(rawTableData)) {
throw new Error(
"Table data is not in array format: " +