added support for additional more user-friendly interfaces, improved some part of the application loading process to make it a bit simpler

This commit is contained in:
faraphel 2025-01-12 12:52:19 +01:00
parent 1a49aa3779
commit f647c960dd
20 changed files with 353 additions and 107 deletions

View file

@ -0,0 +1,16 @@
import typing
from source.api.interface import base
class InterfaceRegistry:
"""
The interface registry
Store the list of other interface available
"""
def __init__(self):
self.interface_types: dict[str, typing.Type[base.BaseInterface]] = {}
def register_type(self, name: str, interface_type: typing.Type[base.BaseInterface]):
self.interface_types[name] = interface_type

View file

@ -0,0 +1,151 @@
import asyncio
import json
import os
import typing
import warnings
from pathlib import Path
import fastapi
from source.model.base import BaseModel
from source.registry import InterfaceRegistry
class ModelRegistry:
"""
The model registry
Load the list of models available, ensure that only one model is loaded at the same time.
"""
def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry):
self.model_library: Path = Path(model_library)
self.interface_registry = interface_registry
self._api_base = api_base
# the model types
self.model_types: dict[str, typing.Type[BaseModel]] = {}
# the models
self.models: dict[str, BaseModel] = {}
# the currently loaded model
# TODO(Faraphel): load more than one model at a time ?
# would require a way more complex manager to handle memory issue
# having two calculations at the same time might not be worth it either
self.current_loaded_model: typing.Optional[BaseModel] = None
# lock to avoid concurrent inference and concurrent model loading and unloading
self.inference_lock = asyncio.Lock()
@property
def api_base(self) -> str:
"""
Base for the api routes
:return: the base for the api routes
"""
return self._api_base
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
self.model_types[name] = model_type
async def load_model(self, model: "BaseModel"):
# lock to avoid concurrent loading
async with self.inference_lock:
# if there is another currently loaded model, unload it
if self.current_loaded_model is not None and self.current_loaded_model is not model:
await self.unload_model(self.current_loaded_model)
# load the model
model.load()
# mark the model as the currently loaded model of the manager
self.current_loaded_model = model
async def unload_model(self, model: "BaseModel"):
# lock to avoid concurrent unloading
async with self.inference_lock:
# if we were the currently loaded model of the manager, demote ourselves
if self.current_loaded_model is model:
self.current_loaded_model = None
# model specific unloading part
model.unload()
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
# lock to avoid concurrent inference
async with self.inference_lock:
return await model.infer(**kwargs)
def reload_models(self) -> None:
"""
Reload the list of models available
"""
# reset the model list
for model in self.models.values():
model.unload()
self.models.clear()
# load all the models in the library
for model_path in self.model_library.iterdir():
model_name: str = model_path.name
model_configuration_path: Path = model_path / "config.json"
# check if the configuration file exists
if not model_configuration_path.exists():
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
continue
# load the configuration file
model_configuration = json.loads(model_configuration_path.read_text())
# get the model type for this model
model_type_name: str = model_configuration.get("type")
if model_type_name not in self.model_types:
warnings.warn("Field 'type' missing from the model configuration file.")
continue
# get the class of this model type
model_type = self.model_types.get(model_type_name)
if model_type is None:
warnings.warn(f"Model type {model_type_name!r} does not exists. Has it been registered ?")
continue
# load the model
self.models[model_name] = model_type(self, model_configuration, model_path)
def mount(self, application: fastapi.FastAPI) -> None:
"""
Mount the models endpoints into a fastapi application
:param application: the fastapi application
"""
@application.get(self.api_base)
async def get_models() -> list[str]:
"""
Get the list of models available
:return: the list of models available
"""
# list the models found
return list(self.models.keys())
@application.get(f"{self.api_base}/{{model_name}}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
# mount all the models in the registry
for model_name, model in self.models.items():
model.mount(application)

View file

@ -0,0 +1,2 @@
from .ModelRegistry import ModelRegistry
from .InterfaceRegistry import InterfaceRegistry