fix: uploaded file were closed automatically before the model could infer their data

This commit is contained in:
faraphel 2025-01-10 11:05:41 +01:00
parent 639425ad7d
commit b89fafdc96
6 changed files with 62 additions and 12 deletions

View file

@ -11,6 +11,7 @@ import fastapi
from source import utils
from source.manager import ModelManager
from source.model import base
from source.utils.fastapi import UploadFileFix
class PythonModel(base.BaseModel):
@ -49,9 +50,18 @@ class PythonModel(base.BaseModel):
parameters = utils.parameters.load(configuration.get("inputs", {}))
# create an endpoint wrapping the inference inside a fastapi call
async def infer_api(*args, **kwargs):
async def infer_api(**kwargs):
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
# curiosity that may change in the future
kwargs = {
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
for key, value in kwargs.items()
}
return fastapi.responses.StreamingResponse(
content=self.infer(*args, **kwargs),
content=self.infer(**kwargs),
media_type=self.output_type,
)
@ -66,5 +76,5 @@ class PythonModel(base.BaseModel):
def _unload(self) -> None:
return self.module.unload(self)
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
return self.module.infer(self, *args, **kwargs)
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
return self.module.infer(self, **kwargs)

View file

@ -104,7 +104,7 @@ class BaseModel(abc.ABC):
Do not call manually, use `unload` instead.
"""
def infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
def infer(self, **kwargs) -> typing.Iterator[bytes]:
"""
Infer our payload through the model within the model manager
:return: the response of the model
@ -114,10 +114,10 @@ class BaseModel(abc.ABC):
self.load()
# model specific inference part
return self._infer(*args, **kwargs)
return self._infer(**kwargs)
@abc.abstractmethod
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
"""
Infer our payload through the model
:return: the response of the model