faster models zoo loading

This commit is contained in:
Saifeddine ALOUI 2024-03-03 00:32:47 +01:00
parent 66e0af1548
commit 6fc1f27d4e
2 changed files with 155 additions and 5 deletions

View File

@ -34,6 +34,7 @@ from enum import Enum
from lollms.utilities import trace_exception
from tqdm import tqdm
from lollms.databases.models_database import ModelsDB
import sys
__author__ = "parisneo"
@ -785,13 +786,19 @@ class LLMBinding:
# Create the file path relative to the child class's directory
full_data = []
for models_dir_name in self.models_dir_names:
file_path = self.lollms_paths.models_zoo_path/f"{models_dir_name}.yaml"
with open(file_path, 'r') as file:
yaml_data = yaml.safe_load(file)
full_data+=yaml_data
self.models_db = ModelsDB(self.lollms_paths.models_zoo_path/f"{models_dir_name}.db")
full_data+=self.models_db.query()
return full_data
def search_models(self, app:LoLLMsCom=None):
# Create the file path relative to the child class's directory
full_data = []
for models_dir_name in self.models_dir_names:
self.models_db = ModelsDB(self.lollms_paths.models_zoo_path/f"{models_dir_name}.db")
full_data+=self.models_db.query()
return full_data
@staticmethod
def vram_usage():

View File

@ -0,0 +1,143 @@
import sqlite3
import yaml
from datetime import datetime
from ascii_colors import ASCIIColors
class ModelsDB:
def __init__(self, db_name='models.db'):
self.conn = sqlite3.connect(db_name)
self.cursor = self.conn.cursor()
self.create_table()
def create_table(self):
self.cursor.execute('''CREATE TABLE IF NOT EXISTS models (
id INTEGER PRIMARY KEY,
category TEXT,
icon TEXT,
datasets TEXT,
last_commit_time TEXT,
license TEXT,
model_creator TEXT,
model_creator_link TEXT,
name TEXT UNIQUE,
quantizer TEXT,
ctx_size INTEGER,
rank REAL,
type TEXT
)''')
self.cursor.execute('''CREATE TABLE IF NOT EXISTS variants (
id INTEGER PRIMARY KEY,
model_id INTEGER,
name TEXT,
size INTEGER,
FOREIGN KEY(model_id) REFERENCES models(id)
)''')
self.conn.commit()
def add_entry(self, entry):
# Check if a model with the same name already exists
self.cursor.execute("SELECT id FROM models WHERE name=?", (entry.get('name'),))
model_id = self.cursor.fetchone()
if model_id is None:
datasets = entry.get('datasets', [])
license = entry.get('license')
# If the model does not exist, insert a new one
data = (entry.get('category'), entry.get('icon'), ','.join(datasets) if type(datasets)==list else datasets, entry.get('last_commit_time'),
','.join(license) if type(license)==list else license,
entry.get('model_creator'), entry.get('model_creator_link'), entry.get('name'),
entry.get('quantizer'), entry.get('ctx_size', 4096), entry.get('rank'), entry.get('type'))
self.cursor.execute('''INSERT INTO models VALUES (NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''', data)
model_id = self.cursor.lastrowid
# Insert the variants
for variant in entry.get('variants', []):
variant_name = variant.get('name')
variant_size = variant.get('size')
variant_data = (model_id, variant_name, variant_size)
self.cursor.execute('''INSERT INTO variants VALUES (NULL, ?, ?, ?)''', variant_data)
else:
ASCIIColors.warning(f"A duplicate of the model {entry.get('name')} has been detected")
self.conn.commit()
def import_from_yaml(self, file_path):
with open(file_path) as file:
data = yaml.safe_load(file)
for entry in data:
self.add_entry(entry)
def query(self, n=None, model_types=None, variant_name=None, keyword=None):
query = "SELECT models.*, variants.name as variant_name, variants.size as variant_size FROM models LEFT JOIN variants ON models.id = variants.model_id WHERE 1=1"
params = []
if model_types:
query += " AND models.type IN (" + ", ".join("?" for _ in model_types) + ")"
params.extend(model_types)
if variant_name:
query += " AND variants.name=?"
params.append(variant_name)
if keyword:
query += " AND models.name LIKE ?"
params.append('%' + keyword + '%')
if n:
query += " ORDER BY models.last_commit_time DESC LIMIT ?"
params.append(n)
self.cursor.execute(query, params)
results = self.cursor.fetchall()
# Group the results by model id and convert the row to a dictionary
model_results = {}
for row in results:
model_id = row[0]
if model_id not in model_results:
model_results[model_id] = {
"id": row[0],
"category": row[1],
"icon": row[2],
"datasets": row[3].split(','),
"last_commit_time": row[4],
"license": row[5],
"model_creator": row[6],
"model_creator_link": row[7],
"name": row[8],
"quantizer": row[9],
"ctx_size": row[10],
"rank": row[11],
"type": row[12],
"variants": []
}
model_results[model_id]["variants"].append({"name": row[13], "size": row[14]})
return list(model_results.values())
def remove_entry(self, model_name):
self.cursor.execute("DELETE FROM variants WHERE model_id IN (SELECT id FROM models WHERE name=?)", (model_name,))
self.cursor.execute("DELETE FROM models WHERE name=?", (model_name,))
self.conn.commit()
def main():
import argparse
parser = argparse.ArgumentParser(description='Build a database from a YAML file.')
parser.add_argument('--yaml_path', type=str, help='Path to the YAML file')
parser.add_argument('--db_path', type=str, help='Path to the database file')
args = parser.parse_args()
# Create a database instance
db = ModelsDB(db_name=args.db_path)
# Import entries from a YAML file
db.import_from_yaml(args.yaml_path)
# Query the database
print(db.query(n=3, model_types=['gguf']))
if __name__=="__main__":
main()