reenabled responder

This commit is contained in:
SevaSk 2023-05-09 17:07:51 -04:00
parent aa9557e13b
commit 7d4df03ee1
3 changed files with 33 additions and 26 deletions

View File

@ -3,10 +3,12 @@ import whisper
import torch import torch
import wave import wave
import os import os
import threading
class AudioTranscriber: class AudioTranscriber:
def __init__(self): def __init__(self):
self.transcript = [] self.transcript_data = []
self.transcript_changed_event = threading.Event()
self.audio_model = whisper.load_model(os.getcwd() + r'\tiny.en' + '.pt') self.audio_model = whisper.load_model(os.getcwd() + r'\tiny.en' + '.pt')
def transcribe(self, audio_data): def transcribe(self, audio_data):
@ -28,7 +30,8 @@ class AudioTranscriber:
audio_data_transcription = self.transcribe(audio_data) audio_data_transcription = self.transcribe(audio_data)
# whisper gives "you" on many null inputs # whisper gives "you" on many null inputs
if audio_data_transcription != '' and audio_data_transcription.lower() != 'you': if audio_data_transcription != '' and audio_data_transcription.lower() != 'you':
self.transcript = [source + ": [" + audio_data_transcription + ']\n\n'] + self.transcript self.transcript_data = [source + ": [" + audio_data_transcription + ']\n\n'] + self.transcript_data
self.transcript_changed_event.set()
def get_transcript(self): def get_transcript(self):
return "".join(self.transcript) return "".join(self.transcript_data)

View File

@ -1,21 +1,17 @@
import openai import openai
from keys import OPENAI_API_KEY from keys import OPENAI_API_KEY
from prompts import create_prompt, INITIAL_RESPONSE from prompts import create_prompt
openai.api_key = OPENAI_API_KEY openai.api_key = OPENAI_API_KEY
class GPTResponder: def generate_response_from_transcript(transcript):
def __init__(self): response = openai.ChatCompletion.create(
self.last_response = INITIAL_RESPONSE model="gpt-3.5-turbo-0301",
messages=[{"role": "system", "content": create_prompt(transcript)}],
def generate_response_from_transcript(self, transcript): temperature = 0.0
response = openai.ChatCompletion.create( )
model="gpt-3.5-turbo-0301", full_response = response.choices[0].message.content
messages=[{"role": "system", "content": create_prompt(transcript)}], try:
temperature = 0.0 return full_response.split('[')[1].split(']')[0]
) except:
full_response = response.choices[0].message.content return ''
try:
self.last_response = full_response.split('[')[1].split(']')[0]
except:
pass

22
main.py
View File

@ -1,11 +1,12 @@
import soundcard as sc import soundcard as sc
import threading import threading
from AudioTranscriber import AudioTranscriber from AudioTranscriber import AudioTranscriber
from GPTResponder import GPTResponder import GPTResponder
import customtkinter as ctk import customtkinter as ctk
from Microphone import Microphone from Microphone import Microphone
from AudioRecorder import AudioRecorder from AudioRecorder import AudioRecorder
import queue import queue
from prompts import INITIAL_RESPONSE
def write_in_textbox(textbox, text): def write_in_textbox(textbox, text):
textbox.delete("0.0", "end") textbox.delete("0.0", "end")
@ -17,15 +18,22 @@ def update_transcript_UI(transcriber, textbox):
textbox.insert("0.0", transcript_string) textbox.insert("0.0", transcript_string)
textbox.after(300, update_transcript_UI, transcriber, textbox) textbox.after(300, update_transcript_UI, transcriber, textbox)
def update_response_UI(transcriber_mic, transcriber_speaker, responder, textbox, update_interval_slider_label, update_interval_slider): def update_response(transcriber, last_response, textbox, update_interval_slider_label, update_interval_slider):
#transcript_string = create_transcript_string(transcriber_mic, transcriber_speaker,reverse=False)
textbox.configure(state="normal") textbox.configure(state="normal")
textbox.delete("0.0", "end") textbox.delete("0.0", "end")
textbox.insert("0.0", responder.last_response)
if transcriber.transcript_changed_event.is_set():
transcriber.transcript_changed_event.clear()
transcript_string = transcriber.get_transcript()
response = GPTResponder.generate_response_from_transcript(transcript_string)
if response != '':
last_response = response
textbox.insert("0.0", last_response)
textbox.configure(state="disabled") textbox.configure(state="disabled")
update_interval = int(update_interval_slider.get()) update_interval = int(update_interval_slider.get())
update_interval_slider_label.configure(text=f"Update interval: {update_interval} seconds") update_interval_slider_label.configure(text=f"Update interval: {update_interval} seconds")
textbox.after(int(update_interval * 1000), update_response_UI, transcriber_mic, transcriber_speaker, responder, textbox, update_interval_slider_label, update_interval_slider) textbox.after(int(update_interval * 1000), update_response, transcriber, last_response, textbox, update_interval_slider_label, update_interval_slider)
def clear_transcript_data(transcriber_mic, transcriber_speaker): def clear_transcript_data(transcriber_mic, transcriber_speaker):
transcriber_mic.transcript_data.clear() transcriber_mic.transcript_data.clear()
@ -87,6 +95,6 @@ if __name__ == "__main__":
root.grid_columnconfigure(1, weight=1) root.grid_columnconfigure(1, weight=1)
update_transcript_UI(global_transcriber, transcript_textbox) update_transcript_UI(global_transcriber, transcript_textbox)
#update_response_UI(user_transcriber, transcriber_speaker, responder, response_textbox, update_interval_slider_label, update_interval_slider) update_response(global_transcriber, INITIAL_RESPONSE, response_textbox, update_interval_slider_label, update_interval_slider)
root.mainloop() root.mainloop()