replaced the previous venv system by a conda one, allowing for better dependencies management
This commit is contained in:
parent
8bf28e4c48
commit
0034c7b31a
17 changed files with 313 additions and 230 deletions
|
@ -1,12 +1,16 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import gc
|
||||
import tempfile
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
import fastapi
|
||||
import inspect
|
||||
|
||||
from source import utils
|
||||
from source.registry import ModelRegistry
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
|
||||
class BaseModel(abc.ABC):
|
||||
|
@ -24,6 +28,9 @@ class BaseModel(abc.ABC):
|
|||
|
||||
# the environment directory of the model
|
||||
self.path = path
|
||||
# get the parameters of the model
|
||||
self.inputs = configuration.get("inputs", {})
|
||||
self.parameters = utils.parameters.load(self.inputs)
|
||||
# the mimetype of the model responses
|
||||
self.output_type: str = configuration.get("output_type", "application/json")
|
||||
# get the tags of the model
|
||||
|
@ -71,8 +78,12 @@ class BaseModel(abc.ABC):
|
|||
|
||||
return {
|
||||
"name": self.name,
|
||||
"summary": self.summary,
|
||||
"description": self.description,
|
||||
"inputs": self.inputs,
|
||||
"output_type": self.output_type,
|
||||
"tags": self.tags
|
||||
"tags": self.tags,
|
||||
"interface": self.interface,
|
||||
}
|
||||
|
||||
async def load(self) -> None:
|
||||
|
@ -145,7 +156,8 @@ class BaseModel(abc.ABC):
|
|||
await self.load()
|
||||
|
||||
# model specific inference part
|
||||
return await self._infer(**kwargs)
|
||||
async for chunk in self._infer(**kwargs):
|
||||
yield chunk
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
|
@ -164,13 +176,50 @@ class BaseModel(abc.ABC):
|
|||
if self.interface is not None:
|
||||
self.interface.mount(application)
|
||||
|
||||
# implementation specific mount
|
||||
self._mount(application)
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
# the arguments will be loaded from the configuration files. Use kwargs for the definition
|
||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||
# curiosity that may change in the future
|
||||
kwargs = {
|
||||
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
||||
for key, value in kwargs.items()
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Add the model to the api
|
||||
Do not call manually, use `unload` instead.
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
# return a streaming response around the inference call
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=self.infer(**kwargs),
|
||||
media_type=self.output_type,
|
||||
headers={
|
||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
||||
}
|
||||
)
|
||||
|
||||
# update the signature of the function to use the configuration parameters
|
||||
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||
|
||||
# format the description
|
||||
description_sections: list[str] = []
|
||||
if self.description is not None:
|
||||
description_sections.append(f"# Description\n{self.description}")
|
||||
if self.interface is not None:
|
||||
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
|
||||
|
||||
description: str = "\n".join(description_sections)
|
||||
|
||||
# add the inference endpoint on the API
|
||||
application.add_api_route(
|
||||
f"{self.api_base}/infer",
|
||||
infer_api,
|
||||
methods=["POST"],
|
||||
tags=self.tags,
|
||||
summary=self.summary,
|
||||
description=description,
|
||||
response_class=fastapi.responses.StreamingResponse,
|
||||
responses={
|
||||
200: {"content": {self.output_type: {}}}
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue