diff --git a/README.md b/README.md index 8fa30084..4ee1a23e 100644 --- a/README.md +++ b/README.md @@ -215,8 +215,8 @@ Now you're ready to work! # Supported backends Two backends are now supported: -1 - The llama_cpp backend -2 - The GPT-j backend +1 - [The llama_cpp backend](https://github.com/nomic-ai/pygpt4all) +2 - [The GPT-j backend](https://github.com/marella/gpt4all-j) 3 - Hugging face's Transformers (under construction) # Supported models diff --git a/app.py b/app.py index 5885a315..42deb15f 100644 --- a/app.py +++ b/app.py @@ -146,8 +146,8 @@ class Gpt4AllWebUI(GPT4AllAPI): ) def list_backends(self): - backends_dir = Path('./pyGpt4All/backends') # replace with the actual path to the models folder - backends = [f.stem for f in backends_dir.glob('*.py') if f.stem!="backend" and f.stem!="__init__"] + backends_dir = Path('./backends') # replace with the actual path to the models folder + backends = [f.stem for f in backends_dir.iterdir() if f.is_dir()] return jsonify(backends) diff --git a/pyGpt4All/backends/__init__.py b/backends/__init__.py similarity index 100% rename from pyGpt4All/backends/__init__.py rename to backends/__init__.py diff --git a/pyGpt4All/backends/gpt_j.py b/backends/gpt_j/__init__.py similarity index 98% rename from pyGpt4All/backends/gpt_j.py rename to backends/gpt_j/__init__.py index 5e5cfc6f..1f0c638a 100644 --- a/pyGpt4All/backends/gpt_j.py +++ b/backends/gpt_j/__init__.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Callable from gpt4allj import Model -from pyGpt4All.backends.backend import GPTBackend +from pyGpt4All.backend import GPTBackend __author__ = "parisneo" __github__ = "https://github.com/nomic-ai/gpt4all-ui" diff --git a/pyGpt4All/backends/transformers.py b/backends/gpt_q/__init__.py similarity index 98% rename from pyGpt4All/backends/transformers.py rename to backends/gpt_q/__init__.py index 791b2cfe..d0c48d43 100644 --- a/pyGpt4All/backends/transformers.py +++ b/backends/gpt_q/__init__.py @@ -11,7 +11,7 @@ from pathlib import Path from typing import Callable from transformers import AutoTokenizer from transformers import AutoModelForCausalLM -from pyGpt4All.backends.backend import GPTBackend +from pyGpt4All.backend import GPTBackend __author__ = "parisneo" __github__ = "https://github.com/nomic-ai/gpt4all-ui" diff --git a/pyGpt4All/backends/llama_cpp.py b/backends/llama_cpp/__init__.py similarity index 97% rename from pyGpt4All/backends/llama_cpp.py rename to backends/llama_cpp/__init__.py index 6332ffc3..f84b3e4b 100644 --- a/pyGpt4All/backends/llama_cpp.py +++ b/backends/llama_cpp/__init__.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Callable from pyllamacpp.model import Model -from pyGpt4All.backends.backend import GPTBackend +from pyGpt4All.backend import GPTBackend __author__ = "parisneo" __github__ = "https://github.com/nomic-ai/gpt4all-ui" diff --git a/pyGpt4All/api.py b/pyGpt4All/api.py index 17fb9a4c..d184afbd 100644 --- a/pyGpt4All/api.py +++ b/pyGpt4All/api.py @@ -46,7 +46,7 @@ class GPT4AllAPI(): self.full_message_list = [] # Select backend - self.BACKENDS_LIST = {f.stem:f for f in (Path("pyGpt4All")/"backends").glob("*.py") if f.stem not in ["__init__","backend"]} + self.BACKENDS_LIST = {f.stem:f for f in Path("backends").iterdir() if f.is_dir()} self.load_backend(self.BACKENDS_LIST[self.config["backend"]]) @@ -86,7 +86,7 @@ class GPT4AllAPI(): module_name = backend_path.stem # use importlib to load the module from the file path - loader = importlib.machinery.SourceFileLoader(module_name, str(absolute_path)) + loader = importlib.machinery.SourceFileLoader(module_name, str(absolute_path/"__init__.py")) backend_module = loader.load_module() backend_class = getattr(backend_module, backend_module.backend_name) self.backend = backend_class diff --git a/pyGpt4All/backends/backend.py b/pyGpt4All/backend.py similarity index 100% rename from pyGpt4All/backends/backend.py rename to pyGpt4All/backend.py diff --git a/requirements.txt b/requirements.txt index 82cb2a22..dfb0be34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,8 @@ pyyaml markdown pyllamacpp==1.0.6 gpt4all-j==0.2.1 +--find-links https://download.pytorch.org/whl/cu117 +torch==2.0.0 +torchvision +torchaudio transformers \ No newline at end of file diff --git a/static/js/main.js b/static/js/main.js index a1026008..e99726c7 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -14,119 +14,134 @@ function update_main(){ }); }) - userInput.addEventListener('keydown', function(event) { - if (event.shiftKey && event.key === 'Enter') { - event.preventDefault(); - userInput.style.height = userInput.scrollHeight + 'px'; - userInput.value += '\n'; - } - }); - chatForm.addEventListener('submit', event => { - event.preventDefault(); - console.log("Submitting") + + function submit_form(){ + console.log("Submitting") - // get user input and clear input field - message = userInput.value; - userInput.value = ''; - - // add user message to chat window - const sendbtn = document.querySelector("#submit-input") - const waitAnimation = document.querySelector("#wait-animation") - const stopGeneration = document.querySelector("#stop-generation") - - sendbtn.style.display="none"; - waitAnimation.style.display="block"; - stopGeneration.style.display = "block"; - console.log("Sending message to bot") + // get user input and clear input field + message = userInput.value; + userInput.value = ''; + + // add user message to chat window + const sendbtn = document.querySelector("#submit-input") + const waitAnimation = document.querySelector("#wait-animation") + const stopGeneration = document.querySelector("#stop-generation") + + sendbtn.style.display="none"; + waitAnimation.style.display="block"; + stopGeneration.style.display = "block"; + console.log("Sending message to bot") - user_msg = addMessage('',message, 0, 0, can_edit=true); - bot_msg = addMessage('', '', 0, 0, can_edit=true); + user_msg = addMessage('',message, 0, 0, can_edit=true); + bot_msg = addMessage('', '', 0, 0, can_edit=true); - fetch('/generate', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ message }) - }).then(function(response) { - const stream = new ReadableStream({ - start(controller) { - const reader = response.body.getReader(); - function push() { - reader.read().then(function(result) { - if (result.done) { - sendbtn.style.display="block"; - waitAnimation.style.display="none"; - stopGeneration.style.display = "none"; - console.log(result) - controller.close(); - return; - } - controller.enqueue(result.value); - push(); - }) - } - push(); - } - }); - const textDecoder = new TextDecoder(); - const readableStreamDefaultReader = stream.getReader(); - let entry_counter = 0 - function readStream() { - readableStreamDefaultReader.read().then(function(result) { - if (result.done) { - return; - } - - text = textDecoder.decode(result.value); - - // The server will first send a json containing information about the message just sent - if(entry_counter==0) - { - // We parse it and - infos = JSON.parse(text); + fetch('/generate', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ message }) + }).then(function(response) { + const stream = new ReadableStream({ + start(controller) { + const reader = response.body.getReader(); + function push() { + reader.read().then(function(result) { + if (result.done) { + sendbtn.style.display="block"; + waitAnimation.style.display="none"; + stopGeneration.style.display = "none"; + console.log(result) + controller.close(); + return; + } + controller.enqueue(result.value); + push(); + }) + } + push(); + } + }); + const textDecoder = new TextDecoder(); + const readableStreamDefaultReader = stream.getReader(); + let entry_counter = 0 + function readStream() { + readableStreamDefaultReader.read().then(function(result) { + if (result.done) { + return; + } - user_msg.setSender(infos.user); - user_msg.setMessage(infos.message); - user_msg.setID(infos.id); - bot_msg.setSender(infos.bot); - bot_msg.setID(infos.response_id); + text = textDecoder.decode(result.value); - bot_msg.messageTextElement; - bot_msg.hiddenElement; - entry_counter ++; + // The server will first send a json containing information about the message just sent + if(entry_counter==0) + { + // We parse it and + infos = JSON.parse(text); + + user_msg.setSender(infos.user); + user_msg.setMessage(infos.message); + user_msg.setID(infos.id); + bot_msg.setSender(infos.bot); + bot_msg.setID(infos.response_id); + + bot_msg.messageTextElement; + bot_msg.hiddenElement; + entry_counter ++; + } + else{ + entry_counter ++; + prefix = "FINAL:"; + if(text.startsWith(prefix)){ + text = text.substring(prefix.length); + bot_msg.hiddenElement.innerHTML = text + bot_msg.messageTextElement.innerHTML = text } else{ - entry_counter ++; - prefix = "FINAL:"; - if(text.startsWith(prefix)){ - text = text.substring(prefix.length); - bot_msg.hiddenElement.innerHTML = text - bot_msg.messageTextElement.innerHTML = text - } - else{ - // For the other enrtries, these are just the text of the chatbot - for (const char of text) { - txt = bot_msg.hiddenElement.innerHTML; - if (char != '\f') { - txt += char - bot_msg.hiddenElement.innerHTML = txt; - bot_msg.messageTextElement.innerHTML = txt; - } + // For the other enrtries, these are just the text of the chatbot + for (const char of text) { + txt = bot_msg.hiddenElement.innerHTML; + if (char != '\f') { + txt += char + bot_msg.hiddenElement.innerHTML = txt; + bot_msg.messageTextElement.innerHTML = txt; + } - // scroll to bottom of chat window - chatWindow.scrollTop = chatWindow.scrollHeight; - } - - } + // scroll to bottom of chat window + chatWindow.scrollTop = chatWindow.scrollHeight; } - - readStream(); - }); - } - readStream(); - }); - + + } + } + + readStream(); + }); + } + readStream(); + }); + } + chatForm.addEventListener('submit', event => { + event.preventDefault(); + submit_form(); }); + userInput.addEventListener("keyup", function(event) { + // Check if Enter key was pressed while holding Shift + // Also check if Shift + Ctrl keys were pressed while typing + // These combinations override the submit action + const shiftPressed = event.shiftKey; + const ctrlPressed = event.ctrlKey && !event.metaKey; + + if ((!shiftPressed) && event.key === "Enter") { + submit_form(); + } + // Restore original functionality for the remaining cases + else if (!shiftPressed && ctrlPressed) { + setTimeout(() => { + userInput.focus(); + contentEditable.value += event.data; + lastValue.innerHTML = userInput.value; + }, 0); + } + }); } \ No newline at end of file diff --git a/templates/main.html b/templates/main.html index 4d88f6ab..28f5b5c8 100644 --- a/templates/main.html +++ b/templates/main.html @@ -35,7 +35,7 @@
- +