import sys import traceback import fastapi import pydantic from source.api import Application from source import manager class InferenceRequest(pydantic.BaseModel): """ Represent a request made when inferring a model """ request: dict def load(application: Application, model_manager: manager.ModelManager): @application.get("/models") async def get_models() -> list[str]: """ Get the list of models available :return: the list of models available """ # reload the model list model_manager.reload() # list the models found return list(model_manager.models.keys()) @application.get("/models/{model_name}") async def get_model(model_name: str) -> dict: """ Get information about a specific model :param model_name: the name of the model :return: the information about the corresponding model """ # get the corresponding model model = model_manager.models.get(model_name) if model is None: raise fastapi.HTTPException(status_code=404, detail="Model not found") # return the model information return model.get_information() @application.post("/models/{model_name}/infer") async def infer_model(model_name: str, request: InferenceRequest) -> fastapi.Response: """ Run an inference through the selected model :param model_name: the name of the model :param request: the data to infer to the model :return: the model response """ # get the corresponding model model = model_manager.models.get(model_name) if model is None: raise fastapi.HTTPException(status_code=404, detail="Model not found") # infer the data through the model try: response = model.infer(request.data) except Exception: print(traceback.format_exc(), file=sys.stderr) raise fastapi.HTTPException(status_code=500, detail="An error occurred while inferring the model.") # pack the model response into a fastapi response return fastapi.Response( content=response, media_type=model.response_mimetype, )