139 lines
5.2 KiB
Python
139 lines
5.2 KiB
Python
import subprocess
|
|
import tempfile
|
|
import time
|
|
import typing
|
|
import textwrap
|
|
import os
|
|
import signal
|
|
from pathlib import Path
|
|
|
|
import Pyro5
|
|
import Pyro5.api
|
|
|
|
from source.model import base
|
|
from source.registry import ModelRegistry
|
|
from source.utils.fastapi import UploadFileFix
|
|
|
|
# enable serpent to represent bytes directly
|
|
Pyro5.config.SERPENT_BYTES_REPR = True
|
|
|
|
|
|
class PythonModel(base.BaseModel):
|
|
"""
|
|
A model running a custom python model.
|
|
"""
|
|
|
|
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
|
super().__init__(registry, configuration, path)
|
|
|
|
# get the environment
|
|
self.environment = self.path / "env"
|
|
if not self.environment.exists():
|
|
raise Exception("The model is missing an environment")
|
|
|
|
# prepare the process that will hold the environment python interpreter
|
|
self._storage: typing.Optional[tempfile.TemporaryDirectory]
|
|
self._process: typing.Optional[subprocess.Popen]
|
|
self._model: typing.Optional[Pyro5.api.Proxy]
|
|
|
|
async def _load(self) -> None:
|
|
# create a temporary space for the unix socket
|
|
self._storage: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory()
|
|
socket_file = Path(self._storage.name) / 'socket.unix'
|
|
|
|
# create a process inside the conda environment
|
|
self._process = subprocess.Popen(
|
|
[
|
|
"conda", "run", # run a command within conda
|
|
"--prefix", self.environment.relative_to(self.path), # use the model environment
|
|
"python3", "-c", # run a python command
|
|
|
|
textwrap.dedent(f"""
|
|
# make sure that Pyro5 is installed for communication
|
|
import sys
|
|
import subprocess
|
|
subprocess.run(["python3", "-m", "pip", "install", "Pyro5"])
|
|
|
|
import os
|
|
import Pyro5
|
|
import Pyro5.api
|
|
import model
|
|
|
|
# allow Pyro5 to return bytes objects directly
|
|
Pyro5.config.SERPENT_BYTES_REPR = True
|
|
|
|
# helper to check if a process is still alive
|
|
def is_pid_alive(pid: int) -> bool:
|
|
try:
|
|
# do nothing if the process is alive, raise an exception if it does not exists
|
|
os.kill(pid, 0)
|
|
except OSError:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
# create a pyro daemon
|
|
daemon = Pyro5.api.Daemon(unixsocket={str(socket_file)!r})
|
|
# export our model through it
|
|
daemon.register(Pyro5.api.expose(model.Model), objectId="model")
|
|
# handle requests
|
|
# stop the process if the manager is no longer alive
|
|
daemon.requestLoop(lambda: is_pid_alive({os.getpid()}))
|
|
""")
|
|
],
|
|
|
|
cwd=self.path, # use the model directory as the working directory
|
|
start_new_session=True, # put the process in a new group to avoid killing ourselves when we unload the process
|
|
)
|
|
|
|
# wait for the process to be initialized properly
|
|
while True:
|
|
# check if the process is still alive
|
|
if self._process.poll() is not None:
|
|
# if the process stopped, raise an error (it shall stay alive until the unloading)
|
|
raise Exception("Could not load the model.")
|
|
|
|
# if the socket file have been created, the program is running successfully
|
|
if socket_file.exists():
|
|
break
|
|
|
|
time.sleep(0.5)
|
|
|
|
# get the proxy model object from the environment
|
|
self._model = Pyro5.api.Proxy(f"PYRO:model@./u:{socket_file}")
|
|
|
|
async def _unload(self) -> None:
|
|
# clear the proxy object
|
|
self._model._pyroRelease() # NOQA
|
|
del self._model
|
|
# stop the environment process
|
|
os.killpg(os.getpgid(self._process.pid), signal.SIGTERM)
|
|
self._process.wait()
|
|
del self._process
|
|
# clear the storage
|
|
self._storage.cleanup()
|
|
del self._storage
|
|
|
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
|
# Pyro5 is not capable of receiving an "UploadFile" object, so save it to a file and send the path instead
|
|
with tempfile.TemporaryDirectory() as working_directory:
|
|
for key, value in kwargs.items():
|
|
# check if this the argument is a file
|
|
if not isinstance(value, UploadFileFix):
|
|
continue
|
|
|
|
# copy the uploaded file to our working directory
|
|
path = Path(working_directory) / value.filename
|
|
with open(path, "wb") as file:
|
|
while content := await value.read(1024*1024):
|
|
file.write(content)
|
|
|
|
# replace the argument
|
|
kwargs[key] = str(path)
|
|
|
|
# run the inference
|
|
for chunk in self._model.infer(**kwargs):
|
|
yield chunk
|
|
|
|
|
|
# TODO(Faraphel): if the FastAPI close, it seem like it wait for conda to finish (or the async tasks ?)
|