Extended prompt permutation gen w recursive template substitution'

This commit is contained in:
Ian Arawjo 2023-05-02 08:50:34 -04:00
parent e54b98fe12
commit 775b61e89c

View File

@ -37,6 +37,12 @@ class PromptTemplate:
def __repr__(self) -> str:
return self.__str__()
def has_var(self, varname) -> bool:
""" Returns True if the template has a variable with the given name.
"""
subbed_str = Template(self.template).safe_substitute({varname: '_'})
return subbed_str != self.template # if the strings differ, a replacement occurred
def is_concrete(self) -> bool:
""" Returns True if no template variables are left in template string.
"""
@ -93,9 +99,18 @@ class PromptPermutationGenerator:
def _gen_perm(self, template, params_to_fill, paramDict):
if len(params_to_fill) == 0: return []
# Peel off first element
param = params_to_fill[0]
params_left = params_to_fill[1:]
# Extract the first param that occurs in the current template
param = None
params_left = params_to_fill
for p in params_to_fill:
if template.has_var(p):
param = p
params_left = [_p for _p in params_to_fill if _p != p]
break
if param is None:
print("Did not find any more params left to fill in current template. Returning empty list...")
return []
# Generate new prompts by filling in its value(s) into the PromptTemplate
val = paramDict[param]
@ -112,7 +127,7 @@ class PromptPermutationGenerator:
else:
res = []
for p in new_prompt_temps:
res.extend(self._gen_perm(p, params_to_fill[1:], paramDict))
res.extend(self._gen_perm(p, params_left, paramDict))
return res
def __call__(self, paramDict: Dict[str, Union[str, List[str]]]):
@ -121,4 +136,26 @@ class PromptPermutationGenerator:
return
for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict):
yield p
yield p
# Test cases
if __name__ == '__main__':
# Single template
gen = PromptPermutationGenerator('What is the ${timeframe} when ${person} was born?')
res = [r for r in gen({'timeframe': ['year', 'decade', 'century'], 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']})]
for r in res:
print(r)
assert len(res) == 9
# Nested templates
gen = PromptPermutationGenerator('${prefix}... ${suffix}')
res = [r for r in gen({
'prefix': ['Who invented ${tool}?', 'When was ${tool} invented?', 'What can you do with ${tool}?'],
'suffix': ['Phrase your answer in the form of a ${response_type}', 'Respond with a ${response_type}'],
'tool': ['the flashlight', 'CRISPR', 'rubber'],
'response_type': ['question', 'poem', 'nightmare']
})]
for r in res:
print(r)
assert len(res) == (3*3)*(2*3)