Image Classification
English
breast
cancer
odelia
mueller-franzes commited on
Commit
255fb0d
·
verified ·
1 Parent(s): c85293d

Upload epoch=17-step=1836.ckpt

Browse files
Files changed (6) hide show
  1. README.md +72 -3
  2. epoch=17-step=1836.ckpt +3 -0
  3. model_config.json +26 -0
  4. models.py +155 -0
  5. predict_attention.py +171 -0
  6. state_dict.pt +3 -0
README.md CHANGED
@@ -1,3 +1,72 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ datasets:
4
+ - ODELIA-AI/ODELIA-Challenge-2025
5
+ language:
6
+ - en
7
+ metrics:
8
+ - roc_auc
9
+ pipeline_tag: image-classification
10
+ tags:
11
+ - breast
12
+ - cancer
13
+ - odelia
14
+ extra_gated_prompt: >-
15
+ ### 🛡️ Model Usage Agreement
16
+
17
+ By accessing or using this model (the “Model”), you acknowledge and agree to the following terms and conditions:
18
+
19
+ #### 1. Research-Only Use
20
+
21
+ The Model is provided strictly for non-commercial, academic, and research purposes. It must not be used for clinical decision-making, diagnosis, treatment, or any other application involving real patients or clinical care.
22
+
23
+ #### 2. No Clinical or Commercial Deployment
24
+
25
+ The Model is **not approved for clinical use** or any commercial application. Any deployment in healthcare settings or use for patient-related decision support is expressly prohibited.
26
+
27
+ #### 3. Redistribution and Modification
28
+
29
+ You may not copy, distribute, sublicense, or otherwise share the Model or any derivative works without prior written permission from the model authors or the ODELIA consortium.
30
+
31
+ #### 4. Privacy and Ethics Compliance
32
+
33
+ You must not attempt to identify, re-identify, or deanonymize any individual whose data may have contributed to the training or evaluation of the Model.
34
+
35
+ #### 5. Attribution Requirement
36
+
37
+ Any publication, presentation, or derivative work that uses or references this Model must include clear attribution to the **ODELIA consortium**, along with any citations specified in the accompanying documentation.
38
+
39
+ #### 6. Responsibility and Verification
40
+
41
+ You are solely responsible for verifying and validating the Model’s outputs and ensuring they are appropriate for your research context. The Model and its outputs are provided “as is,” without warranties of any kind.
42
+
43
+ #### 7. Inclusion of Third-Party Components
44
+
45
+ This Model incorporates or is derived from **DINOv3**, developed by **Meta Platforms**.
46
+ Use of the Model is therefore also subject to the **DINOv3 License Agreement**.
47
+ By using this Model, you agree to comply with both:
48
+
49
+ * This Model Usage Agreement, **and**
50
+ * The [DINOv3 License Terms](https://github.com/facebookresearch/dinov3).
51
+ ---
52
+
53
+ # ODELIA Classification Baseline Model
54
+ For a comprehensive description of the model and its intended use, please refer to our paper: [Read the paper](https://arxiv.org/abs/2506.00474)
55
+
56
+
57
+ ## Get Probabilities and Attention
58
+
59
+ To use this model, first download the required files from this repository:
60
+
61
+ ```python
62
+ from huggingface_hub import hf_hub_download
63
+
64
+ # Download model files to local directory
65
+ hf_hub_download(repo_id="ODELIA-AI/MST", filename="models.py", local_dir="./")
66
+ hf_hub_download(repo_id="ODELIA-AI/MST", filename="predict_attention.py", local_dir="./")
67
+ ```
68
+
69
+ Then execute `predict_attention.py --path_img path/to/Sub_1.nii.gz` to get probabilities and attention maps.
70
+
71
+
72
+
epoch=17-step=1836.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a42b9c83fa4ec1c9a7b0060df288ea6fd3c20a9c9f7002ee26be5fb27f320c71
3
+ size 277159866
model_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint_source": "epoch=17-step=1836.ckpt",
3
+ "created_at": "2025-10-26T16:51:08.236236Z",
4
+ "hparams": {
5
+ "backbone_type": "dinov3",
6
+ "in_ch": 1,
7
+ "loss": "<class 'odelia.models.utils.losses.MulitCELoss'>",
8
+ "loss_kwargs": {
9
+ "class_labels_num": [
10
+ 3
11
+ ]
12
+ },
13
+ "lr_scheduler": null,
14
+ "lr_scheduler_kwargs": {},
15
+ "model_size": "s",
16
+ "optimizer": "<class 'torch.optim.adamw.AdamW'>",
17
+ "optimizer_kwargs": {
18
+ "lr": 1e-05
19
+ },
20
+ "out_ch": 3,
21
+ "save_hyperparameters": true,
22
+ "slice_fusion_type": "transformer",
23
+ "spatial_dims": 3
24
+ },
25
+ "model_class": "odelia.models.mst.MSTRegression"
26
+ }
models.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ import torch.nn as nn
3
+ import torch
4
+ import math
5
+ from transformers import AutoModel
6
+ from x_transformers import Encoder
7
+
8
+
9
+ class _MST(nn.Module):
10
+ def __init__(
11
+ self,
12
+ out_ch=1,
13
+ backbone_type="dinov3",
14
+ model_size = "s", # 34, 50, ... or 's', 'b', 'l'
15
+ slice_fusion_type = "transformer", # transformer, linear, average, none
16
+ ):
17
+ super().__init__()
18
+ self.backbone_type = backbone_type
19
+ self.slice_fusion_type = slice_fusion_type
20
+
21
+ if backbone_type == "dinov2":
22
+ model_size = {'s':'small', 'b':'base', 'l':'large'}.get(model_size)
23
+ self.backbone = AutoModel.from_pretrained(f"facebook/dinov2-with-registers-{model_size}")
24
+ emb_ch = self.backbone.config.hidden_size
25
+ elif backbone_type == "dinov3":
26
+ self.backbone = AutoModel.from_pretrained(f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m")
27
+ emb_ch = self.backbone.config.hidden_size
28
+ else:
29
+ raise ValueError("Unknown backbone_type")
30
+
31
+
32
+ self.emb_ch = emb_ch
33
+ if slice_fusion_type == "transformer":
34
+ self.slice_fusion = Encoder(
35
+ dim = emb_ch,
36
+ heads = 12 if emb_ch%12 == 0 else 8,
37
+ ff_mult = 1,
38
+ attn_dropout=0.0,
39
+ pre_norm = True,
40
+ depth = 1,
41
+ attn_flash = True,
42
+ ff_no_bias = True,
43
+ rotary_pos_emb=True,
44
+ )
45
+ self.cls_token = nn.Parameter(torch.randn(1, 1, emb_ch))
46
+ elif slice_fusion_type == 'average':
47
+ pass
48
+ elif slice_fusion_type == "none":
49
+ pass
50
+ else:
51
+ raise ValueError("Unknown slice_fusion_type")
52
+
53
+ self.linear = nn.Linear(emb_ch, out_ch)
54
+
55
+
56
+
57
+ def forward(self, x, output_attentions=False):
58
+ B, *_ = x.shape
59
+
60
+ # Mask (Slices with constant padded values)
61
+ x_pad = torch.isclose(x.mean(dim=(-1,-2)), x[:, :, :, 0, 0]) # [B, C, D]
62
+ x_pad = rearrange(x_pad, 'b c d -> b (c d)')
63
+
64
+ x = rearrange(x, 'b c d h w -> (b c d) h w')
65
+ x = x[:, None]
66
+ x = x.repeat(1, 3, 1, 1) # Gray to RGB
67
+
68
+ # -------------- Backbone --------------
69
+ backbone_out = self.backbone(x, output_attentions=output_attentions)
70
+ x = backbone_out.pooler_output
71
+ x = rearrange(x, '(b d) e -> b d e', b=B)
72
+
73
+ # -------------- Slice Fusion --------------
74
+ if self.slice_fusion_type == 'none':
75
+ return x
76
+ elif self.slice_fusion_type == 'transformer':
77
+ cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=x.device)
78
+ pad = torch.concat([x_pad, cls_pad], dim=1) # [B, D+1]
79
+ x = torch.concat([x, self.cls_token.repeat(B, 1, 1)], dim=1) # [B, 1+D, E]
80
+ if output_attentions:
81
+ x, slice_hiddens = self.slice_fusion(x, mask=~pad, return_hiddens=True) # [B, D+1, E]
82
+ else:
83
+ x = self.slice_fusion(x, mask=~pad) # [B, D+1, L]
84
+ elif self.slice_fusion_type == 'linear':
85
+ x = rearrange(x, 'b d e -> b e d')
86
+ x = self.slice_fusion(x) # -> [B, E, 1]
87
+ x = rearrange(x, 'b e d -> b d e') # -> [B, 1, E]
88
+ elif self.slice_fusion_type == 'average':
89
+ x = x.mean(dim=1, keepdim=True) # [B, D, E] -> [B, 1, E]
90
+
91
+ # -------------- Logits --------------
92
+ x = self.linear(x[:, -1])
93
+ if output_attentions:
94
+ slice_attn_layers = [
95
+ interm.post_softmax_attn
96
+ for interm in getattr(slice_hiddens, 'attn_intermediates', [])
97
+ if interm is not None and getattr(interm, 'post_softmax_attn', None) is not None
98
+ ]
99
+ return x, backbone_out.attentions, slice_attn_layers
100
+ return x
101
+
102
+ def forward_attention(self, x) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ B, C, D, _, _ = x.shape
104
+ # Disable fast attention
105
+ attn_impl = self.backbone.config._attn_implementation
106
+ self.backbone.set_attn_implementation("eager")
107
+ flash_modules = []
108
+ for module in self.slice_fusion.modules():
109
+ if hasattr(module, 'flash'):
110
+ flash_modules.append((module, module.flash))
111
+ module.flash = False
112
+
113
+ out, backbone_attn, slice_attn_layers = self.forward(x, output_attentions=True)
114
+
115
+ # Restore previous attention implementation
116
+ for module, previous in flash_modules:
117
+ module.flash = previous
118
+ if hasattr(self.backbone, "set_attn_implementation"):
119
+ self.backbone.set_attn_implementation(attn_impl)
120
+
121
+ # Process attentions
122
+ slice_attn = torch.stack(slice_attn_layers)[-1]
123
+ slice_attn = slice_attn.mean(dim=1)
124
+ slice_attn = slice_attn[:, -1, :-1]
125
+ slice_attn = slice_attn.view(B, C, D).mean(dim=1)
126
+
127
+ plane_attn_layers = [att for att in backbone_attn if att is not None]
128
+ plane_attn = torch.stack(plane_attn_layers)[-1]
129
+ plane_attn = plane_attn.mean(dim=1)
130
+ num_reg_tokens = getattr(self.backbone.config, 'num_register_tokens', 0)
131
+ plane_attn = plane_attn[:, 0, 1 + num_reg_tokens:]
132
+ plane_attn = plane_attn.view(B, C * D, -1)
133
+
134
+ # Weight every slice by its slice attention
135
+ plane_attn = plane_attn * slice_attn.unsqueeze(-1)
136
+
137
+ num_patches = plane_attn.shape[-1]
138
+ side = int(math.sqrt(num_patches))
139
+ if side * side != num_patches:
140
+ raise RuntimeError("number of patches is not a perfect square")
141
+ plane_attn = plane_attn.reshape(B, C * D, side, side)
142
+
143
+ return out, plane_attn, slice_attn
144
+
145
+
146
+ class MSTRegression(nn.Module):
147
+ def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", optimizer_kwargs={'lr':1e-5}, **kwargs):
148
+ super().__init__()
149
+ self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type)
150
+
151
+ def forward(self, x):
152
+ return self.mst(x)
153
+
154
+ def forward_attention(self, x):
155
+ return self.mst.forward_attention(x)
predict_attention.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from huggingface_hub import hf_hub_download
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ import torchio as tio
9
+ from torchvision.utils import save_image
10
+ from matplotlib.pyplot import get_cmap
11
+
12
+ from models import MSTRegression
13
+
14
+
15
+
16
+ def minmax_norm(x):
17
+ """Normalizes input to [0, 1] for each batch and channel"""
18
+ return (x - x.min()) / (x.max() - x.min())
19
+
20
+ def tensor2image(tensor, batch=0):
21
+ """Transform tensor into shape of multiple 2D RGB/gray images. """
22
+ return (tensor if tensor.ndim<5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:,None])
23
+
24
+ def tensor_cam2image(tensor, cam, batch=0, alpha=0.5, color_map=get_cmap('jet')):
25
+ """Transform a tensor and a (grad) cam into multiple 2D RGB images."""
26
+ img = tensor2image(tensor, batch) # -> [B, C, H, W]
27
+ img = torch.cat([img for _ in range(3)], dim=1) if img.shape[1]!=3 else img # Ensure RGB [B, 3, H, W]
28
+ cam_img = tensor2image(cam, batch) # -> [B, 1, H, W]
29
+ cam_img = cam_img[:,0].cpu().numpy() # -> [B, H, W]
30
+ cam_img = torch.tensor(color_map(cam_img)) # -> [B, H, W, 4], color_map expects input to be [0.0, 1.0]
31
+ cam_img = torch.moveaxis(cam_img, -1, 1)[:, :3] # -> [B, 3, H, W]
32
+ overlay = (1-alpha)*img + alpha*cam_img
33
+ return overlay
34
+
35
+
36
+
37
+ def crop_breast_height(image, margin_top=10) -> tio.Crop:
38
+ """Crop height to 256 and try to cover breast based on intensity localization"""
39
+ threshold = int(np.quantile(image.data.float(), 0.9))
40
+ foreground = image.data>threshold
41
+ fg_rows = foreground[0].sum(axis=(0, 2))
42
+ top = min(max(512-int(torch.argwhere(fg_rows).max()) - margin_top, 0), 256)
43
+ bottom = 256-top
44
+ return tio.Crop((0,0, bottom, top, 0, 0))
45
+
46
+
47
+ def get_bilateral_transform(img:tio.ScalarImage, ref_img=None, target_spacing = (0.7, 0.7, 3), target_shape = (512, 512, 32)):
48
+ # -------- Settings --------------
49
+ ref_img = img if ref_img is None else ref_img
50
+
51
+ # Spacing
52
+ ref_img = tio.ToCanonical()(ref_img)
53
+ ref_img = tio.Resample(target_spacing)(ref_img)
54
+ resample = tio.Resample(ref_img)
55
+
56
+ # Crop
57
+ ref_img = tio.CropOrPad(target_shape, padding_mode='minimum')(ref_img)
58
+ crop_height = crop_breast_height(ref_img)
59
+
60
+ # Process input image
61
+ trans = tio.Compose([
62
+ resample,
63
+ tio.CropOrPad(target_shape, padding_mode='minimum'),
64
+ crop_height,
65
+ ])
66
+
67
+ trans_inv = tio.Compose([
68
+ crop_height.inverse(),
69
+ tio.CropOrPad(img.spatial_shape, padding_mode='minimum'),
70
+ tio.Resample(img),
71
+ ])
72
+ return trans(img), trans_inv
73
+
74
+ def get_unilateral_transform(img: tio.ScalarImage, target_shape=(224, 224, 32)):
75
+ transform = tio.Compose([
76
+ tio.Flip((1,0)),
77
+ tio.CropOrPad(target_shape),
78
+ tio.ZNormalization(masking_method=lambda x:(x>x.min()) & (x<x.max())),
79
+ ])
80
+ inv_transform = tio.Compose([
81
+ tio.CropOrPad(img.spatial_shape),
82
+ tio.Flip((1,0)),
83
+ ])
84
+ return transform(img), inv_transform
85
+
86
+
87
+ def run_prediction(img: tio.ScalarImage, model: MSTRegression):
88
+ img_bil, bil_trans_rev = get_bilateral_transform(img)
89
+ split_side = {
90
+ 'right': tio.Crop((256, 0, 0, 0, 0, 0)),
91
+ 'left': tio.Crop((0, 256, 0, 0, 0, 0)),
92
+ }
93
+
94
+ weights, probs = {}, {}
95
+ for side, crop in split_side.items():
96
+ img_side = crop(img_bil)
97
+ img_side, uni_trans_inv = get_unilateral_transform(img_side)
98
+ img_side = img_side.data.swapaxes(1,-1)
99
+ img_side = img_side.unsqueeze(0) # Add batch dim -> [1, C, H, W, D]
100
+
101
+ with torch.no_grad():
102
+ device = next(model.parameters()).device
103
+ logits, weight, weight_slice = model.forward_attention(img_side.to(device))
104
+
105
+ weight = F.interpolate(weight.unsqueeze(1), size=img_side.shape[2:], mode='trilinear', align_corners=False).cpu()
106
+ # pred_prob = model.logits2probabilities(logits).cpu()
107
+ pred_prob = F.softmax(logits, dim=-1).cpu()
108
+ probs[side] = pred_prob.squeeze(0)
109
+
110
+ weight = weight.squeeze(0).swapaxes(1,-1) # ->[C, W, H, D]
111
+ weight = uni_trans_inv(weight)
112
+ weights[side] = weight
113
+
114
+ weight = torch.concat([weights['left'], weights['right']], dim=1) # C, W, H, D
115
+ weight = tio.ScalarImage(tensor=weight, affine=img_bil.affine)
116
+ weight = bil_trans_rev(weight)
117
+ weight.set_data(minmax_norm(weight.data))
118
+ return probs, weight
119
+
120
+ def load_model(repo_id= "ODELIA-AI/MST") -> MSTRegression:
121
+ # Download config and state dict
122
+ config_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="model_config.json")
123
+ with open(config_path, "r", encoding="utf-8") as fp:
124
+ config = json.load(fp)
125
+
126
+ hparams = config.get("hparams", {})
127
+ model = MSTRegression(**hparams)
128
+
129
+ state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt")
130
+ state_dict = torch.load(state_dict_path, map_location="cpu")
131
+ model.load_state_dict(state_dict, strict=True)
132
+ return model
133
+
134
+
135
+ if __name__ == "__main__":
136
+ #------------ Get Arguments ----------------
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument('--path_img', default='/home/homesOnMaster/gfranzes/Documents/datasets/ODELIA/UKA/data/UKA_2/Sub_1.nii.gz', type=str)
139
+ args = parser.parse_args()
140
+
141
+
142
+ #------------ Settings/Defaults ----------------
143
+ path_out_dir = Path().cwd()/'results/test_attention'
144
+ path_out_dir.mkdir(parents=True, exist_ok=True)
145
+
146
+
147
+ # ------------ Load Data ----------------
148
+ path_img = Path(args.path_img)
149
+ img = tio.ScalarImage(path_img)
150
+
151
+
152
+ # ------------ Initialize Model ------------
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ model = load_model()
155
+ model.to(device)
156
+ model.eval()
157
+
158
+
159
+ # ------------ Predict ----------------
160
+ probs, weight = run_prediction(img, model)
161
+
162
+ img.save(path_out_dir/f"input.nii.gz")
163
+ weight.save(path_out_dir/f"attention.nii.gz")
164
+ weight = weight.data.swapaxes(1,-1).unsqueeze(0) # C, D, H, W
165
+ img = img.data.swapaxes(1,-1).unsqueeze(0) # C, D, H, W
166
+ save_image(tensor_cam2image(minmax_norm(img), minmax_norm(weight), alpha=0.5),
167
+ path_out_dir/f"overlay.png", normalize=False)
168
+
169
+ for side in ['left', 'right']:
170
+ print(f"{side} breast predicted probabilities: {probs[side]}")
171
+
state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c25602fec82d90912ed6f2623639937a2fb44931cfcfb382aecd16d5647c8327
3
+ size 92379550