fix: uploaded file were closed automatically before the model could infer their data
This commit is contained in:
parent
639425ad7d
commit
b89fafdc96
6 changed files with 62 additions and 12 deletions
|
@ -2,5 +2,7 @@
|
|||
"type": "python",
|
||||
"file": "model.py",
|
||||
|
||||
"inputs": {}
|
||||
"inputs": {
|
||||
"file": {"type": "file"}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,5 +8,5 @@ def load(model) -> None:
|
|||
def unload(model) -> None:
|
||||
pass
|
||||
|
||||
def infer(model) -> typing.Iterator[bytes]:
|
||||
def infer(model, file) -> typing.Iterator[bytes]:
|
||||
yield json.dumps({"hello": "world!"}).encode("utf-8")
|
||||
|
|
|
@ -11,6 +11,7 @@ import fastapi
|
|||
from source import utils
|
||||
from source.manager import ModelManager
|
||||
from source.model import base
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
|
||||
class PythonModel(base.BaseModel):
|
||||
|
@ -49,9 +50,18 @@ class PythonModel(base.BaseModel):
|
|||
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
async def infer_api(*args, **kwargs):
|
||||
async def infer_api(**kwargs):
|
||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||
# curiosity that may change in the future
|
||||
kwargs = {
|
||||
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
||||
for key, value in kwargs.items()
|
||||
}
|
||||
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=self.infer(*args, **kwargs),
|
||||
content=self.infer(**kwargs),
|
||||
media_type=self.output_type,
|
||||
)
|
||||
|
||||
|
@ -66,5 +76,5 @@ class PythonModel(base.BaseModel):
|
|||
def _unload(self) -> None:
|
||||
return self.module.unload(self)
|
||||
|
||||
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
return self.module.infer(self, *args, **kwargs)
|
||||
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
|
||||
return self.module.infer(self, **kwargs)
|
||||
|
|
|
@ -104,7 +104,7 @@ class BaseModel(abc.ABC):
|
|||
Do not call manually, use `unload` instead.
|
||||
"""
|
||||
|
||||
def infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
def infer(self, **kwargs) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model within the model manager
|
||||
:return: the response of the model
|
||||
|
@ -114,10 +114,10 @@ class BaseModel(abc.ABC):
|
|||
self.load()
|
||||
|
||||
# model specific inference part
|
||||
return self._infer(*args, **kwargs)
|
||||
return self._infer(**kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model
|
||||
:return: the response of the model
|
||||
|
|
39
source/utils/fastapi.py
Normal file
39
source/utils/fastapi.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
"""
|
||||
fastapi.UploadFile currently have an issue where using an UploadFile and giving it to a StreamingResponse close it.
|
||||
|
||||
Fix from this comment :
|
||||
https://github.com/fastapi/fastapi/issues/10857#issuecomment-2079878117
|
||||
"""
|
||||
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
|
||||
class UploadFileFix(UploadFile):
|
||||
"""Patches `fastapi.UploadFile` due to buffer close issue.
|
||||
|
||||
See the related github issue:
|
||||
https://github.com/tiangolo/fastapi/issues/10857
|
||||
"""
|
||||
|
||||
def __init__(self, upload_file: UploadFile) -> None:
|
||||
"""Wraps and mutates input `fastapi.UploadFile`.
|
||||
|
||||
Swaps `close` method on the input instance so it's a no-op when called
|
||||
by the framework. Adds `close` method of input as `_close` here, to be
|
||||
called later with overridden `close` method.
|
||||
"""
|
||||
self.filename = upload_file.filename
|
||||
self.file = upload_file.file
|
||||
self.size = upload_file.size
|
||||
self.headers = upload_file.headers
|
||||
|
||||
_close = upload_file.close
|
||||
setattr(upload_file, "close", self._close)
|
||||
setattr(self, "_close", _close)
|
||||
|
||||
async def _close(self) -> None:
|
||||
pass
|
||||
|
||||
async def close(self) -> None: # noqa: D102
|
||||
await self._close()
|
|
@ -1,8 +1,7 @@
|
|||
import inspect
|
||||
from datetime import datetime
|
||||
|
||||
import fastapi
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
# the list of types and their name that can be used by the API
|
||||
types: dict[str, type] = {
|
||||
|
@ -16,7 +15,7 @@ types: dict[str, type] = {
|
|||
"set": set,
|
||||
"dict": dict,
|
||||
"datetime": datetime,
|
||||
"file": fastapi.UploadFile,
|
||||
"file": UploadFile,
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue