ai-server/source/model/PythonModel.py
faraphel 7bd84c8570 added support of inputs parameters that are recognised by the API.
Models are now loaded in separate endpoints for the inputs to be easier to recognise
2025-01-09 23:12:54 +01:00

70 lines
2.2 KiB
Python

import importlib.util
import subprocess
import sys
import typing
import uuid
import inspect
from pathlib import Path
import fastapi
from source import utils
from source.manager import ModelManager
from source.model import base
class PythonModel(base.BaseModel):
"""
A model running a custom python model.
"""
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
super().__init__(manager, configuration, path)
## Configuration
# get the name of the file containing the model code
file = configuration.get("file")
if file is None:
raise ValueError("Field 'file' is missing from the configuration")
# install custom requirements
requirements = configuration.get("requirements", [])
if len(requirements) > 0:
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
# create the module specification
module_spec = importlib.util.spec_from_file_location(
f"model-{uuid.uuid4()}",
self.path / file
)
# get the module
self.module = importlib.util.module_from_spec(module_spec)
# load the module
module_spec.loader.exec_module(self.module)
## Api
# load the inputs data into the inference function signature (used by FastAPI)
parameters = utils.parameters.load(configuration.get("inputs", {}))
# create an endpoint wrapping the inference inside a fastapi call
async def infer_api(*args, **kwargs):
return fastapi.responses.StreamingResponse(
content=self.infer(*args, **kwargs),
media_type=self.output_type,
)
infer_api.__signature__ = inspect.Signature(parameters=parameters)
# add the inference endpoint on the API
self.manager.application.add_api_route(f"/models/{self.name}/infer", infer_api, methods=["POST"])
def _load(self) -> None:
return self.module.load(self)
def _unload(self) -> None:
return self.module.unload(self)
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
return self.module.infer(self, *args, **kwargs)