diff --git a/samples/models/dummy/config.json b/samples/models/dummy/config.json index 09be48c..c12d549 100644 --- a/samples/models/dummy/config.json +++ b/samples/models/dummy/config.json @@ -3,6 +3,8 @@ "tags": ["dummy"], "file": "model.py", + "output_type": "video/mp4", + "inputs": { "file": {"type": "file"} } diff --git a/samples/models/dummy/model.py b/samples/models/dummy/model.py index 1586d61..43be902 100644 --- a/samples/models/dummy/model.py +++ b/samples/models/dummy/model.py @@ -1,4 +1,3 @@ -import json import typing @@ -8,5 +7,5 @@ def load(model) -> None: def unload(model) -> None: pass -def infer(model, file) -> typing.Iterator[bytes]: - yield json.dumps({"hello": "world!"}).encode("utf-8") +async def infer(model, file) -> typing.AsyncIterator[bytes]: + yield await file.read() \ No newline at end of file diff --git a/source/manager/ModelManager.py b/source/manager/ModelManager.py index 2c7106c..6313ab9 100644 --- a/source/manager/ModelManager.py +++ b/source/manager/ModelManager.py @@ -1,3 +1,4 @@ +import asyncio import json import os import typing @@ -10,6 +11,11 @@ from source import model, api class ModelManager: + """ + The model manager + Load the list of models available, ensure that only one model is loaded at the same time. + """ + def __init__(self, application: api.Application, model_library: os.PathLike | str): self.application: api.Application = application self.model_library: Path = Path(model_library) @@ -20,9 +26,14 @@ class ModelManager: self.models: dict[str, model.base.BaseModel] = {} # the currently loaded model - # TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue + # TODO(Faraphel): load more than one model at a time ? + # would require a way more complex manager to handle memory issue + # having two calculations at the same time might not be worth it either self.current_loaded_model: typing.Optional[model.base.BaseModel] = None + # lock to avoid concurrent inference and concurrent model loading and unloading + self.inference_lock = asyncio.Lock() + @self.application.get("/models") async def get_models() -> list[str]: """ diff --git a/source/model/PythonModel.py b/source/model/PythonModel.py index 0d47fc6..c23bceb 100644 --- a/source/model/PythonModel.py +++ b/source/model/PythonModel.py @@ -50,7 +50,7 @@ class PythonModel(base.BaseModel): parameters = utils.parameters.load(configuration.get("inputs", {})) # create an endpoint wrapping the inference inside a fastapi call - async def infer_api(**kwargs): + async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse: # NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse # NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own # fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation @@ -61,8 +61,12 @@ class PythonModel(base.BaseModel): } return fastapi.responses.StreamingResponse( - content=self.infer(**kwargs), + content=await self.infer(**kwargs), media_type=self.output_type, + headers={ + # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI + "content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment" + } ) infer_api.__signature__ = inspect.Signature(parameters=parameters) @@ -73,6 +77,12 @@ class PythonModel(base.BaseModel): infer_api, methods=["POST"], tags=self.tags, + # summary=..., + # description=..., + response_class=fastapi.responses.StreamingResponse, + responses={ + 200: {"content": {self.output_type: {}}} + }, ) def _load(self) -> None: @@ -81,5 +91,5 @@ class PythonModel(base.BaseModel): def _unload(self) -> None: return self.module.unload(self) - def _infer(self, **kwargs) -> typing.Iterator[bytes]: + def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]: return self.module.infer(self, **kwargs) diff --git a/source/model/base/BaseModel.py b/source/model/base/BaseModel.py index d39d7d2..9072e97 100644 --- a/source/model/base/BaseModel.py +++ b/source/model/base/BaseModel.py @@ -106,20 +106,21 @@ class BaseModel(abc.ABC): Do not call manually, use `unload` instead. """ - def infer(self, **kwargs) -> typing.Iterator[bytes]: + async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]: """ Infer our payload through the model within the model manager :return: the response of the model """ - # make sure we are loaded before an inference - self.load() + async with self.manager.inference_lock: + # make sure we are loaded before an inference + self.load() - # model specific inference part - return self._infer(**kwargs) + # model specific inference part + return self._infer(**kwargs) @abc.abstractmethod - def _infer(self, **kwargs) -> typing.Iterator[bytes]: + def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]: """ Infer our payload through the model :return: the response of the model diff --git a/source/utils/__init__.py b/source/utils/__init__.py index f6bd50e..1b48283 100644 --- a/source/utils/__init__.py +++ b/source/utils/__init__.py @@ -1 +1,2 @@ from . import parameters +from . import mimetypes diff --git a/source/utils/mimetypes.py b/source/utils/mimetypes.py new file mode 100644 index 0000000..2b54dd8 --- /dev/null +++ b/source/utils/mimetypes.py @@ -0,0 +1,21 @@ +def is_textlike(mimetype: str) -> bool: + """ + Determinate if a mimetype is considered as holding text + :param mimetype: the mimetype to check + :return: True if the mimetype represent text, False otherwise + """ + + # check the family of the mimetype + if mimetype.startswith("text/"): + return True + + # check applications formats that are text formatted + if mimetype in [ + "application/xml", + "application/json", + "application/javascript" + ]: + return True + + # otherwise consider the file as non-text + return False \ No newline at end of file