fixed the life cycle of the models (they couldn't unload anymore) and simplified the implementation of the Python models

This commit is contained in:
faraphel 2025-01-12 21:26:50 +01:00
parent f647c960dd
commit 8bf28e4c48
9 changed files with 96 additions and 111 deletions

View file

@ -5,25 +5,21 @@ import torch
import transformers
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
class Model:
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def __init__(self) -> None:
self.model = transformers.AutoModel.from_pretrained(self.NAME)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
def load(model) -> None:
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = self.tokenizer(prompt, return_tensors="pt")
def unload(model) -> None:
model.model = None
model.tokenizer = None
with torch.no_grad():
outputs = self.model(**inputs)
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt")
embeddings = outputs.last_hidden_state
with torch.no_grad():
outputs = model.model(**inputs)
embeddings = outputs.last_hidden_state
yield json.dumps({
"data": embeddings.tolist()
}).encode("utf-8")
yield json.dumps({
"data": embeddings.tolist()
}).encode("utf-8")