mirror of
https://github.com/ianarawjo/ChainForge.git
synced 2025-03-14 16:26:45 +00:00
Extended prompt permutation gen w recursive template substitution'
This commit is contained in:
parent
e54b98fe12
commit
775b61e89c
@ -37,6 +37,12 @@ class PromptTemplate:
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def has_var(self, varname) -> bool:
|
||||
""" Returns True if the template has a variable with the given name.
|
||||
"""
|
||||
subbed_str = Template(self.template).safe_substitute({varname: '_'})
|
||||
return subbed_str != self.template # if the strings differ, a replacement occurred
|
||||
|
||||
def is_concrete(self) -> bool:
|
||||
""" Returns True if no template variables are left in template string.
|
||||
"""
|
||||
@ -93,9 +99,18 @@ class PromptPermutationGenerator:
|
||||
def _gen_perm(self, template, params_to_fill, paramDict):
|
||||
if len(params_to_fill) == 0: return []
|
||||
|
||||
# Peel off first element
|
||||
param = params_to_fill[0]
|
||||
params_left = params_to_fill[1:]
|
||||
# Extract the first param that occurs in the current template
|
||||
param = None
|
||||
params_left = params_to_fill
|
||||
for p in params_to_fill:
|
||||
if template.has_var(p):
|
||||
param = p
|
||||
params_left = [_p for _p in params_to_fill if _p != p]
|
||||
break
|
||||
|
||||
if param is None:
|
||||
print("Did not find any more params left to fill in current template. Returning empty list...")
|
||||
return []
|
||||
|
||||
# Generate new prompts by filling in its value(s) into the PromptTemplate
|
||||
val = paramDict[param]
|
||||
@ -112,7 +127,7 @@ class PromptPermutationGenerator:
|
||||
else:
|
||||
res = []
|
||||
for p in new_prompt_temps:
|
||||
res.extend(self._gen_perm(p, params_to_fill[1:], paramDict))
|
||||
res.extend(self._gen_perm(p, params_left, paramDict))
|
||||
return res
|
||||
|
||||
def __call__(self, paramDict: Dict[str, Union[str, List[str]]]):
|
||||
@ -121,4 +136,26 @@ class PromptPermutationGenerator:
|
||||
return
|
||||
|
||||
for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict):
|
||||
yield p
|
||||
yield p
|
||||
|
||||
|
||||
# Test cases
|
||||
if __name__ == '__main__':
|
||||
# Single template
|
||||
gen = PromptPermutationGenerator('What is the ${timeframe} when ${person} was born?')
|
||||
res = [r for r in gen({'timeframe': ['year', 'decade', 'century'], 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']})]
|
||||
for r in res:
|
||||
print(r)
|
||||
assert len(res) == 9
|
||||
|
||||
# Nested templates
|
||||
gen = PromptPermutationGenerator('${prefix}... ${suffix}')
|
||||
res = [r for r in gen({
|
||||
'prefix': ['Who invented ${tool}?', 'When was ${tool} invented?', 'What can you do with ${tool}?'],
|
||||
'suffix': ['Phrase your answer in the form of a ${response_type}', 'Respond with a ${response_type}'],
|
||||
'tool': ['the flashlight', 'CRISPR', 'rubber'],
|
||||
'response_type': ['question', 'poem', 'nightmare']
|
||||
})]
|
||||
for r in res:
|
||||
print(r)
|
||||
assert len(res) == (3*3)*(2*3)
|
Loading…
x
Reference in New Issue
Block a user