YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Protein Sequence-Level Prediction with Multiple Token Aggregation Methods

Extract residue embeddings from a frozen ESM2 backbone and aggregate them into sequence-level representations using 6 different strategies for downstream tasks like subcellular localization prediction.

Aggregation Methods

# Method Class Output Dim Description
1 Mean MeanPooling d Average over non-padded residue embeddings
2 Max MaxPooling d Element-wise max over non-padded residue embeddings
3 CLS CLSPooling d ESM2's <cls> token representation (position 0)
4 GLOT GLOTPooling p*(K+1) Cosine-similarity token graph β†’ GAT GNN β†’ attention readout (arXiv:2603.03389)
5 GLOT-Residue GLOTResidueGraphPooling p*(K+1) Protein 3D residue contact graph (via graphein) β†’ GAT GNN β†’ attention readout
6 Covariance CovariancePooling d_proj*(d_proj+1)/2 Second-order covariance pooling with power normalization (ref)

Where d = ESM2 hidden dimension (e.g. 480 for 35M model), p = GNN hidden dim (default 128), K = GNN layers (default 2).

Supported ESM2 Backbones

The backbone is changeable β€” just pass a different model name:

Model Params Hidden Dim
facebook/esm2_t6_8M_UR50D 8M 320
facebook/esm2_t12_35M_UR50D 35M 480 (default)
facebook/esm2_t30_150M_UR50D 150M 640
facebook/esm2_t33_650M_UR50D 650M 1280
facebook/esm2_t36_3B_UR50D 3B 2560

Quick Start

from protein_aggregator import ProteinSequenceClassifier

# Build model: frozen ESM2 + GLOT aggregation + 10-class head
model = ProteinSequenceClassifier(
    esm2_model_name="facebook/esm2_t12_35M_UR50D",
    aggregation="glot",          # "mean", "max", "cls", "glot", "glot_residue", "covariance"
    num_classes=10,
    aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6},
    classifier_hidden=256,
    dropout=0.1,
).cuda()

# Get sequence-level embeddings
embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"])
print(embeddings.shape)  # [2, 384]  (p*(K+1) = 128*3)

# Or full forward pass with loss
inputs = model.tokenizer(
    ["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"],
    padding=True, truncation=True, return_tensors="pt"
).to("cuda")

outputs = model(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=torch.tensor([0, 3]).cuda(),
)
loss = outputs["loss"]
logits = outputs["logits"]

Using GLOT-Residue with PDB Files

When 3D structure is available, glot_residue builds the token graph from the protein's CΞ±-CΞ± contact map (8Γ… threshold) using graphein:

model = ProteinSequenceClassifier(
    aggregation="glot_residue",
    num_classes=10,
    aggregator_kwargs={
        "contact_threshold": 8.0,  # CΞ±-CΞ± distance in Γ…
        "seq_neighbor_k": 5,       # fallback: Β±k sequence neighbors if no PDB
    },
)

# With PDB files
outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"])

# Without PDB files (falls back to sequence-distance graph)
outputs = model(input_ids=ids, attention_mask=mask)

Using Covariance Pooling

Captures second-order statistics (feature co-activations) across residue positions:

model = ProteinSequenceClassifier(
    aggregation="covariance",
    num_classes=10,
    aggregator_kwargs={"d_proj": 64},  # output dim = 64*65/2 = 2080
)

The d_proj parameter controls the output dimensionality:

  • d_proj=32 β†’ 528 dims
  • d_proj=64 β†’ 2080 dims (default)
  • d_proj=128 β†’ 8256 dims

Architecture Overview

Protein Sequence
       ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  ESM2 (frozen)   β”‚  Extracts per-residue embeddings [B, L, d]
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   Aggregator     β”‚  Compresses token-level β†’ sequence-level [B, agg_dim]
β”‚  (one of 6)      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Classifier Head β”‚  Linear (+ optional hidden layer) β†’ [B, num_classes]
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Only the aggregator and classifier are trained. ESM2 is always frozen.

GLOT Details

The GLOT aggregation (methods 4 & 5) follows arXiv:2603.03389:

  1. Token Graph Construction β€” For standard GLOT: pairwise cosine similarity between residue embeddings β†’ threshold at Ο„ (default 0.6) β†’ binary adjacency. For GLOT-Residue: CΞ±-CΞ± distance contact map from 3D structure.
  2. Token-GNN β€” K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs.
  3. Attention Readout β€” Learned per-token importance scores β†’ softmax β†’ weighted sum to produce the sequence vector.

Default hyperparameters (from the paper): p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer.

Dependencies

pip install torch torch-geometric transformers
# For GLOT-Residue with PDB files:
pip install graphein biopython

File Structure

protein_aggregator/
β”œβ”€β”€ __init__.py          # Package exports
β”œβ”€β”€ aggregators.py       # All 6 aggregation method implementations
└── model.py             # ProteinSequenceClassifier (ESM2 + aggregator + head)
example_localization.py  # Usage example for subcellular localization

References

  • GLOT: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). Paper | Code
  • ESM2: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. Models
  • Covariance Pooling: Goodfire Research
  • Graphein: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. Docs
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for AliSaadatV/protein-sequence-aggregators