diff --git a/lollms/functions/generate_music.py b/lollms/functions/generate_music.py index 5010453..3bc2bcb 100644 --- a/lollms/functions/generate_music.py +++ b/lollms/functions/generate_music.py @@ -4,70 +4,76 @@ # Description: This function generates music based on a given prompt and duration, saving the output to a unique file in the discussion folder. # Import necessary libraries -import torchaudio -from audiocraft.models import musicgen -import torch -from pathlib import Path -from lollms.utilities import PackageManager -from ascii_colors import trace_exception -from functools import partial +try: + import torchaudio + from audiocraft.models import musicgen + import torch + from pathlib import Path + from lollms.utilities import PackageManager + from ascii_colors import trace_exception + from functools import partial -# Check for required packages and install if necessary -if not PackageManager.check_package_installed("audiocraft"): - PackageManager.install_package("audiocraft") + # Check for required packages and install if necessary + if not PackageManager.check_package_installed("audiocraft"): + PackageManager.install_package("audiocraft") -# Function to generate music -def generate_music(processor, client, generation_prompt: str, duration: int, model_name: str = "facebook/musicgen-melody", device: str="cuda:0") -> str: - """ - Generates music based on the given prompt and duration, saving it to a unique file in the discussion folder. - - Parameters: - - processor: The processor object used for managing the generation process. - - client: The client object containing discussion information. - - generation_prompt: The prompt for music generation. - - duration: The duration of the music in seconds. - - model_name: The name of the pretrained music generation model. - - device: The device to run the model on (e.g., 'cpu' or 'cuda'). - - Returns: - - The path of the saved music file. - """ - - try: - # Load the pretrained music generation model - music_model = musicgen.MusicGen.get_pretrained(model_name, device=device) + # Function to generate music + def generate_music(processor, client, generation_prompt: str, duration: int, model_name: str = "facebook/musicgen-melody", device: str="cuda:0") -> str: + """ + Generates music based on the given prompt and duration, saving it to a unique file in the discussion folder. - # Set generation parameters - music_model.set_generation_params(duration=duration) + Parameters: + - processor: The processor object used for managing the generation process. + - client: The client object containing discussion information. + - generation_prompt: The prompt for music generation. + - duration: The duration of the music in seconds. + - model_name: The name of the pretrained music generation model. + - device: The device to run the model on (e.g., 'cpu' or 'cuda'). - # Generate music - res = music_model.generate([generation_prompt]) + Returns: + - The path of the saved music file. + """ - # Create output folder if it doesn't exist - output_folder = client.discussion.discussion_folder / "generated_music" - output_folder.mkdir(parents=True, exist_ok=True) + try: + # Load the pretrained music generation model + music_model = musicgen.MusicGen.get_pretrained(model_name, device=device) + + # Set generation parameters + music_model.set_generation_params(duration=duration) + + # Generate music + res = music_model.generate([generation_prompt]) + + # Create output folder if it doesn't exist + output_folder = client.discussion.discussion_folder / "generated_music" + output_folder.mkdir(parents=True, exist_ok=True) - # Generate a unique file name - output_file = output_folder / f"music_generation_{len(list(output_folder.glob('*.wav')))}.wav" + # Generate a unique file name + output_file = output_folder / f"music_generation_{len(list(output_folder.glob('*.wav')))}.wav" - # Save the generated music to the specified file - torchaudio.save(output_file, res.reshape(1, -1).cpu(), 32000) - - # Return the path of the saved file - return str(output_file) - except Exception as e: - return trace_exception(e) + # Save the generated music to the specified file + torchaudio.save(output_file, res.reshape(1, -1).cpu(), 32000) + + # Return the path of the saved file + return str(output_file) + except Exception as e: + return trace_exception(e) -# Metadata function for the music generation function -def generate_music_function(processor, client): - return { - "function_name": "generate_music", # The function name in string - "function": partial(generate_music, processor=processor, client=client), # The function to be called with preset parameters - "function_description": "Generates music based on a prompt and duration, saving it to a unique file in the discussion folder.", # Description of the function - "function_parameters": [ # Parameters needed for the function - {"name": "generation_prompt", "type": "str"}, - {"name": "duration", "type": "int"}, - {"name": "model_name", "type": "str"}, - {"name": "device", "type": "str"} - ] - } + # Metadata function for the music generation function + def generate_music_function(processor, client): + return { + "function_name": "generate_music", # The function name in string + "function": partial(generate_music, processor=processor, client=client), # The function to be called with preset parameters + "function_description": "Generates music based on a prompt and duration, saving it to a unique file in the discussion folder.", # Description of the function + "function_parameters": [ # Parameters needed for the function + {"name": "generation_prompt", "type": "str"}, + {"name": "duration", "type": "int"}, + {"name": "model_name", "type": "str"}, + {"name": "device", "type": "str"} + ] + } +except: + def generate_music(processor, client, generation_prompt: str, duration: int, model_name: str = "facebook/musicgen-melody", device: str="cuda:0") -> str: + pass + def generate_music_function(processor, client): + pass diff --git a/lollms/server/endpoints/lollms_generator.py b/lollms/server/endpoints/lollms_generator.py index 8a33790..a8c5734 100644 --- a/lollms/server/endpoints/lollms_generator.py +++ b/lollms/server/endpoints/lollms_generator.py @@ -137,7 +137,7 @@ async def lollms_generate(request: LollmsGenerateRequest): async def generate_chunks(): lk = threading.Lock() - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): if elf_server.cancel_gen: return False @@ -305,7 +305,7 @@ async def lollms_generate_with_images(request: LollmsGenerateRequest): async def generate_chunks(): lk = threading.Lock() - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): if elf_server.cancel_gen: return False @@ -359,7 +359,7 @@ async def lollms_generate_with_images(request: LollmsGenerateRequest): elf_server.cancel_gen = False return StreamingResponse(generate_chunks(), media_type="text/plain", headers=headers) else: - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): # Yield each chunk of data if chunk is None: return True @@ -509,7 +509,7 @@ async def v1_chat_completions(request: ChatGenerationRequest): async def generate_chunks(): lk = threading.Lock() - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): if elf_server.cancel_gen: return False @@ -567,7 +567,7 @@ async def v1_chat_completions(request: ChatGenerationRequest): elf_server.cancel_gen = False return StreamingResponse(generate_chunks(), media_type="application/json") else: - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): # Yield each chunk of data if chunk is None: return True @@ -651,7 +651,7 @@ async def ollama_chat_completion(request: ChatGenerationRequest): async def generate_chunks(): lk = threading.Lock() - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): if elf_server.cancel_gen: return False @@ -709,7 +709,7 @@ async def ollama_chat_completion(request: ChatGenerationRequest): elf_server.cancel_gen = False return StreamingResponse(generate_chunks(), media_type="application/json") else: - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): # Yield each chunk of data if chunk is None: return True @@ -805,7 +805,7 @@ async def ollama_generate(request: CompletionGenerationRequest): if stream: output = {"text":""} def generate_chunks(): - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): # Yield each chunk of data output["text"] += chunk antiprompt = detect_antiprompt(output["text"], [start_header_id_template, end_header_id_template]) @@ -826,7 +826,7 @@ async def ollama_generate(request: CompletionGenerationRequest): return StreamingResponse(generate_chunks()) else: output = {"text":""} - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): if chunk is None: return # Yield each chunk of data @@ -891,7 +891,7 @@ async def ollama_completion(request: CompletionGenerationRequest): async def generate_chunks(): lk = threading.Lock() - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): if elf_server.cancel_gen: return False @@ -944,7 +944,7 @@ async def ollama_completion(request: CompletionGenerationRequest): elf_server.cancel_gen = False return StreamingResponse(generate_chunks(), media_type="text/plain") else: - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): # Yield each chunk of data if chunk is None: return True @@ -995,7 +995,7 @@ async def v1_completion(request: CompletionGenerationRequest): if stream: output = {"text":""} def generate_chunks(): - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): # Yield each chunk of data output["text"] += chunk antiprompt = detect_antiprompt(output["text"]) @@ -1016,7 +1016,7 @@ async def v1_completion(request: CompletionGenerationRequest): return StreamingResponse(generate_chunks()) else: output = {"text":""} - def callback(chunk, chunk_type:MSG_TYPE_CONTENT=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): + def callback(chunk, chunk_type:MSG_OPERATION_TYPE=MSG_OPERATION_TYPE.MSG_OPERATION_TYPE_ADD_CHUNK): # Yield each chunk of data output["text"] += chunk antiprompt = detect_antiprompt(output["text"])