95 lines
3.4 KiB
Python
95 lines
3.4 KiB
Python
import importlib.util
|
|
import subprocess
|
|
import sys
|
|
import typing
|
|
import uuid
|
|
import inspect
|
|
from pathlib import Path
|
|
|
|
import fastapi
|
|
|
|
from source import utils
|
|
from source.manager import ModelManager
|
|
from source.model import base
|
|
from source.utils.fastapi import UploadFileFix
|
|
|
|
|
|
class PythonModel(base.BaseModel):
|
|
"""
|
|
A model running a custom python model.
|
|
"""
|
|
|
|
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
|
|
super().__init__(manager, configuration, path)
|
|
|
|
## Configuration
|
|
|
|
# get the name of the file containing the model code
|
|
file = configuration.get("file")
|
|
if file is None:
|
|
raise ValueError("Field 'file' is missing from the configuration")
|
|
|
|
# install custom requirements
|
|
requirements = configuration.get("requirements", [])
|
|
if len(requirements) > 0:
|
|
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
|
|
|
# create the module specification
|
|
module_spec = importlib.util.spec_from_file_location(
|
|
f"model-{uuid.uuid4()}",
|
|
self.path / file
|
|
)
|
|
# get the module
|
|
self.module = importlib.util.module_from_spec(module_spec)
|
|
# load the module
|
|
module_spec.loader.exec_module(self.module)
|
|
|
|
## Api
|
|
|
|
# load the inputs data into the inference function signature (used by FastAPI)
|
|
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
|
|
|
# create an endpoint wrapping the inference inside a fastapi call
|
|
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
|
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
|
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
|
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
|
# curiosity that may change in the future
|
|
kwargs = {
|
|
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
|
for key, value in kwargs.items()
|
|
}
|
|
|
|
return fastapi.responses.StreamingResponse(
|
|
content=await self.infer(**kwargs),
|
|
media_type=self.output_type,
|
|
headers={
|
|
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
|
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
|
}
|
|
)
|
|
|
|
infer_api.__signature__ = inspect.Signature(parameters=parameters)
|
|
|
|
# add the inference endpoint on the API
|
|
self.manager.application.add_api_route(
|
|
f"/models/{self.name}/infer",
|
|
infer_api,
|
|
methods=["POST"],
|
|
tags=self.tags,
|
|
# summary=...,
|
|
# description=...,
|
|
response_class=fastapi.responses.StreamingResponse,
|
|
responses={
|
|
200: {"content": {self.output_type: {}}}
|
|
},
|
|
)
|
|
|
|
def _load(self) -> None:
|
|
return self.module.load(self)
|
|
|
|
def _unload(self) -> None:
|
|
return self.module.unload(self)
|
|
|
|
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
|
|
return self.module.infer(self, **kwargs)
|