ai-server/source/model/base/BaseModel.py

159 lines
4 KiB
Python

import abc
import gc
import typing
from pathlib import Path
import fastapi
from source.registry import ModelRegistry
class BaseModel(abc.ABC):
"""
Represent a model.
"""
def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
# the model registry
self.registry = registry
# get the documentation of the model
self.summary = configuration.get("summary")
self.description = configuration.get("description")
# the environment directory of the model
self.path = path
# the mimetype of the model responses
self.output_type: str = configuration.get("output_type", "application/json")
# get the tags of the model
self.tags = configuration.get("tags", [])
# get the selected interface of the model
interface_name: typing.Optional[str] = configuration.get("interface", None)
self.interface = (
self.registry.interface_registry.interface_types[interface_name](self)
if interface_name is not None else None
)
# is the model currently loaded
self._loaded = False
def __repr__(self):
return f"<{self.__class__.__name__}: {self.name}>"
@property
def api_base(self) -> str:
"""
Base for the API routes
:return: the base for the API routes
"""
return f"{self.registry.api_base}/{self.name}"
@property
def name(self):
"""
Get the name of the model
:return: the name of the model
"""
return self.path.name
def get_information(self):
"""
Get information about the model
:return: information about the model
"""
return {
"name": self.name,
"output_type": self.output_type,
"tags": self.tags
}
def load(self) -> None:
"""
Load the model within the model manager
"""
# if the model is already loaded, skip
if self._loaded:
return
# load the model depending on the implementation
self._load()
# mark the model as loaded
self._loaded = True
@abc.abstractmethod
def _load(self):
"""
Load the model
Do not call manually, use `load` instead.
"""
def unload(self) -> None:
"""
Unload the model within the model manager
"""
# if we are not already loaded, stop
if not self._loaded:
return
# unload the model depending on the implementation
self._unload()
# force the garbage collector to clean the memory
gc.collect()
# mark the model as unloaded
self._loaded = False
@abc.abstractmethod
def _unload(self):
"""
Unload the model
Do not call manually, use `unload` instead.
"""
async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
"""
Infer our payload through the model within the model manager
:return: the response of the model
"""
# make sure we are loaded before an inference
self.load()
# model specific inference part
return await self._infer(**kwargs)
@abc.abstractmethod
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
"""
Infer our payload through the model
:return: the response of the model
"""
def mount(self, application: fastapi.FastAPI) -> None:
"""
Add the model to the api
:param application: the fastapi application
"""
# mount the interface if selected
if self.interface is not None:
self.interface.mount(application)
# implementation specific mount
self._mount(application)
@abc.abstractmethod
def _mount(self, application: fastapi.FastAPI) -> None:
"""
Add the model to the api
Do not call manually, use `unload` instead.
:param application: the fastapi application
"""