80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
import importlib.util
|
|
import subprocess
|
|
import sys
|
|
import typing
|
|
import uuid
|
|
import inspect
|
|
from pathlib import Path
|
|
|
|
import fastapi
|
|
|
|
from source import utils
|
|
from source.manager import ModelManager
|
|
from source.model import base
|
|
from source.utils.fastapi import UploadFileFix
|
|
|
|
|
|
class PythonModel(base.BaseModel):
|
|
"""
|
|
A model running a custom python model.
|
|
"""
|
|
|
|
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
|
|
super().__init__(manager, configuration, path)
|
|
|
|
## Configuration
|
|
|
|
# get the name of the file containing the model code
|
|
file = configuration.get("file")
|
|
if file is None:
|
|
raise ValueError("Field 'file' is missing from the configuration")
|
|
|
|
# install custom requirements
|
|
requirements = configuration.get("requirements", [])
|
|
if len(requirements) > 0:
|
|
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
|
|
|
# create the module specification
|
|
module_spec = importlib.util.spec_from_file_location(
|
|
f"model-{uuid.uuid4()}",
|
|
self.path / file
|
|
)
|
|
# get the module
|
|
self.module = importlib.util.module_from_spec(module_spec)
|
|
# load the module
|
|
module_spec.loader.exec_module(self.module)
|
|
|
|
## Api
|
|
|
|
# load the inputs data into the inference function signature (used by FastAPI)
|
|
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
|
|
|
# create an endpoint wrapping the inference inside a fastapi call
|
|
async def infer_api(**kwargs):
|
|
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
|
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
|
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
|
# curiosity that may change in the future
|
|
kwargs = {
|
|
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
|
for key, value in kwargs.items()
|
|
}
|
|
|
|
return fastapi.responses.StreamingResponse(
|
|
content=self.infer(**kwargs),
|
|
media_type=self.output_type,
|
|
)
|
|
|
|
infer_api.__signature__ = inspect.Signature(parameters=parameters)
|
|
|
|
# add the inference endpoint on the API
|
|
self.manager.application.add_api_route(f"/models/{self.name}/infer", infer_api, methods=["POST"])
|
|
|
|
def _load(self) -> None:
|
|
return self.module.load(self)
|
|
|
|
def _unload(self) -> None:
|
|
return self.module.unload(self)
|
|
|
|
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
|
|
return self.module.infer(self, **kwargs)
|