added support of inputs parameters that are recognised by the API.

Models are now loaded in separate endpoints for the inputs to be easier to recognise
This commit is contained in:
faraphel 2025-01-09 23:12:54 +01:00
parent 900c58ffcb
commit 7bd84c8570
17 changed files with 163 additions and 128 deletions

View file

@ -2,6 +2,7 @@
fastapi fastapi
uvicorn uvicorn
pydantic pydantic
python-multipart
# AI # AI
accelerate accelerate

View file

@ -1,3 +1,6 @@
{ {
"type": "dummy" "type": "python",
"file": "model.py",
"inputs": {}
} }

View file

@ -0,0 +1,12 @@
import json
import typing
def load(model) -> None:
pass
def unload(model) -> None:
pass
def infer(model) -> typing.Iterator[bytes]:
yield json.dumps({"hello": "world!"}).encode("utf-8")

View file

@ -2,6 +2,10 @@
"type": "python", "type": "python",
"file": "model.py", "file": "model.py",
"inputs": {
"prompt": {"type": "str"}
},
"requirements": [ "requirements": [
"transformers", "transformers",
"torch", "torch",

View file

@ -1,4 +1,5 @@
import json import json
import typing
import torch import torch
import transformers import transformers
@ -7,22 +8,22 @@ import transformers
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D" MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def load(model): def load(model) -> None:
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME) model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
def unload(model): def unload(model) -> None:
model.model = None model.model = None
model.tokenizer = None model.tokenizer = None
def infer(model, payload: dict) -> str: def infer(model, prompt: str) -> typing.Iterator[bytes]:
inputs = model.tokenizer(payload["prompt"], return_tensors="pt") inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
outputs = model.model(**inputs) outputs = model.model(**inputs)
embeddings = outputs.last_hidden_state embeddings = outputs.last_hidden_state
return json.dumps({ yield json.dumps({
"data": embeddings.tolist() "data": embeddings.tolist()
}) }).encode("utf-8")

View file

@ -1,4 +1,5 @@
import json import json
import typing
import torch import torch
import transformers import transformers
@ -7,22 +8,22 @@ import transformers
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D" MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def load(model): def load(model) -> None:
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME) model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
def unload(model): def unload(model) -> None:
model.model = None model.model = None
model.tokenizer = None model.tokenizer = None
def infer(model, payload: dict) -> str: def infer(model, prompt: str) -> typing.Iterator[bytes]:
inputs = model.tokenizer(payload["prompt"], return_tensors="pt") inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
outputs = model.model(**inputs) outputs = model.model(**inputs)
embeddings = outputs.last_hidden_state embeddings = outputs.last_hidden_state
return json.dumps({ yield json.dumps({
"data": embeddings.tolist() "data": embeddings.tolist()
}) }).encode("utf-8")

View file

@ -7,13 +7,9 @@ application = api.Application()
# create the model controller # create the model controller
model_controller = manager.ModelManager(os.environ["MODEL_LIBRARY"]) model_controller = manager.ModelManager(application, os.environ["MODEL_LIBRARY"])
model_controller.register_model_type("dummy", model.DummyModel)
model_controller.register_model_type("python", model.PythonModel) model_controller.register_model_type("python", model.PythonModel)
model_controller.reload() model_controller.reload()
api.route.models.load(application, model_controller)
# serve the application # serve the application
application.serve("0.0.0.0", 8000) application.serve("0.0.0.0", 8000)

View file

@ -1,3 +1 @@
from . import route
from .Application import Application from .Application import Application

View file

@ -1 +0,0 @@
from . import models

View file

@ -1,74 +0,0 @@
import sys
import traceback
import fastapi
import pydantic
from source.api import Application
from source import manager
class InferenceRequest(pydantic.BaseModel):
"""
Represent a request made when inferring a model
"""
request: dict
def load(application: Application, model_manager: manager.ModelManager):
@application.get("/models")
async def get_models() -> list[str]:
"""
Get the list of models available
:return: the list of models available
"""
# reload the model list
model_manager.reload()
# list the models found
return list(model_manager.models.keys())
@application.get("/models/{model_name}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = model_manager.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
@application.post("/models/{model_name}/infer")
async def infer_model(model_name: str, request: InferenceRequest) -> fastapi.Response:
"""
Run an inference through the selected model
:param model_name: the name of the model
:param request: the data to infer to the model
:return: the model response
"""
# get the corresponding model
model = model_manager.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# infer the data through the model
try:
response = model.infer(request.request)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
raise fastapi.HTTPException(status_code=500, detail="An error occurred while inferring the model.")
# pack the model response into a fastapi response
return fastapi.Response(
content=response,
media_type=model.response_mimetype,
)

View file

@ -4,11 +4,14 @@ import typing
import warnings import warnings
from pathlib import Path from pathlib import Path
from source import model import fastapi
from source import model, api
class ModelManager: class ModelManager:
def __init__(self, model_library: os.PathLike | str): def __init__(self, application: api.Application, model_library: os.PathLike | str):
self.application: api.Application = application
self.model_library: Path = Path(model_library) self.model_library: Path = Path(model_library)
# the model types # the model types
@ -20,10 +23,43 @@ class ModelManager:
# TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue # TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
def register_model_type(self, name: str, model_type: typing.Type[model.base.BaseModel]): @self.application.get("/models")
async def get_models() -> list[str]:
"""
Get the list of models available
:return: the list of models available
"""
# list the models found
return list(self.models.keys())
@self.application.get("/models/{model_name}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"):
self.model_types[name] = model_type self.model_types[name] = model_type
def reload(self): def reload(self):
# reset the model list
for model in self.models.values():
model.unload()
self.models.clear()
# load all the models in the library
for model_path in self.model_library.iterdir(): for model_path in self.model_library.iterdir():
model_name: str = model_path.name model_name: str = model_path.name
model_configuration_path: Path = model_path / "config.json" model_configuration_path: Path = model_path / "config.json"

View file

@ -1,19 +0,0 @@
import json
from source.model import base
class DummyModel(base.BaseModel):
"""
A dummy model, mainly used to test the API and the manager.
simply send back the request made to it.
"""
def _load(self) -> None:
pass
def _unload(self) -> None:
pass
def _infer(self, payload: dict) -> str | bytes:
return json.dumps(payload)

View file

@ -1,9 +1,14 @@
import importlib.util import importlib.util
import subprocess import subprocess
import sys import sys
import typing
import uuid import uuid
import inspect
from pathlib import Path from pathlib import Path
import fastapi
from source import utils
from source.manager import ModelManager from source.manager import ModelManager
from source.model import base from source.model import base
@ -16,6 +21,8 @@ class PythonModel(base.BaseModel):
def __init__(self, manager: ModelManager, configuration: dict, path: Path): def __init__(self, manager: ModelManager, configuration: dict, path: Path):
super().__init__(manager, configuration, path) super().__init__(manager, configuration, path)
## Configuration
# get the name of the file containing the model code # get the name of the file containing the model code
file = configuration.get("file") file = configuration.get("file")
if file is None: if file is None:
@ -36,11 +43,28 @@ class PythonModel(base.BaseModel):
# load the module # load the module
module_spec.loader.exec_module(self.module) module_spec.loader.exec_module(self.module)
## Api
# load the inputs data into the inference function signature (used by FastAPI)
parameters = utils.parameters.load(configuration.get("inputs", {}))
# create an endpoint wrapping the inference inside a fastapi call
async def infer_api(*args, **kwargs):
return fastapi.responses.StreamingResponse(
content=self.infer(*args, **kwargs),
media_type=self.output_type,
)
infer_api.__signature__ = inspect.Signature(parameters=parameters)
# add the inference endpoint on the API
self.manager.application.add_api_route(f"/models/{self.name}/infer", infer_api, methods=["POST"])
def _load(self) -> None: def _load(self) -> None:
return self.module.load(self) return self.module.load(self)
def _unload(self) -> None: def _unload(self) -> None:
return self.module.unload(self) return self.module.unload(self)
def _infer(self, payload: dict) -> str | bytes: def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
return self.module.infer(self, payload) return self.module.infer(self, *args, **kwargs)

View file

@ -1,4 +1,3 @@
from . import base from . import base
from .DummyModel import DummyModel
from .PythonModel import PythonModel from .PythonModel import PythonModel

View file

@ -1,7 +1,9 @@
import abc import abc
import gc import gc
import typing
from pathlib import Path from pathlib import Path
from source import api
from source.manager import ModelManager from source.manager import ModelManager
@ -10,13 +12,13 @@ class BaseModel(abc.ABC):
Represent a model. Represent a model.
""" """
def __init__(self, manager: ModelManager, configuration: dict, path: Path): def __init__(self, manager: ModelManager, configuration: dict[str, typing.Any], path: Path):
# the environment directory of the model # the environment directory of the model
self.path = path self.path = path
# the model manager # the model manager
self.manager = manager self.manager = manager
# the mimetype of the model responses # the mimetype of the model responses
self.response_mimetype: str = configuration.get("response_mimetype", "application/json") self.output_type: str = configuration.get("output_type", "application/json")
self._loaded = False self._loaded = False
@ -101,13 +103,11 @@ class BaseModel(abc.ABC):
""" """
Unload the model Unload the model
Do not call manually, use `unload` instead. Do not call manually, use `unload` instead.
:return:
""" """
def infer(self, payload: dict) -> str | bytes: def infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
""" """
Infer our payload through the model within the model manager Infer our payload through the model within the model manager
:param payload: the payload to give to the model
:return: the response of the model :return: the response of the model
""" """
@ -115,12 +115,11 @@ class BaseModel(abc.ABC):
self.load() self.load()
# model specific inference part # model specific inference part
return self._infer(payload) return self._infer(*args, **kwargs)
@abc.abstractmethod @abc.abstractmethod
def _infer(self, payload: dict) -> str | bytes: def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
""" """
Infer our payload through the model Infer our payload through the model
:param payload: the payload to give to the model
:return: the response of the model :return: the response of the model
""" """

1
source/utils/__init__.py Normal file
View file

@ -0,0 +1 @@
from . import parameters

View file

@ -0,0 +1,54 @@
import inspect
from datetime import datetime
import fastapi
# the list of types and their name that can be used by the API
types: dict[str, type] = {
"bool": bool,
"int": int,
"float": float,
"str": str,
"bytes": bytes,
"list": list,
"tuple": tuple,
"set": set,
"dict": dict,
"datetime": datetime,
"file": fastapi.UploadFile,
}
def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
"""
Load a list python function parameters from their definitions.
:param parameters_definition: the definitions of the parameters
:return: the python function parameters
Examples:
>>> parameters_definition = {
... "boolean": {"type": "bool", "default": False},
... "list": {"type": "list", "default": [1, 2, 3]},
... "datetime": {"type": "datetime"},
... "file": {"type": "file"},
... }
>>> parameters = load_parameters(parameters_definition)
"""
parameters: list[inspect.Parameter] = []
for name, definition in parameters_definition.items():
# deserialize the parameter
parameter = inspect.Parameter(
name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=definition.get("default", inspect.Parameter.empty),
annotation=types[definition["type"]],
)
parameters.append(parameter)
# sort the parameters so that non-default arguments always end up before default ones
parameters.sort(key=lambda parameter: parameter.default is inspect.Parameter.empty, reverse=True)
return parameters