added support of inputs parameters that are recognised by the API.
Models are now loaded in separate endpoints for the inputs to be easier to recognise
This commit is contained in:
parent
900c58ffcb
commit
7bd84c8570
17 changed files with 163 additions and 128 deletions
|
@ -1,3 +1 @@
|
|||
from . import route
|
||||
|
||||
from .Application import Application
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from . import models
|
|
@ -1,74 +0,0 @@
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import fastapi
|
||||
import pydantic
|
||||
|
||||
from source.api import Application
|
||||
from source import manager
|
||||
|
||||
|
||||
class InferenceRequest(pydantic.BaseModel):
|
||||
"""
|
||||
Represent a request made when inferring a model
|
||||
"""
|
||||
|
||||
request: dict
|
||||
|
||||
|
||||
def load(application: Application, model_manager: manager.ModelManager):
|
||||
@application.get("/models")
|
||||
async def get_models() -> list[str]:
|
||||
"""
|
||||
Get the list of models available
|
||||
:return: the list of models available
|
||||
"""
|
||||
|
||||
# reload the model list
|
||||
model_manager.reload()
|
||||
# list the models found
|
||||
return list(model_manager.models.keys())
|
||||
|
||||
@application.get("/models/{model_name}")
|
||||
async def get_model(model_name: str) -> dict:
|
||||
"""
|
||||
Get information about a specific model
|
||||
:param model_name: the name of the model
|
||||
:return: the information about the corresponding model
|
||||
"""
|
||||
|
||||
# get the corresponding model
|
||||
model = model_manager.models.get(model_name)
|
||||
if model is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# return the model information
|
||||
return model.get_information()
|
||||
|
||||
|
||||
@application.post("/models/{model_name}/infer")
|
||||
async def infer_model(model_name: str, request: InferenceRequest) -> fastapi.Response:
|
||||
"""
|
||||
Run an inference through the selected model
|
||||
:param model_name: the name of the model
|
||||
:param request: the data to infer to the model
|
||||
:return: the model response
|
||||
"""
|
||||
|
||||
# get the corresponding model
|
||||
model = model_manager.models.get(model_name)
|
||||
if model is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# infer the data through the model
|
||||
try:
|
||||
response = model.infer(request.request)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise fastapi.HTTPException(status_code=500, detail="An error occurred while inferring the model.")
|
||||
|
||||
# pack the model response into a fastapi response
|
||||
return fastapi.Response(
|
||||
content=response,
|
||||
media_type=model.response_mimetype,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue