ai-server/samples/models/python-bert-2/model.py

28 lines
629 B
Python

import json
import torch
import transformers
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def load(model):
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
def unload(model):
model.model = None
model.tokenizer = None
def infer(model, payload: dict) -> str:
inputs = model.tokenizer(payload["prompt"], return_tensors="pt")
with torch.no_grad():
outputs = model.model(**inputs)
embeddings = outputs.last_hidden_state
return json.dumps({
"data": embeddings.tolist()
})