toxicity / app-worked-backup-1.py
rudradcruze's picture
upload toxicity api application
1c25c67
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import numpy as np
import os
import time
import joblib
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional
from contextlib import asynccontextmanager
from dotenv import load_dotenv
import shutil
from huggingface_hub import hf_hub_download
# Transformers imports specifically for ProtBERT
from transformers import BertTokenizer, BertModel
# Import your custom model structure
from utils.model_classes import MHSA_GRU
load_dotenv()
# ========================= CONFIGURATION ==========================
# Repository details (Where your trained classifier/scaler live)
MODEL_REPO = {
"repo_id": "camlas/toxicity",
"files": {
"classifier": "mhsa_gru_classifier.pth",
"scaler": "scaler.pkl"
}
}
# Feature Extraction Config - UPDATED FOR PROTBERT
TRANSFORMER_CONFIG = {
"model_name": "Rostlab/prot_bert",
"model_type": "ProtBERT",
"tokenizer_class": BertTokenizer,
"model_class": BertModel
}
CLASSES = ["Non-Toxic", "Toxic"]
API_VERSION = "2.0.0-protbert"
MODEL_VERSION = "ProtBERT-MHSA-GRU-v1"
# Global variables to hold loaded models
models = {
"transformer": None,
"tokenizer": None,
"classifier": None,
"scaler": None
}
# Device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ========================= HELPER FUNCTIONS =========================
def ensure_models_directory():
models_dir = "models"
Path(models_dir).mkdir(exist_ok=True)
return models_dir
def download_model_from_hub(model_key: str) -> Optional[str]:
"""Download custom trained models (Classifier/Scaler) from Private HF Repo"""
try:
filename = MODEL_REPO["files"][model_key]
repo_id = MODEL_REPO["repo_id"]
models_dir = ensure_models_directory()
local_path = os.path.join(models_dir, filename)
# If file exists locally, use it
if os.path.exists(local_path):
print(f"βœ… Found {model_key} locally: {local_path}")
return local_path
print(f"πŸ“₯ Downloading {model_key} from {repo_id}...")
token = os.getenv("HF_TOKEN")
if not token:
print("⚠️ Warning: HF_TOKEN not found in .env. Private repos will fail.")
temp_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model",
token=token
)
shutil.copy2(temp_path, local_path)
return local_path
except Exception as e:
print(f"❌ Error downloading {model_key}: {e}")
return None
def load_feature_extractor():
"""Load the ProtBERT Model from HuggingFace"""
print(f"πŸ”„ Loading Transformer: {TRANSFORMER_CONFIG['model_name']}...")
try:
# Load specifically with do_lower_case=False for ProtBERT
tokenizer = TRANSFORMER_CONFIG['tokenizer_class'].from_pretrained(
TRANSFORMER_CONFIG['model_name'],
do_lower_case=False
)
model = TRANSFORMER_CONFIG['model_class'].from_pretrained(
TRANSFORMER_CONFIG['model_name']
)
model.to(device)
model.eval()
models["tokenizer"] = tokenizer
models["transformer"] = model
print("βœ… ProtBERT Transformer loaded successfully")
return True
except Exception as e:
print(f"❌ Error loading Transformer: {e}")
return False
def load_classifier_and_scaler():
"""Load the custom MHSA-GRU classifier and Scaler"""
try:
# 1. Load Scaler
scaler_path = download_model_from_hub("scaler")
if scaler_path:
models["scaler"] = joblib.load(scaler_path)
print("βœ… Scaler loaded")
# 2. Load Classifier
clf_path = download_model_from_hub("classifier")
if clf_path:
# ProtBERT output dimension is 1024
input_dim = 1024
print(f"ℹ️ Initializing MHSA_GRU with input_dim={input_dim} (ProtBERT)")
classifier = MHSA_GRU(
input_dim=input_dim,
hidden_dim=256, # Matching your training code
num_heads=8,
num_gru_layers=2,
dropout=0.3
)
state_dict = torch.load(clf_path, map_location=device)
classifier.load_state_dict(state_dict)
classifier.to(device)
classifier.eval()
models["classifier"] = classifier
print("βœ… Classifier loaded")
return models["scaler"] is not None and models["classifier"] is not None
except Exception as e:
print(f"❌ Error loading custom models: {e}")
return False
def preprocess_sequence(sequence: str):
"""
Preprocess sequence for ProtBERT.
ProtBERT expects spaces between amino acids: 'M K T A Y...'
"""
# Clean and uppercase
sequence = sequence.upper().strip().replace('\n', '').replace('\r', '')
# Add spaces between residues
spaced_sequence = " ".join(list(sequence))
return spaced_sequence
def extract_features(sequence: str):
"""Run sequence through ProtBERT to get [CLS] embeddings"""
tokenizer = models["tokenizer"]
model = models["transformer"]
processed_seq = preprocess_sequence(sequence)
inputs = tokenizer(
[processed_seq],
return_tensors="pt",
padding=True,
truncation=True,
max_length=512 # ProtBERT max length
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# Extract [CLS] token embedding (Index 0)
# shape: (batch_size, hidden_dim) -> (1, 1024)
features = outputs.last_hidden_state[:, 0, :]
return features.cpu().numpy()
# ========================= FASTAPI LIFESPAN =========================
@asynccontextmanager
async def lifespan(app: FastAPI):
print("πŸš€ Starting Toxicity Detection API (ProtBERT Edition)...")
# Check if utils/model_classes.py exists
if not os.path.exists("utils/model_classes.py"):
print("❌ Error: utils/model_classes.py not found. Please create it.")
success_tf = load_feature_extractor()
success_custom = load_classifier_and_scaler()
if not (success_tf and success_custom):
print("⚠️ Warning: Not all models loaded successfully")
yield
print("πŸ”„ Shutting down API...")
app = FastAPI(
title="Peptide Toxicity Detection API",
description="API using ProtBERT features + MHSA-GRU classifier",
version=API_VERSION,
lifespan=lifespan
)
# ========================= PYDANTIC MODELS =========================
class SequenceRequest(BaseModel):
sequence: str
class PredictionResponse(BaseModel):
sequence_preview: str
is_toxic: bool
label: str
score: float
confidence_level: str
model_used: str
processing_time_ms: float
timestamp: str
# ========================= ENDPOINTS =========================
@app.get("/")
async def root():
return {"message": "Toxicity Detection API is running. Use /predict to analyze sequences."}
@app.get("/health")
async def health_check():
loaded = all(v is not None for v in models.values())
return {
"status": "healthy" if loaded else "degraded",
"models_loaded": {k: v is not None for k, v in models.items()},
"device": str(device),
"model_version": MODEL_VERSION,
"feature_extractor": TRANSFORMER_CONFIG["model_name"]
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: SequenceRequest):
start_time = time.time()
if not all(models.values()):
raise HTTPException(status_code=503, detail="Models are not fully initialized.")
if not request.sequence:
raise HTTPException(status_code=400, detail="Empty sequence provided.")
try:
# 1. Extract Features (ProtBERT [CLS] Token)
# This handles the 'M K T' spacing internally
raw_features = extract_features(request.sequence)
# 2. Scale Features
# Use the scaler loaded from your repo
scaled_features = models["scaler"].transform(raw_features)
# 3. Predict (MHSA-GRU)
features_tensor = torch.FloatTensor(scaled_features).to(device)
with torch.no_grad():
# Get probability (sigmoid output)
probability = models["classifier"](features_tensor).item()
# 4. Interpret Results
# Threshold 0.5
prediction_class = 1 if probability > 0.5 else 0
predicted_label = CLASSES[prediction_class]
# Confidence calculation
confidence_score = abs(probability - 0.5) * 2
confidence_level = "High" if confidence_score > 0.8 else "Medium" if confidence_score > 0.5 else "Low"
processing_time = round((time.time() - start_time) * 1000, 2)
return PredictionResponse(
sequence_preview=request.sequence[:20] + "..." if len(request.sequence) > 20 else request.sequence,
is_toxic=(prediction_class == 1),
label=predicted_label,
score=probability,
confidence_level=confidence_level,
model_used="ProtBERT + MHSA-GRU",
processing_time_ms=processing_time,
timestamp=datetime.now(timezone.utc).isoformat()
)
except Exception as e:
print(f"Error during prediction: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)