YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Protein Sequence-Level Prediction with Multiple Token Aggregation Methods
Extract residue embeddings from a frozen ESM2 backbone and aggregate them into sequence-level representations using 6 different strategies for downstream tasks like subcellular localization prediction.
Aggregation Methods
| # | Method | Class | Output Dim | Description |
|---|---|---|---|---|
| 1 | Mean | MeanPooling |
d |
Average over non-padded residue embeddings |
| 2 | Max | MaxPooling |
d |
Element-wise max over non-padded residue embeddings |
| 3 | CLS | CLSPooling |
d |
ESM2's <cls> token representation (position 0) |
| 4 | GLOT | GLOTPooling |
p*(K+1) |
Cosine-similarity token graph β GAT GNN β attention readout (arXiv:2603.03389) |
| 5 | GLOT-Residue | GLOTResidueGraphPooling |
p*(K+1) |
Protein 3D residue contact graph (via graphein) β GAT GNN β attention readout |
| 6 | Covariance | CovariancePooling |
d_proj*(d_proj+1)/2 |
Second-order covariance pooling with power normalization (ref) |
Where d = ESM2 hidden dimension (e.g. 480 for 35M model), p = GNN hidden dim (default 128), K = GNN layers (default 2).
Supported ESM2 Backbones
The backbone is changeable β just pass a different model name:
| Model | Params | Hidden Dim |
|---|---|---|
facebook/esm2_t6_8M_UR50D |
8M | 320 |
facebook/esm2_t12_35M_UR50D |
35M | 480 (default) |
facebook/esm2_t30_150M_UR50D |
150M | 640 |
facebook/esm2_t33_650M_UR50D |
650M | 1280 |
facebook/esm2_t36_3B_UR50D |
3B | 2560 |
Quick Start
from protein_aggregator import ProteinSequenceClassifier
# Build model: frozen ESM2 + GLOT aggregation + 10-class head
model = ProteinSequenceClassifier(
esm2_model_name="facebook/esm2_t12_35M_UR50D",
aggregation="glot", # "mean", "max", "cls", "glot", "glot_residue", "covariance"
num_classes=10,
aggregator_kwargs={"p": 128, "K": 2, "tau": 0.6},
classifier_hidden=256,
dropout=0.1,
).cuda()
# Get sequence-level embeddings
embeddings = model.encode(["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"])
print(embeddings.shape) # [2, 384] (p*(K+1) = 128*3)
# Or full forward pass with loss
inputs = model.tokenizer(
["MKTAYIAKQRQISFVK", "ACDEFGHIKLMNPQR"],
padding=True, truncation=True, return_tensors="pt"
).to("cuda")
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=torch.tensor([0, 3]).cuda(),
)
loss = outputs["loss"]
logits = outputs["logits"]
Using GLOT-Residue with PDB Files
When 3D structure is available, glot_residue builds the token graph from the protein's CΞ±-CΞ± contact map (8Γ
threshold) using graphein:
model = ProteinSequenceClassifier(
aggregation="glot_residue",
num_classes=10,
aggregator_kwargs={
"contact_threshold": 8.0, # CΞ±-CΞ± distance in Γ
"seq_neighbor_k": 5, # fallback: Β±k sequence neighbors if no PDB
},
)
# With PDB files
outputs = model(input_ids=ids, attention_mask=mask, pdb_paths=["1abc.pdb", "2def.pdb"])
# Without PDB files (falls back to sequence-distance graph)
outputs = model(input_ids=ids, attention_mask=mask)
Using Covariance Pooling
Captures second-order statistics (feature co-activations) across residue positions:
model = ProteinSequenceClassifier(
aggregation="covariance",
num_classes=10,
aggregator_kwargs={"d_proj": 64}, # output dim = 64*65/2 = 2080
)
The d_proj parameter controls the output dimensionality:
d_proj=32β 528 dimsd_proj=64β 2080 dims (default)d_proj=128β 8256 dims
Architecture Overview
Protein Sequence
β
ββββββββββββββββββββ
β ESM2 (frozen) β Extracts per-residue embeddings [B, L, d]
ββββββββββββββββββββ
β
ββββββββββββββββββββ
β Aggregator β Compresses token-level β sequence-level [B, agg_dim]
β (one of 6) β
ββββββββββββββββββββ
β
ββββββββββββββββββββ
β Classifier Head β Linear (+ optional hidden layer) β [B, num_classes]
ββββββββββββββββββββ
Only the aggregator and classifier are trained. ESM2 is always frozen.
GLOT Details
The GLOT aggregation (methods 4 & 5) follows arXiv:2603.03389:
- Token Graph Construction β For standard GLOT: pairwise cosine similarity between residue embeddings β threshold at Ο (default 0.6) β binary adjacency. For GLOT-Residue: CΞ±-CΞ± distance contact map from 3D structure.
- Token-GNN β K layers of GATConv (Graph Attention Network) with ReLU, followed by Jumping Knowledge concatenation of all layer outputs.
- Attention Readout β Learned per-token importance scores β softmax β weighted sum to produce the sequence vector.
Default hyperparameters (from the paper): p=128, K=2, tau=0.6, n_heads=4, lr=2e-4, no weight decay, Adam optimizer.
Dependencies
pip install torch torch-geometric transformers
# For GLOT-Residue with PDB files:
pip install graphein biopython
File Structure
protein_aggregator/
βββ __init__.py # Package exports
βββ aggregators.py # All 6 aggregation method implementations
βββ model.py # ProteinSequenceClassifier (ESM2 + aggregator + head)
example_localization.py # Usage example for subcellular localization
References
- GLOT: Mantri et al., "Towards Improved Sentence Representations using Token Graphs", arXiv:2603.03389 (2025). Paper | Code
- ESM2: Lin et al., "Evolutionary-scale prediction of atomic-level protein structure with a language model", Science 2023. Models
- Covariance Pooling: Goodfire Research
- Graphein: Jamasb et al., "Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Biomolecular Structures and Interaction Networks", NeurIPS 2022. Docs