added support of inputs parameters that are recognised by the API.
Models are now loaded in separate endpoints for the inputs to be easier to recognise
This commit is contained in:
parent
900c58ffcb
commit
7bd84c8570
17 changed files with 163 additions and 128 deletions
|
@ -1,3 +1,6 @@
|
|||
{
|
||||
"type": "dummy"
|
||||
"type": "python",
|
||||
"file": "model.py",
|
||||
|
||||
"inputs": {}
|
||||
}
|
||||
|
|
12
samples/models/dummy/model.py
Normal file
12
samples/models/dummy/model.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import json
|
||||
import typing
|
||||
|
||||
|
||||
def load(model) -> None:
|
||||
pass
|
||||
|
||||
def unload(model) -> None:
|
||||
pass
|
||||
|
||||
def infer(model) -> typing.Iterator[bytes]:
|
||||
yield json.dumps({"hello": "world!"}).encode("utf-8")
|
|
@ -2,6 +2,10 @@
|
|||
"type": "python",
|
||||
"file": "model.py",
|
||||
|
||||
"inputs": {
|
||||
"prompt": {"type": "str"}
|
||||
},
|
||||
|
||||
"requirements": [
|
||||
"transformers",
|
||||
"torch",
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
@ -7,22 +8,22 @@ import transformers
|
|||
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
|
||||
|
||||
def load(model):
|
||||
def load(model) -> None:
|
||||
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
||||
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
def unload(model):
|
||||
def unload(model) -> None:
|
||||
model.model = None
|
||||
model.tokenizer = None
|
||||
|
||||
def infer(model, payload: dict) -> str:
|
||||
inputs = model.tokenizer(payload["prompt"], return_tensors="pt")
|
||||
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.model(**inputs)
|
||||
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
return json.dumps({
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
})
|
||||
}).encode("utf-8")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
@ -7,22 +8,22 @@ import transformers
|
|||
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
|
||||
|
||||
def load(model):
|
||||
def load(model) -> None:
|
||||
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
||||
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
def unload(model):
|
||||
def unload(model) -> None:
|
||||
model.model = None
|
||||
model.tokenizer = None
|
||||
|
||||
def infer(model, payload: dict) -> str:
|
||||
inputs = model.tokenizer(payload["prompt"], return_tensors="pt")
|
||||
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.model(**inputs)
|
||||
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
return json.dumps({
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
})
|
||||
}).encode("utf-8")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue