fixed the life cycle of the models (they couldn't unload anymore) and simplified the implementation of the Python models

This commit is contained in:
faraphel 2025-01-12 21:26:50 +01:00
parent f647c960dd
commit 8bf28e4c48
9 changed files with 96 additions and 111 deletions

View file

@ -33,8 +33,8 @@ class ModelRegistry:
# having two calculations at the same time might not be worth it either
self.current_loaded_model: typing.Optional[BaseModel] = None
# lock to avoid concurrent inference and concurrent model loading and unloading
self.inference_lock = asyncio.Lock()
# lock to control access to model inference
self.infer_lock = asyncio.Lock()
@property
def api_base(self) -> str:
@ -48,34 +48,6 @@ class ModelRegistry:
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
self.model_types[name] = model_type
async def load_model(self, model: "BaseModel"):
# lock to avoid concurrent loading
async with self.inference_lock:
# if there is another currently loaded model, unload it
if self.current_loaded_model is not None and self.current_loaded_model is not model:
await self.unload_model(self.current_loaded_model)
# load the model
model.load()
# mark the model as the currently loaded model of the manager
self.current_loaded_model = model
async def unload_model(self, model: "BaseModel"):
# lock to avoid concurrent unloading
async with self.inference_lock:
# if we were the currently loaded model of the manager, demote ourselves
if self.current_loaded_model is model:
self.current_loaded_model = None
# model specific unloading part
model.unload()
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
# lock to avoid concurrent inference
async with self.inference_lock:
return await model.infer(**kwargs)
def reload_models(self) -> None:
"""
Reload the list of models available