added support for additional more user-friendly interfaces, improved some part of the application loading process to make it a bit simpler
This commit is contained in:
parent
1a49aa3779
commit
f647c960dd
20 changed files with 353 additions and 107 deletions
16
source/registry/InterfaceRegistry.py
Normal file
16
source/registry/InterfaceRegistry.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import typing
|
||||
|
||||
from source.api.interface import base
|
||||
|
||||
|
||||
class InterfaceRegistry:
|
||||
"""
|
||||
The interface registry
|
||||
Store the list of other interface available
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.interface_types: dict[str, typing.Type[base.BaseInterface]] = {}
|
||||
|
||||
def register_type(self, name: str, interface_type: typing.Type[base.BaseInterface]):
|
||||
self.interface_types[name] = interface_type
|
151
source/registry/ModelRegistry.py
Normal file
151
source/registry/ModelRegistry.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import fastapi
|
||||
|
||||
from source.model.base import BaseModel
|
||||
from source.registry import InterfaceRegistry
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""
|
||||
The model registry
|
||||
Load the list of models available, ensure that only one model is loaded at the same time.
|
||||
"""
|
||||
|
||||
def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry):
|
||||
self.model_library: Path = Path(model_library)
|
||||
self.interface_registry = interface_registry
|
||||
self._api_base = api_base
|
||||
|
||||
# the model types
|
||||
self.model_types: dict[str, typing.Type[BaseModel]] = {}
|
||||
# the models
|
||||
self.models: dict[str, BaseModel] = {}
|
||||
|
||||
# the currently loaded model
|
||||
# TODO(Faraphel): load more than one model at a time ?
|
||||
# would require a way more complex manager to handle memory issue
|
||||
# having two calculations at the same time might not be worth it either
|
||||
self.current_loaded_model: typing.Optional[BaseModel] = None
|
||||
|
||||
# lock to avoid concurrent inference and concurrent model loading and unloading
|
||||
self.inference_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def api_base(self) -> str:
|
||||
"""
|
||||
Base for the api routes
|
||||
:return: the base for the api routes
|
||||
"""
|
||||
|
||||
return self._api_base
|
||||
|
||||
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
|
||||
self.model_types[name] = model_type
|
||||
|
||||
async def load_model(self, model: "BaseModel"):
|
||||
# lock to avoid concurrent loading
|
||||
async with self.inference_lock:
|
||||
# if there is another currently loaded model, unload it
|
||||
if self.current_loaded_model is not None and self.current_loaded_model is not model:
|
||||
await self.unload_model(self.current_loaded_model)
|
||||
|
||||
# load the model
|
||||
model.load()
|
||||
|
||||
# mark the model as the currently loaded model of the manager
|
||||
self.current_loaded_model = model
|
||||
|
||||
async def unload_model(self, model: "BaseModel"):
|
||||
# lock to avoid concurrent unloading
|
||||
async with self.inference_lock:
|
||||
# if we were the currently loaded model of the manager, demote ourselves
|
||||
if self.current_loaded_model is model:
|
||||
self.current_loaded_model = None
|
||||
|
||||
# model specific unloading part
|
||||
model.unload()
|
||||
|
||||
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
# lock to avoid concurrent inference
|
||||
async with self.inference_lock:
|
||||
return await model.infer(**kwargs)
|
||||
|
||||
def reload_models(self) -> None:
|
||||
"""
|
||||
Reload the list of models available
|
||||
"""
|
||||
|
||||
# reset the model list
|
||||
for model in self.models.values():
|
||||
model.unload()
|
||||
self.models.clear()
|
||||
|
||||
# load all the models in the library
|
||||
for model_path in self.model_library.iterdir():
|
||||
model_name: str = model_path.name
|
||||
model_configuration_path: Path = model_path / "config.json"
|
||||
|
||||
# check if the configuration file exists
|
||||
if not model_configuration_path.exists():
|
||||
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
|
||||
continue
|
||||
|
||||
# load the configuration file
|
||||
model_configuration = json.loads(model_configuration_path.read_text())
|
||||
|
||||
# get the model type for this model
|
||||
model_type_name: str = model_configuration.get("type")
|
||||
if model_type_name not in self.model_types:
|
||||
warnings.warn("Field 'type' missing from the model configuration file.")
|
||||
continue
|
||||
|
||||
# get the class of this model type
|
||||
model_type = self.model_types.get(model_type_name)
|
||||
if model_type is None:
|
||||
warnings.warn(f"Model type {model_type_name!r} does not exists. Has it been registered ?")
|
||||
continue
|
||||
|
||||
# load the model
|
||||
self.models[model_name] = model_type(self, model_configuration, model_path)
|
||||
|
||||
def mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Mount the models endpoints into a fastapi application
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
|
||||
@application.get(self.api_base)
|
||||
async def get_models() -> list[str]:
|
||||
"""
|
||||
Get the list of models available
|
||||
:return: the list of models available
|
||||
"""
|
||||
|
||||
# list the models found
|
||||
return list(self.models.keys())
|
||||
|
||||
@application.get(f"{self.api_base}/{{model_name}}")
|
||||
async def get_model(model_name: str) -> dict:
|
||||
"""
|
||||
Get information about a specific model
|
||||
:param model_name: the name of the model
|
||||
:return: the information about the corresponding model
|
||||
"""
|
||||
|
||||
# get the corresponding model
|
||||
model = self.models.get(model_name)
|
||||
if model is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# return the model information
|
||||
return model.get_information()
|
||||
|
||||
# mount all the models in the registry
|
||||
for model_name, model in self.models.items():
|
||||
model.mount(application)
|
2
source/registry/__init__.py
Normal file
2
source/registry/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .ModelRegistry import ModelRegistry
|
||||
from .InterfaceRegistry import InterfaceRegistry
|
Loading…
Add table
Add a link
Reference in a new issue