mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Implement generate and replace backend function
This commit is contained in:
parent
410d452552
commit
418308041a
@ -1,4 +1,4 @@
|
||||
import { autofill } from "../ai";
|
||||
import { autofill, generateAndReplace } from "../ai";
|
||||
|
||||
describe("autofill", () => {
|
||||
it("should return an array of n rows", async () => {
|
||||
@ -11,3 +11,15 @@ describe("autofill", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("generateAndReplace", () => {
|
||||
it("should return an array of n rows", async () => {
|
||||
const prompt = "animals";
|
||||
const n = 3;
|
||||
const result = await generateAndReplace(prompt, n);
|
||||
expect(result).toHaveLength(n);
|
||||
result.forEach((row) => {
|
||||
expect(typeof row).toBe("string");
|
||||
});
|
||||
});
|
||||
});
|
@ -28,6 +28,21 @@ function autofillSystemMessage(n: number): string {
|
||||
</rows>`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate the system message used for generate and replace (GAR).
|
||||
*/
|
||||
function GARSystemMessage(n: number): string {
|
||||
return `Pretend you are an autofill system helping to fill out a spreadsheet column. Here is the pattern you should follow in <pattern>. Generate exactly ${n} rows following the pattern. Format your response in XML using the <row> and <rows> tag. Do not ever repeat anything. Here is an example of the structure that your response should follow:
|
||||
|
||||
<rows>
|
||||
<row>first row</row>
|
||||
<row>second row</row>
|
||||
<row>third row</row>
|
||||
<row>fourth row</row>
|
||||
<row>fifth row</row>
|
||||
</rows>`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an XML string representing the given rows using the <rows> and <row> tags.
|
||||
* @param rows to encode
|
||||
@ -84,5 +99,35 @@ export async function autofill(input: Row[], n: number): Promise<Row[]> {
|
||||
/*vars=*/ {},
|
||||
/*chat_history=*/ history);
|
||||
|
||||
return decode(result.responses[0].responses[0])
|
||||
}
|
||||
|
||||
/**
|
||||
* Uses an LLM to generate `n` new rows based on the pattern explained in `prompt`.
|
||||
* @param prompt
|
||||
* @param n
|
||||
*/
|
||||
export async function generateAndReplace(prompt: string, n: number) {
|
||||
// hash the arguments to get a unique id
|
||||
let id = JSON.stringify([prompt, n]);
|
||||
|
||||
let history: ChatHistoryInfo[] = [{
|
||||
messages: [{
|
||||
"role": "system",
|
||||
"content": GARSystemMessage(n),
|
||||
}],
|
||||
fill_history: {},
|
||||
}]
|
||||
|
||||
let input = `<pattern>${prompt}</pattern>`;
|
||||
|
||||
let result = await queryLLM(
|
||||
/*id=*/ id,
|
||||
/*llm=*/ LLM,
|
||||
/*n=*/ 1,
|
||||
/*prompt=*/ input,
|
||||
/*vars=*/ {},
|
||||
/*chat_history=*/ history);
|
||||
|
||||
return decode(result.responses[0].responses[0])
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user