added support of inputs parameters that are recognised by the API.

Models are now loaded in separate endpoints for the inputs to be easier to recognise
This commit is contained in:
faraphel 2025-01-09 23:12:54 +01:00
parent 900c58ffcb
commit 7bd84c8570
17 changed files with 163 additions and 128 deletions

View file

@ -4,11 +4,14 @@ import typing
import warnings
from pathlib import Path
from source import model
import fastapi
from source import model, api
class ModelManager:
def __init__(self, model_library: os.PathLike | str):
def __init__(self, application: api.Application, model_library: os.PathLike | str):
self.application: api.Application = application
self.model_library: Path = Path(model_library)
# the model types
@ -20,10 +23,43 @@ class ModelManager:
# TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
def register_model_type(self, name: str, model_type: typing.Type[model.base.BaseModel]):
@self.application.get("/models")
async def get_models() -> list[str]:
"""
Get the list of models available
:return: the list of models available
"""
# list the models found
return list(self.models.keys())
@self.application.get("/models/{model_name}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"):
self.model_types[name] = model_type
def reload(self):
# reset the model list
for model in self.models.values():
model.unload()
self.models.clear()
# load all the models in the library
for model_path in self.model_library.iterdir():
model_name: str = model_path.name
model_configuration_path: Path = model_path / "config.json"