diff --git a/python-backend/promptengine/template.py b/python-backend/promptengine/template.py index 2d95b6b..41218d0 100644 --- a/python-backend/promptengine/template.py +++ b/python-backend/promptengine/template.py @@ -37,6 +37,12 @@ class PromptTemplate: def __repr__(self) -> str: return self.__str__() + def has_var(self, varname) -> bool: + """ Returns True if the template has a variable with the given name. + """ + subbed_str = Template(self.template).safe_substitute({varname: '_'}) + return subbed_str != self.template # if the strings differ, a replacement occurred + def is_concrete(self) -> bool: """ Returns True if no template variables are left in template string. """ @@ -93,9 +99,18 @@ class PromptPermutationGenerator: def _gen_perm(self, template, params_to_fill, paramDict): if len(params_to_fill) == 0: return [] - # Peel off first element - param = params_to_fill[0] - params_left = params_to_fill[1:] + # Extract the first param that occurs in the current template + param = None + params_left = params_to_fill + for p in params_to_fill: + if template.has_var(p): + param = p + params_left = [_p for _p in params_to_fill if _p != p] + break + + if param is None: + print("Did not find any more params left to fill in current template. Returning empty list...") + return [] # Generate new prompts by filling in its value(s) into the PromptTemplate val = paramDict[param] @@ -112,7 +127,7 @@ class PromptPermutationGenerator: else: res = [] for p in new_prompt_temps: - res.extend(self._gen_perm(p, params_to_fill[1:], paramDict)) + res.extend(self._gen_perm(p, params_left, paramDict)) return res def __call__(self, paramDict: Dict[str, Union[str, List[str]]]): @@ -121,4 +136,26 @@ class PromptPermutationGenerator: return for p in self._gen_perm(self.template, list(paramDict.keys()), paramDict): - yield p \ No newline at end of file + yield p + + +# Test cases +if __name__ == '__main__': + # Single template + gen = PromptPermutationGenerator('What is the ${timeframe} when ${person} was born?') + res = [r for r in gen({'timeframe': ['year', 'decade', 'century'], 'person': ['Howard Hughes', 'Toni Morrison', 'Otis Redding']})] + for r in res: + print(r) + assert len(res) == 9 + + # Nested templates + gen = PromptPermutationGenerator('${prefix}... ${suffix}') + res = [r for r in gen({ + 'prefix': ['Who invented ${tool}?', 'When was ${tool} invented?', 'What can you do with ${tool}?'], + 'suffix': ['Phrase your answer in the form of a ${response_type}', 'Respond with a ${response_type}'], + 'tool': ['the flashlight', 'CRISPR', 'rubber'], + 'response_type': ['question', 'poem', 'nightmare'] + })] + for r in res: + print(r) + assert len(res) == (3*3)*(2*3) \ No newline at end of file