fixed the life cycle of the models (they couldn't unload anymore) and simplified the implementation of the Python models
This commit is contained in:
parent
f647c960dd
commit
8bf28e4c48
9 changed files with 96 additions and 111 deletions
|
@ -1,11 +1,6 @@
|
|||
import typing
|
||||
|
||||
|
||||
def load(model) -> None:
|
||||
pass
|
||||
|
||||
def unload(model) -> None:
|
||||
pass
|
||||
|
||||
async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]:
|
||||
yield messages[-1]["content"].encode("utf-8")
|
||||
class Model:
|
||||
async def infer(self, messages: list[dict]) -> typing.AsyncIterator[bytes]:
|
||||
yield messages[-1]["content"].encode("utf-8")
|
||||
|
|
|
@ -5,25 +5,21 @@ import torch
|
|||
import transformers
|
||||
|
||||
|
||||
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
class Model:
|
||||
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
||||
|
||||
def load(model) -> None:
|
||||
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
||||
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
def unload(model) -> None:
|
||||
model.model = None
|
||||
model.tokenizer = None
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.model(**inputs)
|
||||
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
}).encode("utf-8")
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
}).encode("utf-8")
|
||||
|
|
|
@ -5,25 +5,21 @@ import torch
|
|||
import transformers
|
||||
|
||||
|
||||
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
class Model:
|
||||
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
||||
|
||||
def load(model) -> None:
|
||||
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
||||
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
def unload(model) -> None:
|
||||
model.model = None
|
||||
model.tokenizer = None
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.model(**inputs)
|
||||
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
}).encode("utf-8")
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
}).encode("utf-8")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue