added --api flag

This commit is contained in:
SevaSk 2023-05-29 20:34:23 -04:00
parent e0c72ff24c
commit c0b43f3680
4 changed files with 55 additions and 15 deletions

View File

@ -15,10 +15,10 @@ PHRASE_TIMEOUT = 3.05
MAX_PHRASES = 10 MAX_PHRASES = 10
class AudioTranscriber: class AudioTranscriber:
def __init__(self, mic_source, speaker_source): def __init__(self, mic_source, speaker_source, model):
self.transcript_data = {"You": [], "Speaker": []} self.transcript_data = {"You": [], "Speaker": []}
self.transcript_changed_event = threading.Event() self.transcript_changed_event = threading.Event()
self.audio_model = whisper.load_model(os.path.join(os.getcwd(), 'tiny.en.pt')) self.audio_model = model
self.audio_sources = { self.audio_sources = {
"You": { "You": {
"sample_rate": mic_source.SAMPLE_RATE, "sample_rate": mic_source.SAMPLE_RATE,
@ -46,7 +46,7 @@ class AudioTranscriber:
self.update_last_sample_and_phrase_status(who_spoke, data, time_spoken) self.update_last_sample_and_phrase_status(who_spoke, data, time_spoken)
source_info = self.audio_sources[who_spoke] source_info = self.audio_sources[who_spoke]
temp_file = source_info["process_data_func"](source_info["last_sample"]) temp_file = source_info["process_data_func"](source_info["last_sample"])
text = self.get_transcription(temp_file) text = self.audio_model.get_transcription(temp_file)
if text != '' and text.lower() != 'you': if text != '' and text.lower() != 'you':
self.update_transcript(who_spoke, text, time_spoken) self.update_transcript(who_spoke, text, time_spoken)
@ -81,10 +81,6 @@ class AudioTranscriber:
wf.writeframes(data) wf.writeframes(data)
return temp_file return temp_file
def get_transcription(self, file_path):
result = self.audio_model.transcribe(file_path, fp16=torch.cuda.is_available())
return result['text'].strip()
def update_transcript(self, who_spoke, text, time_spoken): def update_transcript(self, who_spoke, text, time_spoken):
source_info = self.audio_sources[who_spoke] source_info = self.audio_sources[who_spoke]
transcript = self.transcript_data[who_spoke] transcript = self.transcript_data[who_spoke]

View File

@ -68,7 +68,15 @@ Run the main script:
python main.py python main.py
``` ```
Now, Ecoute will start transcribing your microphone input and speaker output in real-time, and provide a suggested response based on the conversation. It may take a couple of seconds to warm up before the transcription becomes real-time. For a better and faster version, use:
```
python main.py --api
```
Upon initiation, Ecoute will begin transcribing your microphone input and speaker output in real-time, generating a suggested response based on the conversation. Please note that it might take a few seconds for the system to warm up before the transcription becomes real-time.
The --api flag significantly enhances transcription speed and accuracy, and it's expected to be the default option in future releases. However, keep in mind that using the Whisper API will consume more OpenAI credits than using the local model. This increased cost is attributed to the advanced features and capabilities that the Whisper API provides. Despite the additional cost, the considerable improvements in speed and transcription accuracy might make it a worthwhile investment for your use case.
### ⚠️ Limitations ### ⚠️ Limitations

34
TranscriberModels.py Normal file
View File

@ -0,0 +1,34 @@
import openai
import whisper
import os
import torch
def get_model(use_api):
if use_api:
return APIWhisperTranscriber()
else:
return WhisperTranscriber()
class WhisperTranscriber:
def __init__(self):
self.audio_model = whisper.load_model(os.path.join(os.getcwd(), 'tiny.en.pt'))
print(f"[INFO] Whisper using GPU: " + str(torch.cuda.is_available()))
def get_transcription(self, wav_file_path):
try:
result = self.audio_model.transcribe(wav_file_path, fp16=torch.cuda.is_available())
except Exception as e:
print(e)
return result['text'].strip()
class APIWhisperTranscriber:
def get_transcription(self, wav_file_path):
new_file_path = wav_file_path + '.wav'
os.rename(wav_file_path, new_file_path)
audio_file= open(new_file_path, "rb")
try:
result = openai.Audio.translate("whisper-1", audio_file)
except Exception as e:
print(e)
return result['text'].strip()

16
main.py
View File

@ -6,6 +6,8 @@ import AudioRecorder
import queue import queue
import time import time
import torch import torch
import sys
import TranscriberModels
def write_in_textbox(textbox, text): def write_in_textbox(textbox, text):
textbox.delete("0.0", "end") textbox.delete("0.0", "end")
@ -76,15 +78,15 @@ def main():
speaker_audio_recorder = AudioRecorder.DefaultSpeakerRecorder() speaker_audio_recorder = AudioRecorder.DefaultSpeakerRecorder()
speaker_audio_recorder.record_into_queue(audio_queue) speaker_audio_recorder.record_into_queue(audio_queue)
global_transcriber = AudioTranscriber(user_audio_recorder.source, speaker_audio_recorder.source) model = TranscriberModels.get_model('--api' in sys.argv)
transcribe = threading.Thread(target=global_transcriber.transcribe_audio_queue, args=(audio_queue,))
transcriber = AudioTranscriber(user_audio_recorder.source, speaker_audio_recorder.source, model)
transcribe = threading.Thread(target=transcriber.transcribe_audio_queue, args=(audio_queue,))
transcribe.daemon = True transcribe.daemon = True
transcribe.start() transcribe.start()
print(f"[INFO] Whisper using GPU: " + str(torch.cuda.is_available()))
responder = GPTResponder() responder = GPTResponder()
respond = threading.Thread(target=responder.respond_to_transcriber, args=(global_transcriber,)) respond = threading.Thread(target=responder.respond_to_transcriber, args=(transcriber,))
respond.daemon = True respond.daemon = True
respond.start() respond.start()
@ -98,7 +100,7 @@ def main():
root.grid_columnconfigure(1, weight=1) root.grid_columnconfigure(1, weight=1)
# Add the clear transcript button to the UI # Add the clear transcript button to the UI
clear_transcript_button = ctk.CTkButton(root, text="Clear Transcript", command=lambda: clear_context(global_transcriber, audio_queue, )) clear_transcript_button = ctk.CTkButton(root, text="Clear Transcript", command=lambda: clear_context(transcriber, audio_queue, ))
clear_transcript_button.grid(row=1, column=0, padx=10, pady=3, sticky="nsew") clear_transcript_button.grid(row=1, column=0, padx=10, pady=3, sticky="nsew")
freeze_state = [False] # Using list to be able to change its content inside inner functions freeze_state = [False] # Using list to be able to change its content inside inner functions
@ -110,7 +112,7 @@ def main():
update_interval_slider_label.configure(text=f"Update interval: {update_interval_slider.get()} seconds") update_interval_slider_label.configure(text=f"Update interval: {update_interval_slider.get()} seconds")
update_transcript_UI(global_transcriber, transcript_textbox) update_transcript_UI(transcriber, transcript_textbox)
update_response_UI(responder, response_textbox, update_interval_slider_label, update_interval_slider, freeze_state) update_response_UI(responder, response_textbox, update_interval_slider_label, update_interval_slider, freeze_state)
root.mainloop() root.mainloop()