mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-29 15:44:17 +00:00
Enhance autogptq backend to support VL models (#1860)
* Enhance autogptq backend to support VL models * update dependencies for autogptq * remove redundant auto-gptq dependency * Convert base64 to image_url for Qwen-VL model * implemented model inference for qwen-vl * remove user prompt from generated answer * fixed write image error --------- Co-authored-by: Binghua Wu <bingwu@estee.com>
This commit is contained in:
parent
e58410fa99
commit
b7ffe66219
@ -5,12 +5,14 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import base64
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
import backend_pb2
|
import backend_pb2
|
||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
from auto_gptq import AutoGPTQForCausalLM
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from transformers import TextGenerationPipeline
|
from transformers import TextGenerationPipeline
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@ -28,9 +30,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if request.Device != "":
|
if request.Device != "":
|
||||||
device = request.Device
|
device = request.Device
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(request.Model, use_fast=request.UseFastTokenizer)
|
# support loading local model files
|
||||||
|
model_path = os.path.join(os.environ.get('MODELS_PATH', './'), request.Model)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=request.TrustRemoteCode)
|
||||||
|
|
||||||
model = AutoGPTQForCausalLM.from_quantized(request.Model,
|
# support model `Qwen/Qwen-VL-Chat-Int4`
|
||||||
|
if "qwen-vl" in request.Model.lower():
|
||||||
|
self.model_name = "Qwen-VL-Chat"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
|
use_triton=request.UseTriton,
|
||||||
|
device_map="auto").eval()
|
||||||
|
else:
|
||||||
|
model = AutoGPTQForCausalLM.from_quantized(model_path,
|
||||||
model_basename=request.ModelBaseName,
|
model_basename=request.ModelBaseName,
|
||||||
use_safetensors=True,
|
use_safetensors=True,
|
||||||
trust_remote_code=request.TrustRemoteCode,
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
@ -55,6 +67,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if request.TopP != 0.0:
|
if request.TopP != 0.0:
|
||||||
top_p = request.TopP
|
top_p = request.TopP
|
||||||
|
|
||||||
|
|
||||||
|
prompt_images = self.recompile_vl_prompt(request)
|
||||||
|
compiled_prompt = prompt_images[0]
|
||||||
|
print(f"Prompt: {compiled_prompt}", file=sys.stderr)
|
||||||
|
|
||||||
# Implement Predict RPC
|
# Implement Predict RPC
|
||||||
pipeline = TextGenerationPipeline(
|
pipeline = TextGenerationPipeline(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@ -64,10 +81,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
repetition_penalty=penalty,
|
repetition_penalty=penalty,
|
||||||
)
|
)
|
||||||
t = pipeline(request.Prompt)[0]["generated_text"]
|
t = pipeline(compiled_prompt)[0]["generated_text"]
|
||||||
# Remove prompt from response if present
|
print(f"generated_text: {t}", file=sys.stderr)
|
||||||
if request.Prompt in t:
|
|
||||||
t = t.replace(request.Prompt, "")
|
if compiled_prompt in t:
|
||||||
|
t = t.replace(compiled_prompt, "")
|
||||||
|
# house keeping. Remove the image files from /tmp folder
|
||||||
|
for img_path in prompt_images[1]:
|
||||||
|
try:
|
||||||
|
os.remove(img_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error removing image file: {img_path}, {e}", file=sys.stderr)
|
||||||
|
|
||||||
return backend_pb2.Result(message=bytes(t, encoding='utf-8'))
|
return backend_pb2.Result(message=bytes(t, encoding='utf-8'))
|
||||||
|
|
||||||
@ -78,6 +102,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
# Not implemented yet
|
# Not implemented yet
|
||||||
return self.Predict(request, context)
|
return self.Predict(request, context)
|
||||||
|
|
||||||
|
def recompile_vl_prompt(self, request):
|
||||||
|
prompt = request.Prompt
|
||||||
|
image_paths = []
|
||||||
|
|
||||||
|
if "qwen-vl" in self.model_name.lower():
|
||||||
|
# request.Images is an array which contains base64 encoded images. Iterate the request.Images array, decode and save each image to /tmp folder with a random filename.
|
||||||
|
# Then, save the image file paths to an array "image_paths".
|
||||||
|
# read "request.Prompt", replace "[img-%d]" with the image file paths in the order they appear in "image_paths". Save the new prompt to "prompt".
|
||||||
|
for i, img in enumerate(request.Images):
|
||||||
|
timestamp = str(int(time.time() * 1000)) # Generate timestamp
|
||||||
|
img_path = f"/tmp/vl-{timestamp}.jpg" # Use timestamp in filename
|
||||||
|
with open(img_path, "wb") as f:
|
||||||
|
f.write(base64.b64decode(img))
|
||||||
|
image_paths.append(img_path)
|
||||||
|
prompt = prompt.replace(f"[img-{i}]", "<img>" + img_path + "</img>,")
|
||||||
|
else:
|
||||||
|
prompt = request.Prompt
|
||||||
|
return (prompt, image_paths)
|
||||||
|
|
||||||
def serve(address):
|
def serve(address):
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
####
|
||||||
|
# Attention! This file is abandoned.
|
||||||
|
# Please use the ../common-env/transformers/transformers.yml file to manage dependencies.
|
||||||
|
###
|
||||||
name: autogptq
|
name: autogptq
|
||||||
channels:
|
channels:
|
||||||
- defaults
|
- defaults
|
||||||
@ -24,12 +28,12 @@ dependencies:
|
|||||||
- xz=5.4.2=h5eee18b_0
|
- xz=5.4.2=h5eee18b_0
|
||||||
- zlib=1.2.13=h5eee18b_0
|
- zlib=1.2.13=h5eee18b_0
|
||||||
- pip:
|
- pip:
|
||||||
- accelerate==0.23.0
|
- accelerate==0.27.0
|
||||||
- aiohttp==3.8.5
|
- aiohttp==3.8.5
|
||||||
- aiosignal==1.3.1
|
- aiosignal==1.3.1
|
||||||
- async-timeout==4.0.3
|
- async-timeout==4.0.3
|
||||||
- attrs==23.1.0
|
- attrs==23.1.0
|
||||||
- auto-gptq==0.4.2
|
- auto-gptq==0.7.1
|
||||||
- certifi==2023.7.22
|
- certifi==2023.7.22
|
||||||
- charset-normalizer==3.3.0
|
- charset-normalizer==3.3.0
|
||||||
- datasets==2.14.5
|
- datasets==2.14.5
|
||||||
@ -59,6 +63,7 @@ dependencies:
|
|||||||
- nvidia-nccl-cu12==2.18.1
|
- nvidia-nccl-cu12==2.18.1
|
||||||
- nvidia-nvjitlink-cu12==12.2.140
|
- nvidia-nvjitlink-cu12==12.2.140
|
||||||
- nvidia-nvtx-cu12==12.1.105
|
- nvidia-nvtx-cu12==12.1.105
|
||||||
|
- optimum==1.17.1
|
||||||
- packaging==23.2
|
- packaging==23.2
|
||||||
- pandas==2.1.1
|
- pandas==2.1.1
|
||||||
- peft==0.5.0
|
- peft==0.5.0
|
||||||
@ -75,9 +80,11 @@ dependencies:
|
|||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers==0.14.0
|
- tokenizers==0.14.0
|
||||||
- torch==2.1.0
|
|
||||||
- tqdm==4.66.1
|
- tqdm==4.66.1
|
||||||
|
- torch==2.2.1
|
||||||
|
- torchvision==0.17.1
|
||||||
- transformers==4.34.0
|
- transformers==4.34.0
|
||||||
|
- transformers_stream_generator==0.0.5
|
||||||
- triton==2.1.0
|
- triton==2.1.0
|
||||||
- typing-extensions==4.8.0
|
- typing-extensions==4.8.0
|
||||||
- tzdata==2023.3
|
- tzdata==2023.3
|
||||||
|
@ -24,10 +24,11 @@ dependencies:
|
|||||||
- xz=5.4.2=h5eee18b_0
|
- xz=5.4.2=h5eee18b_0
|
||||||
- zlib=1.2.13=h5eee18b_0
|
- zlib=1.2.13=h5eee18b_0
|
||||||
- pip:
|
- pip:
|
||||||
- accelerate==0.23.0
|
- accelerate==0.27.0
|
||||||
- aiohttp==3.8.5
|
- aiohttp==3.8.5
|
||||||
- aiosignal==1.3.1
|
- aiosignal==1.3.1
|
||||||
- async-timeout==4.0.3
|
- async-timeout==4.0.3
|
||||||
|
- auto-gptq==0.7.1
|
||||||
- attrs==23.1.0
|
- attrs==23.1.0
|
||||||
- bark==0.1.5
|
- bark==0.1.5
|
||||||
- bitsandbytes==0.43.0
|
- bitsandbytes==0.43.0
|
||||||
@ -69,6 +70,7 @@ dependencies:
|
|||||||
- nvidia-nccl-cu12==2.18.1
|
- nvidia-nccl-cu12==2.18.1
|
||||||
- nvidia-nvjitlink-cu12==12.2.140
|
- nvidia-nvjitlink-cu12==12.2.140
|
||||||
- nvidia-nvtx-cu12==12.1.105
|
- nvidia-nvtx-cu12==12.1.105
|
||||||
|
- optimum==1.17.1
|
||||||
- packaging==23.2
|
- packaging==23.2
|
||||||
- pandas
|
- pandas
|
||||||
- peft==0.5.0
|
- peft==0.5.0
|
||||||
@ -87,7 +89,8 @@ dependencies:
|
|||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers
|
- tokenizers
|
||||||
- torch==2.1.2
|
- torch==2.2.1
|
||||||
|
- torchvision==0.17.1
|
||||||
- torchaudio==2.1.2
|
- torchaudio==2.1.2
|
||||||
- tqdm==4.66.1
|
- tqdm==4.66.1
|
||||||
- triton==2.1.0
|
- triton==2.1.0
|
||||||
@ -95,7 +98,6 @@ dependencies:
|
|||||||
- tzdata==2023.3
|
- tzdata==2023.3
|
||||||
- urllib3==1.26.17
|
- urllib3==1.26.17
|
||||||
- xxhash==3.4.1
|
- xxhash==3.4.1
|
||||||
- auto-gptq==0.6.0
|
|
||||||
- yarl==1.9.2
|
- yarl==1.9.2
|
||||||
- soundfile
|
- soundfile
|
||||||
- langid
|
- langid
|
||||||
@ -116,5 +118,6 @@ dependencies:
|
|||||||
- vocos
|
- vocos
|
||||||
- vllm==0.3.2
|
- vllm==0.3.2
|
||||||
- transformers>=4.38.2 # Updated Version
|
- transformers>=4.38.2 # Updated Version
|
||||||
|
- transformers_stream_generator==0.0.5
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
@ -26,7 +26,8 @@ dependencies:
|
|||||||
- pip:
|
- pip:
|
||||||
- --pre
|
- --pre
|
||||||
- --extra-index-url https://download.pytorch.org/whl/nightly/
|
- --extra-index-url https://download.pytorch.org/whl/nightly/
|
||||||
- accelerate==0.23.0
|
- accelerate==0.27.0
|
||||||
|
- auto-gptq==0.7.1
|
||||||
- aiohttp==3.8.5
|
- aiohttp==3.8.5
|
||||||
- aiosignal==1.3.1
|
- aiosignal==1.3.1
|
||||||
- async-timeout==4.0.3
|
- async-timeout==4.0.3
|
||||||
@ -82,7 +83,6 @@ dependencies:
|
|||||||
- triton==2.1.0
|
- triton==2.1.0
|
||||||
- typing-extensions==4.8.0
|
- typing-extensions==4.8.0
|
||||||
- tzdata==2023.3
|
- tzdata==2023.3
|
||||||
- auto-gptq==0.6.0
|
|
||||||
- urllib3==1.26.17
|
- urllib3==1.26.17
|
||||||
- xxhash==3.4.1
|
- xxhash==3.4.1
|
||||||
- yarl==1.9.2
|
- yarl==1.9.2
|
||||||
@ -90,6 +90,7 @@ dependencies:
|
|||||||
- langid
|
- langid
|
||||||
- wget
|
- wget
|
||||||
- unidecode
|
- unidecode
|
||||||
|
- optimum==1.17.1
|
||||||
- pyopenjtalk-prebuilt
|
- pyopenjtalk-prebuilt
|
||||||
- pypinyin
|
- pypinyin
|
||||||
- inflect
|
- inflect
|
||||||
@ -105,5 +106,6 @@ dependencies:
|
|||||||
- vocos
|
- vocos
|
||||||
- vllm==0.3.2
|
- vllm==0.3.2
|
||||||
- transformers>=4.38.2 # Updated Version
|
- transformers>=4.38.2 # Updated Version
|
||||||
|
- transformers_stream_generator==0.0.5
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
@ -24,9 +24,10 @@ dependencies:
|
|||||||
- xz=5.4.2=h5eee18b_0
|
- xz=5.4.2=h5eee18b_0
|
||||||
- zlib=1.2.13=h5eee18b_0
|
- zlib=1.2.13=h5eee18b_0
|
||||||
- pip:
|
- pip:
|
||||||
- accelerate==0.23.0
|
- accelerate==0.27.0
|
||||||
- aiohttp==3.8.5
|
- aiohttp==3.8.5
|
||||||
- aiosignal==1.3.1
|
- aiosignal==1.3.1
|
||||||
|
- auto-gptq==0.7.1
|
||||||
- async-timeout==4.0.3
|
- async-timeout==4.0.3
|
||||||
- attrs==23.1.0
|
- attrs==23.1.0
|
||||||
- bark==0.1.5
|
- bark==0.1.5
|
||||||
@ -56,6 +57,7 @@ dependencies:
|
|||||||
- multiprocess==0.70.15
|
- multiprocess==0.70.15
|
||||||
- networkx
|
- networkx
|
||||||
- numpy==1.26.0
|
- numpy==1.26.0
|
||||||
|
- optimum==1.17.1
|
||||||
- packaging==23.2
|
- packaging==23.2
|
||||||
- pandas
|
- pandas
|
||||||
- peft==0.5.0
|
- peft==0.5.0
|
||||||
@ -74,13 +76,13 @@ dependencies:
|
|||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers
|
- tokenizers
|
||||||
- torch==2.1.2
|
- torch==2.2.1
|
||||||
|
- torchvision==0.17.1
|
||||||
- torchaudio==2.1.2
|
- torchaudio==2.1.2
|
||||||
- tqdm==4.66.1
|
- tqdm==4.66.1
|
||||||
- triton==2.1.0
|
- triton==2.1.0
|
||||||
- typing-extensions==4.8.0
|
- typing-extensions==4.8.0
|
||||||
- tzdata==2023.3
|
- tzdata==2023.3
|
||||||
- auto-gptq==0.6.0
|
|
||||||
- urllib3==1.26.17
|
- urllib3==1.26.17
|
||||||
- xxhash==3.4.1
|
- xxhash==3.4.1
|
||||||
- yarl==1.9.2
|
- yarl==1.9.2
|
||||||
@ -103,5 +105,6 @@ dependencies:
|
|||||||
- vocos
|
- vocos
|
||||||
- vllm==0.3.2
|
- vllm==0.3.2
|
||||||
- transformers>=4.38.2 # Updated Version
|
- transformers>=4.38.2 # Updated Version
|
||||||
|
- transformers_stream_generator==0.0.5
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user