fixed the life cycle of the models (they couldn't unload anymore) and simplified the implementation of the Python models

This commit is contained in:
faraphel 2025-01-12 21:26:50 +01:00
parent f647c960dd
commit 8bf28e4c48
9 changed files with 96 additions and 111 deletions

View file

@ -41,18 +41,21 @@ class PythonModel(base.BaseModel):
self.path / file
)
# get the module
self.module = importlib.util.module_from_spec(module_spec)
module = importlib.util.module_from_spec(module_spec)
# load the module
module_spec.loader.exec_module(self.module)
module_spec.loader.exec_module(module)
def _load(self) -> None:
return self.module.load(self)
# create the internal model from the class defined in the module
self._model_type = module.Model
self._model: typing.Optional[module.Model] = None
def _unload(self) -> None:
return self.module.unload(self)
async def _load(self) -> None:
self._model = self._model_type()
async def _unload(self) -> None:
self._model = None
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
return self.module.infer(self, **kwargs)
return self._model.infer(**kwargs)
def _mount(self, application: fastapi.FastAPI):
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
@ -69,7 +72,7 @@ class PythonModel(base.BaseModel):
}
return fastapi.responses.StreamingResponse(
content=await self.registry.infer_model(self, **kwargs),
content=await self.infer(**kwargs),
media_type=self.output_type,
headers={
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
@ -82,9 +85,11 @@ class PythonModel(base.BaseModel):
# format the description
description_sections: list[str] = []
if self.description is not None:
description_sections.append(self.description)
description_sections.append(f"# Description\n{self.description}")
if self.interface is not None:
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**")
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
description: str = "\n".join(description_sections)
# add the inference endpoint on the API
application.add_api_route(
@ -93,7 +98,7 @@ class PythonModel(base.BaseModel):
methods=["POST"],
tags=self.tags,
summary=self.summary,
description="<br>".join(description_sections),
description=description,
response_class=fastapi.responses.StreamingResponse,
responses={
200: {"content": {self.output_type: {}}}

View file

@ -1,4 +1,5 @@
import abc
import asyncio
import gc
import typing
from pathlib import Path
@ -38,6 +39,9 @@ class BaseModel(abc.ABC):
# is the model currently loaded
self._loaded = False
# lock to avoid loading and unloading at the same time
self.load_lock = asyncio.Lock()
def __repr__(self):
return f"<{self.__class__.__name__}: {self.name}>"
@ -71,48 +75,60 @@ class BaseModel(abc.ABC):
"tags": self.tags
}
def load(self) -> None:
async def load(self) -> None:
"""
Load the model within the model manager
"""
# if the model is already loaded, skip
if self._loaded:
return
async with self.load_lock:
# if the model is already loaded, skip
if self._loaded:
return
# load the model depending on the implementation
self._load()
# unload the currently loaded model if any
if self.registry.current_loaded_model is not None:
await self.registry.current_loaded_model.unload()
# mark the model as loaded
self._loaded = True
# load the model depending on the implementation
await self._load()
# mark the model as loaded
self._loaded = True
# mark the model as the registry loaded model
self.registry.current_loaded_model = self
@abc.abstractmethod
def _load(self):
async def _load(self):
"""
Load the model
Do not call manually, use `load` instead.
"""
def unload(self) -> None:
async def unload(self) -> None:
"""
Unload the model within the model manager
"""
# if we are not already loaded, stop
if not self._loaded:
return
async with self.load_lock:
# if we are not already loaded, stop
if not self._loaded:
return
# unload the model depending on the implementation
self._unload()
# unload the model depending on the implementation
await self._unload()
# force the garbage collector to clean the memory
gc.collect()
# force the garbage collector to clean the memory
gc.collect()
# mark the model as unloaded
self._loaded = False
# mark the model as unloaded
self._loaded = False
# if we are the registry current loaded model, remove this status
if self.registry.current_loaded_model is self:
self.registry.current_loaded_model = None
@abc.abstractmethod
def _unload(self):
async def _unload(self):
"""
Unload the model
Do not call manually, use `unload` instead.
@ -124,11 +140,12 @@ class BaseModel(abc.ABC):
:return: the response of the model
"""
# make sure we are loaded before an inference
self.load()
async with self.registry.infer_lock:
# ensure that the model is loaded
await self.load()
# model specific inference part
return await self._infer(**kwargs)
# model specific inference part
return await self._infer(**kwargs)
@abc.abstractmethod
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: