fixed the life cycle of the models (they couldn't unload anymore) and simplified the implementation of the Python models
This commit is contained in:
parent
f647c960dd
commit
8bf28e4c48
9 changed files with 96 additions and 111 deletions
|
@ -1,4 +1,5 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import gc
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
@ -38,6 +39,9 @@ class BaseModel(abc.ABC):
|
|||
# is the model currently loaded
|
||||
self._loaded = False
|
||||
|
||||
# lock to avoid loading and unloading at the same time
|
||||
self.load_lock = asyncio.Lock()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}: {self.name}>"
|
||||
|
||||
|
@ -71,48 +75,60 @@ class BaseModel(abc.ABC):
|
|||
"tags": self.tags
|
||||
}
|
||||
|
||||
def load(self) -> None:
|
||||
async def load(self) -> None:
|
||||
"""
|
||||
Load the model within the model manager
|
||||
"""
|
||||
|
||||
# if the model is already loaded, skip
|
||||
if self._loaded:
|
||||
return
|
||||
async with self.load_lock:
|
||||
# if the model is already loaded, skip
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
# load the model depending on the implementation
|
||||
self._load()
|
||||
# unload the currently loaded model if any
|
||||
if self.registry.current_loaded_model is not None:
|
||||
await self.registry.current_loaded_model.unload()
|
||||
|
||||
# mark the model as loaded
|
||||
self._loaded = True
|
||||
# load the model depending on the implementation
|
||||
await self._load()
|
||||
|
||||
# mark the model as loaded
|
||||
self._loaded = True
|
||||
# mark the model as the registry loaded model
|
||||
self.registry.current_loaded_model = self
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load(self):
|
||||
async def _load(self):
|
||||
"""
|
||||
Load the model
|
||||
Do not call manually, use `load` instead.
|
||||
"""
|
||||
|
||||
def unload(self) -> None:
|
||||
async def unload(self) -> None:
|
||||
"""
|
||||
Unload the model within the model manager
|
||||
"""
|
||||
|
||||
# if we are not already loaded, stop
|
||||
if not self._loaded:
|
||||
return
|
||||
async with self.load_lock:
|
||||
# if we are not already loaded, stop
|
||||
if not self._loaded:
|
||||
return
|
||||
|
||||
# unload the model depending on the implementation
|
||||
self._unload()
|
||||
# unload the model depending on the implementation
|
||||
await self._unload()
|
||||
|
||||
# force the garbage collector to clean the memory
|
||||
gc.collect()
|
||||
# force the garbage collector to clean the memory
|
||||
gc.collect()
|
||||
|
||||
# mark the model as unloaded
|
||||
self._loaded = False
|
||||
# mark the model as unloaded
|
||||
self._loaded = False
|
||||
|
||||
# if we are the registry current loaded model, remove this status
|
||||
if self.registry.current_loaded_model is self:
|
||||
self.registry.current_loaded_model = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _unload(self):
|
||||
async def _unload(self):
|
||||
"""
|
||||
Unload the model
|
||||
Do not call manually, use `unload` instead.
|
||||
|
@ -124,11 +140,12 @@ class BaseModel(abc.ABC):
|
|||
:return: the response of the model
|
||||
"""
|
||||
|
||||
# make sure we are loaded before an inference
|
||||
self.load()
|
||||
async with self.registry.infer_lock:
|
||||
# ensure that the model is loaded
|
||||
await self.load()
|
||||
|
||||
# model specific inference part
|
||||
return await self._infer(**kwargs)
|
||||
# model specific inference part
|
||||
return await self._infer(**kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue