diff --git a/lollms/apps/elf/__init__.py b/lollms/apps/elf/__init__.py index c46f79b..ed142ac 100644 --- a/lollms/apps/elf/__init__.py +++ b/lollms/apps/elf/__init__.py @@ -199,12 +199,37 @@ def chat_completions(): full_discussion = "" for message in messages: full_discussion += f'{message["role"]}: {message["content"]}\n' - response = cv.safe_generate(full_discussion=full_discussion, temperature=temperature, top_p=top_p, n_predict=max_tokens) + + def stream_callback(token, message_type): + print(token) + completion_timestamp = int(time.time()) + completion_id = ''.join(random.choices( + 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789', k=28)) + + completion_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion.chunk', + 'created': completion_timestamp, + 'choices': [ + { + 'delta': { + 'content': token + }, + 'index': 0, + 'finish_reason': None + } + ] + } + yield 'data: %s\n\n' % json.dumps(completion_data, separators=(',' ':')) + time.sleep(0.02) + return True + completion_id = "".join(random.choices(string.ascii_letters + string.digits, k=28)) completion_timestamp = int(time.time()) if not streaming_: + response = cv.safe_generate(full_discussion=full_discussion, temperature=temperature, top_p=top_p, n_predict=max_tokens) completion_timestamp = int(time.time()) completion_id = ''.join(random.choices( 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789', k=28)) @@ -230,35 +255,35 @@ def chat_completions(): "total_tokens": None, }, } - - def stream(): - nonlocal response - for token in response: - completion_timestamp = int(time.time()) - completion_id = ''.join(random.choices( - 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789', k=28)) - - completion_data = { - 'id': f'chatcmpl-{completion_id}', - 'object': 'chat.completion.chunk', - 'created': completion_timestamp, - 'choices': [ - { - 'delta': { - 'content': token - }, - 'index': 0, - 'finish_reason': None - } - ] - } - #print(token) - #print(completion_data) - #print('data: %s\n\n' % json.dumps(completion_data, separators=(',' ':'))) - yield 'data: %s\n\n' % json.dumps(completion_data, separators=(',' ':')) - time.sleep(0.02) - print('===Start Streaming===') - return app.response_class(stream(), mimetype='text/event-stream') + else: + print('Streaming') + if True: + response = cv.safe_generate( + full_discussion=full_discussion, + temperature=temperature, + top_p=top_p, + n_predict=max_tokens, + callback=stream_callback + ) + def stream(): + nonlocal response + for token in response: + stream_callback(token, None) + return app.response_class( + response, + mimetype='text/event-stream' + ) + else: + return app.response_class( + cv.safe_generate( + full_discussion=full_discussion, + temperature=temperature, + top_p=top_p, + n_predict=max_tokens, + callback=stream_callback + ), + mimetype='text/event-stream' + ) # define the engines endpoint @app.route('/v1/engines')