added support for additional more user-friendly interfaces, improved some part of the application loading process to make it a bit simpler
This commit is contained in:
parent
1a49aa3779
commit
f647c960dd
20 changed files with 353 additions and 107 deletions
75
source/api/interface/ChatInterface.py
Normal file
75
source/api/interface/ChatInterface.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import textwrap
|
||||
|
||||
import gradio
|
||||
|
||||
from source import meta
|
||||
from source.api.interface import base
|
||||
from source.model.base import BaseModel
|
||||
|
||||
|
||||
class ChatInterface(base.BaseInterface):
|
||||
"""
|
||||
An interface for Chat-like models.
|
||||
Use the OpenAI convention (list of dict with roles and content)
|
||||
"""
|
||||
|
||||
def __init__(self, model: "BaseModel"):
|
||||
# Function to send and receive chat messages
|
||||
super().__init__(model)
|
||||
|
||||
async def send_message(self, user_message, old_messages: list[dict], system_message: str):
|
||||
# normalize the user message (the type can be wrong, especially when "edited")
|
||||
if isinstance(user_message, str):
|
||||
user_message: dict = {"files": [], "text": user_message}
|
||||
|
||||
# copy the history to avoid modifying it
|
||||
messages: list[dict] = old_messages.copy()
|
||||
|
||||
# add the system instruction
|
||||
if system_message:
|
||||
messages.insert(0, {"role": "system", "content": system_message})
|
||||
|
||||
# add the user message
|
||||
# NOTE: gradio.ChatInterface add our message and the assistant message
|
||||
# TODO(Faraphel): add support for files
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": user_message["text"],
|
||||
})
|
||||
|
||||
# infer the message through the model
|
||||
chunks = [chunk async for chunk in await self.model.infer(messages=messages)]
|
||||
assistant_message: str = b"".join(chunks).decode("utf-8")
|
||||
|
||||
# send back the messages, clear the user prompt, disable the system prompt
|
||||
return assistant_message
|
||||
|
||||
def get_gradio_application(self):
|
||||
# create a gradio interface
|
||||
with gradio.Blocks(analytics_enabled=False) as application:
|
||||
# header
|
||||
gradio.Markdown(textwrap.dedent(f"""
|
||||
# {meta.name}
|
||||
## {self.model.name}
|
||||
"""))
|
||||
|
||||
# additional settings
|
||||
with gradio.Accordion("Advanced Settings") as advanced_settings:
|
||||
system_prompt = gradio.Textbox(
|
||||
label="System prompt",
|
||||
placeholder="You are an expert in C++...",
|
||||
lines=2,
|
||||
)
|
||||
|
||||
# chat interface
|
||||
gradio.ChatInterface(
|
||||
fn=self.send_message,
|
||||
type="messages",
|
||||
multimodal=True,
|
||||
editable=True,
|
||||
save_history=True,
|
||||
additional_inputs=[system_prompt],
|
||||
additional_inputs_accordion=advanced_settings,
|
||||
)
|
||||
|
||||
return application
|
3
source/api/interface/__init__.py
Normal file
3
source/api/interface/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from . import base
|
||||
|
||||
from .ChatInterface import ChatInterface
|
40
source/api/interface/base/BaseInterface.py
Normal file
40
source/api/interface/base/BaseInterface.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import abc
|
||||
|
||||
import fastapi
|
||||
import gradio
|
||||
|
||||
import source
|
||||
|
||||
|
||||
class BaseInterface(abc.ABC):
|
||||
def __init__(self, model: "source.model.base.BaseModel"):
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def route(self) -> str:
|
||||
"""
|
||||
The route to the interface
|
||||
:return: the route to the interface
|
||||
"""
|
||||
|
||||
return f"{self.model.api_base}/interface"
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_gradio_application(self) -> gradio.Blocks:
|
||||
"""
|
||||
Get a gradio application
|
||||
:return: a gradio application
|
||||
"""
|
||||
|
||||
def mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Mount the interface on an application
|
||||
:param application: the application to mount the interface on
|
||||
:param path: the path where to mount the application
|
||||
"""
|
||||
|
||||
gradio.mount_gradio_app(
|
||||
application,
|
||||
self.get_gradio_application(),
|
||||
self.route
|
||||
)
|
1
source/api/interface/base/__init__.py
Normal file
1
source/api/interface/base/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .BaseInterface import BaseInterface
|
Loading…
Add table
Add a link
Reference in a new issue