ai-server/source/registry/ModelRegistry.py

131 lines
4.4 KiB
Python

import asyncio
import json
import os
import typing
import warnings
from pathlib import Path
import fastapi
from source.model.base import BaseModel
from source.registry import InterfaceRegistry
class ModelRegistry:
"""
The model registry
Load the list of models available, ensure that only one model is loaded at the same time.
"""
def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry):
self.model_library: Path = Path(model_library)
self.interface_registry = interface_registry
self._api_base = api_base
# the model types
self.model_types: dict[str, typing.Type[BaseModel]] = {}
# the models
self.models: dict[str, BaseModel] = {}
# the currently loaded model
# TODO(Faraphel): load more than one model at a time ?
# would require a way more complex manager to handle memory issue
# having two calculations at the same time might not be worth it either
self.current_loaded_model: typing.Optional[BaseModel] = None
# lock to control access to model inference
self.infer_lock = asyncio.Lock()
@property
def api_base(self) -> str:
"""
Base for the api routes
:return: the base for the api routes
"""
return self._api_base
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
self.model_types[name] = model_type
def reload_models(self) -> None:
"""
Reload the list of models available
"""
# reset the model list
for model in self.models.values():
model.unload()
self.models.clear()
# load all the models in the library
for model_path in self.model_library.iterdir():
# get the model name
model_name: str = model_path.name
if model_name.startswith("."):
# ignore model starting with a dot
continue
model_configuration_path: Path = model_path / "config.json"
# check if the configuration file exists
if not model_configuration_path.exists():
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
continue
try:
# load the configuration file
model_configuration = json.loads(model_configuration_path.read_text())
except json.decoder.JSONDecodeError:
raise Exception(f"Model {model_name!r}'s configuration is invalid. See above.")
# get the model type for this model
model_type_name: str = model_configuration.get("type")
if model_type_name not in self.model_types:
warnings.warn("Field 'type' missing from the model configuration file.")
continue
# get the class of this model type
model_type = self.model_types.get(model_type_name)
if model_type is None:
warnings.warn(f"Model type {model_type_name!r} does not exists. Has it been registered ?")
continue
# load the model
self.models[model_name] = model_type(self, model_configuration, model_path)
def mount(self, application: fastapi.FastAPI) -> None:
"""
Mount the models endpoints into a fastapi application
:param application: the fastapi application
"""
@application.get(self.api_base)
async def get_models() -> list[str]:
"""
Get the list of models available
:return: the list of models available
"""
# list the models found
return list(self.models.keys())
@application.get(f"{self.api_base}/{{model_name}}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
# mount all the models in the registry
for model_name, model in self.models.items():
model.mount(application)