mirror of
https://github.com/ParisNeo/lollms.git
synced 2024-12-24 06:46:40 +00:00
added vision to all models
This commit is contained in:
parent
9963df3e2b
commit
504c936288
@ -67,6 +67,16 @@ class LLMBinding:
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.binding_config = binding_config
|
self.binding_config = binding_config
|
||||||
|
|
||||||
|
|
||||||
|
binding_config.addConfigs([
|
||||||
|
{"name":"clip_model_name","type":"str","value":'ViT-L-14/openai','options':["ViT-L-14/openai","ViT-H-14/laion2b_s32b_b79k"], "help":"Clip model to be used for images understanding"},
|
||||||
|
{"name":"caption_model_name","type":"str","value":'blip-large','options':['blip-base', 'git-large-coco', 'blip-large','blip2-2.7b', 'blip2-flan-t5-xl'], "help":"Clip model to be used for images understanding"},
|
||||||
|
{"name":"vqa_model_name","type":"str","value":'Salesforce/blip-vqa-capfilt-large','options':['Salesforce/blip-vqa-capfilt-large', 'Salesforce/blip-vqa-base', 'Salesforce/blip-image-captioning-large','Salesforce/blip2-opt-2.7b', 'Salesforce/blip2-flan-t5-xxl'], "help":"Salesforce question/answer model"},
|
||||||
|
|
||||||
|
])
|
||||||
|
self.interrogatorStorer = None
|
||||||
|
|
||||||
|
|
||||||
self.supported_file_extensions = supported_file_extensions
|
self.supported_file_extensions = supported_file_extensions
|
||||||
self.seed = config["seed"]
|
self.seed = config["seed"]
|
||||||
self.notification_callback = notification_callback
|
self.notification_callback = notification_callback
|
||||||
@ -327,6 +337,24 @@ class LLMBinding:
|
|||||||
"""
|
"""
|
||||||
self.binding_config.config.save_config(self.configuration_file_path)
|
self.binding_config.config.save_config(self.configuration_file_path)
|
||||||
|
|
||||||
|
def interrogate_blip(self, images):
|
||||||
|
if self.interrogatorStorer is None:
|
||||||
|
from lollms.image_gen_modules.clip_interrogator import InterrogatorStorer
|
||||||
|
self.interrogatorStorer = InterrogatorStorer(self.binding_config.clip_model_name, self.binding_config.caption_model_name)
|
||||||
|
descriptions = []
|
||||||
|
for image in images:
|
||||||
|
descriptions.append(self.interrogatorStorer.interrogate(image))
|
||||||
|
return descriptions
|
||||||
|
|
||||||
|
def qna_blip(self, images, question=""):
|
||||||
|
if self.interrogatorStorer is None:
|
||||||
|
from lollms.image_gen_modules.blip_vqa import BlipInterrogatorStorer
|
||||||
|
self.interrogatorStorer = BlipInterrogatorStorer()
|
||||||
|
descriptions = []
|
||||||
|
for image in images:
|
||||||
|
descriptions.append(self.interrogatorStorer.interrogate(image,question))
|
||||||
|
return descriptions
|
||||||
|
|
||||||
def generate_with_images(self,
|
def generate_with_images(self,
|
||||||
prompt:str,
|
prompt:str,
|
||||||
images:list=[],
|
images:list=[],
|
||||||
|
@ -422,6 +422,10 @@ class TypedConfig:
|
|||||||
# Fill the template values from the config values
|
# Fill the template values from the config values
|
||||||
self.sync()
|
self.sync()
|
||||||
|
|
||||||
|
def addConfigs(self, cfg_template:list):
|
||||||
|
self.config_template.template += cfg_template
|
||||||
|
self.sync()
|
||||||
|
|
||||||
def update_template(self, new_template):
|
def update_template(self, new_template):
|
||||||
self.config_template.template = new_template
|
self.config_template.template = new_template
|
||||||
self.config = BaseConfig.from_template(self.config_template,self.config.exceptional_keys, self.config.file_path)
|
self.config = BaseConfig.from_template(self.config_template,self.config.exceptional_keys, self.config.file_path)
|
||||||
|
16
lollms/image_gen_modules/blip_vqa.py
Normal file
16
lollms/image_gen_modules/blip_vqa.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import BlipProcessor, BlipForQuestionAnswering
|
||||||
|
|
||||||
|
class BlipInterrogatorStorer():
|
||||||
|
def __init__(self, vqa_model_name="Salesforce/blip-vqa-base"):
|
||||||
|
self.vqa_model_name = vqa_model_name
|
||||||
|
self.processor = BlipProcessor.from_pretrained(vqa_model_name)
|
||||||
|
self.model = BlipForQuestionAnswering.from_pretrained(vqa_model_name, torch_dtype=torch.float16).to("cuda")
|
||||||
|
|
||||||
|
def interrogate(self, raw_image:Image, question:str, max_length:int=256):
|
||||||
|
inputs = self.processor(raw_image, question, return_tensors="pt").to("cuda", torch.float16)
|
||||||
|
out = self.model.generate(**inputs, max_length=max_length)
|
||||||
|
return self.processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
|
426
lollms/image_gen_modules/clip_interrogator.py
Normal file
426
lollms/image_gen_modules/clip_interrogator.py
Normal file
@ -0,0 +1,426 @@
|
|||||||
|
# Title LollmsSD
|
||||||
|
# Licence: MIT
|
||||||
|
# Author : Paris Neo
|
||||||
|
# Adapted from the work of pharmapsychotic's clip-interrogator
|
||||||
|
# check it out : https://github.com/pharmapsychotic/clip-interrogator
|
||||||
|
# Here is a copy of the LICENCE https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE
|
||||||
|
# All rights are reserved
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from lollms.utilities import PackageManager
|
||||||
|
import hashlib
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import open_clip
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from safetensors.numpy import load_file, save_file
|
||||||
|
|
||||||
|
CAPTION_MODELS = {
|
||||||
|
'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB
|
||||||
|
'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB
|
||||||
|
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB
|
||||||
|
'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
|
||||||
|
'git-large-coco': 'microsoft/git-large-coco', # 1.58GB
|
||||||
|
}
|
||||||
|
|
||||||
|
CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoLLMS_CLIP_Config:
|
||||||
|
# models can optionally be passed in directly
|
||||||
|
caption_model = None
|
||||||
|
caption_processor = None
|
||||||
|
clip_model = None
|
||||||
|
clip_preprocess = None
|
||||||
|
|
||||||
|
# blip settings
|
||||||
|
caption_max_length: int = 256
|
||||||
|
caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None
|
||||||
|
caption_offload: bool = False
|
||||||
|
|
||||||
|
# clip settings
|
||||||
|
clip_model_name: str = 'ViT-L-14/openai'
|
||||||
|
clip_model_path: Optional[str] = None
|
||||||
|
clip_offload: bool = False
|
||||||
|
|
||||||
|
# interrogator settings
|
||||||
|
cache_path: str = 'cache' # path to store cached text embeddings
|
||||||
|
download_cache: bool = True # when true, cached embeds are downloaded from huggingface
|
||||||
|
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
|
||||||
|
data_path: str = os.path.join(os.path.dirname(__file__), 'data')
|
||||||
|
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
flavor_intermediate_count: int = 2048
|
||||||
|
quiet: bool = False # when quiet progress bars are not shown
|
||||||
|
|
||||||
|
def apply_low_vram_defaults(self):
|
||||||
|
self.caption_model_name = 'blip-base'
|
||||||
|
self.caption_offload = True
|
||||||
|
self.clip_offload = True
|
||||||
|
self.chunk_size = 1024
|
||||||
|
self.flavor_intermediate_count = 1024
|
||||||
|
|
||||||
|
class LoLLMS_CLIP_Interrogator():
|
||||||
|
def __init__(self, config: LoLLMS_CLIP_Config):
|
||||||
|
self.config = config
|
||||||
|
self.device = config.device
|
||||||
|
self.dtype = torch.float16 if self.device == 'cuda' else torch.float32
|
||||||
|
self.caption_offloaded = True
|
||||||
|
self.clip_offloaded = True
|
||||||
|
self.load_caption_model()
|
||||||
|
self.load_clip_model()
|
||||||
|
|
||||||
|
def load_caption_model(self):
|
||||||
|
if self.config.caption_model is None and self.config.caption_model_name:
|
||||||
|
if not self.config.quiet:
|
||||||
|
print(f"Loading caption model {self.config.caption_model_name}...")
|
||||||
|
|
||||||
|
model_path = CAPTION_MODELS[self.config.caption_model_name]
|
||||||
|
if self.config.caption_model_name.startswith('git-'):
|
||||||
|
caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
|
||||||
|
elif self.config.caption_model_name.startswith('blip2-'):
|
||||||
|
caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
|
||||||
|
else:
|
||||||
|
caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
|
||||||
|
self.caption_processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
|
||||||
|
caption_model.eval()
|
||||||
|
if not self.config.caption_offload:
|
||||||
|
caption_model = caption_model.to(self.config.device)
|
||||||
|
self.caption_model = caption_model
|
||||||
|
else:
|
||||||
|
self.caption_model = self.config.caption_model
|
||||||
|
self.caption_processor = self.config.caption_processor
|
||||||
|
|
||||||
|
def load_clip_model(self):
|
||||||
|
start_time = time.time()
|
||||||
|
config = self.config
|
||||||
|
|
||||||
|
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
|
||||||
|
|
||||||
|
if config.clip_model is None:
|
||||||
|
if not config.quiet:
|
||||||
|
print(f"Loading CLIP model {config.clip_model_name}...")
|
||||||
|
|
||||||
|
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
|
||||||
|
clip_model_name,
|
||||||
|
pretrained=clip_model_pretrained_name,
|
||||||
|
precision='fp16' if config.device == 'cuda' else 'fp32',
|
||||||
|
device=config.device,
|
||||||
|
jit=False,
|
||||||
|
cache_dir=config.clip_model_path
|
||||||
|
)
|
||||||
|
self.clip_model.eval()
|
||||||
|
else:
|
||||||
|
self.clip_model = config.clip_model
|
||||||
|
self.clip_preprocess = config.clip_preprocess
|
||||||
|
self.tokenize = open_clip.get_tokenizer(clip_model_name)
|
||||||
|
self._prepare_clip()
|
||||||
|
end_time = time.time()
|
||||||
|
if not config.quiet:
|
||||||
|
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
|
||||||
|
|
||||||
|
def chain(
|
||||||
|
self,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
phrases: List[str],
|
||||||
|
best_prompt: str="",
|
||||||
|
best_sim: float=0,
|
||||||
|
min_count: int=8,
|
||||||
|
max_count: int=32,
|
||||||
|
desc="Chaining",
|
||||||
|
reverse: bool=False
|
||||||
|
) -> str:
|
||||||
|
self._prepare_clip()
|
||||||
|
|
||||||
|
phrases = set(phrases)
|
||||||
|
if not best_prompt:
|
||||||
|
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
|
||||||
|
best_sim = self.similarity(image_features, best_prompt)
|
||||||
|
phrases.remove(best_prompt)
|
||||||
|
curr_prompt, curr_sim = best_prompt, best_sim
|
||||||
|
|
||||||
|
def check(addition: str, idx: int) -> bool:
|
||||||
|
nonlocal best_prompt, best_sim, curr_prompt, curr_sim
|
||||||
|
prompt = curr_prompt + ", " + addition
|
||||||
|
sim = self.similarity(image_features, prompt)
|
||||||
|
if reverse:
|
||||||
|
sim = -sim
|
||||||
|
|
||||||
|
if sim > best_sim:
|
||||||
|
best_prompt, best_sim = prompt, sim
|
||||||
|
if sim > curr_sim or idx < min_count:
|
||||||
|
curr_prompt, curr_sim = prompt, sim
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
for idx in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
|
||||||
|
best = self.rank_top(image_features, [f"{curr_prompt}, {f}" for f in phrases], reverse=reverse)
|
||||||
|
flave = best[len(curr_prompt)+2:]
|
||||||
|
if not check(flave, idx):
|
||||||
|
break
|
||||||
|
if _prompt_at_max_len(curr_prompt, self.tokenize):
|
||||||
|
break
|
||||||
|
phrases.remove(flave)
|
||||||
|
|
||||||
|
return best_prompt
|
||||||
|
|
||||||
|
def generate_caption(self, pil_image: Image) -> str:
|
||||||
|
assert self.caption_model is not None, "No caption model loaded."
|
||||||
|
self._prepare_caption()
|
||||||
|
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device)
|
||||||
|
if not self.config.caption_model_name.startswith('git-'):
|
||||||
|
inputs = inputs.to(self.dtype)
|
||||||
|
tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length)
|
||||||
|
return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip()
|
||||||
|
|
||||||
|
def image_to_features(self, image: Image) -> torch.Tensor:
|
||||||
|
self._prepare_clip()
|
||||||
|
images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
|
||||||
|
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||||
|
image_features = self.clip_model.encode_image(images)
|
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def interrogate_classic(self, image: Image, max_flavors: int=3, caption: Optional[str]=None) -> str:
|
||||||
|
"""Classic mode creates a prompt in a standard format first describing the image,
|
||||||
|
then listing the artist, trending, movement, and flavor text modifiers."""
|
||||||
|
caption = caption or self.generate_caption(image)
|
||||||
|
image_features = self.image_to_features(image)
|
||||||
|
|
||||||
|
medium = self.mediums.rank(image_features, 1)[0]
|
||||||
|
artist = self.artists.rank(image_features, 1)[0]
|
||||||
|
trending = self.trendings.rank(image_features, 1)[0]
|
||||||
|
movement = self.movements.rank(image_features, 1)[0]
|
||||||
|
flaves = ", ".join(self.flavors.rank(image_features, max_flavors))
|
||||||
|
|
||||||
|
if caption.startswith(medium):
|
||||||
|
prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}"
|
||||||
|
else:
|
||||||
|
prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}"
|
||||||
|
|
||||||
|
return _truncate_to_fit(prompt, self.tokenize)
|
||||||
|
|
||||||
|
def interrogate_fast(self, image: Image, max_flavors: int=32, caption: Optional[str]=None) -> str:
|
||||||
|
"""Fast mode simply adds the top ranked terms after a caption. It generally results in
|
||||||
|
better similarity between generated prompt and image than classic mode, but the prompts
|
||||||
|
are less readable."""
|
||||||
|
caption = caption or self.generate_caption(image)
|
||||||
|
return _truncate_to_fit(caption, self.tokenize)
|
||||||
|
|
||||||
|
def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32, caption: Optional[str]=None) -> str:
|
||||||
|
caption = caption or self.generate_caption(image)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
|
||||||
|
self._prepare_clip()
|
||||||
|
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
|
||||||
|
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||||
|
text_features = self.clip_model.encode_text(text_tokens)
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
similarity = text_features @ image_features.T
|
||||||
|
if reverse:
|
||||||
|
similarity = -similarity
|
||||||
|
return text_array[similarity.argmax().item()]
|
||||||
|
|
||||||
|
def similarity(self, image_features: torch.Tensor, text: str) -> float:
|
||||||
|
self._prepare_clip()
|
||||||
|
text_tokens = self.tokenize([text]).to(self.device)
|
||||||
|
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||||
|
text_features = self.clip_model.encode_text(text_tokens)
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
similarity = text_features @ image_features.T
|
||||||
|
return similarity[0][0].item()
|
||||||
|
|
||||||
|
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]:
|
||||||
|
self._prepare_clip()
|
||||||
|
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
|
||||||
|
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||||
|
text_features = self.clip_model.encode_text(text_tokens)
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
similarity = text_features @ image_features.T
|
||||||
|
return similarity.T[0].tolist()
|
||||||
|
|
||||||
|
def _prepare_caption(self):
|
||||||
|
if self.config.clip_offload and not self.clip_offloaded:
|
||||||
|
self.clip_model = self.clip_model.to('cpu')
|
||||||
|
self.clip_offloaded = True
|
||||||
|
if self.caption_offloaded:
|
||||||
|
self.caption_model = self.caption_model.to(self.device)
|
||||||
|
self.caption_offloaded = False
|
||||||
|
|
||||||
|
def _prepare_clip(self):
|
||||||
|
if self.config.caption_offload and not self.caption_offloaded:
|
||||||
|
self.caption_model = self.caption_model.to('cpu')
|
||||||
|
self.caption_offloaded = True
|
||||||
|
if self.clip_offloaded:
|
||||||
|
self.clip_model = self.clip_model.to(self.device)
|
||||||
|
self.clip_offloaded = False
|
||||||
|
|
||||||
|
|
||||||
|
class LoLLMS_CLIP_LabelTable():
|
||||||
|
def __init__(self, labels:List[str], desc:str, ci: LoLLMS_CLIP_Interrogator):
|
||||||
|
clip_model, config = ci.clip_model, ci.config
|
||||||
|
self.chunk_size = config.chunk_size
|
||||||
|
self.config = config
|
||||||
|
self.device = config.device
|
||||||
|
self.embeds = []
|
||||||
|
self.labels = labels
|
||||||
|
self.tokenize = ci.tokenize
|
||||||
|
|
||||||
|
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
|
||||||
|
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_')
|
||||||
|
self._load_cached(desc, hash, sanitized_name)
|
||||||
|
|
||||||
|
if len(self.labels) != len(self.embeds):
|
||||||
|
self.embeds = []
|
||||||
|
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size))
|
||||||
|
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
|
||||||
|
text_tokens = self.tokenize(chunk).to(self.device)
|
||||||
|
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||||
|
text_features = clip_model.encode_text(text_tokens)
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
text_features = text_features.half().cpu().numpy()
|
||||||
|
for i in range(text_features.shape[0]):
|
||||||
|
self.embeds.append(text_features[i])
|
||||||
|
|
||||||
|
if desc and self.config.cache_path:
|
||||||
|
os.makedirs(self.config.cache_path, exist_ok=True)
|
||||||
|
cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
|
||||||
|
tensors = {
|
||||||
|
"embeds": np.stack(self.embeds),
|
||||||
|
"hash": np.array([ord(c) for c in hash], dtype=np.int8)
|
||||||
|
}
|
||||||
|
save_file(tensors, cache_filepath)
|
||||||
|
|
||||||
|
if self.device == 'cpu' or self.device == torch.device('cpu'):
|
||||||
|
self.embeds = [e.astype(np.float32) for e in self.embeds]
|
||||||
|
|
||||||
|
def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool:
|
||||||
|
if self.config.cache_path is None or desc is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
|
||||||
|
|
||||||
|
if self.config.download_cache and not os.path.exists(cached_safetensors):
|
||||||
|
download_url = CACHE_URL_BASE + f"{sanitized_name}_{desc}.safetensors"
|
||||||
|
try:
|
||||||
|
os.makedirs(self.config.cache_path, exist_ok=True)
|
||||||
|
_download_file(download_url, cached_safetensors, quiet=self.config.quiet)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to download {download_url}")
|
||||||
|
print(e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if os.path.exists(cached_safetensors):
|
||||||
|
try:
|
||||||
|
tensors = load_file(cached_safetensors)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load {cached_safetensors}")
|
||||||
|
print(e)
|
||||||
|
return False
|
||||||
|
if 'hash' in tensors and 'embeds' in tensors:
|
||||||
|
if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)):
|
||||||
|
self.embeds = tensors['embeds']
|
||||||
|
if len(self.embeds.shape) == 2:
|
||||||
|
self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])]
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
|
||||||
|
top_count = min(top_count, len(text_embeds))
|
||||||
|
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
similarity = image_features @ text_embeds.T
|
||||||
|
if reverse:
|
||||||
|
similarity = -similarity
|
||||||
|
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
|
||||||
|
return [top_labels[0][i].numpy() for i in range(top_count)]
|
||||||
|
|
||||||
|
def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]:
|
||||||
|
if len(self.labels) <= self.chunk_size:
|
||||||
|
tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse)
|
||||||
|
return [self.labels[i] for i in tops]
|
||||||
|
|
||||||
|
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size))
|
||||||
|
keep_per_chunk = int(self.chunk_size / num_chunks)
|
||||||
|
|
||||||
|
top_labels, top_embeds = [], []
|
||||||
|
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
|
||||||
|
start = chunk_idx*self.chunk_size
|
||||||
|
stop = min(start+self.chunk_size, len(self.embeds))
|
||||||
|
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse)
|
||||||
|
top_labels.extend([self.labels[start+i] for i in tops])
|
||||||
|
top_embeds.extend([self.embeds[start+i] for i in tops])
|
||||||
|
|
||||||
|
tops = self._rank(image_features, top_embeds, top_count=top_count)
|
||||||
|
return [top_labels[i] for i in tops]
|
||||||
|
|
||||||
|
|
||||||
|
def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet: bool = False):
|
||||||
|
r = requests.get(url, stream=True)
|
||||||
|
if r.status_code != 200:
|
||||||
|
return
|
||||||
|
|
||||||
|
file_size = int(r.headers.get("Content-Length", 0))
|
||||||
|
filename = url.split("/")[-1]
|
||||||
|
progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet)
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
for chunk in r.iter_content(chunk_size=chunk_size):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
progress.update(len(chunk))
|
||||||
|
progress.close()
|
||||||
|
|
||||||
|
def _merge_tables(tables: List[LoLLMS_CLIP_LabelTable], ci: LoLLMS_CLIP_Interrogator) -> LoLLMS_CLIP_LabelTable:
|
||||||
|
m = LoLLMS_CLIP_LabelTable([], None, ci)
|
||||||
|
for table in tables:
|
||||||
|
m.labels.extend(table.labels)
|
||||||
|
m.embeds.extend(table.embeds)
|
||||||
|
return m
|
||||||
|
|
||||||
|
def _prompt_at_max_len(text: str, tokenize) -> bool:
|
||||||
|
tokens = tokenize([text])
|
||||||
|
return tokens[0][-1] != 0
|
||||||
|
|
||||||
|
def _truncate_to_fit(text: str, tokenize) -> str:
|
||||||
|
parts = text.split(', ')
|
||||||
|
new_text = parts[0]
|
||||||
|
for part in parts[1:]:
|
||||||
|
if _prompt_at_max_len(new_text + part, tokenize):
|
||||||
|
break
|
||||||
|
new_text += ', ' + part
|
||||||
|
return new_text
|
||||||
|
|
||||||
|
def list_caption_models() -> List[str]:
|
||||||
|
return list(CAPTION_MODELS.keys())
|
||||||
|
|
||||||
|
def list_clip_models() -> List[str]:
|
||||||
|
return ['/'.join(x) for x in open_clip.list_pretrained()]
|
||||||
|
|
||||||
|
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
|
||||||
|
"""Load a list of strings from a file."""
|
||||||
|
if filename is not None:
|
||||||
|
data_path = os.path.join(data_path, filename)
|
||||||
|
with open(data_path, 'r', encoding='utf-8', errors='replace') as f:
|
||||||
|
items = [line.strip() for line in f.readlines()]
|
||||||
|
return items
|
||||||
|
|
||||||
|
class InterrogatorStorer():
|
||||||
|
def __init__(self, clip_model_name='ViT-L-14/openai', caption_model_name='blip-large'):
|
||||||
|
self.clip_model_name = clip_model_name
|
||||||
|
self.interrogator = LoLLMS_CLIP_Interrogator(LoLLMS_CLIP_Config(clip_model_name=clip_model_name, caption_model_name=caption_model_name))
|
||||||
|
def interrogate(self, image:Image):
|
||||||
|
return self.interrogator.interrogate(image)
|
||||||
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||||||
from lollms.config import InstallOption, TypedConfig, BaseConfig
|
from lollms.config import InstallOption, TypedConfig, BaseConfig
|
||||||
from lollms.main_config import LOLLMSConfig
|
from lollms.main_config import LOLLMSConfig
|
||||||
from lollms.paths import LollmsPaths
|
from lollms.paths import LollmsPaths
|
||||||
from lollms.binding import LLMBinding
|
from lollms.binding import LLMBinding, BindingType
|
||||||
from lollms.utilities import PromptReshaper, PackageManager
|
from lollms.utilities import PromptReshaper, PackageManager
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -32,7 +32,7 @@ from safe_store import TextVectorizer, GenericDataLoader, VisualizationMethod, V
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from lollms.helpers import get_trace_exception
|
from lollms.helpers import trace_exception
|
||||||
|
|
||||||
def is_package_installed(package_name):
|
def is_package_installed(package_name):
|
||||||
try:
|
try:
|
||||||
@ -93,6 +93,7 @@ class AIPersonality:
|
|||||||
self.notify = None
|
self.notify = None
|
||||||
self.text_files = []
|
self.text_files = []
|
||||||
self.image_files = []
|
self.image_files = []
|
||||||
|
self.images_descriptions = []
|
||||||
self.vectorizer = None
|
self.vectorizer = None
|
||||||
|
|
||||||
self.installation_option = installation_option
|
self.installation_option = installation_option
|
||||||
@ -184,6 +185,101 @@ Date: {{date}}
|
|||||||
self.personality_output_folder = lollms_paths.personal_outputs_path/self.name
|
self.personality_output_folder = lollms_paths.personal_outputs_path/self.name
|
||||||
self.personality_output_folder.mkdir(parents=True, exist_ok=True)
|
self.personality_output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def new_message(self, message_text:str, message_type:MSG_TYPE= MSG_TYPE.MSG_TYPE_FULL, metadata=[], callback: Callable[[str, int, dict, list], bool]=None):
|
||||||
|
"""This sends step rogress to front end
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_text (dict): The step progress in %
|
||||||
|
callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the progress to. Defaults to None.
|
||||||
|
"""
|
||||||
|
if not callback and self.callback:
|
||||||
|
callback = self.callback
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback(message_text, MSG_TYPE.MSG_TYPE_NEW_MESSAGE, parameters={'type':message_type.value,'metadata':metadata})
|
||||||
|
|
||||||
|
def full(self, full_text:str, callback: Callable[[str, MSG_TYPE, dict, list], bool]=None):
|
||||||
|
"""This sends full text to front end
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_text (dict): The step text
|
||||||
|
callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the text to. Defaults to None.
|
||||||
|
"""
|
||||||
|
if not callback and self.callback:
|
||||||
|
callback = self.callback
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback(full_text, MSG_TYPE.MSG_TYPE_FULL)
|
||||||
|
|
||||||
|
|
||||||
|
def full_invisible_to_ai(self, full_text:str, callback: Callable[[str, MSG_TYPE, dict, list], bool]=None):
|
||||||
|
"""This sends full text to front end (INVISIBLE to AI)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_text (dict): The step text
|
||||||
|
callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the text to. Defaults to None.
|
||||||
|
"""
|
||||||
|
if not callback and self.callback:
|
||||||
|
callback = self.callback
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback(full_text, MSG_TYPE.MSG_TYPE_FULL_INVISIBLE_TO_AI)
|
||||||
|
|
||||||
|
def full_invisible_to_user(self, full_text:str, callback: Callable[[str, MSG_TYPE, dict, list], bool]=None):
|
||||||
|
"""This sends full text to front end (INVISIBLE to user)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_text (dict): The step text
|
||||||
|
callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the text to. Defaults to None.
|
||||||
|
"""
|
||||||
|
if not callback and self.callback:
|
||||||
|
callback = self.callback
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback(full_text, MSG_TYPE.MSG_TYPE_FULL_INVISIBLE_TO_USER)
|
||||||
|
def step_start(self, step_text, callback: Callable[[str, MSG_TYPE, dict, list], bool]=None):
|
||||||
|
"""This triggers a step start
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_text (str): The step text
|
||||||
|
callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the step start to. Defaults to None.
|
||||||
|
"""
|
||||||
|
if not callback and self.callback:
|
||||||
|
callback = self.callback
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback(step_text, MSG_TYPE.MSG_TYPE_STEP_START)
|
||||||
|
|
||||||
|
def step_end(self, step_text, status=True, callback: Callable[[str, int, dict, list], bool]=None):
|
||||||
|
"""This triggers a step end
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_text (str): The step text
|
||||||
|
callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the step end to. Defaults to None.
|
||||||
|
"""
|
||||||
|
if not callback and self.callback:
|
||||||
|
callback = self.callback
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback(step_text, MSG_TYPE.MSG_TYPE_STEP_END, {'status':status})
|
||||||
|
|
||||||
|
def step(self, step_text, callback: Callable[[str, MSG_TYPE, dict, list], bool]=None):
|
||||||
|
"""This triggers a step information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_text (str): The step text
|
||||||
|
callback (callable, optional): A callable with this signature (str, MSG_TYPE, dict, list) to send the step to. Defaults to None.
|
||||||
|
The callback has these fields:
|
||||||
|
- chunk
|
||||||
|
- Message Type : the type of message
|
||||||
|
- Parameters (optional) : a dictionary of parameters
|
||||||
|
- Metadata (optional) : a list of metadata
|
||||||
|
"""
|
||||||
|
if not callback and self.callback:
|
||||||
|
callback = self.callback
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback(step_text, MSG_TYPE.MSG_TYPE_STEP)
|
||||||
|
|
||||||
def print_prompt(self, title, prompt):
|
def print_prompt(self, title, prompt):
|
||||||
ASCIIColors.red("*-*-*-*-*-*-*-* ", end="")
|
ASCIIColors.red("*-*-*-*-*-*-*-* ", end="")
|
||||||
@ -485,7 +581,7 @@ Date: {{date}}
|
|||||||
db_path = self.lollms_paths.personal_databases_path / "personalities" / self.name / "db.json"
|
db_path = self.lollms_paths.personal_databases_path / "personalities" / self.name / "db.json"
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if path.suffix in [".png",".jpg",".gif",".bmp"]:
|
if path.suffix in [".png",".jpg",".gif",".bmp",".webp"]:
|
||||||
if self.callback:
|
if self.callback:
|
||||||
try:
|
try:
|
||||||
if callback:
|
if callback:
|
||||||
@ -493,14 +589,35 @@ Date: {{date}}
|
|||||||
if "uploads" in pth:
|
if "uploads" in pth:
|
||||||
idx = pth.index("uploads")
|
idx = pth.index("uploads")
|
||||||
pth = "/".join(pth[idx:])
|
pth = "/".join(pth[idx:])
|
||||||
callback(f'<img src="{pth}" width="300">', MSG_TYPE.MSG_TYPE_NEW_MESSAGE, parameters={'type':MSG_TYPE.MSG_TYPE_FULL.value,'metadata':[]})
|
self.new_message("",MSG_TYPE.MSG_TYPE_FULL)
|
||||||
|
output = f'<img src="{pth}" width="300">\n\n'
|
||||||
|
self.full(output)
|
||||||
|
|
||||||
|
if self.model.binding_type not in [BindingType.TEXT_IMAGE, BindingType.TEXT_IMAGE_VIDEO]:
|
||||||
|
self.step_start("Understanding image (please wait)")
|
||||||
|
from PIL import Image
|
||||||
|
img = Image.open(str(path))
|
||||||
|
# Convert the image to RGB mode
|
||||||
|
img = img.convert("RGB")
|
||||||
|
output += "## image description :\n"+ self.model.interrogate_blip([img])[0]
|
||||||
|
# output += "## image description :\n"+ self.model.qna_blip([img],"Describe this photo with details.\n")[0]
|
||||||
|
self.full(output)
|
||||||
|
self.step_end("Understanding image (please wait)")
|
||||||
|
if self.config.debug:
|
||||||
|
ASCIIColors.yellow(output)
|
||||||
|
else:
|
||||||
|
self.step_start("Importing image (please wait)")
|
||||||
|
self.step_end("Importing image (please wait)")
|
||||||
|
self.full(output)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
trace_exception(ex)
|
||||||
|
self.step_end("Understanding image (please wait)", False)
|
||||||
ASCIIColors.error("Couldn't create new message")
|
ASCIIColors.error("Couldn't create new message")
|
||||||
self.image_files.append(path)
|
self.image_files.append(path)
|
||||||
ASCIIColors.info("Received image file")
|
ASCIIColors.info("Received image file")
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback("Image file added successfully",MSG_TYPE.MSG_TYPE_INFO)
|
callback("Image file added successfully", MSG_TYPE.MSG_TYPE_INFO)
|
||||||
else:
|
else:
|
||||||
self.text_files.append(path)
|
self.text_files.append(path)
|
||||||
ASCIIColors.info("Received text compatible file")
|
ASCIIColors.info("Received text compatible file")
|
||||||
@ -1206,6 +1323,7 @@ class APScript(StateMachine):
|
|||||||
self.notify = personality.app.notify
|
self.notify = personality.app.notify
|
||||||
self.text_files = []
|
self.text_files = []
|
||||||
self.image_files = []
|
self.image_files = []
|
||||||
|
self.images_descriptions=[]
|
||||||
|
|
||||||
self.personality = personality
|
self.personality = personality
|
||||||
self.personality_config = personality_config
|
self.personality_config = personality_config
|
||||||
|
2
setup.py
2
setup.py
@ -26,7 +26,7 @@ def get_all_files(path):
|
|||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="lollms",
|
name="lollms",
|
||||||
version="6.5.2",
|
version="6.6.0",
|
||||||
author="Saifeddine ALOUI",
|
author="Saifeddine ALOUI",
|
||||||
author_email="aloui.saifeddine@gmail.com",
|
author_email="aloui.saifeddine@gmail.com",
|
||||||
description="A python library for AI personality definition",
|
description="A python library for AI personality definition",
|
||||||
|
Loading…
Reference in New Issue
Block a user