Image Classification
English
breast
cancer
odelia
mueller-franzes commited on
Commit
fb68040
·
verified ·
1 Parent(s): 255fb0d

Upload epoch=17-step=1836.ckpt

Browse files
Files changed (4) hide show
  1. README.md +30 -0
  2. model_config.json +1 -1
  3. models.py +30 -8
  4. predict_attention.py +1 -1
README.md CHANGED
@@ -53,6 +53,36 @@ extra_gated_prompt: >-
53
  # ODELIA Classification Baseline Model
54
  For a comprehensive description of the model and its intended use, please refer to our paper: [Read the paper](https://arxiv.org/abs/2506.00474)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  ## Get Probabilities and Attention
58
 
 
53
  # ODELIA Classification Baseline Model
54
  For a comprehensive description of the model and its intended use, please refer to our paper: [Read the paper](https://arxiv.org/abs/2506.00474)
55
 
56
+ ## Setup
57
+
58
+ To run the code, we recommend creating a Python virtual environment.
59
+
60
+ ### Using venv
61
+
62
+ ```bash
63
+ # Create a virtual environment
64
+ python -m venv venv
65
+
66
+ # Activate the environment
67
+ # On Linux/Mac:
68
+ source venv/bin/activate
69
+ # On Windows:
70
+ # venv\Scripts\activate
71
+
72
+ # Install dependencies
73
+ pip install torch torchvision numpy huggingface_hub torchio matplotlib transformers einops x_transformers
74
+ ```
75
+
76
+ ### Using Conda
77
+
78
+ ```bash
79
+ # Create a conda environment
80
+ conda create -n odelia_hf python=3.10
81
+ conda activate odelia_hf
82
+
83
+ # Install dependencies
84
+ pip install torch torchvision numpy huggingface_hub torchio matplotlib transformers einops x_transformers
85
+ ```
86
 
87
  ## Get Probabilities and Attention
88
 
model_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "checkpoint_source": "epoch=17-step=1836.ckpt",
3
- "created_at": "2025-10-26T16:51:08.236236Z",
4
  "hparams": {
5
  "backbone_type": "dinov3",
6
  "in_ch": 1,
 
1
  {
2
  "checkpoint_source": "epoch=17-step=1836.ckpt",
3
+ "created_at": "2025-11-23T17:24:34.909430Z",
4
  "hparams": {
5
  "backbone_type": "dinov3",
6
  "in_ch": 1,
models.py CHANGED
@@ -2,7 +2,7 @@ from einops import rearrange
2
  import torch.nn as nn
3
  import torch
4
  import math
5
- from transformers import AutoModel
6
  from x_transformers import Encoder
7
 
8
 
@@ -13,22 +13,44 @@ class _MST(nn.Module):
13
  backbone_type="dinov3",
14
  model_size = "s", # 34, 50, ... or 's', 'b', 'l'
15
  slice_fusion_type = "transformer", # transformer, linear, average, none
 
16
  ):
17
  super().__init__()
18
  self.backbone_type = backbone_type
19
  self.slice_fusion_type = slice_fusion_type
20
 
21
  if backbone_type == "dinov2":
22
- model_size = {'s':'small', 'b':'base', 'l':'large'}.get(model_size)
23
- self.backbone = AutoModel.from_pretrained(f"facebook/dinov2-with-registers-{model_size}")
 
 
 
 
 
 
 
 
 
 
 
 
24
  emb_ch = self.backbone.config.hidden_size
25
  elif backbone_type == "dinov3":
26
- self.backbone = AutoModel.from_pretrained(f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m")
27
- emb_ch = self.backbone.config.hidden_size
 
 
 
 
 
 
 
 
 
28
  else:
29
  raise ValueError("Unknown backbone_type")
30
 
31
-
32
  self.emb_ch = emb_ch
33
  if slice_fusion_type == "transformer":
34
  self.slice_fusion = Encoder(
@@ -144,9 +166,9 @@ class _MST(nn.Module):
144
 
145
 
146
  class MSTRegression(nn.Module):
147
- def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", optimizer_kwargs={'lr':1e-5}, **kwargs):
148
  super().__init__()
149
- self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type)
150
 
151
  def forward(self, x):
152
  return self.mst(x)
 
2
  import torch.nn as nn
3
  import torch
4
  import math
5
+ from transformers import AutoModel, Dinov2WithRegistersModel, Dinov2WithRegistersConfig, DINOv3ViTConfig, DINOv3ViTModel
6
  from x_transformers import Encoder
7
 
8
 
 
13
  backbone_type="dinov3",
14
  model_size = "s", # 34, 50, ... or 's', 'b', 'l'
15
  slice_fusion_type = "transformer", # transformer, linear, average, none
16
+ weights=True,
17
  ):
18
  super().__init__()
19
  self.backbone_type = backbone_type
20
  self.slice_fusion_type = slice_fusion_type
21
 
22
  if backbone_type == "dinov2":
23
+ model_size_key = {'s':'small', 'b':'base', 'l':'large'}.get(model_size)
24
+ model_name = f"facebook/dinov2-with-registers-{model_size_key}"
25
+ if weights:
26
+ self.backbone = AutoModel.from_pretrained(model_name)
27
+ else:
28
+ configs = {
29
+ 'small': Dinov2WithRegistersConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6),
30
+ 'base': Dinov2WithRegistersConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12),
31
+ 'large': Dinov2WithRegistersConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16),
32
+ }
33
+ config = configs.get(model_size_key)
34
+ config.image_size = 518
35
+ config.patch_size = 14
36
+ self.backbone = Dinov2WithRegistersModel(config)
37
  emb_ch = self.backbone.config.hidden_size
38
  elif backbone_type == "dinov3":
39
+ model_name = f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m"
40
+ if weights:
41
+ self.backbone = AutoModel.from_pretrained(model_name)
42
+ else:
43
+ configs = {
44
+ 's': DINOv3ViTConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, intermediate_size=1536, patch_size=16, num_register_tokens=4),
45
+ 'b': DINOv3ViTConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, patch_size=16, num_register_tokens=4),
46
+ 'l': DINOv3ViTConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, patch_size=16, num_register_tokens=4),
47
+ }
48
+ config = configs.get(model_size)
49
+ self.backbone = DINOv3ViTModel(config)
50
  else:
51
  raise ValueError("Unknown backbone_type")
52
 
53
+ emb_ch = self.backbone.config.hidden_size
54
  self.emb_ch = emb_ch
55
  if slice_fusion_type == "transformer":
56
  self.slice_fusion = Encoder(
 
166
 
167
 
168
  class MSTRegression(nn.Module):
169
+ def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", weights=True, **kwargs):
170
  super().__init__()
171
+ self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type, weights=weights)
172
 
173
  def forward(self, x):
174
  return self.mst(x)
predict_attention.py CHANGED
@@ -124,7 +124,7 @@ def load_model(repo_id= "ODELIA-AI/MST") -> MSTRegression:
124
  config = json.load(fp)
125
 
126
  hparams = config.get("hparams", {})
127
- model = MSTRegression(**hparams)
128
 
129
  state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt")
130
  state_dict = torch.load(state_dict_path, map_location="cpu")
 
124
  config = json.load(fp)
125
 
126
  hparams = config.get("hparams", {})
127
+ model = MSTRegression(weights=False, **hparams)
128
 
129
  state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt")
130
  state_dict = torch.load(state_dict_path, map_location="cpu")