mirror of
https://github.com/ParisNeo/lollms.git
synced 2025-04-16 06:56:33 +00:00
Fixed tools
This commit is contained in:
parent
cb8a8c7677
commit
c7d6aab3f5
@ -85,7 +85,10 @@ class AIPersonality:
|
||||
self.config = config
|
||||
self.callback = callback
|
||||
self.app = app
|
||||
|
||||
if app is not None:
|
||||
self.notify = app.notify
|
||||
else:
|
||||
self.notify = None
|
||||
self.text_files = []
|
||||
self.image_files = []
|
||||
self.vectorizer = None
|
||||
@ -379,20 +382,14 @@ Date: {{date}}
|
||||
self.scripts_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Verify if the persona has a data folder
|
||||
if self.data_path.exists():
|
||||
text = []
|
||||
text_files = [file if file.exists() else "" for file in self.data_path.glob("*.txt")]
|
||||
for file in text_files:
|
||||
with open(str(file),"r") as f:
|
||||
text.append(f.read())
|
||||
# Replace 'example_dir' with your desired directory containing .txt files
|
||||
self._data = "\n".join(map((lambda x: f"\n{x}"), text))
|
||||
print(self._data)
|
||||
self.database_path = self.data_path / "db.json"
|
||||
if self.database_path.exists():
|
||||
ASCIIColors.info("Building data ...",end="")
|
||||
self.persona_data_vectorizer = TextVectorizer(
|
||||
self.config.data_vectorization_method, # supported "model_embedding" or "tfidf_vectorizer"
|
||||
"tfidf_vectorizer", # self.config.data_vectorization_method, # supported "model_embedding" or "tfidf_vectorizer"
|
||||
model=self.model, #needed in case of using model_embedding
|
||||
save_db=False,
|
||||
save_db=True,
|
||||
database_path=self.database_path,
|
||||
data_visualization_method=VisualizationMethod.PCA,
|
||||
database_dict=None)
|
||||
self.persona_data_vectorizer.add_document("persona_data", self._data,512,0)
|
||||
@ -1540,19 +1537,6 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(code, MSG_TYPE.MSG_TYPE_CODE)
|
||||
|
||||
def notify(self, full_text:str, callback: Callable[[str, MSG_TYPE, dict, list], bool]=None):
|
||||
"""This sends full text to front end
|
||||
|
||||
Args:
|
||||
step_text (dict): The step text
|
||||
callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the text to. Defaults to None.
|
||||
"""
|
||||
if not callback and self.callback:
|
||||
callback = self.callback
|
||||
|
||||
if callback:
|
||||
callback(full_text, MSG_TYPE.MSG_TYPE_INFO)
|
||||
|
||||
def full(self, full_text:str, callback: Callable[[str, MSG_TYPE, dict, list], bool]=None):
|
||||
"""This sends full text to front end
|
||||
|
||||
|
@ -16,8 +16,47 @@ import json
|
||||
import re
|
||||
import subprocess
|
||||
import gc
|
||||
|
||||
from typing import List
|
||||
|
||||
def find_first_available_file_index(folder_path, prefix, extension=""):
|
||||
"""
|
||||
Finds the first available file index in a folder with files that have a prefix and an optional extension.
|
||||
|
||||
Args:
|
||||
folder_path (str): The path to the folder.
|
||||
prefix (str): The file prefix.
|
||||
extension (str, optional): The file extension (including the dot). Defaults to "".
|
||||
|
||||
Returns:
|
||||
int: The first available file index.
|
||||
"""
|
||||
# Create a Path object for the folder
|
||||
folder = Path(folder_path)
|
||||
|
||||
# Get a list of all files in the folder
|
||||
files = folder.glob(f'{prefix}*'+extension)
|
||||
|
||||
# Initialize the first available number
|
||||
available_number = 1
|
||||
|
||||
# Iterate through the files
|
||||
for file in files:
|
||||
# Extract the number from the file name
|
||||
file_number = int(file.stem[len(prefix):])
|
||||
|
||||
# If the file number is equal to the available number, increment the available number
|
||||
if file_number == available_number:
|
||||
available_number += 1
|
||||
# If the file number is greater than the available number, break the loop
|
||||
elif file_number > available_number:
|
||||
break
|
||||
|
||||
return available_number
|
||||
|
||||
|
||||
|
||||
|
||||
# Prompting tools
|
||||
def detect_antiprompt(text:str, anti_prompts=["!@>"]) -> bool:
|
||||
"""
|
||||
@ -69,7 +108,16 @@ def check_torch_version(min_version, min_cuda_versio=12):
|
||||
|
||||
|
||||
def reinstall_pytorch_with_cuda():
|
||||
result = subprocess.run(["pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir", "--index-url", "https://download.pytorch.org/whl/cu121"])
|
||||
try:
|
||||
ASCIIColors.info("Installing cuda 12.1.1")
|
||||
result = subprocess.run(["conda", "install", "-c", "nvidia/label/cuda-12.1.1", "cuda-toolkit"])
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
try:
|
||||
ASCIIColors.info("Installing pytorch 2.1.1")
|
||||
result = subprocess.run(["pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir", "--index-url", "https://download.pytorch.org/whl/cu121"])
|
||||
except Exception as ex:
|
||||
ASCIIColors.error(ex)
|
||||
if result.returncode != 0:
|
||||
ASCIIColors.warning("Couldn't find Cuda build tools on your PC. Reverting to CPU.")
|
||||
result = subprocess.run(["pip", "install", "--upgrade", "torch", "torchvision", "torchaudio", "--no-cache-dir"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user