ML6-UniKP / main.py
Topallaj Denis
make main endpoint post
94e262e
raw
history blame
6 kB
from fastapi import FastAPI
from typing import Dict, List, Any, Tuple
import pickle
import math
import re
import gc
from utils import split
import torch
from build_vocab import WordVocab
from pretrain_trfm import TrfmSeq2seq
from transformers import T5EncoderModel, T5Tokenizer
import numpy as np
import pydantic
app = FastAPI()
"""
tokenizer = T5Tokenizer.from_pretrained(
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False, torch_dtype=torch.float16)
model = T5EncoderModel.from_pretrained(
"Rostlab/prot_t5_xl_half_uniref50-enc")
"""
class Item(pydantic.BaseModel):
sequence: str
smiles: str
@app.post("/predict")
def predict(item: Item):
endpointHandler = EndpointHandler()
result = endpointHandler.predict({
"inputs": {
"sequence": item.sequence,
"smiles": item.smiles
}
})
return result
class EndpointHandler():
def __init__(self, path=""):
self.tokenizer = tokenizer
self.model = model
# path to the vocab_content and trfm model
vocab_content_path = "vocab_content.txt"
trfm_path = "trfm_12_23000.pkl"
# load the vocab_content instead of the pickle file
with open(vocab_content_path, "r", encoding="utf-8") as f:
vocab_content = f.read().strip().split("\n")
# load the vocab and trfm model
self.vocab = WordVocab(vocab_content)
self.trfm = TrfmSeq2seq(len(self.vocab), 256, len(self.vocab), 4)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.trfm.load_state_dict(torch.load(trfm_path, map_location=device))
self.trfm.eval()
# path to the pretrained models
self.Km_model_path = "Km.pkl"
self.Kcat_model_path = "Kcat.pkl"
self.Kcat_over_Km_model_path = "Kcat_over_Km.pkl"
# vocab indices
self.pad_index = 0
self.unk_index = 1
self.eos_index = 2
self.sos_index = 3
def predict(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Function where the endpoint logic is implemented.
Args:
data (Dict[str, Any]): The input data for the endpoint. It only contain a single key "inputs" which is a list of dictionaries. The dictionary contains the following keys:
- sequence (str): Amino acid sequence.
- smiles (str): SMILES representation of the molecule.
Returns:
Dict[str, Any]: The output data for the endpoint. The dictionary contains the following keys:
- Km (float): float of predicted Km value.
- Kcat (float): float of predicted Kcat value.
- Vmax (float): float of predicted Vmax value.
"""
sequence = data["inputs"]["sequence"]
smiles = data["inputs"]["smiles"]
seq_vec = self.Seq_to_vec(sequence)
smiles_vec = self.smiles_to_vec(smiles)
fused_vector = np.concatenate((smiles_vec, seq_vec), axis=1)
pred_Km = self.predict_feature_using_model(
fused_vector, self.Km_model_path)
pred_Kcat = self.predict_feature_using_model(
fused_vector, self.Kcat_model_path)
pred_Vmax = self.predict_feature_using_model(
fused_vector, self.Kcat_over_Km_model_path)
result = {
"Km": pred_Km,
"Kcat": pred_Kcat,
"Vmax": pred_Vmax,
}
return result
def predict_feature_using_model(self, X: np.array, model_path: str) -> float:
"""
Function to predict the feature using the pretrained model.
"""
with open(model_path, "rb") as f:
model = pickle.load(f)
pred_feature = model.predict(X)
pred_feature_pow = math.pow(10, pred_feature)
return pred_feature_pow
def smiles_to_vec(self, Smiles: str) -> np.array:
"""
Function to convert the smiles to a vector using the pretrained model.
"""
Smiles = [Smiles]
x_split = [split(sm) for sm in Smiles]
xid, xseg = self.get_array(x_split, self.vocab)
X = self.trfm.encode(torch.t(xid))
return X
def get_inputs(self, sm: str, vocab: WordVocab) -> Tuple[List[int], List[int]]:
"""
Convert smiles to tensor
"""
seq_len = len(sm)
sm = sm.split()
ids = [vocab.stoi.get(token, self.unk_index) for token in sm]
ids = [self.sos_index] + ids + [self.eos_index]
seg = [1]*len(ids)
padding = [self.pad_index]*(seq_len - len(ids))
ids.extend(padding), seg.extend(padding)
return ids, seg
def get_array(self, smiles: list[str], vocab: WordVocab) -> Tuple[torch.tensor, torch.tensor]:
"""
Convert smiles to tensor
"""
x_id, x_seg = [], []
for sm in smiles:
a,b = self.get_inputs(sm, vocab)
x_id.append(a)
x_seg.append(b)
return torch.tensor(x_id), torch.tensor(x_seg)
def Seq_to_vec(self, Sequence: str) -> np.array:
"""
Function to convert the sequence to a vector using the pretrained model.
"""
Sequence = [Sequence]
sequences_Example = []
for i in range(len(Sequence)):
zj = ''
for j in range(len(Sequence[i]) - 1):
zj += Sequence[i][j] + ' '
zj += Sequence[i][-1]
sequences_Example.append(zj)
gc.collect()
print(torch.cuda.is_available())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(device)
self.model = self.model.eval()
features = []
for i in range(len(sequences_Example)):
sequences_Example_i = sequences_Example[i]
sequences_Example_i = [re.sub(r"[UZOB]", "X", sequences_Example_i)]
ids = self.tokenizer.batch_encode_plus(sequences_Example_i, add_special_tokens=True, padding=True)
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)
with torch.no_grad():
embedding = self.model(input_ids=input_ids, attention_mask=attention_mask)
embedding = embedding.last_hidden_state.cpu().numpy()
for seq_num in range(len(embedding)):
seq_len = (attention_mask[seq_num] == 1).sum()
seq_emd = embedding[seq_num][:seq_len - 1]
features.append(seq_emd)
features_normalize = np.zeros([len(features), len(features[0][0])], dtype=float)
for i in range(len(features)):
for k in range(len(features[0][0])):
for j in range(len(features[i])):
features_normalize[i][k] += features[i][j][k]
features_normalize[i][k] /= len(features[i])
return features_normalize