Upload epoch=17-step=1836.ckpt
Browse files- README.md +30 -0
- model_config.json +1 -1
- models.py +30 -8
- predict_attention.py +1 -1
README.md
CHANGED
|
@@ -53,6 +53,36 @@ extra_gated_prompt: >-
|
|
| 53 |
# ODELIA Classification Baseline Model
|
| 54 |
For a comprehensive description of the model and its intended use, please refer to our paper: [Read the paper](https://arxiv.org/abs/2506.00474)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
## Get Probabilities and Attention
|
| 58 |
|
|
|
|
| 53 |
# ODELIA Classification Baseline Model
|
| 54 |
For a comprehensive description of the model and its intended use, please refer to our paper: [Read the paper](https://arxiv.org/abs/2506.00474)
|
| 55 |
|
| 56 |
+
## Setup
|
| 57 |
+
|
| 58 |
+
To run the code, we recommend creating a Python virtual environment.
|
| 59 |
+
|
| 60 |
+
### Using venv
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
# Create a virtual environment
|
| 64 |
+
python -m venv venv
|
| 65 |
+
|
| 66 |
+
# Activate the environment
|
| 67 |
+
# On Linux/Mac:
|
| 68 |
+
source venv/bin/activate
|
| 69 |
+
# On Windows:
|
| 70 |
+
# venv\Scripts\activate
|
| 71 |
+
|
| 72 |
+
# Install dependencies
|
| 73 |
+
pip install torch torchvision numpy huggingface_hub torchio matplotlib transformers einops x_transformers
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### Using Conda
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
# Create a conda environment
|
| 80 |
+
conda create -n odelia_hf python=3.10
|
| 81 |
+
conda activate odelia_hf
|
| 82 |
+
|
| 83 |
+
# Install dependencies
|
| 84 |
+
pip install torch torchvision numpy huggingface_hub torchio matplotlib transformers einops x_transformers
|
| 85 |
+
```
|
| 86 |
|
| 87 |
## Get Probabilities and Attention
|
| 88 |
|
model_config.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"checkpoint_source": "epoch=17-step=1836.ckpt",
|
| 3 |
-
"created_at": "2025-
|
| 4 |
"hparams": {
|
| 5 |
"backbone_type": "dinov3",
|
| 6 |
"in_ch": 1,
|
|
|
|
| 1 |
{
|
| 2 |
"checkpoint_source": "epoch=17-step=1836.ckpt",
|
| 3 |
+
"created_at": "2025-11-23T17:24:34.909430Z",
|
| 4 |
"hparams": {
|
| 5 |
"backbone_type": "dinov3",
|
| 6 |
"in_ch": 1,
|
models.py
CHANGED
|
@@ -2,7 +2,7 @@ from einops import rearrange
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch
|
| 4 |
import math
|
| 5 |
-
from transformers import AutoModel
|
| 6 |
from x_transformers import Encoder
|
| 7 |
|
| 8 |
|
|
@@ -13,22 +13,44 @@ class _MST(nn.Module):
|
|
| 13 |
backbone_type="dinov3",
|
| 14 |
model_size = "s", # 34, 50, ... or 's', 'b', 'l'
|
| 15 |
slice_fusion_type = "transformer", # transformer, linear, average, none
|
|
|
|
| 16 |
):
|
| 17 |
super().__init__()
|
| 18 |
self.backbone_type = backbone_type
|
| 19 |
self.slice_fusion_type = slice_fusion_type
|
| 20 |
|
| 21 |
if backbone_type == "dinov2":
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
emb_ch = self.backbone.config.hidden_size
|
| 25 |
elif backbone_type == "dinov3":
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
else:
|
| 29 |
raise ValueError("Unknown backbone_type")
|
| 30 |
|
| 31 |
-
|
| 32 |
self.emb_ch = emb_ch
|
| 33 |
if slice_fusion_type == "transformer":
|
| 34 |
self.slice_fusion = Encoder(
|
|
@@ -144,9 +166,9 @@ class _MST(nn.Module):
|
|
| 144 |
|
| 145 |
|
| 146 |
class MSTRegression(nn.Module):
|
| 147 |
-
def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer",
|
| 148 |
super().__init__()
|
| 149 |
-
self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type)
|
| 150 |
|
| 151 |
def forward(self, x):
|
| 152 |
return self.mst(x)
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch
|
| 4 |
import math
|
| 5 |
+
from transformers import AutoModel, Dinov2WithRegistersModel, Dinov2WithRegistersConfig, DINOv3ViTConfig, DINOv3ViTModel
|
| 6 |
from x_transformers import Encoder
|
| 7 |
|
| 8 |
|
|
|
|
| 13 |
backbone_type="dinov3",
|
| 14 |
model_size = "s", # 34, 50, ... or 's', 'b', 'l'
|
| 15 |
slice_fusion_type = "transformer", # transformer, linear, average, none
|
| 16 |
+
weights=True,
|
| 17 |
):
|
| 18 |
super().__init__()
|
| 19 |
self.backbone_type = backbone_type
|
| 20 |
self.slice_fusion_type = slice_fusion_type
|
| 21 |
|
| 22 |
if backbone_type == "dinov2":
|
| 23 |
+
model_size_key = {'s':'small', 'b':'base', 'l':'large'}.get(model_size)
|
| 24 |
+
model_name = f"facebook/dinov2-with-registers-{model_size_key}"
|
| 25 |
+
if weights:
|
| 26 |
+
self.backbone = AutoModel.from_pretrained(model_name)
|
| 27 |
+
else:
|
| 28 |
+
configs = {
|
| 29 |
+
'small': Dinov2WithRegistersConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6),
|
| 30 |
+
'base': Dinov2WithRegistersConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12),
|
| 31 |
+
'large': Dinov2WithRegistersConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16),
|
| 32 |
+
}
|
| 33 |
+
config = configs.get(model_size_key)
|
| 34 |
+
config.image_size = 518
|
| 35 |
+
config.patch_size = 14
|
| 36 |
+
self.backbone = Dinov2WithRegistersModel(config)
|
| 37 |
emb_ch = self.backbone.config.hidden_size
|
| 38 |
elif backbone_type == "dinov3":
|
| 39 |
+
model_name = f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m"
|
| 40 |
+
if weights:
|
| 41 |
+
self.backbone = AutoModel.from_pretrained(model_name)
|
| 42 |
+
else:
|
| 43 |
+
configs = {
|
| 44 |
+
's': DINOv3ViTConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, intermediate_size=1536, patch_size=16, num_register_tokens=4),
|
| 45 |
+
'b': DINOv3ViTConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, patch_size=16, num_register_tokens=4),
|
| 46 |
+
'l': DINOv3ViTConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, patch_size=16, num_register_tokens=4),
|
| 47 |
+
}
|
| 48 |
+
config = configs.get(model_size)
|
| 49 |
+
self.backbone = DINOv3ViTModel(config)
|
| 50 |
else:
|
| 51 |
raise ValueError("Unknown backbone_type")
|
| 52 |
|
| 53 |
+
emb_ch = self.backbone.config.hidden_size
|
| 54 |
self.emb_ch = emb_ch
|
| 55 |
if slice_fusion_type == "transformer":
|
| 56 |
self.slice_fusion = Encoder(
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
class MSTRegression(nn.Module):
|
| 169 |
+
def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", weights=True, **kwargs):
|
| 170 |
super().__init__()
|
| 171 |
+
self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type, weights=weights)
|
| 172 |
|
| 173 |
def forward(self, x):
|
| 174 |
return self.mst(x)
|
predict_attention.py
CHANGED
|
@@ -124,7 +124,7 @@ def load_model(repo_id= "ODELIA-AI/MST") -> MSTRegression:
|
|
| 124 |
config = json.load(fp)
|
| 125 |
|
| 126 |
hparams = config.get("hparams", {})
|
| 127 |
-
model = MSTRegression(**hparams)
|
| 128 |
|
| 129 |
state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt")
|
| 130 |
state_dict = torch.load(state_dict_path, map_location="cpu")
|
|
|
|
| 124 |
config = json.load(fp)
|
| 125 |
|
| 126 |
hparams = config.get("hparams", {})
|
| 127 |
+
model = MSTRegression(weights=False, **hparams)
|
| 128 |
|
| 129 |
state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt")
|
| 130 |
state_dict = torch.load(state_dict_path, map_location="cpu")
|