28 lines
629 B
Python
28 lines
629 B
Python
import json
|
|
|
|
import torch
|
|
import transformers
|
|
|
|
|
|
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
|
|
|
|
|
def load(model):
|
|
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
|
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
def unload(model):
|
|
model.model = None
|
|
model.tokenizer = None
|
|
|
|
def infer(model, payload: dict) -> str:
|
|
inputs = model.tokenizer(payload["prompt"], return_tensors="pt")
|
|
|
|
with torch.no_grad():
|
|
outputs = model.model(**inputs)
|
|
|
|
embeddings = outputs.last_hidden_state
|
|
|
|
return json.dumps({
|
|
"data": embeddings.tolist()
|
|
})
|