25 lines
663 B
Python
25 lines
663 B
Python
import json
|
|
import typing
|
|
|
|
import torch
|
|
import transformers
|
|
|
|
|
|
class Model:
|
|
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
|
|
|
def __init__(self) -> None:
|
|
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
|
|
|
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
|
inputs = self.tokenizer(prompt, return_tensors="pt")
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
|
|
embeddings = outputs.last_hidden_state
|
|
|
|
yield json.dumps({
|
|
"data": embeddings.tolist()
|
|
}).encode("utf-8")
|