ai-server/source/manager/ModelManager.py
faraphel 7bd84c8570 added support of inputs parameters that are recognised by the API.
Models are now loaded in separate endpoints for the inputs to be easier to recognise
2025-01-09 23:12:54 +01:00

88 lines
3.1 KiB
Python

import json
import os
import typing
import warnings
from pathlib import Path
import fastapi
from source import model, api
class ModelManager:
def __init__(self, application: api.Application, model_library: os.PathLike | str):
self.application: api.Application = application
self.model_library: Path = Path(model_library)
# the model types
self.model_types: dict[str, typing.Type[model.base.BaseModel]] = {}
# the models
self.models: dict[str, model.base.BaseModel] = {}
# the currently loaded model
# TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
@self.application.get("/models")
async def get_models() -> list[str]:
"""
Get the list of models available
:return: the list of models available
"""
# list the models found
return list(self.models.keys())
@self.application.get("/models/{model_name}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"):
self.model_types[name] = model_type
def reload(self):
# reset the model list
for model in self.models.values():
model.unload()
self.models.clear()
# load all the models in the library
for model_path in self.model_library.iterdir():
model_name: str = model_path.name
model_configuration_path: Path = model_path / "config.json"
# check if the configuration file exists
if not model_configuration_path.exists():
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
continue
# load the configuration file
model_configuration = json.loads(model_configuration_path.read_text())
# get the model type for this model
model_type_name: str = model_configuration.get("type")
if model_type_name not in self.model_types:
warnings.warn("Field 'type' missing from the model configuration file.")
continue
# get the class of this model type
model_type = self.model_types.get(model_type_name)
if model_type is None:
warnings.warn(f"Model type {model_type_name!r} does not exists. Has it been registered ?")
continue
# load the model
self.models[model_name] = model_type(self, model_configuration, model_path)