fix: uploaded file were closed automatically before the model could infer their data
This commit is contained in:
parent
639425ad7d
commit
b89fafdc96
6 changed files with 62 additions and 12 deletions
|
@ -11,6 +11,7 @@ import fastapi
|
|||
from source import utils
|
||||
from source.manager import ModelManager
|
||||
from source.model import base
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
|
||||
class PythonModel(base.BaseModel):
|
||||
|
@ -49,9 +50,18 @@ class PythonModel(base.BaseModel):
|
|||
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
async def infer_api(*args, **kwargs):
|
||||
async def infer_api(**kwargs):
|
||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||
# curiosity that may change in the future
|
||||
kwargs = {
|
||||
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
||||
for key, value in kwargs.items()
|
||||
}
|
||||
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=self.infer(*args, **kwargs),
|
||||
content=self.infer(**kwargs),
|
||||
media_type=self.output_type,
|
||||
)
|
||||
|
||||
|
@ -66,5 +76,5 @@ class PythonModel(base.BaseModel):
|
|||
def _unload(self) -> None:
|
||||
return self.module.unload(self)
|
||||
|
||||
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
return self.module.infer(self, *args, **kwargs)
|
||||
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
|
||||
return self.module.infer(self, **kwargs)
|
||||
|
|
|
@ -104,7 +104,7 @@ class BaseModel(abc.ABC):
|
|||
Do not call manually, use `unload` instead.
|
||||
"""
|
||||
|
||||
def infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
def infer(self, **kwargs) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model within the model manager
|
||||
:return: the response of the model
|
||||
|
@ -114,10 +114,10 @@ class BaseModel(abc.ABC):
|
|||
self.load()
|
||||
|
||||
# model specific inference part
|
||||
return self._infer(*args, **kwargs)
|
||||
return self._infer(**kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model
|
||||
:return: the response of the model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue