ai-server/samples/models/python-bert-2/model.py
faraphel 7bd84c8570 added support of inputs parameters that are recognised by the API.
Models are now loaded in separate endpoints for the inputs to be easier to recognise
2025-01-09 23:12:54 +01:00

29 lines
680 B
Python

import json
import typing
import torch
import transformers
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def load(model) -> None:
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
def unload(model) -> None:
model.model = None
model.tokenizer = None
def infer(model, prompt: str) -> typing.Iterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.model(**inputs)
embeddings = outputs.last_hidden_state
yield json.dumps({
"data": embeddings.tolist()
}).encode("utf-8")