import subprocess import tempfile import time import typing import textwrap import os import signal from pathlib import Path import Pyro5 import Pyro5.api from source.model import base from source.registry import ModelRegistry from source.utils.fastapi import UploadFileFix # enable serpent to represent bytes directly Pyro5.config.SERPENT_BYTES_REPR = True class PythonModel(base.BaseModel): """ A model running a custom python model. """ def __init__(self, registry: ModelRegistry, configuration: dict, path: Path): super().__init__(registry, configuration, path) # get the environment self.environment = self.path / "env" if not self.environment.exists(): raise Exception("The model is missing an environment") # prepare the process that will hold the environment python interpreter self._storage: typing.Optional[tempfile.TemporaryDirectory] self._process: typing.Optional[subprocess.Popen] self._model: typing.Optional[Pyro5.api.Proxy] async def _load(self) -> None: # create a temporary space for the unix socket self._storage: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory() socket_file = Path(self._storage.name) / 'socket.unix' # create a process inside the conda environment self._process = subprocess.Popen( [ "conda", "run", # run a command within conda "--prefix", self.environment.relative_to(self.path), # use the model environment "python3", "-c", # run a python command textwrap.dedent(f""" # make sure that Pyro5 is installed for communication import sys import subprocess subprocess.run(["python3", "-m", "pip", "install", "Pyro5"]) import os import Pyro5 import Pyro5.api import model # allow Pyro5 to return bytes objects directly Pyro5.config.SERPENT_BYTES_REPR = True # helper to check if a process is still alive def is_pid_alive(pid: int) -> bool: try: # do nothing if the process is alive, raise an exception if it does not exists os.kill(pid, 0) except OSError: return False else: return True # create a pyro daemon daemon = Pyro5.api.Daemon(unixsocket={str(socket_file)!r}) # export our model through it daemon.register(Pyro5.api.expose(model.Model), objectId="model") # handle requests # stop the process if the manager is no longer alive daemon.requestLoop(lambda: is_pid_alive({os.getpid()})) """) ], cwd=self.path, # use the model directory as the working directory start_new_session=True, # put the process in a new group to avoid killing ourselves when we unload the process ) # wait for the process to be initialized properly while True: # check if the process is still alive if self._process.poll() is not None: # if the process stopped, raise an error (it shall stay alive until the unloading) raise Exception("Could not load the model.") # if the socket file have been created, the program is running successfully if socket_file.exists(): break time.sleep(0.5) # get the proxy model object from the environment self._model = Pyro5.api.Proxy(f"PYRO:model@./u:{socket_file}") async def _unload(self) -> None: # clear the proxy object self._model._pyroRelease() # NOQA del self._model # stop the environment process os.killpg(os.getpgid(self._process.pid), signal.SIGTERM) self._process.wait() del self._process # clear the storage self._storage.cleanup() del self._storage async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: # Pyro5 is not capable of receiving an "UploadFile" object, so save it to a file and send the path instead with tempfile.TemporaryDirectory() as working_directory: for key, value in kwargs.items(): # check if this the argument is a file if not isinstance(value, UploadFileFix): continue # copy the uploaded file to our working directory path = Path(working_directory) / value.filename with open(path, "wb") as file: while content := await value.read(1024*1024): file.write(content) # replace the argument kwargs[key] = str(path) # run the inference for chunk in self._model.infer(**kwargs): yield chunk # TODO(Faraphel): if the FastAPI close, it seem like it wait for conda to finish (or the async tasks ?)