added support for additional more user-friendly interfaces, improved some part of the application loading process to make it a bit simpler
This commit is contained in:
parent
1a49aa3779
commit
f647c960dd
20 changed files with 353 additions and 107 deletions
|
@ -9,8 +9,8 @@ from pathlib import Path
|
|||
import fastapi
|
||||
|
||||
from source import utils
|
||||
from source.manager import ModelManager
|
||||
from source.model import base
|
||||
from source.registry import ModelRegistry
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
|
||||
|
@ -19,21 +19,22 @@ class PythonModel(base.BaseModel):
|
|||
A model running a custom python model.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
|
||||
super().__init__(manager, configuration, path)
|
||||
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
||||
super().__init__(registry, configuration, path)
|
||||
|
||||
## Configuration
|
||||
|
||||
# get the name of the file containing the model code
|
||||
file = configuration.get("file")
|
||||
if file is None:
|
||||
raise ValueError("Field 'file' is missing from the configuration")
|
||||
# get the parameters of the model
|
||||
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
|
||||
# install custom requirements
|
||||
requirements = configuration.get("requirements", [])
|
||||
if len(requirements) > 0:
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
||||
|
||||
# get the name of the file containing the model code
|
||||
file = configuration.get("file")
|
||||
if file is None:
|
||||
raise ValueError("Field 'file' is missing from the configuration")
|
||||
|
||||
# create the module specification
|
||||
module_spec = importlib.util.spec_from_file_location(
|
||||
f"model-{uuid.uuid4()}",
|
||||
|
@ -44,10 +45,17 @@ class PythonModel(base.BaseModel):
|
|||
# load the module
|
||||
module_spec.loader.exec_module(self.module)
|
||||
|
||||
## Api
|
||||
def _load(self) -> None:
|
||||
return self.module.load(self)
|
||||
|
||||
# load the inputs data into the inference function signature (used by FastAPI)
|
||||
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
def _unload(self) -> None:
|
||||
return self.module.unload(self)
|
||||
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
return self.module.infer(self, **kwargs)
|
||||
|
||||
def _mount(self, application: fastapi.FastAPI):
|
||||
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
||||
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||
|
@ -61,7 +69,7 @@ class PythonModel(base.BaseModel):
|
|||
}
|
||||
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=await self.infer(**kwargs),
|
||||
content=await self.registry.infer_model(self, **kwargs),
|
||||
media_type=self.output_type,
|
||||
headers={
|
||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||
|
@ -69,27 +77,25 @@ class PythonModel(base.BaseModel):
|
|||
}
|
||||
)
|
||||
|
||||
infer_api.__signature__ = inspect.Signature(parameters=parameters)
|
||||
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||
|
||||
# format the description
|
||||
description_sections: list[str] = []
|
||||
if self.description is not None:
|
||||
description_sections.append(self.description)
|
||||
if self.interface is not None:
|
||||
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**")
|
||||
|
||||
# add the inference endpoint on the API
|
||||
self.manager.application.add_api_route(
|
||||
f"/models/{self.name}/infer",
|
||||
application.add_api_route(
|
||||
f"{self.api_base}/infer",
|
||||
infer_api,
|
||||
methods=["POST"],
|
||||
tags=self.tags,
|
||||
# summary=...,
|
||||
# description=...,
|
||||
summary=self.summary,
|
||||
description="<br>".join(description_sections),
|
||||
response_class=fastapi.responses.StreamingResponse,
|
||||
responses={
|
||||
200: {"content": {self.output_type: {}}}
|
||||
},
|
||||
)
|
||||
|
||||
def _load(self) -> None:
|
||||
return self.module.load(self)
|
||||
|
||||
def _unload(self) -> None:
|
||||
return self.module.unload(self)
|
||||
|
||||
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
|
||||
return self.module.infer(self, **kwargs)
|
||||
|
|
|
@ -3,7 +3,9 @@ import gc
|
|||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
from source.manager import ModelManager
|
||||
import fastapi
|
||||
|
||||
from source.registry import ModelRegistry
|
||||
|
||||
|
||||
class BaseModel(abc.ABC):
|
||||
|
@ -11,21 +13,43 @@ class BaseModel(abc.ABC):
|
|||
Represent a model.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: ModelManager, configuration: dict[str, typing.Any], path: Path):
|
||||
def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
|
||||
# the model registry
|
||||
self.registry = registry
|
||||
|
||||
# get the documentation of the model
|
||||
self.summary = configuration.get("summary")
|
||||
self.description = configuration.get("description")
|
||||
|
||||
# the environment directory of the model
|
||||
self.path = path
|
||||
# the model manager
|
||||
self.manager = manager
|
||||
# the mimetype of the model responses
|
||||
self.output_type: str = configuration.get("output_type", "application/json")
|
||||
# get the tags of the model
|
||||
self.tags = configuration.get("tags", [])
|
||||
|
||||
# get the selected interface of the model
|
||||
interface_name: typing.Optional[str] = configuration.get("interface", None)
|
||||
self.interface = (
|
||||
self.registry.interface_registry.interface_types[interface_name](self)
|
||||
if interface_name is not None else None
|
||||
)
|
||||
|
||||
# is the model currently loaded
|
||||
self._loaded = False
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}: {self.name}>"
|
||||
|
||||
@property
|
||||
def api_base(self) -> str:
|
||||
"""
|
||||
Base for the API routes
|
||||
:return: the base for the API routes
|
||||
"""
|
||||
|
||||
return f"{self.registry.api_base}/{self.name}"
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
|
@ -44,6 +68,7 @@ class BaseModel(abc.ABC):
|
|||
return {
|
||||
"name": self.name,
|
||||
"output_type": self.output_type,
|
||||
"tags": self.tags
|
||||
}
|
||||
|
||||
def load(self) -> None:
|
||||
|
@ -51,22 +76,13 @@ class BaseModel(abc.ABC):
|
|||
Load the model within the model manager
|
||||
"""
|
||||
|
||||
# if we are already loaded, stop
|
||||
# if the model is already loaded, skip
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
# check if we are the current loaded model
|
||||
if self.manager.current_loaded_model is not self:
|
||||
# unload the previous model
|
||||
if self.manager.current_loaded_model is not None:
|
||||
self.manager.current_loaded_model.unload()
|
||||
|
||||
# model specific loading
|
||||
# load the model depending on the implementation
|
||||
self._load()
|
||||
|
||||
# declare ourselves as the currently loaded model
|
||||
self.manager.current_loaded_model = self
|
||||
|
||||
# mark the model as loaded
|
||||
self._loaded = True
|
||||
|
||||
|
@ -86,11 +102,7 @@ class BaseModel(abc.ABC):
|
|||
if not self._loaded:
|
||||
return
|
||||
|
||||
# if we were the currently loaded model of the manager, demote ourselves
|
||||
if self.manager.current_loaded_model is self:
|
||||
self.manager.current_loaded_model = None
|
||||
|
||||
# model specific unloading part
|
||||
# unload the model depending on the implementation
|
||||
self._unload()
|
||||
|
||||
# force the garbage collector to clean the memory
|
||||
|
@ -106,22 +118,42 @@ class BaseModel(abc.ABC):
|
|||
Do not call manually, use `unload` instead.
|
||||
"""
|
||||
|
||||
async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
|
||||
async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model within the model manager
|
||||
:return: the response of the model
|
||||
"""
|
||||
|
||||
async with self.manager.inference_lock:
|
||||
# make sure we are loaded before an inference
|
||||
self.load()
|
||||
# make sure we are loaded before an inference
|
||||
self.load()
|
||||
|
||||
# model specific inference part
|
||||
return self._infer(**kwargs)
|
||||
# model specific inference part
|
||||
return await self._infer(**kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model
|
||||
:return: the response of the model
|
||||
"""
|
||||
|
||||
def mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Add the model to the api
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
|
||||
# mount the interface if selected
|
||||
if self.interface is not None:
|
||||
self.interface.mount(application)
|
||||
|
||||
# implementation specific mount
|
||||
self._mount(application)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Add the model to the api
|
||||
Do not call manually, use `unload` instead.
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue