Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import BertTokenizer, BertModel | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| import requests | |
| import py3Dmol | |
| from Bio import SeqIO | |
| import io | |
| from Bio.SeqUtils.ProtParam import ProteinAnalysis | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from collections import Counter | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| # import shap | |
| st.set_page_config( | |
| page_title="Parkinson's Protein Classifier", | |
| page_icon="🧬", | |
| layout="wide" | |
| ) | |
| # Load ProtBERT Model | |
| def load_protbert(): | |
| tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) | |
| model = BertModel.from_pretrained("Rostlab/prot_bert") | |
| model.eval() | |
| return tokenizer, model | |
| # Embedding Function | |
| def get_protbert_embedding(sequence, tokenizer, model): | |
| sequence = sequence.replace(" ", "") | |
| sequence = ' '.join(list(sequence)) | |
| tokens = tokenizer(sequence, return_tensors='pt') | |
| with torch.no_grad(): | |
| outputs = model(**tokens) | |
| embedding = torch.mean(outputs.last_hidden_state, dim=1) | |
| return embedding.squeeze().numpy() | |
| # Protein Analysis Function | |
| def analyze_protein(sequence): | |
| sequence = sequence.upper().replace(" ", "").replace("\n", "") | |
| if not all(residue in "ACDEFGHIKLMNPQRSTVWY" for residue in sequence): | |
| return "Invalid amino acid sequence!", None | |
| analysis = ProteinAnalysis(sequence) | |
| length = len(sequence) | |
| mw = analysis.molecular_weight() | |
| aromaticity = analysis.aromaticity() | |
| instability = analysis.instability_index() | |
| gravy = analysis.gravy() | |
| aa_counts = analysis.count_amino_acids() | |
| aa_percent = {k: v/length*100 for k, v in aa_counts.items()} | |
| # Secondary structure | |
| sec_struct = analysis.secondary_structure_fraction() | |
| # Isoelectric point | |
| pI = analysis.isoelectric_point() | |
| # Flexibility | |
| flexibility = analysis.flexibility() | |
| results = { | |
| 'basic': { | |
| 'Length': length, | |
| 'Molecular Weight (Da)': mw, | |
| 'Aromaticity': aromaticity, | |
| 'Instability Index': instability, | |
| 'GRAVY (Hydrophobicity)': gravy, | |
| 'Isoelectric Point (pI)': pI | |
| }, | |
| 'aa_composition': aa_percent, | |
| 'secondary_structure': { | |
| 'Helix': sec_struct[0], | |
| 'Turn': sec_struct[1], | |
| 'Sheet': sec_struct[2] | |
| }, | |
| 'flexibility': flexibility | |
| } | |
| parkinsons_analysis = { | |
| 'risk_factors': [], | |
| 'notes': [] | |
| } | |
| if length != 140: | |
| parkinsons_analysis['risk_factors'].append(f"Sequence length ({length}) deviates from wild-type (140)") | |
| if mw > 14660 or mw < 14400: | |
| parkinsons_analysis['risk_factors'].append(f"Molecular weight ({mw:.2f} Da) differs from wild-type (14.46 kDa)") | |
| if aromaticity > 0.05: | |
| parkinsons_analysis['risk_factors'].append("High aromaticity (potential aggregation risk)") | |
| if instability > 45: | |
| parkinsons_analysis['risk_factors'].append(f"High instability index ({instability:.2f}) suggests toxic form") | |
| if gravy > -0.3: | |
| parkinsons_analysis['risk_factors'].append(f"Hydrophobicity (GRAVY: {gravy:.3f}) suggests aggregation-prone variant") | |
| key_positions = { | |
| 53: 'A53T (known pathogenic)', | |
| 30: 'E46K (known pathogenic)', | |
| 83: 'E83Q (known pathogenic)' | |
| } | |
| high_risk_aas = { | |
| 'C': "Cysteine residues can promote aggregation", | |
| 'G': "Glycine substitutions often pathogenic", | |
| 'P': "Proline substitutions can disrupt structure" | |
| } | |
| for aa, risk in high_risk_aas.items(): | |
| if aa_counts.get(aa, 0) > 0: | |
| parkinsons_analysis['notes'].append(f"{risk} ({aa_counts.get(aa, 0)} {aa} residues)") | |
| return results, parkinsons_analysis | |
| def get_sample_data(): | |
| data = { | |
| 'sequence': [ | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVTTVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # A53T | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVVNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # Random (non-pathogenic) | |
| "MDVFMKGLSKAKEGVVAAAIKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # Random (non-pathogenic) | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDVEPEA", | |
| "MDVFMKGLSGAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", # K10G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGGVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F94G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEACEMPSEEGYQDYEPEA", #Y125C | |
| "MDVFMKGLSKHKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A11C | |
| "MDVFMKGLSKAKEGVVAASEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A19S | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLTVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Y39T | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #M5A | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGCQDYEPEA", #Y133C | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDGLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Q99G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDCEPEA", #Y136C | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLCVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Y39C | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDAPVDPDNEAYEMPSEEGYQDYEPEA", #M116A | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTGEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K45G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVGKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K96G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDGPVDPDNEAYEMPSEEGYQDYEPEA", #M116G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAPEMPSEEGYQDYEPEA", #Y125P | |
| "GDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #M1G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGSVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F94S | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGVQDYEPEA", #Y133V | |
| "MDVFMKGLSKAKEGVVAAAEKTKGGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #Q24G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEGGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E105G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDTEPEA", #Y136T | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKGGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E35G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGTQDYEPEA", #Y133T | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTGEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #K60G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #F4G | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVFGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #E83F | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTKVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #N65K | |
| "MDVFMKGLSKSKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", #A11S | |
| "MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVATVAEKTKEQVTEVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA" #N65E | |
| ], | |
| 'label': [1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], | |
| 'mutation': ['A53T', 'None', 'None', 'Unknown', 'K10G', 'F94G', 'Y125C', 'A11C', 'A19S', 'Y39T', | |
| 'M5A', 'Y133C', 'Q99G', 'Y136C', 'Y39C', 'M116A', 'K45G', 'K96G', 'M116G', 'Y125P', | |
| 'M1G', 'F94S', 'Y133V', 'Q24G', 'E105G', 'Y136T', 'E35G', 'Y133T', 'K60G', 'F4G', | |
| 'E83F', 'N65K', 'A11S', 'N65E'] | |
| } | |
| return pd.DataFrame(data) | |
| def train_classifier(X, y): | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| clf = RandomForestClassifier(n_estimators=100, random_state=42) | |
| clf.fit(X_train, y_train) | |
| return clf, X_test, y_test | |
| # Main App | |
| def main(): | |
| st.title("🧬 Parkinson's Disease Protein Sequence Classifier") | |
| st.markdown(""" | |
| This app uses ProtBERT to generate protein sequence embeddings and a Random Forest classifier | |
| to predict whether a protein sequence is associated with Parkinson's disease. | |
| """) | |
| st.sidebar.header("About") | |
| st.sidebar.info(""" | |
| This tool uses: | |
| - ProtBERT for protein sequence embeddings | |
| - Random Forest for classification | |
| - Sample dataset of known variants | |
| - 3D structure prediction via ESMFold API | |
| """) | |
| with st.spinner("Loading ProtBERT model..."): | |
| tokenizer, model = load_protbert() | |
| if 'classifier' not in st.session_state: | |
| st.session_state.classifier = None | |
| st.session_state.X_test = None | |
| st.session_state.y_test = None | |
| st.session_state.training_data = None | |
| tab1, tab2, tab3, tab4 = st.tabs(["Train Model", "Evaluate Model", "Predict New Sequence", "Data Exploration"]) | |
| with tab1: | |
| st.header("Train Classification Model") | |
| if st.button("Train Model with Sample Data"): | |
| with st.spinner("Training in progress..."): | |
| df = get_sample_data() | |
| embeddings = [] | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| for i, seq in enumerate(df['sequence']): | |
| try: | |
| status_text.text(f"Processing sequence {i+1}/{len(df['sequence'])}...") | |
| progress_bar.progress((i+1)/len(df['sequence'])) | |
| emb = get_protbert_embedding(seq, tokenizer, model) | |
| embeddings.append(emb) | |
| except Exception as e: | |
| st.warning(f"Error with sequence {i+1}: {str(e)}") | |
| embeddings.append(np.zeros(1024)) | |
| X = np.array(embeddings) | |
| y = df['label'].values | |
| clf, X_test, y_test = train_classifier(X, y) | |
| st.session_state.classifier = clf | |
| st.session_state.X_test = X_test | |
| st.session_state.y_test = y_test | |
| st.session_state.training_data = df | |
| st.success("Model trained successfully!") | |
| st.subheader("Sample Training Data") | |
| st.dataframe(df) | |
| st.subheader("Class Distribution") | |
| class_counts = df['label'].value_counts() | |
| fig = px.pie(values=class_counts, names=class_counts.index.map({0: 'Non-Parkinson', 1: 'Parkinson'})) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with tab2: | |
| st.header("Evaluate Model Performance") | |
| if st.session_state.classifier is not None: | |
| clf = st.session_state.classifier | |
| X_test = st.session_state.X_test | |
| y_test = st.session_state.y_test | |
| y_pred = clf.predict(X_test) | |
| y_proba = clf.predict_proba(X_test)[:, 1] | |
| st.subheader("Classification Report") | |
| report = classification_report(y_test, y_pred, output_dict=True) | |
| st.dataframe(pd.DataFrame(report).transpose()) | |
| st.subheader("Confusion Matrix") | |
| cm = confusion_matrix(y_test, y_pred) | |
| fig = px.imshow(cm, | |
| labels=dict(x="Predicted", y="Actual", color="Count"), | |
| x=['Non-Parkinson', 'Parkinson'], | |
| y=['Non-Parkinson', 'Parkinson'], | |
| text_auto=True) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.subheader("Feature Importance") | |
| try: | |
| importances = clf.feature_importances_ | |
| top_n = 20 | |
| indices = np.argsort(importances)[-top_n:] | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| y=[f"Feature {i}" for i in indices], | |
| x=importances[indices], | |
| orientation='h' | |
| )) | |
| fig.update_layout(title=f"Top {top_n} Important Features", | |
| xaxis_title="Importance Score") | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception as e: | |
| st.warning(f"Could not display feature importance: {str(e)}") | |
| else: | |
| st.warning("Please train the model first using the 'Train Model' tab.") | |
| with tab3: | |
| def fetch_structure(sequence): | |
| url = "https://api.esmatlas.com/foldSequence/v1/pdb/" | |
| headers = {"Content-Type": "text/plain"} | |
| try: | |
| response = requests.post(url, data=sequence, headers=headers, timeout=30) | |
| if response.status_code == 200: | |
| return response.text | |
| else: | |
| raise Exception(f"API returned status code {response.status_code}") | |
| except Exception as e: | |
| raise Exception(f"Failed to fetch structure: {str(e)}") | |
| def display_structure(pdb_data, color="chain", show_sidechains=True, show_mainchains=False): | |
| view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js') | |
| view.addModel(pdb_data, 'pdb') | |
| if color == "rainbow": | |
| view.setStyle({'cartoon': {'color': 'spectrum'}}) | |
| elif color == "chain": | |
| view.setStyle({'cartoon': {'color': 'chain'}}) | |
| elif color == "residue": | |
| view.setStyle({'cartoon': {'colorscheme': 'residue'}}) | |
| else: | |
| view.setStyle({'cartoon': {'color': 'white'}}) | |
| if show_sidechains: | |
| view.addStyle({'and': [{'atom': ['C', 'O', 'N'], 'invert': True}]}, | |
| {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}}) | |
| if show_mainchains: | |
| view.addStyle({'atom': ['C', 'O', 'N', 'CA']}, | |
| {'stick': {'colorscheme': "WhiteCarbon", 'radius': 0.3}}) | |
| view.zoomTo() | |
| return view | |
| st.header("Predict New Sequence") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| uploaded_file = st.file_uploader("Upload a FASTA file:", type=["fasta", "fa"]) | |
| seq_input = st.text_area( | |
| "Or enter protein sequence manually:", | |
| value="MDVFMKGLSKAKEGVVAAAEKTKQGVAEAAGKTKEGVLYVGSKTKEGVVHGVTTVAEKTKEQVTNVGGAVVTGVTAVAQKTVEGAGSIAAATGFVKKDQLGKNEEGAPQEGILEDMPVDPDNEAYEMPSEEGYQDYEPEA", | |
| height=200 | |
| ) | |
| if uploaded_file is not None: | |
| try: | |
| fasta_content = uploaded_file.read().decode("utf-8") | |
| fasta_io = io.StringIO(fasta_content) | |
| record = next(SeqIO.parse(fasta_io, "fasta")) | |
| seq_input = str(record.seq) | |
| st.success(f"Sequence loaded from FASTA file: {record.id}") | |
| except Exception as e: | |
| st.error(f"Error reading FASTA file: {e}") | |
| with col2: | |
| if st.button("Analyze Sequence"): | |
| if not seq_input.strip(): | |
| st.error("Please enter a protein sequence.") | |
| else: | |
| with st.spinner("Analyzing sequence..."): | |
| try: | |
| analysis_results, parkinsons_analysis = analyze_protein(seq_input) | |
| if isinstance(analysis_results, str): | |
| st.error(analysis_results) | |
| else: | |
| st.subheader("Basic Properties") | |
| st.table(pd.DataFrame.from_dict(analysis_results['basic'], orient='index')) | |
| st.subheader("Amino Acid Composition") | |
| aa_df = pd.DataFrame.from_dict(analysis_results['aa_composition'], orient='index', columns=['Percentage']) | |
| st.bar_chart(aa_df) | |
| st.subheader("Secondary Structure") | |
| ss_df = pd.DataFrame.from_dict(analysis_results['secondary_structure'], orient='index', columns=['Fraction']) | |
| st.bar_chart(ss_df) | |
| st.subheader("Parkinson's Risk Analysis") | |
| if parkinsons_analysis['risk_factors']: | |
| st.warning("⚠️ Potential Parkinson's risk factors detected:") | |
| for factor in parkinsons_analysis['risk_factors']: | |
| st.write(f"- {factor}") | |
| else: | |
| st.success("No obvious Parkinson's risk factors detected") | |
| if parkinsons_analysis['notes']: | |
| st.info("Additional notes:") | |
| for note in parkinsons_analysis['notes']: | |
| st.write(f"- {note}") | |
| except Exception as e: | |
| st.error(f"Error analyzing sequence: {str(e)}") | |
| if st.button("Predict Parkinson's Association"): | |
| if st.session_state.classifier is None: | |
| st.error("Please train the model first using the 'Train Model' tab.") | |
| elif not seq_input.strip(): | |
| st.error("Please enter a protein sequence.") | |
| else: | |
| with st.spinner("Generating embedding and making prediction..."): | |
| try: | |
| new_emb = get_protbert_embedding(seq_input, tokenizer, model).reshape(1, -1) | |
| prediction = st.session_state.classifier.predict(new_emb) | |
| proba = st.session_state.classifier.predict_proba(new_emb) | |
| st.subheader("Prediction Result") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if prediction[0] == 1: | |
| st.error("**Prediction: Parkinson-related protein**") | |
| else: | |
| st.success("**Prediction: Not Parkinson-related**") | |
| st.write(f"Confidence: {max(proba[0])*100:.2f}%") | |
| proba_df = pd.DataFrame({ | |
| "Class": ["Not Parkinson-related", "Parkinson-related"], | |
| "Probability": proba[0] | |
| }) | |
| fig = px.bar(proba_df, x='Class', y='Probability', | |
| color='Class', | |
| color_discrete_map={ | |
| "Not Parkinson-related": "green", | |
| "Parkinson-related": "red" | |
| }) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # with col2: | |
| # # Show SHAP values if available | |
| # try: | |
| # explainer = shap.TreeExplainer(st.session_state.classifier) | |
| # shap_values = explainer.shap_values(new_emb) | |
| # fig, ax = plt.subplots() | |
| # shap.summary_plot(shap_values, new_emb, | |
| # feature_names=[f"Feature {i}" for i in range(new_emb.shape[1])], | |
| # plot_type="bar", | |
| # show=False) | |
| # st.pyplot(fig) | |
| # plt.close() | |
| # except Exception as e: | |
| # st.warning(f"Could not generate SHAP explanation: {str(e)}") | |
| st.subheader("3D Protein Structure Prediction") | |
| try: | |
| pdb_data = fetch_structure(seq_input) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("**Cartoon Representation**") | |
| view = display_structure(pdb_data, color="chain") | |
| st.components.v1.html(view._make_html(), height=500) | |
| with col2: | |
| st.write("**Residue Coloring**") | |
| view = display_structure(pdb_data, color="residue", show_sidechains=True) | |
| st.components.v1.html(view._make_html(), height=500) | |
| st.download_button( | |
| label="Download PDB File", | |
| data=pdb_data, | |
| file_name="predicted_structure.pdb", | |
| mime="chemical/x-pdb" | |
| ) | |
| except Exception as e: | |
| st.error(f"Could not fetch protein structure: {str(e)}") | |
| except Exception as e: | |
| st.error(f"Error processing sequence: {str(e)}") | |
| with tab4: | |
| st.header("Data Exploration") | |
| if st.session_state.training_data is not None: | |
| df = st.session_state.training_data | |
| st.subheader("Training Data Overview") | |
| st.dataframe(df) | |
| st.subheader("Mutation Analysis") | |
| mutation_counts = df['mutation'].value_counts().reset_index() | |
| mutation_counts.columns = ['Mutation', 'Count'] | |
| fig = px.bar(mutation_counts, x='Mutation', y='Count') | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.subheader("Label Distribution by Mutation") | |
| fig = px.histogram(df, x='mutation', color='label', | |
| barmode='group', | |
| color_discrete_map={0: 'green', 1: 'red'}) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.subheader("Sequence Length Distribution") | |
| df['length'] = df['sequence'].apply(len) | |
| fig = px.histogram(df, x='length', color='label', | |
| color_discrete_map={0: 'green', 1: 'red'}) | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.warning("Please train the model first to explore the data.") | |
| if __name__ == "__main__": | |
| main() |