replaced the previous venv system by a conda one, allowing for better dependencies management
This commit is contained in:
parent
8bf28e4c48
commit
0034c7b31a
17 changed files with 313 additions and 230 deletions
|
@ -3,6 +3,7 @@ import os
|
|||
from source import registry, model, api
|
||||
from source.api import interface
|
||||
|
||||
|
||||
# create a fastapi application
|
||||
application = api.Application()
|
||||
|
||||
|
|
|
@ -30,19 +30,19 @@ class ChatInterface(base.BaseInterface):
|
|||
messages.insert(0, {"role": "system", "content": system_message})
|
||||
|
||||
# add the user message
|
||||
# NOTE: gradio.ChatInterface add our message and the assistant message
|
||||
# TODO(Faraphel): add support for files
|
||||
# NOTE: gradio.ChatInterface add our message and the assistant message automatically
|
||||
# TODO(Faraphel): add support for files - directory use user_message ? apparently, field "image" is supported.
|
||||
# check "https://huggingface.co/docs/transformers/main_classes/pipelines" at "ImageTextToTextPipeline"
|
||||
|
||||
# TODO(Faraphel): add a "MultimodalChatInterface" to support images
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": user_message["text"],
|
||||
})
|
||||
|
||||
# infer the message through the model
|
||||
chunks = [chunk async for chunk in await self.model.infer(messages=messages)]
|
||||
assistant_message: str = b"".join(chunks).decode("utf-8")
|
||||
|
||||
# send back the messages, clear the user prompt, disable the system prompt
|
||||
return assistant_message
|
||||
async for chunk in self.model.infer(messages=messages):
|
||||
yield chunk.decode("utf-8")
|
||||
|
||||
def get_application(self):
|
||||
# create a gradio interface
|
||||
|
@ -65,7 +65,7 @@ class ChatInterface(base.BaseInterface):
|
|||
gradio.ChatInterface(
|
||||
fn=self.send_message,
|
||||
type="messages",
|
||||
multimodal=True,
|
||||
multimodal=False, # TODO(Faraphel): should handle at least image and text files
|
||||
editable=True,
|
||||
save_history=True,
|
||||
additional_inputs=[system_prompt],
|
||||
|
|
|
@ -1,18 +1,22 @@
|
|||
import importlib.util
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import typing
|
||||
import uuid
|
||||
import inspect
|
||||
import textwrap
|
||||
import os
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
import fastapi
|
||||
import Pyro5
|
||||
import Pyro5.api
|
||||
|
||||
from source import utils
|
||||
from source.model import base
|
||||
from source.registry import ModelRegistry
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
# enable serpent to represent bytes directly
|
||||
Pyro5.config.SERPENT_BYTES_REPR = True
|
||||
|
||||
|
||||
class PythonModel(base.BaseModel):
|
||||
"""
|
||||
|
@ -22,85 +26,114 @@ class PythonModel(base.BaseModel):
|
|||
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
||||
super().__init__(registry, configuration, path)
|
||||
|
||||
# get the parameters of the model
|
||||
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
# get the environment
|
||||
self.environment = self.path / "env"
|
||||
if not self.environment.exists():
|
||||
raise Exception("The model is missing an environment")
|
||||
|
||||
# install custom requirements
|
||||
requirements = configuration.get("requirements", [])
|
||||
if len(requirements) > 0:
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
||||
|
||||
# get the name of the file containing the model code
|
||||
file = configuration.get("file")
|
||||
if file is None:
|
||||
raise ValueError("Field 'file' is missing from the configuration")
|
||||
|
||||
# create the module specification
|
||||
module_spec = importlib.util.spec_from_file_location(
|
||||
f"model-{uuid.uuid4()}",
|
||||
self.path / file
|
||||
)
|
||||
# get the module
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# load the module
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
# create the internal model from the class defined in the module
|
||||
self._model_type = module.Model
|
||||
self._model: typing.Optional[module.Model] = None
|
||||
# prepare the process that will hold the environment python interpreter
|
||||
self._storage: typing.Optional[tempfile.TemporaryDirectory]
|
||||
self._process: typing.Optional[subprocess.Popen]
|
||||
self._model: typing.Optional[Pyro5.api.Proxy]
|
||||
|
||||
async def _load(self) -> None:
|
||||
self._model = self._model_type()
|
||||
# create a temporary space for the unix socket
|
||||
self._storage: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory()
|
||||
socket_file = Path(self._storage.name) / 'socket.unix'
|
||||
|
||||
# create a process inside the conda environment
|
||||
self._process = subprocess.Popen(
|
||||
[
|
||||
"conda", "run", # run a command within conda
|
||||
"--prefix", self.environment.relative_to(self.path), # use the model environment
|
||||
"python3", "-c", # run a python command
|
||||
|
||||
textwrap.dedent(f"""
|
||||
# make sure that Pyro5 is installed for communication
|
||||
import sys
|
||||
import subprocess
|
||||
subprocess.run(["python3", "-m", "pip", "install", "Pyro5"])
|
||||
|
||||
import os
|
||||
import Pyro5
|
||||
import Pyro5.api
|
||||
import model
|
||||
|
||||
# allow Pyro5 to return bytes objects directly
|
||||
Pyro5.config.SERPENT_BYTES_REPR = True
|
||||
|
||||
# helper to check if a process is still alive
|
||||
def is_pid_alive(pid: int) -> bool:
|
||||
try:
|
||||
# do nothing if the process is alive, raise an exception if it does not exists
|
||||
os.kill(pid, 0)
|
||||
except OSError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
# create a pyro daemon
|
||||
daemon = Pyro5.api.Daemon(unixsocket={str(socket_file)!r})
|
||||
# export our model through it
|
||||
daemon.register(Pyro5.api.expose(model.Model), objectId="model")
|
||||
# handle requests
|
||||
# stop the process if the manager is no longer alive
|
||||
daemon.requestLoop(lambda: is_pid_alive({os.getpid()}))
|
||||
""")
|
||||
],
|
||||
|
||||
cwd=self.path, # use the model directory as the working directory
|
||||
start_new_session=True, # put the process in a new group to avoid killing ourselves when we unload the process
|
||||
)
|
||||
|
||||
# wait for the process to be initialized properly
|
||||
while True:
|
||||
# check if the process is still alive
|
||||
if self._process.poll() is not None:
|
||||
# if the process stopped, raise an error (it shall stay alive until the unloading)
|
||||
raise Exception("Could not load the model.")
|
||||
|
||||
# if the socket file have been created, the program is running successfully
|
||||
if socket_file.exists():
|
||||
break
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
# get the proxy model object from the environment
|
||||
self._model = Pyro5.api.Proxy(f"PYRO:model@./u:{socket_file}")
|
||||
|
||||
async def _unload(self) -> None:
|
||||
self._model = None
|
||||
# clear the proxy object
|
||||
self._model._pyroRelease() # NOQA
|
||||
del self._model
|
||||
# stop the environment process
|
||||
os.killpg(os.getpgid(self._process.pid), signal.SIGTERM)
|
||||
self._process.wait()
|
||||
del self._process
|
||||
# clear the storage
|
||||
self._storage.cleanup()
|
||||
del self._storage
|
||||
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
return self._model.infer(**kwargs)
|
||||
# Pyro5 is not capable of receiving an "UploadFile" object, so save it to a file and send the path instead
|
||||
with tempfile.TemporaryDirectory() as working_directory:
|
||||
for key, value in kwargs.items():
|
||||
# check if this the argument is a file
|
||||
if not isinstance(value, UploadFileFix):
|
||||
continue
|
||||
|
||||
def _mount(self, application: fastapi.FastAPI):
|
||||
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
||||
# copy the uploaded file to our working directory
|
||||
path = Path(working_directory) / value.filename
|
||||
with open(path, "wb") as file:
|
||||
while content := await value.read(1024*1024):
|
||||
file.write(content)
|
||||
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||
# curiosity that may change in the future
|
||||
kwargs = {
|
||||
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
||||
for key, value in kwargs.items()
|
||||
}
|
||||
# replace the argument
|
||||
kwargs[key] = str(path)
|
||||
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=await self.infer(**kwargs),
|
||||
media_type=self.output_type,
|
||||
headers={
|
||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
||||
}
|
||||
)
|
||||
# run the inference
|
||||
for chunk in self._model.infer(**kwargs):
|
||||
yield chunk
|
||||
|
||||
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||
|
||||
# format the description
|
||||
description_sections: list[str] = []
|
||||
if self.description is not None:
|
||||
description_sections.append(f"# Description\n{self.description}")
|
||||
if self.interface is not None:
|
||||
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
|
||||
|
||||
description: str = "\n".join(description_sections)
|
||||
|
||||
# add the inference endpoint on the API
|
||||
application.add_api_route(
|
||||
f"{self.api_base}/infer",
|
||||
infer_api,
|
||||
methods=["POST"],
|
||||
tags=self.tags,
|
||||
summary=self.summary,
|
||||
description=description,
|
||||
response_class=fastapi.responses.StreamingResponse,
|
||||
responses={
|
||||
200: {"content": {self.output_type: {}}}
|
||||
},
|
||||
)
|
||||
# TODO(Faraphel): if the FastAPI close, it seem like it wait for conda to finish (or the async tasks ?)
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import gc
|
||||
import tempfile
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
import fastapi
|
||||
import inspect
|
||||
|
||||
from source import utils
|
||||
from source.registry import ModelRegistry
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
|
||||
class BaseModel(abc.ABC):
|
||||
|
@ -24,6 +28,9 @@ class BaseModel(abc.ABC):
|
|||
|
||||
# the environment directory of the model
|
||||
self.path = path
|
||||
# get the parameters of the model
|
||||
self.inputs = configuration.get("inputs", {})
|
||||
self.parameters = utils.parameters.load(self.inputs)
|
||||
# the mimetype of the model responses
|
||||
self.output_type: str = configuration.get("output_type", "application/json")
|
||||
# get the tags of the model
|
||||
|
@ -71,8 +78,12 @@ class BaseModel(abc.ABC):
|
|||
|
||||
return {
|
||||
"name": self.name,
|
||||
"summary": self.summary,
|
||||
"description": self.description,
|
||||
"inputs": self.inputs,
|
||||
"output_type": self.output_type,
|
||||
"tags": self.tags
|
||||
"tags": self.tags,
|
||||
"interface": self.interface,
|
||||
}
|
||||
|
||||
async def load(self) -> None:
|
||||
|
@ -145,7 +156,8 @@ class BaseModel(abc.ABC):
|
|||
await self.load()
|
||||
|
||||
# model specific inference part
|
||||
return await self._infer(**kwargs)
|
||||
async for chunk in self._infer(**kwargs):
|
||||
yield chunk
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
|
@ -164,13 +176,50 @@ class BaseModel(abc.ABC):
|
|||
if self.interface is not None:
|
||||
self.interface.mount(application)
|
||||
|
||||
# implementation specific mount
|
||||
self._mount(application)
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
# the arguments will be loaded from the configuration files. Use kwargs for the definition
|
||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||
# curiosity that may change in the future
|
||||
kwargs = {
|
||||
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
||||
for key, value in kwargs.items()
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Add the model to the api
|
||||
Do not call manually, use `unload` instead.
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
# return a streaming response around the inference call
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=self.infer(**kwargs),
|
||||
media_type=self.output_type,
|
||||
headers={
|
||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
||||
}
|
||||
)
|
||||
|
||||
# update the signature of the function to use the configuration parameters
|
||||
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||
|
||||
# format the description
|
||||
description_sections: list[str] = []
|
||||
if self.description is not None:
|
||||
description_sections.append(f"# Description\n{self.description}")
|
||||
if self.interface is not None:
|
||||
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
|
||||
|
||||
description: str = "\n".join(description_sections)
|
||||
|
||||
# add the inference endpoint on the API
|
||||
application.add_api_route(
|
||||
f"{self.api_base}/infer",
|
||||
infer_api,
|
||||
methods=["POST"],
|
||||
tags=self.tags,
|
||||
summary=self.summary,
|
||||
description=description,
|
||||
response_class=fastapi.responses.StreamingResponse,
|
||||
responses={
|
||||
200: {"content": {self.output_type: {}}}
|
||||
},
|
||||
)
|
||||
|
|
|
@ -60,7 +60,12 @@ class ModelRegistry:
|
|||
|
||||
# load all the models in the library
|
||||
for model_path in self.model_library.iterdir():
|
||||
# get the model name
|
||||
model_name: str = model_path.name
|
||||
if model_name.startswith("."):
|
||||
# ignore model starting with a dot
|
||||
continue
|
||||
|
||||
model_configuration_path: Path = model_path / "config.json"
|
||||
|
||||
# check if the configuration file exists
|
||||
|
@ -68,8 +73,11 @@ class ModelRegistry:
|
|||
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
|
||||
continue
|
||||
|
||||
# load the configuration file
|
||||
model_configuration = json.loads(model_configuration_path.read_text())
|
||||
try:
|
||||
# load the configuration file
|
||||
model_configuration = json.loads(model_configuration_path.read_text())
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise Exception(f"Model {model_name!r}'s configuration is invalid. See above.")
|
||||
|
||||
# get the model type for this model
|
||||
model_type_name: str = model_configuration.get("type")
|
||||
|
|
|
@ -1,24 +1,5 @@
|
|||
import inspect
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
# the list of types and their name that can be used by the API
|
||||
types: dict[str, type] = {
|
||||
"bool": bool,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"str": str,
|
||||
"bytes": bytes,
|
||||
"dict": dict,
|
||||
"datetime": datetime,
|
||||
"file": UploadFile,
|
||||
|
||||
# TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ?
|
||||
"list[dict]": list[dict],
|
||||
# "tuple": tuple,
|
||||
# "set": set,
|
||||
}
|
||||
import fastapi
|
||||
|
||||
|
||||
def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
||||
|
@ -31,7 +12,6 @@ def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
|||
>>> parameters_definition = {
|
||||
... "boolean": {"type": "bool", "default": False},
|
||||
... "list": {"type": "list", "default": [1, 2, 3]},
|
||||
... "datetime": {"type": "datetime"},
|
||||
... "file": {"type": "file"},
|
||||
... }
|
||||
>>> parameters = load_parameters(parameters_definition)
|
||||
|
@ -40,12 +20,19 @@ def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
|||
parameters: list[inspect.Parameter] = []
|
||||
|
||||
for name, definition in parameters_definition.items():
|
||||
# preprocess the type
|
||||
match definition["type"]:
|
||||
case "file":
|
||||
# shortcut for uploading a file
|
||||
definition["type"] = fastapi.UploadFile
|
||||
|
||||
|
||||
# deserialize the parameter
|
||||
parameter = inspect.Parameter(
|
||||
name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=definition.get("default", inspect.Parameter.empty),
|
||||
annotation=types[definition["type"]],
|
||||
annotation=definition["type"],
|
||||
)
|
||||
parameters.append(parameter)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue