added support for additional more user-friendly interfaces, improved some part of the application loading process to make it a bit simpler
This commit is contained in:
parent
1a49aa3779
commit
f647c960dd
20 changed files with 353 additions and 107 deletions
|
@ -2,6 +2,7 @@
|
|||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
gradio
|
||||
python-multipart
|
||||
|
||||
# AI
|
||||
|
|
|
@ -2,10 +2,14 @@
|
|||
"type": "python",
|
||||
"tags": ["dummy"],
|
||||
"file": "model.py",
|
||||
"interface": "chat",
|
||||
|
||||
"output_type": "video/mp4",
|
||||
"summary": "Echo model",
|
||||
"description": "The most basic example model, simply echo the input",
|
||||
|
||||
"inputs": {
|
||||
"file": {"type": "file"}
|
||||
}
|
||||
"messages": {"type": "list[dict]", "default": "[{\"role\": \"user\", \"content\": \"who are you ?\"}]"}
|
||||
},
|
||||
|
||||
"output_type": "text/markdown"
|
||||
}
|
||||
|
|
|
@ -7,5 +7,5 @@ def load(model) -> None:
|
|||
def unload(model) -> None:
|
||||
pass
|
||||
|
||||
async def infer(model, file) -> typing.AsyncIterator[bytes]:
|
||||
yield await file.read()
|
||||
async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]:
|
||||
yield messages[-1]["content"].encode("utf-8")
|
||||
|
|
|
@ -16,7 +16,7 @@ def unload(model) -> None:
|
|||
model.model = None
|
||||
model.tokenizer = None
|
||||
|
||||
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
||||
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -16,7 +16,7 @@ def unload(model) -> None:
|
|||
model.model = None
|
||||
model.tokenizer = None
|
||||
|
||||
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
||||
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from . import api
|
||||
from . import model
|
||||
from . import manager
|
||||
from . import registry
|
||||
|
|
|
@ -1,15 +1,25 @@
|
|||
import os
|
||||
|
||||
from source import manager, model, api
|
||||
from source import registry, model, api
|
||||
from source.api import interface
|
||||
|
||||
# create a fastapi application
|
||||
application = api.Application()
|
||||
|
||||
|
||||
# create the model controller
|
||||
model_controller = manager.ModelManager(application, os.environ["MODEL_LIBRARY"])
|
||||
model_controller.register_model_type("python", model.PythonModel)
|
||||
model_controller.reload()
|
||||
# create the interface registry
|
||||
interface_registry = registry.InterfaceRegistry()
|
||||
interface_registry.register_type("chat", interface.ChatInterface)
|
||||
|
||||
|
||||
# create the model registry
|
||||
model_registry = registry.ModelRegistry(os.environ["MODEL_LIBRARY"], "/models", interface_registry)
|
||||
model_registry.register_type("python", model.PythonModel)
|
||||
model_registry.reload_models()
|
||||
|
||||
# add the model registry routes to the fastapi
|
||||
model_registry.mount(application)
|
||||
|
||||
|
||||
# serve the application
|
||||
application.serve("0.0.0.0", 8000)
|
||||
|
|
|
@ -8,7 +8,8 @@ class Application(fastapi.FastAPI):
|
|||
def __init__(self):
|
||||
super().__init__(
|
||||
title=meta.name,
|
||||
description=meta.description
|
||||
description=meta.description,
|
||||
redoc_url=None,
|
||||
)
|
||||
|
||||
def serve(self, host: str = "0.0.0.0", port: int = 8080):
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
from . import interface
|
||||
|
||||
from .Application import Application
|
||||
|
|
75
source/api/interface/ChatInterface.py
Normal file
75
source/api/interface/ChatInterface.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import textwrap
|
||||
|
||||
import gradio
|
||||
|
||||
from source import meta
|
||||
from source.api.interface import base
|
||||
from source.model.base import BaseModel
|
||||
|
||||
|
||||
class ChatInterface(base.BaseInterface):
|
||||
"""
|
||||
An interface for Chat-like models.
|
||||
Use the OpenAI convention (list of dict with roles and content)
|
||||
"""
|
||||
|
||||
def __init__(self, model: "BaseModel"):
|
||||
# Function to send and receive chat messages
|
||||
super().__init__(model)
|
||||
|
||||
async def send_message(self, user_message, old_messages: list[dict], system_message: str):
|
||||
# normalize the user message (the type can be wrong, especially when "edited")
|
||||
if isinstance(user_message, str):
|
||||
user_message: dict = {"files": [], "text": user_message}
|
||||
|
||||
# copy the history to avoid modifying it
|
||||
messages: list[dict] = old_messages.copy()
|
||||
|
||||
# add the system instruction
|
||||
if system_message:
|
||||
messages.insert(0, {"role": "system", "content": system_message})
|
||||
|
||||
# add the user message
|
||||
# NOTE: gradio.ChatInterface add our message and the assistant message
|
||||
# TODO(Faraphel): add support for files
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": user_message["text"],
|
||||
})
|
||||
|
||||
# infer the message through the model
|
||||
chunks = [chunk async for chunk in await self.model.infer(messages=messages)]
|
||||
assistant_message: str = b"".join(chunks).decode("utf-8")
|
||||
|
||||
# send back the messages, clear the user prompt, disable the system prompt
|
||||
return assistant_message
|
||||
|
||||
def get_gradio_application(self):
|
||||
# create a gradio interface
|
||||
with gradio.Blocks(analytics_enabled=False) as application:
|
||||
# header
|
||||
gradio.Markdown(textwrap.dedent(f"""
|
||||
# {meta.name}
|
||||
## {self.model.name}
|
||||
"""))
|
||||
|
||||
# additional settings
|
||||
with gradio.Accordion("Advanced Settings") as advanced_settings:
|
||||
system_prompt = gradio.Textbox(
|
||||
label="System prompt",
|
||||
placeholder="You are an expert in C++...",
|
||||
lines=2,
|
||||
)
|
||||
|
||||
# chat interface
|
||||
gradio.ChatInterface(
|
||||
fn=self.send_message,
|
||||
type="messages",
|
||||
multimodal=True,
|
||||
editable=True,
|
||||
save_history=True,
|
||||
additional_inputs=[system_prompt],
|
||||
additional_inputs_accordion=advanced_settings,
|
||||
)
|
||||
|
||||
return application
|
3
source/api/interface/__init__.py
Normal file
3
source/api/interface/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from . import base
|
||||
|
||||
from .ChatInterface import ChatInterface
|
40
source/api/interface/base/BaseInterface.py
Normal file
40
source/api/interface/base/BaseInterface.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import abc
|
||||
|
||||
import fastapi
|
||||
import gradio
|
||||
|
||||
import source
|
||||
|
||||
|
||||
class BaseInterface(abc.ABC):
|
||||
def __init__(self, model: "source.model.base.BaseModel"):
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def route(self) -> str:
|
||||
"""
|
||||
The route to the interface
|
||||
:return: the route to the interface
|
||||
"""
|
||||
|
||||
return f"{self.model.api_base}/interface"
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_gradio_application(self) -> gradio.Blocks:
|
||||
"""
|
||||
Get a gradio application
|
||||
:return: a gradio application
|
||||
"""
|
||||
|
||||
def mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Mount the interface on an application
|
||||
:param application: the application to mount the interface on
|
||||
:param path: the path where to mount the application
|
||||
"""
|
||||
|
||||
gradio.mount_gradio_app(
|
||||
application,
|
||||
self.get_gradio_application(),
|
||||
self.route
|
||||
)
|
1
source/api/interface/base/__init__.py
Normal file
1
source/api/interface/base/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .BaseInterface import BaseInterface
|
|
@ -1 +0,0 @@
|
|||
from .ModelManager import ModelManager
|
|
@ -9,8 +9,8 @@ from pathlib import Path
|
|||
import fastapi
|
||||
|
||||
from source import utils
|
||||
from source.manager import ModelManager
|
||||
from source.model import base
|
||||
from source.registry import ModelRegistry
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
|
||||
|
@ -19,21 +19,22 @@ class PythonModel(base.BaseModel):
|
|||
A model running a custom python model.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
|
||||
super().__init__(manager, configuration, path)
|
||||
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
||||
super().__init__(registry, configuration, path)
|
||||
|
||||
## Configuration
|
||||
|
||||
# get the name of the file containing the model code
|
||||
file = configuration.get("file")
|
||||
if file is None:
|
||||
raise ValueError("Field 'file' is missing from the configuration")
|
||||
# get the parameters of the model
|
||||
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
|
||||
# install custom requirements
|
||||
requirements = configuration.get("requirements", [])
|
||||
if len(requirements) > 0:
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
||||
|
||||
# get the name of the file containing the model code
|
||||
file = configuration.get("file")
|
||||
if file is None:
|
||||
raise ValueError("Field 'file' is missing from the configuration")
|
||||
|
||||
# create the module specification
|
||||
module_spec = importlib.util.spec_from_file_location(
|
||||
f"model-{uuid.uuid4()}",
|
||||
|
@ -44,10 +45,17 @@ class PythonModel(base.BaseModel):
|
|||
# load the module
|
||||
module_spec.loader.exec_module(self.module)
|
||||
|
||||
## Api
|
||||
def _load(self) -> None:
|
||||
return self.module.load(self)
|
||||
|
||||
# load the inputs data into the inference function signature (used by FastAPI)
|
||||
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
def _unload(self) -> None:
|
||||
return self.module.unload(self)
|
||||
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
return self.module.infer(self, **kwargs)
|
||||
|
||||
def _mount(self, application: fastapi.FastAPI):
|
||||
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
||||
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||
|
@ -61,7 +69,7 @@ class PythonModel(base.BaseModel):
|
|||
}
|
||||
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=await self.infer(**kwargs),
|
||||
content=await self.registry.infer_model(self, **kwargs),
|
||||
media_type=self.output_type,
|
||||
headers={
|
||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||
|
@ -69,27 +77,25 @@ class PythonModel(base.BaseModel):
|
|||
}
|
||||
)
|
||||
|
||||
infer_api.__signature__ = inspect.Signature(parameters=parameters)
|
||||
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||
|
||||
# format the description
|
||||
description_sections: list[str] = []
|
||||
if self.description is not None:
|
||||
description_sections.append(self.description)
|
||||
if self.interface is not None:
|
||||
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**")
|
||||
|
||||
# add the inference endpoint on the API
|
||||
self.manager.application.add_api_route(
|
||||
f"/models/{self.name}/infer",
|
||||
application.add_api_route(
|
||||
f"{self.api_base}/infer",
|
||||
infer_api,
|
||||
methods=["POST"],
|
||||
tags=self.tags,
|
||||
# summary=...,
|
||||
# description=...,
|
||||
summary=self.summary,
|
||||
description="<br>".join(description_sections),
|
||||
response_class=fastapi.responses.StreamingResponse,
|
||||
responses={
|
||||
200: {"content": {self.output_type: {}}}
|
||||
},
|
||||
)
|
||||
|
||||
def _load(self) -> None:
|
||||
return self.module.load(self)
|
||||
|
||||
def _unload(self) -> None:
|
||||
return self.module.unload(self)
|
||||
|
||||
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
|
||||
return self.module.infer(self, **kwargs)
|
||||
|
|
|
@ -3,7 +3,9 @@ import gc
|
|||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
from source.manager import ModelManager
|
||||
import fastapi
|
||||
|
||||
from source.registry import ModelRegistry
|
||||
|
||||
|
||||
class BaseModel(abc.ABC):
|
||||
|
@ -11,21 +13,43 @@ class BaseModel(abc.ABC):
|
|||
Represent a model.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: ModelManager, configuration: dict[str, typing.Any], path: Path):
|
||||
def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
|
||||
# the model registry
|
||||
self.registry = registry
|
||||
|
||||
# get the documentation of the model
|
||||
self.summary = configuration.get("summary")
|
||||
self.description = configuration.get("description")
|
||||
|
||||
# the environment directory of the model
|
||||
self.path = path
|
||||
# the model manager
|
||||
self.manager = manager
|
||||
# the mimetype of the model responses
|
||||
self.output_type: str = configuration.get("output_type", "application/json")
|
||||
# get the tags of the model
|
||||
self.tags = configuration.get("tags", [])
|
||||
|
||||
# get the selected interface of the model
|
||||
interface_name: typing.Optional[str] = configuration.get("interface", None)
|
||||
self.interface = (
|
||||
self.registry.interface_registry.interface_types[interface_name](self)
|
||||
if interface_name is not None else None
|
||||
)
|
||||
|
||||
# is the model currently loaded
|
||||
self._loaded = False
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}: {self.name}>"
|
||||
|
||||
@property
|
||||
def api_base(self) -> str:
|
||||
"""
|
||||
Base for the API routes
|
||||
:return: the base for the API routes
|
||||
"""
|
||||
|
||||
return f"{self.registry.api_base}/{self.name}"
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
|
@ -44,6 +68,7 @@ class BaseModel(abc.ABC):
|
|||
return {
|
||||
"name": self.name,
|
||||
"output_type": self.output_type,
|
||||
"tags": self.tags
|
||||
}
|
||||
|
||||
def load(self) -> None:
|
||||
|
@ -51,22 +76,13 @@ class BaseModel(abc.ABC):
|
|||
Load the model within the model manager
|
||||
"""
|
||||
|
||||
# if we are already loaded, stop
|
||||
# if the model is already loaded, skip
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
# check if we are the current loaded model
|
||||
if self.manager.current_loaded_model is not self:
|
||||
# unload the previous model
|
||||
if self.manager.current_loaded_model is not None:
|
||||
self.manager.current_loaded_model.unload()
|
||||
|
||||
# model specific loading
|
||||
# load the model depending on the implementation
|
||||
self._load()
|
||||
|
||||
# declare ourselves as the currently loaded model
|
||||
self.manager.current_loaded_model = self
|
||||
|
||||
# mark the model as loaded
|
||||
self._loaded = True
|
||||
|
||||
|
@ -86,11 +102,7 @@ class BaseModel(abc.ABC):
|
|||
if not self._loaded:
|
||||
return
|
||||
|
||||
# if we were the currently loaded model of the manager, demote ourselves
|
||||
if self.manager.current_loaded_model is self:
|
||||
self.manager.current_loaded_model = None
|
||||
|
||||
# model specific unloading part
|
||||
# unload the model depending on the implementation
|
||||
self._unload()
|
||||
|
||||
# force the garbage collector to clean the memory
|
||||
|
@ -106,22 +118,42 @@ class BaseModel(abc.ABC):
|
|||
Do not call manually, use `unload` instead.
|
||||
"""
|
||||
|
||||
async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
|
||||
async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model within the model manager
|
||||
:return: the response of the model
|
||||
"""
|
||||
|
||||
async with self.manager.inference_lock:
|
||||
# make sure we are loaded before an inference
|
||||
self.load()
|
||||
# make sure we are loaded before an inference
|
||||
self.load()
|
||||
|
||||
# model specific inference part
|
||||
return self._infer(**kwargs)
|
||||
# model specific inference part
|
||||
return await self._infer(**kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model
|
||||
:return: the response of the model
|
||||
"""
|
||||
|
||||
def mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Add the model to the api
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
|
||||
# mount the interface if selected
|
||||
if self.interface is not None:
|
||||
self.interface.mount(application)
|
||||
|
||||
# implementation specific mount
|
||||
self._mount(application)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Add the model to the api
|
||||
Do not call manually, use `unload` instead.
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
|
|
16
source/registry/InterfaceRegistry.py
Normal file
16
source/registry/InterfaceRegistry.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import typing
|
||||
|
||||
from source.api.interface import base
|
||||
|
||||
|
||||
class InterfaceRegistry:
|
||||
"""
|
||||
The interface registry
|
||||
Store the list of other interface available
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.interface_types: dict[str, typing.Type[base.BaseInterface]] = {}
|
||||
|
||||
def register_type(self, name: str, interface_type: typing.Type[base.BaseInterface]):
|
||||
self.interface_types[name] = interface_type
|
|
@ -7,64 +7,80 @@ from pathlib import Path
|
|||
|
||||
import fastapi
|
||||
|
||||
from source import model, api
|
||||
from source.model.base import BaseModel
|
||||
from source.registry import InterfaceRegistry
|
||||
|
||||
|
||||
class ModelManager:
|
||||
class ModelRegistry:
|
||||
"""
|
||||
The model manager
|
||||
The model registry
|
||||
Load the list of models available, ensure that only one model is loaded at the same time.
|
||||
"""
|
||||
|
||||
def __init__(self, application: api.Application, model_library: os.PathLike | str):
|
||||
self.application: api.Application = application
|
||||
def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry):
|
||||
self.model_library: Path = Path(model_library)
|
||||
self.interface_registry = interface_registry
|
||||
self._api_base = api_base
|
||||
|
||||
# the model types
|
||||
self.model_types: dict[str, typing.Type[model.base.BaseModel]] = {}
|
||||
self.model_types: dict[str, typing.Type[BaseModel]] = {}
|
||||
# the models
|
||||
self.models: dict[str, model.base.BaseModel] = {}
|
||||
self.models: dict[str, BaseModel] = {}
|
||||
|
||||
# the currently loaded model
|
||||
# TODO(Faraphel): load more than one model at a time ?
|
||||
# would require a way more complex manager to handle memory issue
|
||||
# having two calculations at the same time might not be worth it either
|
||||
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
|
||||
self.current_loaded_model: typing.Optional[BaseModel] = None
|
||||
|
||||
# lock to avoid concurrent inference and concurrent model loading and unloading
|
||||
self.inference_lock = asyncio.Lock()
|
||||
|
||||
@self.application.get("/models")
|
||||
async def get_models() -> list[str]:
|
||||
"""
|
||||
Get the list of models available
|
||||
:return: the list of models available
|
||||
"""
|
||||
@property
|
||||
def api_base(self) -> str:
|
||||
"""
|
||||
Base for the api routes
|
||||
:return: the base for the api routes
|
||||
"""
|
||||
|
||||
# list the models found
|
||||
return list(self.models.keys())
|
||||
return self._api_base
|
||||
|
||||
@self.application.get("/models/{model_name}")
|
||||
async def get_model(model_name: str) -> dict:
|
||||
"""
|
||||
Get information about a specific model
|
||||
:param model_name: the name of the model
|
||||
:return: the information about the corresponding model
|
||||
"""
|
||||
|
||||
# get the corresponding model
|
||||
model = self.models.get(model_name)
|
||||
if model is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# return the model information
|
||||
return model.get_information()
|
||||
|
||||
|
||||
def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"):
|
||||
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
|
||||
self.model_types[name] = model_type
|
||||
|
||||
def reload(self):
|
||||
async def load_model(self, model: "BaseModel"):
|
||||
# lock to avoid concurrent loading
|
||||
async with self.inference_lock:
|
||||
# if there is another currently loaded model, unload it
|
||||
if self.current_loaded_model is not None and self.current_loaded_model is not model:
|
||||
await self.unload_model(self.current_loaded_model)
|
||||
|
||||
# load the model
|
||||
model.load()
|
||||
|
||||
# mark the model as the currently loaded model of the manager
|
||||
self.current_loaded_model = model
|
||||
|
||||
async def unload_model(self, model: "BaseModel"):
|
||||
# lock to avoid concurrent unloading
|
||||
async with self.inference_lock:
|
||||
# if we were the currently loaded model of the manager, demote ourselves
|
||||
if self.current_loaded_model is model:
|
||||
self.current_loaded_model = None
|
||||
|
||||
# model specific unloading part
|
||||
model.unload()
|
||||
|
||||
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
# lock to avoid concurrent inference
|
||||
async with self.inference_lock:
|
||||
return await model.infer(**kwargs)
|
||||
|
||||
def reload_models(self) -> None:
|
||||
"""
|
||||
Reload the list of models available
|
||||
"""
|
||||
|
||||
# reset the model list
|
||||
for model in self.models.values():
|
||||
model.unload()
|
||||
|
@ -97,3 +113,39 @@ class ModelManager:
|
|||
|
||||
# load the model
|
||||
self.models[model_name] = model_type(self, model_configuration, model_path)
|
||||
|
||||
def mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Mount the models endpoints into a fastapi application
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
|
||||
@application.get(self.api_base)
|
||||
async def get_models() -> list[str]:
|
||||
"""
|
||||
Get the list of models available
|
||||
:return: the list of models available
|
||||
"""
|
||||
|
||||
# list the models found
|
||||
return list(self.models.keys())
|
||||
|
||||
@application.get(f"{self.api_base}/{{model_name}}")
|
||||
async def get_model(model_name: str) -> dict:
|
||||
"""
|
||||
Get information about a specific model
|
||||
:param model_name: the name of the model
|
||||
:return: the information about the corresponding model
|
||||
"""
|
||||
|
||||
# get the corresponding model
|
||||
model = self.models.get(model_name)
|
||||
if model is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# return the model information
|
||||
return model.get_information()
|
||||
|
||||
# mount all the models in the registry
|
||||
for model_name, model in self.models.items():
|
||||
model.mount(application)
|
2
source/registry/__init__.py
Normal file
2
source/registry/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .ModelRegistry import ModelRegistry
|
||||
from .InterfaceRegistry import InterfaceRegistry
|
|
@ -10,12 +10,14 @@ types: dict[str, type] = {
|
|||
"float": float,
|
||||
"str": str,
|
||||
"bytes": bytes,
|
||||
"list": list,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"dict": dict,
|
||||
"datetime": datetime,
|
||||
"file": UploadFile,
|
||||
|
||||
# TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ?
|
||||
"list[dict]": list[dict],
|
||||
# "tuple": tuple,
|
||||
# "set": set,
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue