import asyncio import json import os import typing import warnings from pathlib import Path import fastapi from source.model.base import BaseModel from source.registry import InterfaceRegistry class ModelRegistry: """ The model registry Load the list of models available, ensure that only one model is loaded at the same time. """ def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry): self.model_library: Path = Path(model_library) self.interface_registry = interface_registry self._api_base = api_base # the model types self.model_types: dict[str, typing.Type[BaseModel]] = {} # the models self.models: dict[str, BaseModel] = {} # the currently loaded model # TODO(Faraphel): load more than one model at a time ? # would require a way more complex manager to handle memory issue # having two calculations at the same time might not be worth it either self.current_loaded_model: typing.Optional[BaseModel] = None # lock to control access to model inference self.infer_lock = asyncio.Lock() @property def api_base(self) -> str: """ Base for the api routes :return: the base for the api routes """ return self._api_base def register_type(self, name: str, model_type: "typing.Type[BaseModel]"): self.model_types[name] = model_type def reload_models(self) -> None: """ Reload the list of models available """ # reset the model list for model in self.models.values(): model.unload() self.models.clear() # load all the models in the library for model_path in self.model_library.iterdir(): # get the model name model_name: str = model_path.name if model_name.startswith("."): # ignore model starting with a dot continue model_configuration_path: Path = model_path / "config.json" # check if the configuration file exists if not model_configuration_path.exists(): warnings.warn(f"Model {model_name!r} is missing a config.json file.") continue try: # load the configuration file model_configuration = json.loads(model_configuration_path.read_text()) except json.decoder.JSONDecodeError: raise Exception(f"Model {model_name!r}'s configuration is invalid. See above.") # get the model type for this model model_type_name: str = model_configuration.get("type") if model_type_name not in self.model_types: warnings.warn("Field 'type' missing from the model configuration file.") continue # get the class of this model type model_type = self.model_types.get(model_type_name) if model_type is None: warnings.warn(f"Model type {model_type_name!r} does not exists. Has it been registered ?") continue # load the model self.models[model_name] = model_type(self, model_configuration, model_path) def mount(self, application: fastapi.FastAPI) -> None: """ Mount the models endpoints into a fastapi application :param application: the fastapi application """ @application.get(self.api_base) async def get_models() -> list[str]: """ Get the list of models available :return: the list of models available """ # list the models found return list(self.models.keys()) @application.get(f"{self.api_base}/{{model_name}}") async def get_model(model_name: str) -> dict: """ Get information about a specific model :param model_name: the name of the model :return: the information about the corresponding model """ # get the corresponding model model = self.models.get(model_name) if model is None: raise fastapi.HTTPException(status_code=404, detail="Model not found") # return the model information return model.get_information() # mount all the models in the registry for model_name, model in self.models.items(): model.mount(application)