ai-server/source/model/PythonModel.py

139 lines
5.2 KiB
Python

import subprocess
import tempfile
import time
import typing
import textwrap
import os
import signal
from pathlib import Path
import Pyro5
import Pyro5.api
from source.model import base
from source.registry import ModelRegistry
from source.utils.fastapi import UploadFileFix
# enable serpent to represent bytes directly
Pyro5.config.SERPENT_BYTES_REPR = True
class PythonModel(base.BaseModel):
"""
A model running a custom python model.
"""
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
super().__init__(registry, configuration, path)
# get the environment
self.environment = self.path / "env"
if not self.environment.exists():
raise Exception("The model is missing an environment")
# prepare the process that will hold the environment python interpreter
self._storage: typing.Optional[tempfile.TemporaryDirectory]
self._process: typing.Optional[subprocess.Popen]
self._model: typing.Optional[Pyro5.api.Proxy]
async def _load(self) -> None:
# create a temporary space for the unix socket
self._storage: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory()
socket_file = Path(self._storage.name) / 'socket.unix'
# create a process inside the conda environment
self._process = subprocess.Popen(
[
"conda", "run", # run a command within conda
"--prefix", self.environment.relative_to(self.path), # use the model environment
"python3", "-c", # run a python command
textwrap.dedent(f"""
# make sure that Pyro5 is installed for communication
import sys
import subprocess
subprocess.run(["python3", "-m", "pip", "install", "Pyro5"])
import os
import Pyro5
import Pyro5.api
import model
# allow Pyro5 to return bytes objects directly
Pyro5.config.SERPENT_BYTES_REPR = True
# helper to check if a process is still alive
def is_pid_alive(pid: int) -> bool:
try:
# do nothing if the process is alive, raise an exception if it does not exists
os.kill(pid, 0)
except OSError:
return False
else:
return True
# create a pyro daemon
daemon = Pyro5.api.Daemon(unixsocket={str(socket_file)!r})
# export our model through it
daemon.register(Pyro5.api.expose(model.Model), objectId="model")
# handle requests
# stop the process if the manager is no longer alive
daemon.requestLoop(lambda: is_pid_alive({os.getpid()}))
""")
],
cwd=self.path, # use the model directory as the working directory
start_new_session=True, # put the process in a new group to avoid killing ourselves when we unload the process
)
# wait for the process to be initialized properly
while True:
# check if the process is still alive
if self._process.poll() is not None:
# if the process stopped, raise an error (it shall stay alive until the unloading)
raise Exception("Could not load the model.")
# if the socket file have been created, the program is running successfully
if socket_file.exists():
break
time.sleep(0.5)
# get the proxy model object from the environment
self._model = Pyro5.api.Proxy(f"PYRO:model@./u:{socket_file}")
async def _unload(self) -> None:
# clear the proxy object
self._model._pyroRelease() # NOQA
del self._model
# stop the environment process
os.killpg(os.getpgid(self._process.pid), signal.SIGTERM)
self._process.wait()
del self._process
# clear the storage
self._storage.cleanup()
del self._storage
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
# Pyro5 is not capable of receiving an "UploadFile" object, so save it to a file and send the path instead
with tempfile.TemporaryDirectory() as working_directory:
for key, value in kwargs.items():
# check if this the argument is a file
if not isinstance(value, UploadFileFix):
continue
# copy the uploaded file to our working directory
path = Path(working_directory) / value.filename
with open(path, "wb") as file:
while content := await value.read(1024*1024):
file.write(content)
# replace the argument
kwargs[key] = str(path)
# run the inference
for chunk in self._model.infer(**kwargs):
yield chunk
# TODO(Faraphel): if the FastAPI close, it seem like it wait for conda to finish (or the async tasks ?)