Implement generate and replace backend function

This commit is contained in:
Sean Yang 2023-10-16 13:40:36 -04:00
parent 410d452552
commit 418308041a
2 changed files with 58 additions and 1 deletions

View File

@ -1,4 +1,4 @@
import { autofill } from "../ai";
import { autofill, generateAndReplace } from "../ai";
describe("autofill", () => {
it("should return an array of n rows", async () => {
@ -11,3 +11,15 @@ describe("autofill", () => {
});
});
});
describe("generateAndReplace", () => {
it("should return an array of n rows", async () => {
const prompt = "animals";
const n = 3;
const result = await generateAndReplace(prompt, n);
expect(result).toHaveLength(n);
result.forEach((row) => {
expect(typeof row).toBe("string");
});
});
});

View File

@ -28,6 +28,21 @@ function autofillSystemMessage(n: number): string {
</rows>`;
}
/**
* Generate the system message used for generate and replace (GAR).
*/
function GARSystemMessage(n: number): string {
return `Pretend you are an autofill system helping to fill out a spreadsheet column. Here is the pattern you should follow in <pattern>. Generate exactly ${n} rows following the pattern. Format your response in XML using the <row> and <rows> tag. Do not ever repeat anything. Here is an example of the structure that your response should follow:
<rows>
<row>first row</row>
<row>second row</row>
<row>third row</row>
<row>fourth row</row>
<row>fifth row</row>
</rows>`;
}
/**
* Returns an XML string representing the given rows using the <rows> and <row> tags.
* @param rows to encode
@ -84,5 +99,35 @@ export async function autofill(input: Row[], n: number): Promise<Row[]> {
/*vars=*/ {},
/*chat_history=*/ history);
return decode(result.responses[0].responses[0])
}
/**
* Uses an LLM to generate `n` new rows based on the pattern explained in `prompt`.
* @param prompt
* @param n
*/
export async function generateAndReplace(prompt: string, n: number) {
// hash the arguments to get a unique id
let id = JSON.stringify([prompt, n]);
let history: ChatHistoryInfo[] = [{
messages: [{
"role": "system",
"content": GARSystemMessage(n),
}],
fill_history: {},
}]
let input = `<pattern>${prompt}</pattern>`;
let result = await queryLLM(
/*id=*/ id,
/*llm=*/ LLM,
/*n=*/ 1,
/*prompt=*/ input,
/*vars=*/ {},
/*chat_history=*/ history);
return decode(result.responses[0].responses[0])
}