import abc import gc import typing from pathlib import Path import fastapi from source.registry import ModelRegistry class BaseModel(abc.ABC): """ Represent a model. """ def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path): # the model registry self.registry = registry # get the documentation of the model self.summary = configuration.get("summary") self.description = configuration.get("description") # the environment directory of the model self.path = path # the mimetype of the model responses self.output_type: str = configuration.get("output_type", "application/json") # get the tags of the model self.tags = configuration.get("tags", []) # get the selected interface of the model interface_name: typing.Optional[str] = configuration.get("interface", None) self.interface = ( self.registry.interface_registry.interface_types[interface_name](self) if interface_name is not None else None ) # is the model currently loaded self._loaded = False def __repr__(self): return f"<{self.__class__.__name__}: {self.name}>" @property def api_base(self) -> str: """ Base for the API routes :return: the base for the API routes """ return f"{self.registry.api_base}/{self.name}" @property def name(self): """ Get the name of the model :return: the name of the model """ return self.path.name def get_information(self): """ Get information about the model :return: information about the model """ return { "name": self.name, "output_type": self.output_type, "tags": self.tags } def load(self) -> None: """ Load the model within the model manager """ # if the model is already loaded, skip if self._loaded: return # load the model depending on the implementation self._load() # mark the model as loaded self._loaded = True @abc.abstractmethod def _load(self): """ Load the model Do not call manually, use `load` instead. """ def unload(self) -> None: """ Unload the model within the model manager """ # if we are not already loaded, stop if not self._loaded: return # unload the model depending on the implementation self._unload() # force the garbage collector to clean the memory gc.collect() # mark the model as unloaded self._loaded = False @abc.abstractmethod def _unload(self): """ Unload the model Do not call manually, use `unload` instead. """ async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]: """ Infer our payload through the model within the model manager :return: the response of the model """ # make sure we are loaded before an inference self.load() # model specific inference part return await self._infer(**kwargs) @abc.abstractmethod async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: """ Infer our payload through the model :return: the response of the model """ def mount(self, application: fastapi.FastAPI) -> None: """ Add the model to the api :param application: the fastapi application """ # mount the interface if selected if self.interface is not None: self.interface.mount(application) # implementation specific mount self._mount(application) @abc.abstractmethod def _mount(self, application: fastapi.FastAPI) -> None: """ Add the model to the api Do not call manually, use `unload` instead. :param application: the fastapi application """