diff --git a/lollms/personality.py b/lollms/personality.py index 95ad445..a826c1d 100644 --- a/lollms/personality.py +++ b/lollms/personality.py @@ -4461,39 +4461,38 @@ transition-all duration-300 ease-in-out"> Returns: int: Index of the selected option within the possible_ansers list. Or -1 if there was not match found among any of them. """ - start_header_id_template = self.config.start_header_id_template - end_header_id_template = self.config.end_header_id_template - system_message_template = self.config.system_message_template - choices = "\n".join([f"{i}. {possible_answer}" for i, possible_answer in enumerate(possible_answers)]) elements = [conditionning] if conditionning!="" else [] elements += [ - f"{start_header_id_template}{system_message_template}{end_header_id_template}", - "Answer this multi choices question.", + f"{self.system_full_header}", + "Answer this multi choices question in form of a json in this form:\n", + """```json +{ + "justification": "A justification for your choice", + "choice_index": the index of the choice made +} +``` + """, ] if context!="": elements+=[ - f"{start_header_id_template}Context{end_header_id_template}", + self.system_custom_header("Context"), f"{context}", ] - elements +=[ - "Answer with an id from the possible answers.", - "Do not answer with an id outside this possible answers.", - "Do not explain your reasons or add comments.", - "the output should be an integer." - ] elements += [ - f"{start_header_id_template}question{end_header_id_template}{question}", - f"{start_header_id_template}possible answers{end_header_id_template}", + self.system_custom_header("question"), + question, + self.system_custom_header("possible answers"), f"{choices}", ] - elements += [f"{start_header_id_template}answer{end_header_id_template}"] + elements += [self.system_custom_header("answer")] prompt = self.build_prompt(elements) - gen = self.generate(prompt, max_answer_length, temperature=0.1, top_k=50, top_p=0.9, repeat_penalty=1.0, repeat_last_n=50, callback=self.sink).strip().replace("","").replace("","") - if len(gen)>0: - selection = gen.strip().split()[0].replace(",","").replace(".","") - self.print_prompt("Multi choice selection",prompt+gen) + code = self.generate_code(prompt, self.personality.image_files, max_answer_length, temperature=0.1, top_k=50, top_p=0.9, repeat_penalty=1.0, repeat_last_n=50, callback=self.sink).strip().replace("","").replace("","") + if len(code)>0: + json_code = json.loads(code) + selection = json_code["choice_index"] + self.print_prompt("Multi choice selection",prompt+code) try: return int(selection) except: