diff --git a/requirements.txt b/requirements.txt index 26d6902..35452a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ fastapi uvicorn pydantic +gradio python-multipart # AI diff --git a/samples/models/dummy/config.json b/samples/models/dummy/config.json index c12d549..bba7ae4 100644 --- a/samples/models/dummy/config.json +++ b/samples/models/dummy/config.json @@ -2,10 +2,14 @@ "type": "python", "tags": ["dummy"], "file": "model.py", + "interface": "chat", - "output_type": "video/mp4", + "summary": "Echo model", + "description": "The most basic example model, simply echo the input", "inputs": { - "file": {"type": "file"} - } + "messages": {"type": "list[dict]", "default": "[{\"role\": \"user\", \"content\": \"who are you ?\"}]"} + }, + + "output_type": "text/markdown" } diff --git a/samples/models/dummy/model.py b/samples/models/dummy/model.py index 43be902..4661536 100644 --- a/samples/models/dummy/model.py +++ b/samples/models/dummy/model.py @@ -7,5 +7,5 @@ def load(model) -> None: def unload(model) -> None: pass -async def infer(model, file) -> typing.AsyncIterator[bytes]: - yield await file.read() \ No newline at end of file +async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]: + yield messages[-1]["content"].encode("utf-8") diff --git a/samples/models/python-bert-1/model.py b/samples/models/python-bert-1/model.py index 4c013a6..b8dfaf9 100644 --- a/samples/models/python-bert-1/model.py +++ b/samples/models/python-bert-1/model.py @@ -16,7 +16,7 @@ def unload(model) -> None: model.model = None model.tokenizer = None -def infer(model, prompt: str) -> typing.Iterator[bytes]: +async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]: inputs = model.tokenizer(prompt, return_tensors="pt") with torch.no_grad(): diff --git a/samples/models/python-bert-2/model.py b/samples/models/python-bert-2/model.py index 4c013a6..b8dfaf9 100644 --- a/samples/models/python-bert-2/model.py +++ b/samples/models/python-bert-2/model.py @@ -16,7 +16,7 @@ def unload(model) -> None: model.model = None model.tokenizer = None -def infer(model, prompt: str) -> typing.Iterator[bytes]: +async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]: inputs = model.tokenizer(prompt, return_tensors="pt") with torch.no_grad(): diff --git a/source/__init__.py b/source/__init__.py index 29afff1..5781a95 100644 --- a/source/__init__.py +++ b/source/__init__.py @@ -1,3 +1,3 @@ from . import api from . import model -from . import manager +from . import registry diff --git a/source/__main__.py b/source/__main__.py index 42663f4..9253d56 100644 --- a/source/__main__.py +++ b/source/__main__.py @@ -1,15 +1,25 @@ import os -from source import manager, model, api +from source import registry, model, api +from source.api import interface # create a fastapi application application = api.Application() -# create the model controller -model_controller = manager.ModelManager(application, os.environ["MODEL_LIBRARY"]) -model_controller.register_model_type("python", model.PythonModel) -model_controller.reload() +# create the interface registry +interface_registry = registry.InterfaceRegistry() +interface_registry.register_type("chat", interface.ChatInterface) + + +# create the model registry +model_registry = registry.ModelRegistry(os.environ["MODEL_LIBRARY"], "/models", interface_registry) +model_registry.register_type("python", model.PythonModel) +model_registry.reload_models() + +# add the model registry routes to the fastapi +model_registry.mount(application) + # serve the application application.serve("0.0.0.0", 8000) diff --git a/source/api/Application.py b/source/api/Application.py index 2406f40..9a4328e 100644 --- a/source/api/Application.py +++ b/source/api/Application.py @@ -8,7 +8,8 @@ class Application(fastapi.FastAPI): def __init__(self): super().__init__( title=meta.name, - description=meta.description + description=meta.description, + redoc_url=None, ) def serve(self, host: str = "0.0.0.0", port: int = 8080): diff --git a/source/api/__init__.py b/source/api/__init__.py index 9041d95..538207c 100644 --- a/source/api/__init__.py +++ b/source/api/__init__.py @@ -1 +1,3 @@ +from . import interface + from .Application import Application diff --git a/source/api/interface/ChatInterface.py b/source/api/interface/ChatInterface.py new file mode 100644 index 0000000..2cb7f27 --- /dev/null +++ b/source/api/interface/ChatInterface.py @@ -0,0 +1,75 @@ +import textwrap + +import gradio + +from source import meta +from source.api.interface import base +from source.model.base import BaseModel + + +class ChatInterface(base.BaseInterface): + """ + An interface for Chat-like models. + Use the OpenAI convention (list of dict with roles and content) + """ + + def __init__(self, model: "BaseModel"): + # Function to send and receive chat messages + super().__init__(model) + + async def send_message(self, user_message, old_messages: list[dict], system_message: str): + # normalize the user message (the type can be wrong, especially when "edited") + if isinstance(user_message, str): + user_message: dict = {"files": [], "text": user_message} + + # copy the history to avoid modifying it + messages: list[dict] = old_messages.copy() + + # add the system instruction + if system_message: + messages.insert(0, {"role": "system", "content": system_message}) + + # add the user message + # NOTE: gradio.ChatInterface add our message and the assistant message + # TODO(Faraphel): add support for files + messages.append({ + "role": "user", + "content": user_message["text"], + }) + + # infer the message through the model + chunks = [chunk async for chunk in await self.model.infer(messages=messages)] + assistant_message: str = b"".join(chunks).decode("utf-8") + + # send back the messages, clear the user prompt, disable the system prompt + return assistant_message + + def get_gradio_application(self): + # create a gradio interface + with gradio.Blocks(analytics_enabled=False) as application: + # header + gradio.Markdown(textwrap.dedent(f""" + # {meta.name} + ## {self.model.name} + """)) + + # additional settings + with gradio.Accordion("Advanced Settings") as advanced_settings: + system_prompt = gradio.Textbox( + label="System prompt", + placeholder="You are an expert in C++...", + lines=2, + ) + + # chat interface + gradio.ChatInterface( + fn=self.send_message, + type="messages", + multimodal=True, + editable=True, + save_history=True, + additional_inputs=[system_prompt], + additional_inputs_accordion=advanced_settings, + ) + + return application diff --git a/source/api/interface/__init__.py b/source/api/interface/__init__.py new file mode 100644 index 0000000..2444578 --- /dev/null +++ b/source/api/interface/__init__.py @@ -0,0 +1,3 @@ +from . import base + +from .ChatInterface import ChatInterface diff --git a/source/api/interface/base/BaseInterface.py b/source/api/interface/base/BaseInterface.py new file mode 100644 index 0000000..dc6ad71 --- /dev/null +++ b/source/api/interface/base/BaseInterface.py @@ -0,0 +1,40 @@ +import abc + +import fastapi +import gradio + +import source + + +class BaseInterface(abc.ABC): + def __init__(self, model: "source.model.base.BaseModel"): + self.model = model + + @property + def route(self) -> str: + """ + The route to the interface + :return: the route to the interface + """ + + return f"{self.model.api_base}/interface" + + @abc.abstractmethod + def get_gradio_application(self) -> gradio.Blocks: + """ + Get a gradio application + :return: a gradio application + """ + + def mount(self, application: fastapi.FastAPI) -> None: + """ + Mount the interface on an application + :param application: the application to mount the interface on + :param path: the path where to mount the application + """ + + gradio.mount_gradio_app( + application, + self.get_gradio_application(), + self.route + ) diff --git a/source/api/interface/base/__init__.py b/source/api/interface/base/__init__.py new file mode 100644 index 0000000..11435c4 --- /dev/null +++ b/source/api/interface/base/__init__.py @@ -0,0 +1 @@ +from .BaseInterface import BaseInterface diff --git a/source/manager/__init__.py b/source/manager/__init__.py deleted file mode 100644 index ab87b8b..0000000 --- a/source/manager/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .ModelManager import ModelManager diff --git a/source/model/PythonModel.py b/source/model/PythonModel.py index c23bceb..9d8efab 100644 --- a/source/model/PythonModel.py +++ b/source/model/PythonModel.py @@ -9,8 +9,8 @@ from pathlib import Path import fastapi from source import utils -from source.manager import ModelManager from source.model import base +from source.registry import ModelRegistry from source.utils.fastapi import UploadFileFix @@ -19,21 +19,22 @@ class PythonModel(base.BaseModel): A model running a custom python model. """ - def __init__(self, manager: ModelManager, configuration: dict, path: Path): - super().__init__(manager, configuration, path) + def __init__(self, registry: ModelRegistry, configuration: dict, path: Path): + super().__init__(registry, configuration, path) - ## Configuration - - # get the name of the file containing the model code - file = configuration.get("file") - if file is None: - raise ValueError("Field 'file' is missing from the configuration") + # get the parameters of the model + self.parameters = utils.parameters.load(configuration.get("inputs", {})) # install custom requirements requirements = configuration.get("requirements", []) if len(requirements) > 0: subprocess.run([sys.executable, "-m", "pip", "install", *requirements]) + # get the name of the file containing the model code + file = configuration.get("file") + if file is None: + raise ValueError("Field 'file' is missing from the configuration") + # create the module specification module_spec = importlib.util.spec_from_file_location( f"model-{uuid.uuid4()}", @@ -44,10 +45,17 @@ class PythonModel(base.BaseModel): # load the module module_spec.loader.exec_module(self.module) - ## Api + def _load(self) -> None: + return self.module.load(self) - # load the inputs data into the inference function signature (used by FastAPI) - parameters = utils.parameters.load(configuration.get("inputs", {})) + def _unload(self) -> None: + return self.module.unload(self) + + async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: + return self.module.infer(self, **kwargs) + + def _mount(self, application: fastapi.FastAPI): + # TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ? # create an endpoint wrapping the inference inside a fastapi call async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse: @@ -61,7 +69,7 @@ class PythonModel(base.BaseModel): } return fastapi.responses.StreamingResponse( - content=await self.infer(**kwargs), + content=await self.registry.infer_model(self, **kwargs), media_type=self.output_type, headers={ # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI @@ -69,27 +77,25 @@ class PythonModel(base.BaseModel): } ) - infer_api.__signature__ = inspect.Signature(parameters=parameters) + infer_api.__signature__ = inspect.Signature(parameters=self.parameters) + + # format the description + description_sections: list[str] = [] + if self.description is not None: + description_sections.append(self.description) + if self.interface is not None: + description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**") # add the inference endpoint on the API - self.manager.application.add_api_route( - f"/models/{self.name}/infer", + application.add_api_route( + f"{self.api_base}/infer", infer_api, methods=["POST"], tags=self.tags, - # summary=..., - # description=..., + summary=self.summary, + description="
".join(description_sections), response_class=fastapi.responses.StreamingResponse, responses={ 200: {"content": {self.output_type: {}}} }, ) - - def _load(self) -> None: - return self.module.load(self) - - def _unload(self) -> None: - return self.module.unload(self) - - def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]: - return self.module.infer(self, **kwargs) diff --git a/source/model/base/BaseModel.py b/source/model/base/BaseModel.py index 9072e97..be939c2 100644 --- a/source/model/base/BaseModel.py +++ b/source/model/base/BaseModel.py @@ -3,7 +3,9 @@ import gc import typing from pathlib import Path -from source.manager import ModelManager +import fastapi + +from source.registry import ModelRegistry class BaseModel(abc.ABC): @@ -11,21 +13,43 @@ class BaseModel(abc.ABC): Represent a model. """ - def __init__(self, manager: ModelManager, configuration: dict[str, typing.Any], path: Path): + def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path): + # the model registry + self.registry = registry + + # get the documentation of the model + self.summary = configuration.get("summary") + self.description = configuration.get("description") + # the environment directory of the model self.path = path - # the model manager - self.manager = manager # the mimetype of the model responses self.output_type: str = configuration.get("output_type", "application/json") # get the tags of the model self.tags = configuration.get("tags", []) + # get the selected interface of the model + interface_name: typing.Optional[str] = configuration.get("interface", None) + self.interface = ( + self.registry.interface_registry.interface_types[interface_name](self) + if interface_name is not None else None + ) + + # is the model currently loaded self._loaded = False def __repr__(self): return f"<{self.__class__.__name__}: {self.name}>" + @property + def api_base(self) -> str: + """ + Base for the API routes + :return: the base for the API routes + """ + + return f"{self.registry.api_base}/{self.name}" + @property def name(self): """ @@ -44,6 +68,7 @@ class BaseModel(abc.ABC): return { "name": self.name, "output_type": self.output_type, + "tags": self.tags } def load(self) -> None: @@ -51,22 +76,13 @@ class BaseModel(abc.ABC): Load the model within the model manager """ - # if we are already loaded, stop + # if the model is already loaded, skip if self._loaded: return - # check if we are the current loaded model - if self.manager.current_loaded_model is not self: - # unload the previous model - if self.manager.current_loaded_model is not None: - self.manager.current_loaded_model.unload() - - # model specific loading + # load the model depending on the implementation self._load() - # declare ourselves as the currently loaded model - self.manager.current_loaded_model = self - # mark the model as loaded self._loaded = True @@ -86,11 +102,7 @@ class BaseModel(abc.ABC): if not self._loaded: return - # if we were the currently loaded model of the manager, demote ourselves - if self.manager.current_loaded_model is self: - self.manager.current_loaded_model = None - - # model specific unloading part + # unload the model depending on the implementation self._unload() # force the garbage collector to clean the memory @@ -106,22 +118,42 @@ class BaseModel(abc.ABC): Do not call manually, use `unload` instead. """ - async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]: + async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]: """ Infer our payload through the model within the model manager :return: the response of the model """ - async with self.manager.inference_lock: - # make sure we are loaded before an inference - self.load() + # make sure we are loaded before an inference + self.load() - # model specific inference part - return self._infer(**kwargs) + # model specific inference part + return await self._infer(**kwargs) @abc.abstractmethod - def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]: + async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: """ Infer our payload through the model :return: the response of the model """ + + def mount(self, application: fastapi.FastAPI) -> None: + """ + Add the model to the api + :param application: the fastapi application + """ + + # mount the interface if selected + if self.interface is not None: + self.interface.mount(application) + + # implementation specific mount + self._mount(application) + + @abc.abstractmethod + def _mount(self, application: fastapi.FastAPI) -> None: + """ + Add the model to the api + Do not call manually, use `unload` instead. + :param application: the fastapi application + """ diff --git a/source/registry/InterfaceRegistry.py b/source/registry/InterfaceRegistry.py new file mode 100644 index 0000000..dda4173 --- /dev/null +++ b/source/registry/InterfaceRegistry.py @@ -0,0 +1,16 @@ +import typing + +from source.api.interface import base + + +class InterfaceRegistry: + """ + The interface registry + Store the list of other interface available + """ + + def __init__(self): + self.interface_types: dict[str, typing.Type[base.BaseInterface]] = {} + + def register_type(self, name: str, interface_type: typing.Type[base.BaseInterface]): + self.interface_types[name] = interface_type diff --git a/source/manager/ModelManager.py b/source/registry/ModelRegistry.py similarity index 54% rename from source/manager/ModelManager.py rename to source/registry/ModelRegistry.py index 6313ab9..0f57197 100644 --- a/source/manager/ModelManager.py +++ b/source/registry/ModelRegistry.py @@ -7,64 +7,80 @@ from pathlib import Path import fastapi -from source import model, api +from source.model.base import BaseModel +from source.registry import InterfaceRegistry -class ModelManager: +class ModelRegistry: """ - The model manager + The model registry Load the list of models available, ensure that only one model is loaded at the same time. """ - def __init__(self, application: api.Application, model_library: os.PathLike | str): - self.application: api.Application = application + def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry): self.model_library: Path = Path(model_library) + self.interface_registry = interface_registry + self._api_base = api_base # the model types - self.model_types: dict[str, typing.Type[model.base.BaseModel]] = {} + self.model_types: dict[str, typing.Type[BaseModel]] = {} # the models - self.models: dict[str, model.base.BaseModel] = {} + self.models: dict[str, BaseModel] = {} # the currently loaded model # TODO(Faraphel): load more than one model at a time ? # would require a way more complex manager to handle memory issue # having two calculations at the same time might not be worth it either - self.current_loaded_model: typing.Optional[model.base.BaseModel] = None + self.current_loaded_model: typing.Optional[BaseModel] = None # lock to avoid concurrent inference and concurrent model loading and unloading self.inference_lock = asyncio.Lock() - @self.application.get("/models") - async def get_models() -> list[str]: - """ - Get the list of models available - :return: the list of models available - """ + @property + def api_base(self) -> str: + """ + Base for the api routes + :return: the base for the api routes + """ - # list the models found - return list(self.models.keys()) + return self._api_base - @self.application.get("/models/{model_name}") - async def get_model(model_name: str) -> dict: - """ - Get information about a specific model - :param model_name: the name of the model - :return: the information about the corresponding model - """ - - # get the corresponding model - model = self.models.get(model_name) - if model is None: - raise fastapi.HTTPException(status_code=404, detail="Model not found") - - # return the model information - return model.get_information() - - - def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"): + def register_type(self, name: str, model_type: "typing.Type[BaseModel]"): self.model_types[name] = model_type - def reload(self): + async def load_model(self, model: "BaseModel"): + # lock to avoid concurrent loading + async with self.inference_lock: + # if there is another currently loaded model, unload it + if self.current_loaded_model is not None and self.current_loaded_model is not model: + await self.unload_model(self.current_loaded_model) + + # load the model + model.load() + + # mark the model as the currently loaded model of the manager + self.current_loaded_model = model + + async def unload_model(self, model: "BaseModel"): + # lock to avoid concurrent unloading + async with self.inference_lock: + # if we were the currently loaded model of the manager, demote ourselves + if self.current_loaded_model is model: + self.current_loaded_model = None + + # model specific unloading part + model.unload() + + async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]: + # lock to avoid concurrent inference + async with self.inference_lock: + return await model.infer(**kwargs) + + def reload_models(self) -> None: + """ + Reload the list of models available + """ + # reset the model list for model in self.models.values(): model.unload() @@ -97,3 +113,39 @@ class ModelManager: # load the model self.models[model_name] = model_type(self, model_configuration, model_path) + + def mount(self, application: fastapi.FastAPI) -> None: + """ + Mount the models endpoints into a fastapi application + :param application: the fastapi application + """ + + @application.get(self.api_base) + async def get_models() -> list[str]: + """ + Get the list of models available + :return: the list of models available + """ + + # list the models found + return list(self.models.keys()) + + @application.get(f"{self.api_base}/{{model_name}}") + async def get_model(model_name: str) -> dict: + """ + Get information about a specific model + :param model_name: the name of the model + :return: the information about the corresponding model + """ + + # get the corresponding model + model = self.models.get(model_name) + if model is None: + raise fastapi.HTTPException(status_code=404, detail="Model not found") + + # return the model information + return model.get_information() + + # mount all the models in the registry + for model_name, model in self.models.items(): + model.mount(application) diff --git a/source/registry/__init__.py b/source/registry/__init__.py new file mode 100644 index 0000000..4012130 --- /dev/null +++ b/source/registry/__init__.py @@ -0,0 +1,2 @@ +from .ModelRegistry import ModelRegistry +from .InterfaceRegistry import InterfaceRegistry diff --git a/source/utils/parameters.py b/source/utils/parameters.py index 35542ae..cc8c41f 100644 --- a/source/utils/parameters.py +++ b/source/utils/parameters.py @@ -10,12 +10,14 @@ types: dict[str, type] = { "float": float, "str": str, "bytes": bytes, - "list": list, - "tuple": tuple, - "set": set, "dict": dict, "datetime": datetime, "file": UploadFile, + + # TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ? + "list[dict]": list[dict], + # "tuple": tuple, + # "set": set, }