added support of inputs parameters that are recognised by the API.

Models are now loaded in separate endpoints for the inputs to be easier to recognise
This commit is contained in:
faraphel 2025-01-09 23:12:54 +01:00
parent 900c58ffcb
commit 7bd84c8570
17 changed files with 163 additions and 128 deletions

View file

@ -1,4 +1,5 @@
import json
import typing
import torch
import transformers
@ -7,22 +8,22 @@ import transformers
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def load(model):
def load(model) -> None:
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
def unload(model):
def unload(model) -> None:
model.model = None
model.tokenizer = None
def infer(model, payload: dict) -> str:
inputs = model.tokenizer(payload["prompt"], return_tensors="pt")
def infer(model, prompt: str) -> typing.Iterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.model(**inputs)
embeddings = outputs.last_hidden_state
return json.dumps({
yield json.dumps({
"data": embeddings.tolist()
})
}).encode("utf-8")