74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
import sys
|
|
import traceback
|
|
|
|
import fastapi
|
|
import pydantic
|
|
|
|
from source.api import Application
|
|
from source import manager
|
|
|
|
|
|
class InferenceRequest(pydantic.BaseModel):
|
|
"""
|
|
Represent a request made when inferring a model
|
|
"""
|
|
|
|
request: dict
|
|
|
|
|
|
def load(application: Application, model_manager: manager.ModelManager):
|
|
@application.get("/models")
|
|
async def get_models() -> list[str]:
|
|
"""
|
|
Get the list of models available
|
|
:return: the list of models available
|
|
"""
|
|
|
|
# reload the model list
|
|
model_manager.reload()
|
|
# list the models found
|
|
return list(model_manager.models.keys())
|
|
|
|
@application.get("/models/{model_name}")
|
|
async def get_model(model_name: str) -> dict:
|
|
"""
|
|
Get information about a specific model
|
|
:param model_name: the name of the model
|
|
:return: the information about the corresponding model
|
|
"""
|
|
|
|
# get the corresponding model
|
|
model = model_manager.models.get(model_name)
|
|
if model is None:
|
|
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
|
|
|
# return the model information
|
|
return model.get_information()
|
|
|
|
|
|
@application.post("/models/{model_name}/infer")
|
|
async def infer_model(model_name: str, request: InferenceRequest) -> fastapi.Response:
|
|
"""
|
|
Run an inference through the selected model
|
|
:param model_name: the name of the model
|
|
:param request: the data to infer to the model
|
|
:return: the model response
|
|
"""
|
|
|
|
# get the corresponding model
|
|
model = model_manager.models.get(model_name)
|
|
if model is None:
|
|
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
|
|
|
# infer the data through the model
|
|
try:
|
|
response = model.infer(request.data)
|
|
except Exception:
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
raise fastapi.HTTPException(status_code=500, detail="An error occurred while inferring the model.")
|
|
|
|
# pack the model response into a fastapi response
|
|
return fastapi.Response(
|
|
content=response,
|
|
media_type=model.response_mimetype,
|
|
)
|