diff --git a/samples/models/dummy/model.py b/samples/models/dummy/model.py index 1586d61..43be902 100644 --- a/samples/models/dummy/model.py +++ b/samples/models/dummy/model.py @@ -1,4 +1,3 @@ -import json import typing @@ -8,5 +7,5 @@ def load(model) -> None: def unload(model) -> None: pass -def infer(model, file) -> typing.Iterator[bytes]: - yield json.dumps({"hello": "world!"}).encode("utf-8") +async def infer(model, file) -> typing.AsyncIterator[bytes]: + yield await file.read() \ No newline at end of file diff --git a/source/manager/ModelManager.py b/source/manager/ModelManager.py index 2c7106c..f6e7268 100644 --- a/source/manager/ModelManager.py +++ b/source/manager/ModelManager.py @@ -23,6 +23,9 @@ class ModelManager: # TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue self.current_loaded_model: typing.Optional[model.base.BaseModel] = None + # lock to avoid concurrent inference and concurrent model loading and unloading + self.inference_lock = asyncio.Lock() + @self.application.get("/models") async def get_models() -> list[str]: """ diff --git a/source/model/PythonModel.py b/source/model/PythonModel.py index 0d47fc6..2d91df1 100644 --- a/source/model/PythonModel.py +++ b/source/model/PythonModel.py @@ -50,7 +50,7 @@ class PythonModel(base.BaseModel): parameters = utils.parameters.load(configuration.get("inputs", {})) # create an endpoint wrapping the inference inside a fastapi call - async def infer_api(**kwargs): + async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse: # NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse # NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own # fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation @@ -61,8 +61,12 @@ class PythonModel(base.BaseModel): } return fastapi.responses.StreamingResponse( - content=self.infer(**kwargs), + content=await self.infer(**kwargs), media_type=self.output_type, + headers={ + # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI + "content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment" + } ) infer_api.__signature__ = inspect.Signature(parameters=parameters) @@ -81,5 +85,5 @@ class PythonModel(base.BaseModel): def _unload(self) -> None: return self.module.unload(self) - def _infer(self, **kwargs) -> typing.Iterator[bytes]: + def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]: return self.module.infer(self, **kwargs) diff --git a/source/model/base/BaseModel.py b/source/model/base/BaseModel.py index d39d7d2..9072e97 100644 --- a/source/model/base/BaseModel.py +++ b/source/model/base/BaseModel.py @@ -106,20 +106,21 @@ class BaseModel(abc.ABC): Do not call manually, use `unload` instead. """ - def infer(self, **kwargs) -> typing.Iterator[bytes]: + async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]: """ Infer our payload through the model within the model manager :return: the response of the model """ - # make sure we are loaded before an inference - self.load() + async with self.manager.inference_lock: + # make sure we are loaded before an inference + self.load() - # model specific inference part - return self._infer(**kwargs) + # model specific inference part + return self._infer(**kwargs) @abc.abstractmethod - def _infer(self, **kwargs) -> typing.Iterator[bytes]: + def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]: """ Infer our payload through the model :return: the response of the model