import abc import asyncio import gc import tempfile import typing from pathlib import Path import fastapi import inspect from source import utils from source.registry import ModelRegistry from source.utils.fastapi import UploadFileFix class BaseModel(abc.ABC): """ Represent a model. """ def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path): # the model registry self.registry = registry # get the documentation of the model self.summary = configuration.get("summary") self.description = configuration.get("description") # the environment directory of the model self.path = path # get the parameters of the model self.inputs = configuration.get("inputs", {}) self.parameters = utils.parameters.load(self.inputs) # the mimetype of the model responses self.output_type: str = configuration.get("output_type", "application/json") # get the tags of the model self.tags = configuration.get("tags", []) # get the selected interface of the model interface_name: typing.Optional[str] = configuration.get("interface", None) self.interface = ( self.registry.interface_registry.interface_types[interface_name](self) if interface_name is not None else None ) # is the model currently loaded self._loaded = False # lock to avoid loading and unloading at the same time self.load_lock = asyncio.Lock() def __repr__(self): return f"<{self.__class__.__name__}: {self.name}>" @property def api_base(self) -> str: """ Base for the API routes :return: the base for the API routes """ return f"{self.registry.api_base}/{self.name}" @property def name(self): """ Get the name of the model :return: the name of the model """ return self.path.name def get_information(self): """ Get information about the model :return: information about the model """ return { "name": self.name, "summary": self.summary, "description": self.description, "inputs": self.inputs, "output_type": self.output_type, "tags": self.tags, "interface": self.interface, } async def load(self) -> None: """ Load the model within the model manager """ async with self.load_lock: # if the model is already loaded, skip if self._loaded: return # unload the currently loaded model if any if self.registry.current_loaded_model is not None: await self.registry.current_loaded_model.unload() # load the model depending on the implementation await self._load() # mark the model as loaded self._loaded = True # mark the model as the registry loaded model self.registry.current_loaded_model = self @abc.abstractmethod async def _load(self): """ Load the model Do not call manually, use `load` instead. """ async def unload(self) -> None: """ Unload the model within the model manager """ async with self.load_lock: # if we are not already loaded, stop if not self._loaded: return # unload the model depending on the implementation await self._unload() # force the garbage collector to clean the memory gc.collect() # mark the model as unloaded self._loaded = False # if we are the registry current loaded model, remove this status if self.registry.current_loaded_model is self: self.registry.current_loaded_model = None @abc.abstractmethod async def _unload(self): """ Unload the model Do not call manually, use `unload` instead. """ async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]: """ Infer our payload through the model within the model manager :return: the response of the model """ async with self.registry.infer_lock: # ensure that the model is loaded await self.load() # model specific inference part async for chunk in self._infer(**kwargs): yield chunk @abc.abstractmethod async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: """ Infer our payload through the model :return: the response of the model """ def mount(self, application: fastapi.FastAPI) -> None: """ Add the model to the api :param application: the fastapi application """ # mount the interface if selected if self.interface is not None: self.interface.mount(application) # create an endpoint wrapping the inference inside a fastapi call # the arguments will be loaded from the configuration files. Use kwargs for the definition async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse: # NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse # NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own # fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation # curiosity that may change in the future kwargs = { key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value for key, value in kwargs.items() } # return a streaming response around the inference call return fastapi.responses.StreamingResponse( content=self.infer(**kwargs), media_type=self.output_type, headers={ # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI "content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment" } ) # update the signature of the function to use the configuration parameters infer_api.__signature__ = inspect.Signature(parameters=self.parameters) # format the description description_sections: list[str] = [] if self.description is not None: description_sections.append(f"# Description\n{self.description}") if self.interface is not None: description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**") description: str = "\n".join(description_sections) # add the inference endpoint on the API application.add_api_route( f"{self.api_base}/infer", infer_api, methods=["POST"], tags=self.tags, summary=self.summary, description=description, response_class=fastapi.responses.StreamingResponse, responses={ 200: {"content": {self.output_type: {}}} }, )