101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
import importlib.util
|
|
import subprocess
|
|
import sys
|
|
import typing
|
|
import uuid
|
|
import inspect
|
|
from pathlib import Path
|
|
|
|
import fastapi
|
|
|
|
from source import utils
|
|
from source.model import base
|
|
from source.registry import ModelRegistry
|
|
from source.utils.fastapi import UploadFileFix
|
|
|
|
|
|
class PythonModel(base.BaseModel):
|
|
"""
|
|
A model running a custom python model.
|
|
"""
|
|
|
|
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
|
super().__init__(registry, configuration, path)
|
|
|
|
# get the parameters of the model
|
|
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
|
|
|
|
# install custom requirements
|
|
requirements = configuration.get("requirements", [])
|
|
if len(requirements) > 0:
|
|
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
|
|
|
# get the name of the file containing the model code
|
|
file = configuration.get("file")
|
|
if file is None:
|
|
raise ValueError("Field 'file' is missing from the configuration")
|
|
|
|
# create the module specification
|
|
module_spec = importlib.util.spec_from_file_location(
|
|
f"model-{uuid.uuid4()}",
|
|
self.path / file
|
|
)
|
|
# get the module
|
|
self.module = importlib.util.module_from_spec(module_spec)
|
|
# load the module
|
|
module_spec.loader.exec_module(self.module)
|
|
|
|
def _load(self) -> None:
|
|
return self.module.load(self)
|
|
|
|
def _unload(self) -> None:
|
|
return self.module.unload(self)
|
|
|
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
|
return self.module.infer(self, **kwargs)
|
|
|
|
def _mount(self, application: fastapi.FastAPI):
|
|
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
|
|
|
# create an endpoint wrapping the inference inside a fastapi call
|
|
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
|
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
|
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
|
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
|
# curiosity that may change in the future
|
|
kwargs = {
|
|
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
|
for key, value in kwargs.items()
|
|
}
|
|
|
|
return fastapi.responses.StreamingResponse(
|
|
content=await self.registry.infer_model(self, **kwargs),
|
|
media_type=self.output_type,
|
|
headers={
|
|
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
|
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
|
}
|
|
)
|
|
|
|
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
|
|
|
# format the description
|
|
description_sections: list[str] = []
|
|
if self.description is not None:
|
|
description_sections.append(self.description)
|
|
if self.interface is not None:
|
|
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**")
|
|
|
|
# add the inference endpoint on the API
|
|
application.add_api_route(
|
|
f"{self.api_base}/infer",
|
|
infer_api,
|
|
methods=["POST"],
|
|
tags=self.tags,
|
|
summary=self.summary,
|
|
description="<br>".join(description_sections),
|
|
response_class=fastapi.responses.StreamingResponse,
|
|
responses={
|
|
200: {"content": {self.output_type: {}}}
|
|
},
|
|
)
|