Compare commits
2 commits
c6d779f591
...
1a49aa3779
Author | SHA1 | Date | |
---|---|---|---|
1a49aa3779 | |||
775c78c6cb |
7 changed files with 58 additions and 13 deletions
|
@ -3,6 +3,8 @@
|
||||||
"tags": ["dummy"],
|
"tags": ["dummy"],
|
||||||
"file": "model.py",
|
"file": "model.py",
|
||||||
|
|
||||||
|
"output_type": "video/mp4",
|
||||||
|
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"file": {"type": "file"}
|
"file": {"type": "file"}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,5 +7,5 @@ def load(model) -> None:
|
||||||
def unload(model) -> None:
|
def unload(model) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def infer(model, file) -> typing.Iterator[bytes]:
|
async def infer(model, file) -> typing.AsyncIterator[bytes]:
|
||||||
yield json.dumps({"hello": "world!"}).encode("utf-8")
|
yield await file.read()
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import typing
|
import typing
|
||||||
|
@ -10,6 +11,11 @@ from source import model, api
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
|
"""
|
||||||
|
The model manager
|
||||||
|
Load the list of models available, ensure that only one model is loaded at the same time.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, application: api.Application, model_library: os.PathLike | str):
|
def __init__(self, application: api.Application, model_library: os.PathLike | str):
|
||||||
self.application: api.Application = application
|
self.application: api.Application = application
|
||||||
self.model_library: Path = Path(model_library)
|
self.model_library: Path = Path(model_library)
|
||||||
|
@ -20,9 +26,14 @@ class ModelManager:
|
||||||
self.models: dict[str, model.base.BaseModel] = {}
|
self.models: dict[str, model.base.BaseModel] = {}
|
||||||
|
|
||||||
# the currently loaded model
|
# the currently loaded model
|
||||||
# TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue
|
# TODO(Faraphel): load more than one model at a time ?
|
||||||
|
# would require a way more complex manager to handle memory issue
|
||||||
|
# having two calculations at the same time might not be worth it either
|
||||||
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
|
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
|
||||||
|
|
||||||
|
# lock to avoid concurrent inference and concurrent model loading and unloading
|
||||||
|
self.inference_lock = asyncio.Lock()
|
||||||
|
|
||||||
@self.application.get("/models")
|
@self.application.get("/models")
|
||||||
async def get_models() -> list[str]:
|
async def get_models() -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -50,7 +50,7 @@ class PythonModel(base.BaseModel):
|
||||||
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||||
|
|
||||||
# create an endpoint wrapping the inference inside a fastapi call
|
# create an endpoint wrapping the inference inside a fastapi call
|
||||||
async def infer_api(**kwargs):
|
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||||
|
@ -61,8 +61,12 @@ class PythonModel(base.BaseModel):
|
||||||
}
|
}
|
||||||
|
|
||||||
return fastapi.responses.StreamingResponse(
|
return fastapi.responses.StreamingResponse(
|
||||||
content=self.infer(**kwargs),
|
content=await self.infer(**kwargs),
|
||||||
media_type=self.output_type,
|
media_type=self.output_type,
|
||||||
|
headers={
|
||||||
|
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||||
|
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
infer_api.__signature__ = inspect.Signature(parameters=parameters)
|
infer_api.__signature__ = inspect.Signature(parameters=parameters)
|
||||||
|
@ -73,6 +77,12 @@ class PythonModel(base.BaseModel):
|
||||||
infer_api,
|
infer_api,
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
tags=self.tags,
|
tags=self.tags,
|
||||||
|
# summary=...,
|
||||||
|
# description=...,
|
||||||
|
response_class=fastapi.responses.StreamingResponse,
|
||||||
|
responses={
|
||||||
|
200: {"content": {self.output_type: {}}}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load(self) -> None:
|
def _load(self) -> None:
|
||||||
|
@ -81,5 +91,5 @@ class PythonModel(base.BaseModel):
|
||||||
def _unload(self) -> None:
|
def _unload(self) -> None:
|
||||||
return self.module.unload(self)
|
return self.module.unload(self)
|
||||||
|
|
||||||
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
|
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
|
||||||
return self.module.infer(self, **kwargs)
|
return self.module.infer(self, **kwargs)
|
||||||
|
|
|
@ -106,12 +106,13 @@ class BaseModel(abc.ABC):
|
||||||
Do not call manually, use `unload` instead.
|
Do not call manually, use `unload` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def infer(self, **kwargs) -> typing.Iterator[bytes]:
|
async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
|
||||||
"""
|
"""
|
||||||
Infer our payload through the model within the model manager
|
Infer our payload through the model within the model manager
|
||||||
:return: the response of the model
|
:return: the response of the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async with self.manager.inference_lock:
|
||||||
# make sure we are loaded before an inference
|
# make sure we are loaded before an inference
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
|
@ -119,7 +120,7 @@ class BaseModel(abc.ABC):
|
||||||
return self._infer(**kwargs)
|
return self._infer(**kwargs)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
|
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
|
||||||
"""
|
"""
|
||||||
Infer our payload through the model
|
Infer our payload through the model
|
||||||
:return: the response of the model
|
:return: the response of the model
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
from . import parameters
|
from . import parameters
|
||||||
|
from . import mimetypes
|
||||||
|
|
21
source/utils/mimetypes.py
Normal file
21
source/utils/mimetypes.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
def is_textlike(mimetype: str) -> bool:
|
||||||
|
"""
|
||||||
|
Determinate if a mimetype is considered as holding text
|
||||||
|
:param mimetype: the mimetype to check
|
||||||
|
:return: True if the mimetype represent text, False otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
# check the family of the mimetype
|
||||||
|
if mimetype.startswith("text/"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# check applications formats that are text formatted
|
||||||
|
if mimetype in [
|
||||||
|
"application/xml",
|
||||||
|
"application/json",
|
||||||
|
"application/javascript"
|
||||||
|
]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# otherwise consider the file as non-text
|
||||||
|
return False
|
Loading…
Add table
Add a link
Reference in a new issue