import json import typing import torch import transformers class Model: NAME: str = "huawei-noah/TinyBERT_General_4L_312D" def __init__(self) -> None: self.model = transformers.AutoModel.from_pretrained(self.NAME) self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME) async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]: inputs = self.tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) embeddings = outputs.last_hidden_state yield json.dumps({ "data": embeddings.tolist() }).encode("utf-8")