|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
import numpy as np |
|
|
import os |
|
|
import time |
|
|
import joblib |
|
|
from pathlib import Path |
|
|
from datetime import datetime, timezone |
|
|
from typing import Optional |
|
|
from contextlib import asynccontextmanager |
|
|
from dotenv import load_dotenv |
|
|
import shutil |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
from transformers import BertTokenizer, BertModel |
|
|
|
|
|
|
|
|
from utils.model_classes import MHSA_GRU |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO = { |
|
|
"repo_id": "camlas/toxicity", |
|
|
"files": { |
|
|
"classifier": "mhsa_gru_classifier.pth", |
|
|
"scaler": "scaler.pkl" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
TRANSFORMER_CONFIG = { |
|
|
"model_name": "Rostlab/prot_bert", |
|
|
"model_type": "ProtBERT", |
|
|
"tokenizer_class": BertTokenizer, |
|
|
"model_class": BertModel |
|
|
} |
|
|
|
|
|
CLASSES = ["Non-Toxic", "Toxic"] |
|
|
API_VERSION = "2.0.0-protbert" |
|
|
MODEL_VERSION = "ProtBERT-MHSA-GRU-v1" |
|
|
|
|
|
|
|
|
models = { |
|
|
"transformer": None, |
|
|
"tokenizer": None, |
|
|
"classifier": None, |
|
|
"scaler": None |
|
|
} |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
def ensure_models_directory(): |
|
|
models_dir = "models" |
|
|
Path(models_dir).mkdir(exist_ok=True) |
|
|
return models_dir |
|
|
|
|
|
def download_model_from_hub(model_key: str) -> Optional[str]: |
|
|
"""Download custom trained models (Classifier/Scaler) from Private HF Repo""" |
|
|
try: |
|
|
filename = MODEL_REPO["files"][model_key] |
|
|
repo_id = MODEL_REPO["repo_id"] |
|
|
models_dir = ensure_models_directory() |
|
|
local_path = os.path.join(models_dir, filename) |
|
|
|
|
|
|
|
|
if os.path.exists(local_path): |
|
|
print(f"β
Found {model_key} locally: {local_path}") |
|
|
return local_path |
|
|
|
|
|
print(f"π₯ Downloading {model_key} from {repo_id}...") |
|
|
token = os.getenv("HF_TOKEN") |
|
|
|
|
|
if not token: |
|
|
print("β οΈ Warning: HF_TOKEN not found in .env. Private repos will fail.") |
|
|
|
|
|
temp_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=filename, |
|
|
repo_type="model", |
|
|
token=token |
|
|
) |
|
|
shutil.copy2(temp_path, local_path) |
|
|
return local_path |
|
|
except Exception as e: |
|
|
print(f"β Error downloading {model_key}: {e}") |
|
|
return None |
|
|
|
|
|
def load_feature_extractor(): |
|
|
"""Load the ProtBERT Model from HuggingFace""" |
|
|
print(f"π Loading Transformer: {TRANSFORMER_CONFIG['model_name']}...") |
|
|
try: |
|
|
|
|
|
tokenizer = TRANSFORMER_CONFIG['tokenizer_class'].from_pretrained( |
|
|
TRANSFORMER_CONFIG['model_name'], |
|
|
do_lower_case=False |
|
|
) |
|
|
model = TRANSFORMER_CONFIG['model_class'].from_pretrained( |
|
|
TRANSFORMER_CONFIG['model_name'] |
|
|
) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
models["tokenizer"] = tokenizer |
|
|
models["transformer"] = model |
|
|
print("β
ProtBERT Transformer loaded successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Error loading Transformer: {e}") |
|
|
return False |
|
|
|
|
|
def load_classifier_and_scaler(): |
|
|
"""Load the custom MHSA-GRU classifier and Scaler""" |
|
|
try: |
|
|
|
|
|
scaler_path = download_model_from_hub("scaler") |
|
|
if scaler_path: |
|
|
models["scaler"] = joblib.load(scaler_path) |
|
|
print("β
Scaler loaded") |
|
|
|
|
|
|
|
|
clf_path = download_model_from_hub("classifier") |
|
|
if clf_path: |
|
|
|
|
|
input_dim = 1024 |
|
|
|
|
|
print(f"βΉοΈ Initializing MHSA_GRU with input_dim={input_dim} (ProtBERT)") |
|
|
|
|
|
classifier = MHSA_GRU( |
|
|
input_dim=input_dim, |
|
|
hidden_dim=256, |
|
|
num_heads=8, |
|
|
num_gru_layers=2, |
|
|
dropout=0.3 |
|
|
) |
|
|
|
|
|
state_dict = torch.load(clf_path, map_location=device) |
|
|
classifier.load_state_dict(state_dict) |
|
|
classifier.to(device) |
|
|
classifier.eval() |
|
|
models["classifier"] = classifier |
|
|
print("β
Classifier loaded") |
|
|
|
|
|
return models["scaler"] is not None and models["classifier"] is not None |
|
|
except Exception as e: |
|
|
print(f"β Error loading custom models: {e}") |
|
|
return False |
|
|
|
|
|
def preprocess_sequence(sequence: str): |
|
|
""" |
|
|
Preprocess sequence for ProtBERT. |
|
|
ProtBERT expects spaces between amino acids: 'M K T A Y...' |
|
|
""" |
|
|
|
|
|
sequence = sequence.upper().strip().replace('\n', '').replace('\r', '') |
|
|
|
|
|
|
|
|
spaced_sequence = " ".join(list(sequence)) |
|
|
return spaced_sequence |
|
|
|
|
|
def extract_features(sequence: str): |
|
|
"""Run sequence through ProtBERT to get [CLS] embeddings""" |
|
|
tokenizer = models["tokenizer"] |
|
|
model = models["transformer"] |
|
|
|
|
|
processed_seq = preprocess_sequence(sequence) |
|
|
|
|
|
inputs = tokenizer( |
|
|
[processed_seq], |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
|
|
|
features = outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
return features.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
print("π Starting Toxicity Detection API (ProtBERT Edition)...") |
|
|
|
|
|
|
|
|
if not os.path.exists("utils/model_classes.py"): |
|
|
print("β Error: utils/model_classes.py not found. Please create it.") |
|
|
|
|
|
success_tf = load_feature_extractor() |
|
|
success_custom = load_classifier_and_scaler() |
|
|
|
|
|
if not (success_tf and success_custom): |
|
|
print("β οΈ Warning: Not all models loaded successfully") |
|
|
yield |
|
|
print("π Shutting down API...") |
|
|
|
|
|
app = FastAPI( |
|
|
title="Peptide Toxicity Detection API", |
|
|
description="API using ProtBERT features + MHSA-GRU classifier", |
|
|
version=API_VERSION, |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class SequenceRequest(BaseModel): |
|
|
sequence: str |
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
sequence_preview: str |
|
|
is_toxic: bool |
|
|
label: str |
|
|
score: float |
|
|
confidence_level: str |
|
|
model_used: str |
|
|
processing_time_ms: float |
|
|
timestamp: str |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Toxicity Detection API is running. Use /predict to analyze sequences."} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
loaded = all(v is not None for v in models.values()) |
|
|
return { |
|
|
"status": "healthy" if loaded else "degraded", |
|
|
"models_loaded": {k: v is not None for k, v in models.items()}, |
|
|
"device": str(device), |
|
|
"model_version": MODEL_VERSION, |
|
|
"feature_extractor": TRANSFORMER_CONFIG["model_name"] |
|
|
} |
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
|
async def predict(request: SequenceRequest): |
|
|
start_time = time.time() |
|
|
|
|
|
if not all(models.values()): |
|
|
raise HTTPException(status_code=503, detail="Models are not fully initialized.") |
|
|
|
|
|
if not request.sequence: |
|
|
raise HTTPException(status_code=400, detail="Empty sequence provided.") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
raw_features = extract_features(request.sequence) |
|
|
|
|
|
|
|
|
|
|
|
scaled_features = models["scaler"].transform(raw_features) |
|
|
|
|
|
|
|
|
features_tensor = torch.FloatTensor(scaled_features).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
probability = models["classifier"](features_tensor).item() |
|
|
|
|
|
|
|
|
|
|
|
prediction_class = 1 if probability > 0.5 else 0 |
|
|
predicted_label = CLASSES[prediction_class] |
|
|
|
|
|
|
|
|
confidence_score = abs(probability - 0.5) * 2 |
|
|
confidence_level = "High" if confidence_score > 0.8 else "Medium" if confidence_score > 0.5 else "Low" |
|
|
|
|
|
processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
|
|
|
return PredictionResponse( |
|
|
sequence_preview=request.sequence[:20] + "..." if len(request.sequence) > 20 else request.sequence, |
|
|
is_toxic=(prediction_class == 1), |
|
|
label=predicted_label, |
|
|
score=probability, |
|
|
confidence_level=confidence_level, |
|
|
model_used="ProtBERT + MHSA-GRU", |
|
|
processing_time_ms=processing_time, |
|
|
timestamp=datetime.now(timezone.utc).isoformat() |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during prediction: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |