import json
import typing

import torch
import transformers


MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"


def load(model) -> None:
    model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
    model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)

def unload(model) -> None:
    model.model = None
    model.tokenizer = None

def infer(model, prompt: str) -> typing.Iterator[bytes]:
    inputs = model.tokenizer(prompt, return_tensors="pt")

    with torch.no_grad():
        outputs = model.model(**inputs)

    embeddings = outputs.last_hidden_state

    yield json.dumps({
        "data": embeddings.tolist()
    }).encode("utf-8")