diff --git a/samples/models/dummy/config.json b/samples/models/dummy/config.json index 25be088..d5eabc8 100644 --- a/samples/models/dummy/config.json +++ b/samples/models/dummy/config.json @@ -2,5 +2,7 @@ "type": "python", "file": "model.py", - "inputs": {} + "inputs": { + "file": {"type": "file"} + } } diff --git a/samples/models/dummy/model.py b/samples/models/dummy/model.py index d22c850..1586d61 100644 --- a/samples/models/dummy/model.py +++ b/samples/models/dummy/model.py @@ -8,5 +8,5 @@ def load(model) -> None: def unload(model) -> None: pass -def infer(model) -> typing.Iterator[bytes]: +def infer(model, file) -> typing.Iterator[bytes]: yield json.dumps({"hello": "world!"}).encode("utf-8") diff --git a/source/model/PythonModel.py b/source/model/PythonModel.py index 8a85d7f..bb2531f 100644 --- a/source/model/PythonModel.py +++ b/source/model/PythonModel.py @@ -11,6 +11,7 @@ import fastapi from source import utils from source.manager import ModelManager from source.model import base +from source.utils.fastapi import UploadFileFix class PythonModel(base.BaseModel): @@ -49,9 +50,18 @@ class PythonModel(base.BaseModel): parameters = utils.parameters.load(configuration.get("inputs", {})) # create an endpoint wrapping the inference inside a fastapi call - async def infer_api(*args, **kwargs): + async def infer_api(**kwargs): + # NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse + # NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own + # fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation + # curiosity that may change in the future + kwargs = { + key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value + for key, value in kwargs.items() + } + return fastapi.responses.StreamingResponse( - content=self.infer(*args, **kwargs), + content=self.infer(**kwargs), media_type=self.output_type, ) @@ -66,5 +76,5 @@ class PythonModel(base.BaseModel): def _unload(self) -> None: return self.module.unload(self) - def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]: - return self.module.infer(self, *args, **kwargs) + def _infer(self, **kwargs) -> typing.Iterator[bytes]: + return self.module.infer(self, **kwargs) diff --git a/source/model/base/BaseModel.py b/source/model/base/BaseModel.py index d94dfca..4978843 100644 --- a/source/model/base/BaseModel.py +++ b/source/model/base/BaseModel.py @@ -104,7 +104,7 @@ class BaseModel(abc.ABC): Do not call manually, use `unload` instead. """ - def infer(self, *args, **kwargs) -> typing.Iterator[bytes]: + def infer(self, **kwargs) -> typing.Iterator[bytes]: """ Infer our payload through the model within the model manager :return: the response of the model @@ -114,10 +114,10 @@ class BaseModel(abc.ABC): self.load() # model specific inference part - return self._infer(*args, **kwargs) + return self._infer(**kwargs) @abc.abstractmethod - def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]: + def _infer(self, **kwargs) -> typing.Iterator[bytes]: """ Infer our payload through the model :return: the response of the model diff --git a/source/utils/fastapi.py b/source/utils/fastapi.py new file mode 100644 index 0000000..44df1e3 --- /dev/null +++ b/source/utils/fastapi.py @@ -0,0 +1,39 @@ +""" +fastapi.UploadFile currently have an issue where using an UploadFile and giving it to a StreamingResponse close it. + +Fix from this comment : +https://github.com/fastapi/fastapi/issues/10857#issuecomment-2079878117 +""" + + +from fastapi import UploadFile + + +class UploadFileFix(UploadFile): + """Patches `fastapi.UploadFile` due to buffer close issue. + + See the related github issue: + https://github.com/tiangolo/fastapi/issues/10857 + """ + + def __init__(self, upload_file: UploadFile) -> None: + """Wraps and mutates input `fastapi.UploadFile`. + + Swaps `close` method on the input instance so it's a no-op when called + by the framework. Adds `close` method of input as `_close` here, to be + called later with overridden `close` method. + """ + self.filename = upload_file.filename + self.file = upload_file.file + self.size = upload_file.size + self.headers = upload_file.headers + + _close = upload_file.close + setattr(upload_file, "close", self._close) + setattr(self, "_close", _close) + + async def _close(self) -> None: + pass + + async def close(self) -> None: # noqa: D102 + await self._close() diff --git a/source/utils/parameters.py b/source/utils/parameters.py index 5304bb1..35542ae 100644 --- a/source/utils/parameters.py +++ b/source/utils/parameters.py @@ -1,8 +1,7 @@ import inspect from datetime import datetime -import fastapi - +from fastapi import UploadFile # the list of types and their name that can be used by the API types: dict[str, type] = { @@ -16,7 +15,7 @@ types: dict[str, type] = { "set": set, "dict": dict, "datetime": datetime, - "file": fastapi.UploadFile, + "file": UploadFile, }