From 8bf28e4c4817a25c2b8d6588c853f5246ac2b970 Mon Sep 17 00:00:00 2001 From: faraphel Date: Sun, 12 Jan 2025 21:26:50 +0100 Subject: [PATCH] fixed the life cycle of the models (they couldn't unload anymore) and simplified the implementation of the Python models --- requirements.txt | 4 ++ samples/models/dummy/model.py | 11 +--- samples/models/python-bert-1/model.py | 30 +++++----- samples/models/python-bert-2/model.py | 30 +++++----- source/api/interface/ChatInterface.py | 2 +- source/api/interface/base/BaseInterface.py | 6 +- source/model/PythonModel.py | 27 +++++---- source/model/base/BaseModel.py | 65 ++++++++++++++-------- source/registry/ModelRegistry.py | 32 +---------- 9 files changed, 96 insertions(+), 111 deletions(-) diff --git a/requirements.txt b/requirements.txt index 35452a2..00110d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,9 @@ pydantic gradio python-multipart +# data manipulation +pillow +numpy + # AI accelerate diff --git a/samples/models/dummy/model.py b/samples/models/dummy/model.py index 4661536..f41489c 100644 --- a/samples/models/dummy/model.py +++ b/samples/models/dummy/model.py @@ -1,11 +1,6 @@ import typing -def load(model) -> None: - pass - -def unload(model) -> None: - pass - -async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]: - yield messages[-1]["content"].encode("utf-8") +class Model: + async def infer(self, messages: list[dict]) -> typing.AsyncIterator[bytes]: + yield messages[-1]["content"].encode("utf-8") diff --git a/samples/models/python-bert-1/model.py b/samples/models/python-bert-1/model.py index b8dfaf9..5d5d225 100644 --- a/samples/models/python-bert-1/model.py +++ b/samples/models/python-bert-1/model.py @@ -5,25 +5,21 @@ import torch import transformers -MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D" +class Model: + NAME: str = "huawei-noah/TinyBERT_General_4L_312D" + def __init__(self) -> None: + self.model = transformers.AutoModel.from_pretrained(self.NAME) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME) -def load(model) -> None: - model.model = transformers.AutoModel.from_pretrained(MODEL_NAME) - model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) + async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]: + inputs = self.tokenizer(prompt, return_tensors="pt") -def unload(model) -> None: - model.model = None - model.tokenizer = None + with torch.no_grad(): + outputs = self.model(**inputs) -async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]: - inputs = model.tokenizer(prompt, return_tensors="pt") + embeddings = outputs.last_hidden_state - with torch.no_grad(): - outputs = model.model(**inputs) - - embeddings = outputs.last_hidden_state - - yield json.dumps({ - "data": embeddings.tolist() - }).encode("utf-8") + yield json.dumps({ + "data": embeddings.tolist() + }).encode("utf-8") diff --git a/samples/models/python-bert-2/model.py b/samples/models/python-bert-2/model.py index b8dfaf9..5d5d225 100644 --- a/samples/models/python-bert-2/model.py +++ b/samples/models/python-bert-2/model.py @@ -5,25 +5,21 @@ import torch import transformers -MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D" +class Model: + NAME: str = "huawei-noah/TinyBERT_General_4L_312D" + def __init__(self) -> None: + self.model = transformers.AutoModel.from_pretrained(self.NAME) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME) -def load(model) -> None: - model.model = transformers.AutoModel.from_pretrained(MODEL_NAME) - model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) + async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]: + inputs = self.tokenizer(prompt, return_tensors="pt") -def unload(model) -> None: - model.model = None - model.tokenizer = None + with torch.no_grad(): + outputs = self.model(**inputs) -async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]: - inputs = model.tokenizer(prompt, return_tensors="pt") + embeddings = outputs.last_hidden_state - with torch.no_grad(): - outputs = model.model(**inputs) - - embeddings = outputs.last_hidden_state - - yield json.dumps({ - "data": embeddings.tolist() - }).encode("utf-8") + yield json.dumps({ + "data": embeddings.tolist() + }).encode("utf-8") diff --git a/source/api/interface/ChatInterface.py b/source/api/interface/ChatInterface.py index 2cb7f27..a608bc5 100644 --- a/source/api/interface/ChatInterface.py +++ b/source/api/interface/ChatInterface.py @@ -44,7 +44,7 @@ class ChatInterface(base.BaseInterface): # send back the messages, clear the user prompt, disable the system prompt return assistant_message - def get_gradio_application(self): + def get_application(self): # create a gradio interface with gradio.Blocks(analytics_enabled=False) as application: # header diff --git a/source/api/interface/base/BaseInterface.py b/source/api/interface/base/BaseInterface.py index dc6ad71..425ccaf 100644 --- a/source/api/interface/base/BaseInterface.py +++ b/source/api/interface/base/BaseInterface.py @@ -7,7 +7,7 @@ import source class BaseInterface(abc.ABC): - def __init__(self, model: "source.model.base.BaseModel"): + def __init__(self, model: "source._model.base.BaseModel"): self.model = model @property @@ -20,7 +20,7 @@ class BaseInterface(abc.ABC): return f"{self.model.api_base}/interface" @abc.abstractmethod - def get_gradio_application(self) -> gradio.Blocks: + def get_application(self) -> gradio.Blocks: """ Get a gradio application :return: a gradio application @@ -35,6 +35,6 @@ class BaseInterface(abc.ABC): gradio.mount_gradio_app( application, - self.get_gradio_application(), + self.get_application(), self.route ) diff --git a/source/model/PythonModel.py b/source/model/PythonModel.py index 9d8efab..d5d8853 100644 --- a/source/model/PythonModel.py +++ b/source/model/PythonModel.py @@ -41,18 +41,21 @@ class PythonModel(base.BaseModel): self.path / file ) # get the module - self.module = importlib.util.module_from_spec(module_spec) + module = importlib.util.module_from_spec(module_spec) # load the module - module_spec.loader.exec_module(self.module) + module_spec.loader.exec_module(module) - def _load(self) -> None: - return self.module.load(self) + # create the internal model from the class defined in the module + self._model_type = module.Model + self._model: typing.Optional[module.Model] = None - def _unload(self) -> None: - return self.module.unload(self) + async def _load(self) -> None: + self._model = self._model_type() + async def _unload(self) -> None: + self._model = None async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: - return self.module.infer(self, **kwargs) + return self._model.infer(**kwargs) def _mount(self, application: fastapi.FastAPI): # TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ? @@ -69,7 +72,7 @@ class PythonModel(base.BaseModel): } return fastapi.responses.StreamingResponse( - content=await self.registry.infer_model(self, **kwargs), + content=await self.infer(**kwargs), media_type=self.output_type, headers={ # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI @@ -82,9 +85,11 @@ class PythonModel(base.BaseModel): # format the description description_sections: list[str] = [] if self.description is not None: - description_sections.append(self.description) + description_sections.append(f"# Description\n{self.description}") if self.interface is not None: - description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**") + description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**") + + description: str = "\n".join(description_sections) # add the inference endpoint on the API application.add_api_route( @@ -93,7 +98,7 @@ class PythonModel(base.BaseModel): methods=["POST"], tags=self.tags, summary=self.summary, - description="
".join(description_sections), + description=description, response_class=fastapi.responses.StreamingResponse, responses={ 200: {"content": {self.output_type: {}}} diff --git a/source/model/base/BaseModel.py b/source/model/base/BaseModel.py index be939c2..3fc347e 100644 --- a/source/model/base/BaseModel.py +++ b/source/model/base/BaseModel.py @@ -1,4 +1,5 @@ import abc +import asyncio import gc import typing from pathlib import Path @@ -38,6 +39,9 @@ class BaseModel(abc.ABC): # is the model currently loaded self._loaded = False + # lock to avoid loading and unloading at the same time + self.load_lock = asyncio.Lock() + def __repr__(self): return f"<{self.__class__.__name__}: {self.name}>" @@ -71,48 +75,60 @@ class BaseModel(abc.ABC): "tags": self.tags } - def load(self) -> None: + async def load(self) -> None: """ Load the model within the model manager """ - # if the model is already loaded, skip - if self._loaded: - return + async with self.load_lock: + # if the model is already loaded, skip + if self._loaded: + return - # load the model depending on the implementation - self._load() + # unload the currently loaded model if any + if self.registry.current_loaded_model is not None: + await self.registry.current_loaded_model.unload() - # mark the model as loaded - self._loaded = True + # load the model depending on the implementation + await self._load() + + # mark the model as loaded + self._loaded = True + # mark the model as the registry loaded model + self.registry.current_loaded_model = self @abc.abstractmethod - def _load(self): + async def _load(self): """ Load the model Do not call manually, use `load` instead. """ - def unload(self) -> None: + async def unload(self) -> None: """ Unload the model within the model manager """ - # if we are not already loaded, stop - if not self._loaded: - return + async with self.load_lock: + # if we are not already loaded, stop + if not self._loaded: + return - # unload the model depending on the implementation - self._unload() + # unload the model depending on the implementation + await self._unload() - # force the garbage collector to clean the memory - gc.collect() + # force the garbage collector to clean the memory + gc.collect() - # mark the model as unloaded - self._loaded = False + # mark the model as unloaded + self._loaded = False + + # if we are the registry current loaded model, remove this status + if self.registry.current_loaded_model is self: + self.registry.current_loaded_model = None @abc.abstractmethod - def _unload(self): + async def _unload(self): """ Unload the model Do not call manually, use `unload` instead. @@ -124,11 +140,12 @@ class BaseModel(abc.ABC): :return: the response of the model """ - # make sure we are loaded before an inference - self.load() + async with self.registry.infer_lock: + # ensure that the model is loaded + await self.load() - # model specific inference part - return await self._infer(**kwargs) + # model specific inference part + return await self._infer(**kwargs) @abc.abstractmethod async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: diff --git a/source/registry/ModelRegistry.py b/source/registry/ModelRegistry.py index 0f57197..ecba609 100644 --- a/source/registry/ModelRegistry.py +++ b/source/registry/ModelRegistry.py @@ -33,8 +33,8 @@ class ModelRegistry: # having two calculations at the same time might not be worth it either self.current_loaded_model: typing.Optional[BaseModel] = None - # lock to avoid concurrent inference and concurrent model loading and unloading - self.inference_lock = asyncio.Lock() + # lock to control access to model inference + self.infer_lock = asyncio.Lock() @property def api_base(self) -> str: @@ -48,34 +48,6 @@ class ModelRegistry: def register_type(self, name: str, model_type: "typing.Type[BaseModel]"): self.model_types[name] = model_type - async def load_model(self, model: "BaseModel"): - # lock to avoid concurrent loading - async with self.inference_lock: - # if there is another currently loaded model, unload it - if self.current_loaded_model is not None and self.current_loaded_model is not model: - await self.unload_model(self.current_loaded_model) - - # load the model - model.load() - - # mark the model as the currently loaded model of the manager - self.current_loaded_model = model - - async def unload_model(self, model: "BaseModel"): - # lock to avoid concurrent unloading - async with self.inference_lock: - # if we were the currently loaded model of the manager, demote ourselves - if self.current_loaded_model is model: - self.current_loaded_model = None - - # model specific unloading part - model.unload() - - async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]: - # lock to avoid concurrent inference - async with self.inference_lock: - return await model.infer(**kwargs) - def reload_models(self) -> None: """ Reload the list of models available