diff --git a/source/api/interface/ChatInterface.py b/source/api/interface/ChatInterface.py index e3b30ee..a4542a6 100644 --- a/source/api/interface/ChatInterface.py +++ b/source/api/interface/ChatInterface.py @@ -20,7 +20,7 @@ class ChatInterface(base.BaseInterface): async def send_message(self, user_message, old_messages: list[dict], system_message: str): # normalize the user message (the type can be wrong, especially when "edited") if isinstance(user_message, str): - user_message: dict = {"files": [], "text": user_message} + user_message: dict = {"text": user_message} # copy the history to avoid modifying it messages: list[dict] = old_messages.copy() @@ -41,8 +41,10 @@ class ChatInterface(base.BaseInterface): }) # infer the message through the model + assistant_message = "" async for chunk in self.model.infer(messages=messages): - yield chunk.decode("utf-8") + assistant_message += " " + chunk.decode("utf-8") + yield assistant_message def get_application(self): # create a gradio interface