Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -93,8 +93,89 @@ def predict(sequence):
|
|
| 93 |
return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
|
| 94 |
else:
|
| 95 |
return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
|
|
|
|
|
|
|
| 96 |
def predictmic(sequence):
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
# Gradio interface
|
| 99 |
iface = gr.Interface(
|
| 100 |
fn=predict,
|
|
|
|
| 93 |
return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
|
| 94 |
else:
|
| 95 |
return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
def predictmic(sequence):
|
| 99 |
+
import torch
|
| 100 |
+
from transformers import BertTokenizer, BertModel
|
| 101 |
+
import numpy as np
|
| 102 |
+
import pickle
|
| 103 |
+
from math import expm1
|
| 104 |
+
|
| 105 |
+
# === Load ProtBert model ===
|
| 106 |
+
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
|
| 107 |
+
model = BertModel.from_pretrained("Rostlab/prot_bert")
|
| 108 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 109 |
+
model = model.to(device).eval()
|
| 110 |
+
|
| 111 |
+
# === Preprocess input sequence ===
|
| 112 |
+
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 113 |
+
if len(sequence) < 10:
|
| 114 |
+
return {"Error": "Sequence too short or invalid. Must contain at least 10 valid amino acids."}
|
| 115 |
+
|
| 116 |
+
# === Tokenize & embed using mean pooling ===
|
| 117 |
+
seq_spaced = ' '.join(list(sequence))
|
| 118 |
+
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
|
| 119 |
+
tokens = {k: v.to(device) for k, v in tokens.items()}
|
| 120 |
+
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
outputs = model(**tokens)
|
| 123 |
+
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1) # Shape: (1, 1024)
|
| 124 |
+
|
| 125 |
+
# === MIC models and scalers for each bacterium ===
|
| 126 |
+
bacteria_config = {
|
| 127 |
+
"E.coli": {
|
| 128 |
+
"model": "coli_xgboost_model.pkl",
|
| 129 |
+
"scaler": "coli_scaler.pkl",
|
| 130 |
+
"pca": None
|
| 131 |
+
},
|
| 132 |
+
"S.aureus": {
|
| 133 |
+
"model": "aur_xgboost_model.pkl",
|
| 134 |
+
"scaler": "aur_scaler.pkl",
|
| 135 |
+
"pca": None
|
| 136 |
+
},
|
| 137 |
+
"P.aeruginosa": {
|
| 138 |
+
"model": "arg_xgboost_model.pkl",
|
| 139 |
+
"scaler": "arg_scaler.pkl",
|
| 140 |
+
"pca": None
|
| 141 |
+
},
|
| 142 |
+
"K.Pneumonia": {
|
| 143 |
+
"model": "pne_mlp_model.pkl",
|
| 144 |
+
"scaler": "pne_scaler.pkl",
|
| 145 |
+
"pca": "pne_pca"
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
mic_results = {}
|
| 150 |
+
|
| 151 |
+
for bacterium, cfg in bacteria_config.items():
|
| 152 |
+
try:
|
| 153 |
+
# === Load scaler and transform ===
|
| 154 |
+
with open(cfg["scaler"], "rb") as f:
|
| 155 |
+
scaler = pickle.load(f)
|
| 156 |
+
scaled = scaler.transform(embedding)
|
| 157 |
+
|
| 158 |
+
# === Apply PCA if exists ===
|
| 159 |
+
if cfg["pca"] is not None:
|
| 160 |
+
with open(cfg["pca"], "rb") as f:
|
| 161 |
+
pca = pickle.load(f)
|
| 162 |
+
transformed = pca.transform(scaled)
|
| 163 |
+
else:
|
| 164 |
+
transformed = scaled
|
| 165 |
+
|
| 166 |
+
# === Load model and predict ===
|
| 167 |
+
with open(cfg["model"], "rb") as f:
|
| 168 |
+
mic_model = pickle.load(f)
|
| 169 |
+
mic_log = mic_model.predict(transformed)[0]
|
| 170 |
+
mic = round(expm1(mic_log), 3) # Inverse of log1p used in training
|
| 171 |
+
|
| 172 |
+
mic_results[bacterium] = mic
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
mic_results[bacterium] = f"Error: {str(e)}"
|
| 176 |
+
|
| 177 |
+
return mic_results
|
| 178 |
+
|
| 179 |
# Gradio interface
|
| 180 |
iface = gr.Interface(
|
| 181 |
fn=predict,
|