fixed the life cycle of the models (they couldn't unload anymore) and simplified the implementation of the Python models

This commit is contained in:
faraphel 2025-01-12 21:26:50 +01:00
parent f647c960dd
commit 8bf28e4c48
9 changed files with 96 additions and 111 deletions

View file

@ -5,5 +5,9 @@ pydantic
gradio gradio
python-multipart python-multipart
# data manipulation
pillow
numpy
# AI # AI
accelerate accelerate

View file

@ -1,11 +1,6 @@
import typing import typing
def load(model) -> None: class Model:
pass async def infer(self, messages: list[dict]) -> typing.AsyncIterator[bytes]:
def unload(model) -> None:
pass
async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]:
yield messages[-1]["content"].encode("utf-8") yield messages[-1]["content"].encode("utf-8")

View file

@ -5,22 +5,18 @@ import torch
import transformers import transformers
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D" class Model:
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def __init__(self) -> None:
self.model = transformers.AutoModel.from_pretrained(self.NAME)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
def load(model) -> None: async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME) inputs = self.tokenizer(prompt, return_tensors="pt")
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
def unload(model) -> None:
model.model = None
model.tokenizer = None
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
outputs = model.model(**inputs) outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state embeddings = outputs.last_hidden_state

View file

@ -5,22 +5,18 @@ import torch
import transformers import transformers
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D" class Model:
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def __init__(self) -> None:
self.model = transformers.AutoModel.from_pretrained(self.NAME)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
def load(model) -> None: async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME) inputs = self.tokenizer(prompt, return_tensors="pt")
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
def unload(model) -> None:
model.model = None
model.tokenizer = None
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
outputs = model.model(**inputs) outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state embeddings = outputs.last_hidden_state

View file

@ -44,7 +44,7 @@ class ChatInterface(base.BaseInterface):
# send back the messages, clear the user prompt, disable the system prompt # send back the messages, clear the user prompt, disable the system prompt
return assistant_message return assistant_message
def get_gradio_application(self): def get_application(self):
# create a gradio interface # create a gradio interface
with gradio.Blocks(analytics_enabled=False) as application: with gradio.Blocks(analytics_enabled=False) as application:
# header # header

View file

@ -7,7 +7,7 @@ import source
class BaseInterface(abc.ABC): class BaseInterface(abc.ABC):
def __init__(self, model: "source.model.base.BaseModel"): def __init__(self, model: "source._model.base.BaseModel"):
self.model = model self.model = model
@property @property
@ -20,7 +20,7 @@ class BaseInterface(abc.ABC):
return f"{self.model.api_base}/interface" return f"{self.model.api_base}/interface"
@abc.abstractmethod @abc.abstractmethod
def get_gradio_application(self) -> gradio.Blocks: def get_application(self) -> gradio.Blocks:
""" """
Get a gradio application Get a gradio application
:return: a gradio application :return: a gradio application
@ -35,6 +35,6 @@ class BaseInterface(abc.ABC):
gradio.mount_gradio_app( gradio.mount_gradio_app(
application, application,
self.get_gradio_application(), self.get_application(),
self.route self.route
) )

View file

@ -41,18 +41,21 @@ class PythonModel(base.BaseModel):
self.path / file self.path / file
) )
# get the module # get the module
self.module = importlib.util.module_from_spec(module_spec) module = importlib.util.module_from_spec(module_spec)
# load the module # load the module
module_spec.loader.exec_module(self.module) module_spec.loader.exec_module(module)
def _load(self) -> None: # create the internal model from the class defined in the module
return self.module.load(self) self._model_type = module.Model
self._model: typing.Optional[module.Model] = None
def _unload(self) -> None: async def _load(self) -> None:
return self.module.unload(self) self._model = self._model_type()
async def _unload(self) -> None:
self._model = None
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
return self.module.infer(self, **kwargs) return self._model.infer(**kwargs)
def _mount(self, application: fastapi.FastAPI): def _mount(self, application: fastapi.FastAPI):
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ? # TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
@ -69,7 +72,7 @@ class PythonModel(base.BaseModel):
} }
return fastapi.responses.StreamingResponse( return fastapi.responses.StreamingResponse(
content=await self.registry.infer_model(self, **kwargs), content=await self.infer(**kwargs),
media_type=self.output_type, media_type=self.output_type,
headers={ headers={
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
@ -82,9 +85,11 @@ class PythonModel(base.BaseModel):
# format the description # format the description
description_sections: list[str] = [] description_sections: list[str] = []
if self.description is not None: if self.description is not None:
description_sections.append(self.description) description_sections.append(f"# Description\n{self.description}")
if self.interface is not None: if self.interface is not None:
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**") description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
description: str = "\n".join(description_sections)
# add the inference endpoint on the API # add the inference endpoint on the API
application.add_api_route( application.add_api_route(
@ -93,7 +98,7 @@ class PythonModel(base.BaseModel):
methods=["POST"], methods=["POST"],
tags=self.tags, tags=self.tags,
summary=self.summary, summary=self.summary,
description="<br>".join(description_sections), description=description,
response_class=fastapi.responses.StreamingResponse, response_class=fastapi.responses.StreamingResponse,
responses={ responses={
200: {"content": {self.output_type: {}}} 200: {"content": {self.output_type: {}}}

View file

@ -1,4 +1,5 @@
import abc import abc
import asyncio
import gc import gc
import typing import typing
from pathlib import Path from pathlib import Path
@ -38,6 +39,9 @@ class BaseModel(abc.ABC):
# is the model currently loaded # is the model currently loaded
self._loaded = False self._loaded = False
# lock to avoid loading and unloading at the same time
self.load_lock = asyncio.Lock()
def __repr__(self): def __repr__(self):
return f"<{self.__class__.__name__}: {self.name}>" return f"<{self.__class__.__name__}: {self.name}>"
@ -71,39 +75,47 @@ class BaseModel(abc.ABC):
"tags": self.tags "tags": self.tags
} }
def load(self) -> None: async def load(self) -> None:
""" """
Load the model within the model manager Load the model within the model manager
""" """
async with self.load_lock:
# if the model is already loaded, skip # if the model is already loaded, skip
if self._loaded: if self._loaded:
return return
# unload the currently loaded model if any
if self.registry.current_loaded_model is not None:
await self.registry.current_loaded_model.unload()
# load the model depending on the implementation # load the model depending on the implementation
self._load() await self._load()
# mark the model as loaded # mark the model as loaded
self._loaded = True self._loaded = True
# mark the model as the registry loaded model
self.registry.current_loaded_model = self
@abc.abstractmethod @abc.abstractmethod
def _load(self): async def _load(self):
""" """
Load the model Load the model
Do not call manually, use `load` instead. Do not call manually, use `load` instead.
""" """
def unload(self) -> None: async def unload(self) -> None:
""" """
Unload the model within the model manager Unload the model within the model manager
""" """
async with self.load_lock:
# if we are not already loaded, stop # if we are not already loaded, stop
if not self._loaded: if not self._loaded:
return return
# unload the model depending on the implementation # unload the model depending on the implementation
self._unload() await self._unload()
# force the garbage collector to clean the memory # force the garbage collector to clean the memory
gc.collect() gc.collect()
@ -111,8 +123,12 @@ class BaseModel(abc.ABC):
# mark the model as unloaded # mark the model as unloaded
self._loaded = False self._loaded = False
# if we are the registry current loaded model, remove this status
if self.registry.current_loaded_model is self:
self.registry.current_loaded_model = None
@abc.abstractmethod @abc.abstractmethod
def _unload(self): async def _unload(self):
""" """
Unload the model Unload the model
Do not call manually, use `unload` instead. Do not call manually, use `unload` instead.
@ -124,8 +140,9 @@ class BaseModel(abc.ABC):
:return: the response of the model :return: the response of the model
""" """
# make sure we are loaded before an inference async with self.registry.infer_lock:
self.load() # ensure that the model is loaded
await self.load()
# model specific inference part # model specific inference part
return await self._infer(**kwargs) return await self._infer(**kwargs)

View file

@ -33,8 +33,8 @@ class ModelRegistry:
# having two calculations at the same time might not be worth it either # having two calculations at the same time might not be worth it either
self.current_loaded_model: typing.Optional[BaseModel] = None self.current_loaded_model: typing.Optional[BaseModel] = None
# lock to avoid concurrent inference and concurrent model loading and unloading # lock to control access to model inference
self.inference_lock = asyncio.Lock() self.infer_lock = asyncio.Lock()
@property @property
def api_base(self) -> str: def api_base(self) -> str:
@ -48,34 +48,6 @@ class ModelRegistry:
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"): def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
self.model_types[name] = model_type self.model_types[name] = model_type
async def load_model(self, model: "BaseModel"):
# lock to avoid concurrent loading
async with self.inference_lock:
# if there is another currently loaded model, unload it
if self.current_loaded_model is not None and self.current_loaded_model is not model:
await self.unload_model(self.current_loaded_model)
# load the model
model.load()
# mark the model as the currently loaded model of the manager
self.current_loaded_model = model
async def unload_model(self, model: "BaseModel"):
# lock to avoid concurrent unloading
async with self.inference_lock:
# if we were the currently loaded model of the manager, demote ourselves
if self.current_loaded_model is model:
self.current_loaded_model = None
# model specific unloading part
model.unload()
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
# lock to avoid concurrent inference
async with self.inference_lock:
return await model.infer(**kwargs)
def reload_models(self) -> None: def reload_models(self) -> None:
""" """
Reload the list of models available Reload the list of models available