Compare commits

...

2 commits

7 changed files with 58 additions and 13 deletions

View file

@ -3,6 +3,8 @@
"tags": ["dummy"], "tags": ["dummy"],
"file": "model.py", "file": "model.py",
"output_type": "video/mp4",
"inputs": { "inputs": {
"file": {"type": "file"} "file": {"type": "file"}
} }

View file

@ -1,4 +1,3 @@
import json
import typing import typing
@ -8,5 +7,5 @@ def load(model) -> None:
def unload(model) -> None: def unload(model) -> None:
pass pass
def infer(model, file) -> typing.Iterator[bytes]: async def infer(model, file) -> typing.AsyncIterator[bytes]:
yield json.dumps({"hello": "world!"}).encode("utf-8") yield await file.read()

View file

@ -1,3 +1,4 @@
import asyncio
import json import json
import os import os
import typing import typing
@ -10,6 +11,11 @@ from source import model, api
class ModelManager: class ModelManager:
"""
The model manager
Load the list of models available, ensure that only one model is loaded at the same time.
"""
def __init__(self, application: api.Application, model_library: os.PathLike | str): def __init__(self, application: api.Application, model_library: os.PathLike | str):
self.application: api.Application = application self.application: api.Application = application
self.model_library: Path = Path(model_library) self.model_library: Path = Path(model_library)
@ -20,9 +26,14 @@ class ModelManager:
self.models: dict[str, model.base.BaseModel] = {} self.models: dict[str, model.base.BaseModel] = {}
# the currently loaded model # the currently loaded model
# TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue # TODO(Faraphel): load more than one model at a time ?
# would require a way more complex manager to handle memory issue
# having two calculations at the same time might not be worth it either
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
# lock to avoid concurrent inference and concurrent model loading and unloading
self.inference_lock = asyncio.Lock()
@self.application.get("/models") @self.application.get("/models")
async def get_models() -> list[str]: async def get_models() -> list[str]:
""" """

View file

@ -50,7 +50,7 @@ class PythonModel(base.BaseModel):
parameters = utils.parameters.load(configuration.get("inputs", {})) parameters = utils.parameters.load(configuration.get("inputs", {}))
# create an endpoint wrapping the inference inside a fastapi call # create an endpoint wrapping the inference inside a fastapi call
async def infer_api(**kwargs): async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse # NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own # NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation # fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
@ -61,8 +61,12 @@ class PythonModel(base.BaseModel):
} }
return fastapi.responses.StreamingResponse( return fastapi.responses.StreamingResponse(
content=self.infer(**kwargs), content=await self.infer(**kwargs),
media_type=self.output_type, media_type=self.output_type,
headers={
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
}
) )
infer_api.__signature__ = inspect.Signature(parameters=parameters) infer_api.__signature__ = inspect.Signature(parameters=parameters)
@ -73,6 +77,12 @@ class PythonModel(base.BaseModel):
infer_api, infer_api,
methods=["POST"], methods=["POST"],
tags=self.tags, tags=self.tags,
# summary=...,
# description=...,
response_class=fastapi.responses.StreamingResponse,
responses={
200: {"content": {self.output_type: {}}}
},
) )
def _load(self) -> None: def _load(self) -> None:
@ -81,5 +91,5 @@ class PythonModel(base.BaseModel):
def _unload(self) -> None: def _unload(self) -> None:
return self.module.unload(self) return self.module.unload(self)
def _infer(self, **kwargs) -> typing.Iterator[bytes]: def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
return self.module.infer(self, **kwargs) return self.module.infer(self, **kwargs)

View file

@ -106,12 +106,13 @@ class BaseModel(abc.ABC):
Do not call manually, use `unload` instead. Do not call manually, use `unload` instead.
""" """
def infer(self, **kwargs) -> typing.Iterator[bytes]: async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
""" """
Infer our payload through the model within the model manager Infer our payload through the model within the model manager
:return: the response of the model :return: the response of the model
""" """
async with self.manager.inference_lock:
# make sure we are loaded before an inference # make sure we are loaded before an inference
self.load() self.load()
@ -119,7 +120,7 @@ class BaseModel(abc.ABC):
return self._infer(**kwargs) return self._infer(**kwargs)
@abc.abstractmethod @abc.abstractmethod
def _infer(self, **kwargs) -> typing.Iterator[bytes]: def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
""" """
Infer our payload through the model Infer our payload through the model
:return: the response of the model :return: the response of the model

View file

@ -1 +1,2 @@
from . import parameters from . import parameters
from . import mimetypes

21
source/utils/mimetypes.py Normal file
View file

@ -0,0 +1,21 @@
def is_textlike(mimetype: str) -> bool:
"""
Determinate if a mimetype is considered as holding text
:param mimetype: the mimetype to check
:return: True if the mimetype represent text, False otherwise
"""
# check the family of the mimetype
if mimetype.startswith("text/"):
return True
# check applications formats that are text formatted
if mimetype in [
"application/xml",
"application/json",
"application/javascript"
]:
return True
# otherwise consider the file as non-text
return False