Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import re | |
| import PIL.Image | |
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| from datasets import load_dataset | |
| import matplotlib.pyplot as plt | |
| from sklearn.manifold import TSNE | |
| from sklearn.preprocessing import LabelEncoder | |
| import torch | |
| from torch import nn | |
| from transformers import BertConfig, BertForMaskedLM, PreTrainedTokenizerFast | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from pinecone import Pinecone | |
| import rasterio | |
| from rasterio.sample import sample_gen | |
| from config import DEFAULT_INPUTS, MODELS, DATASETS, ID_TO_GENUS_MAP, LAYER_NAMES | |
| # We need this for the eco layers because they are too big | |
| PIL.Image.MAX_IMAGE_PIXELS = None | |
| torch.set_grad_enabled(False) | |
| # Configure pinecone | |
| pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
| pc_index = pc.Index("amazon") | |
| # Load models | |
| class DNASeqClassifier(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, bert_model, env_dim, num_classes): | |
| super(DNASeqClassifier, self).__init__() | |
| self.bert = bert_model | |
| self.env_dim = env_dim | |
| self.num_classes = num_classes | |
| self.fc = nn.Linear(768 + env_dim, num_classes) | |
| def forward(self, bert_inputs, env_data): | |
| outputs = self.bert(**bert_inputs) | |
| dna_embeddings = outputs.hidden_states[-1].mean(1) | |
| combined = torch.cat((dna_embeddings, env_data), dim=1) | |
| logits = self.fc(combined) | |
| return logits | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(MODELS["embeddings"]) | |
| embeddings_model = BertForMaskedLM.from_pretrained(MODELS["embeddings"]) | |
| classification_model = DNASeqClassifier.from_pretrained( | |
| MODELS["classification"], | |
| bert_model=BertForMaskedLM( | |
| BertConfig(vocab_size=259, output_hidden_states=True), | |
| ), | |
| ) | |
| with open("scaler.pkl", "rb") as f: | |
| scaler = pickle.load(f) | |
| embeddings_model.eval() | |
| classification_model.eval() | |
| # Load datasets | |
| amazon_ds = load_dataset(DATASETS["amazon"])['train'].to_pandas() | |
| amazon_ds = amazon_ds[amazon_ds["genus"].notna()] | |
| def set_default_inputs(): | |
| return (DEFAULT_INPUTS["dna_sequence"], | |
| DEFAULT_INPUTS["latitude"], | |
| DEFAULT_INPUTS["longitude"]) | |
| def preprocess(dna_sequence: str, latitude: float, longitude: float): | |
| """Prepares app input for downsteram tasks""" | |
| # Preprocess the DNA sequence turning it into an embedding | |
| dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
| dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
| dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
| dna_seq_preprocessed = " ".join([ | |
| dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
| ]) | |
| dna_embedding: torch.Tensor = embeddings_model( | |
| **tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
| ).hidden_states[-1].mean(1).squeeze() | |
| # Preprocess the location data | |
| coords = (float(latitude), float(longitude)) | |
| return dna_embedding, coords[0], coords[1] | |
| def tokenize(dna_sequence: str) -> dict[str, torch.Tensor]: | |
| dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
| dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
| dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
| dna_seq_preprocessed = " ".join([ | |
| dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
| ]) | |
| return tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
| def get_embedding(dna_sequence: str) -> torch.Tensor: | |
| dna_embedding: torch.Tensor = embeddings_model( | |
| **tokenize(dna_sequence) | |
| ).hidden_states[-1].mean(1).squeeze() | |
| return dna_embedding | |
| def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str): | |
| coords = (float(latitude), float(longitude)) | |
| if method == "cosine": | |
| embedding = get_embedding(dna_sequence) | |
| result = pc_index.query( | |
| namespace="all", | |
| vector=embedding.tolist(), | |
| top_k=10, | |
| include_metadata=True, | |
| ) | |
| top_k = [m["metadata"]["genus"] for m in result["matches"]] | |
| top_k = pd.Series(top_k).value_counts() | |
| top_k = top_k / top_k.sum() | |
| if method == "fine_tuned_model": | |
| bert_inputs = tokenize(dna_sequence) | |
| env_data = [] | |
| for layer in LAYER_NAMES: | |
| with rasterio.open(layer) as dataset: | |
| # Get the corresponding ecological values for the samples | |
| results = sample_gen(dataset, [coords]) | |
| results = [r for r in results] | |
| layer_data = np.mean(results[0]) | |
| env_data.append(layer_data) | |
| env_data = scaler.transform([env_data]) | |
| env_data = torch.from_numpy(env_data).to(torch.float32) | |
| logits = classification_model(bert_inputs, env_data) | |
| temperature = 0.2 | |
| probs = torch.softmax(logits / temperature, dim=1).squeeze() | |
| top_k = torch.topk(probs, 10) | |
| top_k = pd.Series( | |
| top_k.values.detach().numpy(), | |
| index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()] | |
| ) | |
| # fig, ax = plt.subplots() | |
| # ax.bar(top_k.index.astype(str), top_k.values) | |
| # ax.set_ylim(0, 1) | |
| # ax.set_title("Genus Prediction") | |
| # ax.set_xlabel("Genus") | |
| # ax.set_ylabel("Probability") | |
| # ax.set_xticks(range(len(top_k))) | |
| # ax.set_xticklabels(top_k.index.astype(str), rotation=90) | |
| # fig.subplots_adjust(bottom=0.3) | |
| # fig.canvas.draw() | |
| # return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| return top_k | |
| def genus_hist(method: str, dna_sequence: str, latitude: str, longitude: str): | |
| top_k = predict_genus(method, dna_sequence, latitude, longitude) | |
| fig, ax = plt.subplots() | |
| ax.bar(top_k.index.astype(str), top_k.values) | |
| ax.set_ylim(0, 1) | |
| ax.set_title("Genus Prediction") | |
| ax.set_xlabel("Genus") | |
| ax.set_ylabel("Probability") | |
| ax.set_xticks(range(len(top_k))) | |
| ax.set_xticklabels(top_k.index.astype(str), rotation=90) | |
| fig.subplots_adjust(bottom=0.3) | |
| fig.canvas.draw() | |
| return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| def cluster_dna(k: float): | |
| df = amazon_ds | |
| # df = df[df["genus"].notna()] | |
| k = int(k) | |
| genus_counts = df["genus"].value_counts() | |
| top_genuses = genus_counts.head(k).index | |
| df = df[df["genus"].isin(top_genuses)] | |
| tsne = TSNE( | |
| n_components=2, perplexity=30, learning_rate=200, | |
| n_iter=1000, random_state=0, | |
| ) | |
| X = np.stack(df["embeddings"].tolist()) | |
| y = df["genus"].tolist() | |
| X_tsne = tsne.fit_transform(X) | |
| label_encoder = LabelEncoder() | |
| y_encoded = label_encoder.fit_transform(y) | |
| classes = list(label_encoder.inverse_transform(range(len(df['genus'].unique())))) | |
| fig, ax = plt.subplots() | |
| plot = ax.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="tab20", alpha=0.7) | |
| handles, _ = plot.legend_elements(prop='colors') | |
| ax.legend(handles, classes) | |
| ax.set_title(f"DNA Embedding Space (of {str(k)} most common genera)") | |
| # Reduce unnecessary whitespace | |
| ax.set_xlim(X_tsne[:, 0].min() - 0.1, X_tsne[:, 0].max() + 0.1) | |
| fig.canvas.draw() | |
| return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| def cluster_dna2(k: float, method: str, dna_sequence: str, latitude: str, longitude: str): | |
| top_genuses = predict_genus(method, dna_sequence, latitude, longitude) | |
| embed = get_embedding(dna_sequence).tolist() | |
| # df = amazon_ds["train"].to_pandas() | |
| df = amazon_ds | |
| # df = df[df["genus"].notna()] | |
| k = int(k) | |
| # genus_counts = df["genus"].value_counts() | |
| top_genuses = top_genuses.head(k).index | |
| df = df[df["genus"].isin(top_genuses)] | |
| tsne = TSNE( | |
| n_components=2, perplexity=5, learning_rate=200, | |
| n_iter=1000, random_state=0, | |
| ) | |
| X = np.vstack([df['embeddings'].tolist(), embed]) | |
| # X = np.stack(df["embeddings"].tolist()) | |
| y = df["genus"].tolist() | |
| X_tsne = tsne.fit_transform(X) | |
| tsne_embed_space = X_tsne[:-1] | |
| tsne_single = X_tsne[-1] | |
| label_encoder = LabelEncoder() | |
| y_encoded = label_encoder.fit_transform(y) | |
| classes = list(label_encoder.inverse_transform(range(len(df['genus'].unique())))) | |
| fig, ax = plt.subplots() | |
| plot = ax.scatter(tsne_embed_space[:, 0], tsne_embed_space[:, 1], c=y_encoded, cmap="tab20", alpha=0.7) | |
| ax.scatter(tsne_single[0], tsne_single[1], color='red', edgecolor='black') | |
| handles, _ = plot.legend_elements(prop='colors') | |
| ax.legend(handles, classes) | |
| # ax.legend(loc='best') | |
| ax.text(tsne_single[0], tsne_single[1], 'Your DNA Seq', fontsize=10, color='black') | |
| ax.set_title(f"DNA Embedding Space Around Your DNA's Embedding") | |
| # Reduce unnecessary whitespace | |
| ax.set_xlim(X_tsne[:, 0].min() + 0.1, X_tsne[:, 0].max() + 0.1) | |
| fig.canvas.draw() | |
| return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| with gr.Blocks() as demo: | |
| # Header section | |
| gr.Markdown((""" | |
| # DNA Identifier Tool | |
| Welcome to Lofi Amazon Beats' DNA Identifier Tool. Please enter a DNA | |
| sequence and the coordinates at which its sample was taken to get | |
| started. Click 'I'm feeling lucky' to see use a random sequence. | |
| For more information on how to use check out our | |
| [README](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/README.md) | |
| """)) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)") | |
| with gr.Column(): | |
| with gr.Row(): | |
| inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. 2.009083") | |
| with gr.Row(): | |
| inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -41.68281") | |
| with gr.Row(): | |
| btn_defaults = gr.Button("I'm feeling lucky") | |
| btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng]) | |
| with gr.Tab("Genus Prediction"): | |
| gr.Markdown(""" | |
| ## Genus prediction | |
| A demo of predicting the genus of a DNA sequence using multiple | |
| approaches (method dropdown): | |
| - **fine_tuned_model**: uses our | |
| `LofiAmazon/BarcodeBERT-Finetuned-Amazon` model which predicts the genus | |
| based on the DNA sequence and environmental data. | |
| - **cosine**: computes a cosine similarity between the DNA sequence | |
| embedding generated by our model and the embeddings of known samples | |
| that we precomputed and stored. This method DOES NOT use ecological layer data. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| method_dropdown = gr.Dropdown( | |
| choices=["cosine", "fine_tuned_model"], value="fine_tuned_model", | |
| ) | |
| predict_button = gr.Button("Predict Genus") | |
| with gr.Column(): | |
| genus_output = gr.Image() | |
| predict_button.click( | |
| fn=genus_hist, | |
| inputs=[method_dropdown, inp_dna, inp_lat, inp_lng], | |
| outputs=genus_output | |
| ) | |
| with gr.Tab("DNA Embedding Space Visualizer"): | |
| gr.Markdown(""" | |
| ## DNA Embedding Space Visualizer | |
| Use this tool to visualize how our DNA Transformer model | |
| learns to cluster similar DNA sequences together. | |
| """) | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # top_k_slider = gr.Slider( | |
| # minimum=1, maximum=10, step=1, value=5, | |
| # label="Choose **k**, the number of top genera to visualize", | |
| # ) | |
| # visualize_button = gr.Button("Visualize Embedding Space") | |
| # with gr.Column(): | |
| # visualize_output = gr.Image() | |
| # visualize_button.click( | |
| # fn=cluster_dna, | |
| # inputs=top_k_slider, | |
| # outputs=visualize_output | |
| # ) | |
| with gr.Row(): | |
| top_k_slider = gr.Slider( | |
| minimum=1, maximum=10, step=1, value=5, | |
| label="Choose **k**, the number of top genera to visualize", | |
| ) | |
| visualize_button = gr.Button("Visualize Embedding Space") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| t-SNE plot of the DNA embedding spaces of the **k** most common | |
| genera in our dataset. | |
| """) | |
| visualize_output = gr.Image() | |
| visualize_button.click( | |
| fn=cluster_dna, | |
| inputs=top_k_slider, | |
| outputs=visualize_output | |
| ) | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| t-SNE plot of the DNA embedding spaces of the **k** most likely | |
| genera for the DNA sequence you provided. | |
| """) | |
| visualize_output2 = gr.Image() | |
| visualize_button.click( | |
| fn=cluster_dna2, | |
| inputs=[top_k_slider, method_dropdown, inp_dna, inp_lat, inp_lng], | |
| outputs=visualize_output2 | |
| ) | |
| demo.launch() | |