ai-server/source/model/base/BaseModel.py

225 lines
7.3 KiB
Python

import abc
import asyncio
import gc
import tempfile
import typing
from pathlib import Path
import fastapi
import inspect
from source import utils
from source.registry import ModelRegistry
from source.utils.fastapi import UploadFileFix
class BaseModel(abc.ABC):
"""
Represent a model.
"""
def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
# the model registry
self.registry = registry
# get the documentation of the model
self.summary = configuration.get("summary")
self.description = configuration.get("description")
# the environment directory of the model
self.path = path
# get the parameters of the model
self.inputs = configuration.get("inputs", {})
self.parameters = utils.parameters.load(self.inputs)
# the mimetype of the model responses
self.output_type: str = configuration.get("output_type", "application/json")
# get the tags of the model
self.tags = configuration.get("tags", [])
# get the selected interface of the model
interface_name: typing.Optional[str] = configuration.get("interface", None)
self.interface = (
self.registry.interface_registry.interface_types[interface_name](self)
if interface_name is not None else None
)
# is the model currently loaded
self._loaded = False
# lock to avoid loading and unloading at the same time
self.load_lock = asyncio.Lock()
def __repr__(self):
return f"<{self.__class__.__name__}: {self.name}>"
@property
def api_base(self) -> str:
"""
Base for the API routes
:return: the base for the API routes
"""
return f"{self.registry.api_base}/{self.name}"
@property
def name(self):
"""
Get the name of the model
:return: the name of the model
"""
return self.path.name
def get_information(self):
"""
Get information about the model
:return: information about the model
"""
return {
"name": self.name,
"summary": self.summary,
"description": self.description,
"inputs": self.inputs,
"output_type": self.output_type,
"tags": self.tags,
"interface": self.interface,
}
async def load(self) -> None:
"""
Load the model within the model manager
"""
async with self.load_lock:
# if the model is already loaded, skip
if self._loaded:
return
# unload the currently loaded model if any
if self.registry.current_loaded_model is not None:
await self.registry.current_loaded_model.unload()
# load the model depending on the implementation
await self._load()
# mark the model as loaded
self._loaded = True
# mark the model as the registry loaded model
self.registry.current_loaded_model = self
@abc.abstractmethod
async def _load(self):
"""
Load the model
Do not call manually, use `load` instead.
"""
async def unload(self) -> None:
"""
Unload the model within the model manager
"""
async with self.load_lock:
# if we are not already loaded, stop
if not self._loaded:
return
# unload the model depending on the implementation
await self._unload()
# force the garbage collector to clean the memory
gc.collect()
# mark the model as unloaded
self._loaded = False
# if we are the registry current loaded model, remove this status
if self.registry.current_loaded_model is self:
self.registry.current_loaded_model = None
@abc.abstractmethod
async def _unload(self):
"""
Unload the model
Do not call manually, use `unload` instead.
"""
async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
"""
Infer our payload through the model within the model manager
:return: the response of the model
"""
async with self.registry.infer_lock:
# ensure that the model is loaded
await self.load()
# model specific inference part
async for chunk in self._infer(**kwargs):
yield chunk
@abc.abstractmethod
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
"""
Infer our payload through the model
:return: the response of the model
"""
def mount(self, application: fastapi.FastAPI) -> None:
"""
Add the model to the api
:param application: the fastapi application
"""
# mount the interface if selected
if self.interface is not None:
self.interface.mount(application)
# create an endpoint wrapping the inference inside a fastapi call
# the arguments will be loaded from the configuration files. Use kwargs for the definition
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
# curiosity that may change in the future
kwargs = {
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
for key, value in kwargs.items()
}
# return a streaming response around the inference call
return fastapi.responses.StreamingResponse(
content=self.infer(**kwargs),
media_type=self.output_type,
headers={
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
}
)
# update the signature of the function to use the configuration parameters
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
# format the description
description_sections: list[str] = []
if self.description is not None:
description_sections.append(f"# Description\n{self.description}")
if self.interface is not None:
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
description: str = "\n".join(description_sections)
# add the inference endpoint on the API
application.add_api_route(
f"{self.api_base}/infer",
infer_api,
methods=["POST"],
tags=self.tags,
summary=self.summary,
description=description,
response_class=fastapi.responses.StreamingResponse,
responses={
200: {"content": {self.output_type: {}}}
},
)