Upload epoch=17-step=1836.ckpt
Browse files- README.md +72 -3
- epoch=17-step=1836.ckpt +3 -0
- model_config.json +26 -0
- models.py +155 -0
- predict_attention.py +171 -0
- state_dict.pt +3 -0
README.md
CHANGED
|
@@ -1,3 +1,72 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: cc-by-nc-4.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
datasets:
|
| 4 |
+
- ODELIA-AI/ODELIA-Challenge-2025
|
| 5 |
+
language:
|
| 6 |
+
- en
|
| 7 |
+
metrics:
|
| 8 |
+
- roc_auc
|
| 9 |
+
pipeline_tag: image-classification
|
| 10 |
+
tags:
|
| 11 |
+
- breast
|
| 12 |
+
- cancer
|
| 13 |
+
- odelia
|
| 14 |
+
extra_gated_prompt: >-
|
| 15 |
+
### 🛡️ Model Usage Agreement
|
| 16 |
+
|
| 17 |
+
By accessing or using this model (the “Model”), you acknowledge and agree to the following terms and conditions:
|
| 18 |
+
|
| 19 |
+
#### 1. Research-Only Use
|
| 20 |
+
|
| 21 |
+
The Model is provided strictly for non-commercial, academic, and research purposes. It must not be used for clinical decision-making, diagnosis, treatment, or any other application involving real patients or clinical care.
|
| 22 |
+
|
| 23 |
+
#### 2. No Clinical or Commercial Deployment
|
| 24 |
+
|
| 25 |
+
The Model is **not approved for clinical use** or any commercial application. Any deployment in healthcare settings or use for patient-related decision support is expressly prohibited.
|
| 26 |
+
|
| 27 |
+
#### 3. Redistribution and Modification
|
| 28 |
+
|
| 29 |
+
You may not copy, distribute, sublicense, or otherwise share the Model or any derivative works without prior written permission from the model authors or the ODELIA consortium.
|
| 30 |
+
|
| 31 |
+
#### 4. Privacy and Ethics Compliance
|
| 32 |
+
|
| 33 |
+
You must not attempt to identify, re-identify, or deanonymize any individual whose data may have contributed to the training or evaluation of the Model.
|
| 34 |
+
|
| 35 |
+
#### 5. Attribution Requirement
|
| 36 |
+
|
| 37 |
+
Any publication, presentation, or derivative work that uses or references this Model must include clear attribution to the **ODELIA consortium**, along with any citations specified in the accompanying documentation.
|
| 38 |
+
|
| 39 |
+
#### 6. Responsibility and Verification
|
| 40 |
+
|
| 41 |
+
You are solely responsible for verifying and validating the Model’s outputs and ensuring they are appropriate for your research context. The Model and its outputs are provided “as is,” without warranties of any kind.
|
| 42 |
+
|
| 43 |
+
#### 7. Inclusion of Third-Party Components
|
| 44 |
+
|
| 45 |
+
This Model incorporates or is derived from **DINOv3**, developed by **Meta Platforms**.
|
| 46 |
+
Use of the Model is therefore also subject to the **DINOv3 License Agreement**.
|
| 47 |
+
By using this Model, you agree to comply with both:
|
| 48 |
+
|
| 49 |
+
* This Model Usage Agreement, **and**
|
| 50 |
+
* The [DINOv3 License Terms](https://github.com/facebookresearch/dinov3).
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
# ODELIA Classification Baseline Model
|
| 54 |
+
For a comprehensive description of the model and its intended use, please refer to our paper: [Read the paper](https://arxiv.org/abs/2506.00474)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
## Get Probabilities and Attention
|
| 58 |
+
|
| 59 |
+
To use this model, first download the required files from this repository:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
from huggingface_hub import hf_hub_download
|
| 63 |
+
|
| 64 |
+
# Download model files to local directory
|
| 65 |
+
hf_hub_download(repo_id="ODELIA-AI/MST", filename="models.py", local_dir="./")
|
| 66 |
+
hf_hub_download(repo_id="ODELIA-AI/MST", filename="predict_attention.py", local_dir="./")
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Then execute `predict_attention.py --path_img path/to/Sub_1.nii.gz` to get probabilities and attention maps.
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
epoch=17-step=1836.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a42b9c83fa4ec1c9a7b0060df288ea6fd3c20a9c9f7002ee26be5fb27f320c71
|
| 3 |
+
size 277159866
|
model_config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint_source": "epoch=17-step=1836.ckpt",
|
| 3 |
+
"created_at": "2025-10-26T16:51:08.236236Z",
|
| 4 |
+
"hparams": {
|
| 5 |
+
"backbone_type": "dinov3",
|
| 6 |
+
"in_ch": 1,
|
| 7 |
+
"loss": "<class 'odelia.models.utils.losses.MulitCELoss'>",
|
| 8 |
+
"loss_kwargs": {
|
| 9 |
+
"class_labels_num": [
|
| 10 |
+
3
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
"lr_scheduler": null,
|
| 14 |
+
"lr_scheduler_kwargs": {},
|
| 15 |
+
"model_size": "s",
|
| 16 |
+
"optimizer": "<class 'torch.optim.adamw.AdamW'>",
|
| 17 |
+
"optimizer_kwargs": {
|
| 18 |
+
"lr": 1e-05
|
| 19 |
+
},
|
| 20 |
+
"out_ch": 3,
|
| 21 |
+
"save_hyperparameters": true,
|
| 22 |
+
"slice_fusion_type": "transformer",
|
| 23 |
+
"spatial_dims": 3
|
| 24 |
+
},
|
| 25 |
+
"model_class": "odelia.models.mst.MSTRegression"
|
| 26 |
+
}
|
models.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from einops import rearrange
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
from transformers import AutoModel
|
| 6 |
+
from x_transformers import Encoder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class _MST(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
out_ch=1,
|
| 13 |
+
backbone_type="dinov3",
|
| 14 |
+
model_size = "s", # 34, 50, ... or 's', 'b', 'l'
|
| 15 |
+
slice_fusion_type = "transformer", # transformer, linear, average, none
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.backbone_type = backbone_type
|
| 19 |
+
self.slice_fusion_type = slice_fusion_type
|
| 20 |
+
|
| 21 |
+
if backbone_type == "dinov2":
|
| 22 |
+
model_size = {'s':'small', 'b':'base', 'l':'large'}.get(model_size)
|
| 23 |
+
self.backbone = AutoModel.from_pretrained(f"facebook/dinov2-with-registers-{model_size}")
|
| 24 |
+
emb_ch = self.backbone.config.hidden_size
|
| 25 |
+
elif backbone_type == "dinov3":
|
| 26 |
+
self.backbone = AutoModel.from_pretrained(f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m")
|
| 27 |
+
emb_ch = self.backbone.config.hidden_size
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError("Unknown backbone_type")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
self.emb_ch = emb_ch
|
| 33 |
+
if slice_fusion_type == "transformer":
|
| 34 |
+
self.slice_fusion = Encoder(
|
| 35 |
+
dim = emb_ch,
|
| 36 |
+
heads = 12 if emb_ch%12 == 0 else 8,
|
| 37 |
+
ff_mult = 1,
|
| 38 |
+
attn_dropout=0.0,
|
| 39 |
+
pre_norm = True,
|
| 40 |
+
depth = 1,
|
| 41 |
+
attn_flash = True,
|
| 42 |
+
ff_no_bias = True,
|
| 43 |
+
rotary_pos_emb=True,
|
| 44 |
+
)
|
| 45 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_ch))
|
| 46 |
+
elif slice_fusion_type == 'average':
|
| 47 |
+
pass
|
| 48 |
+
elif slice_fusion_type == "none":
|
| 49 |
+
pass
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError("Unknown slice_fusion_type")
|
| 52 |
+
|
| 53 |
+
self.linear = nn.Linear(emb_ch, out_ch)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def forward(self, x, output_attentions=False):
|
| 58 |
+
B, *_ = x.shape
|
| 59 |
+
|
| 60 |
+
# Mask (Slices with constant padded values)
|
| 61 |
+
x_pad = torch.isclose(x.mean(dim=(-1,-2)), x[:, :, :, 0, 0]) # [B, C, D]
|
| 62 |
+
x_pad = rearrange(x_pad, 'b c d -> b (c d)')
|
| 63 |
+
|
| 64 |
+
x = rearrange(x, 'b c d h w -> (b c d) h w')
|
| 65 |
+
x = x[:, None]
|
| 66 |
+
x = x.repeat(1, 3, 1, 1) # Gray to RGB
|
| 67 |
+
|
| 68 |
+
# -------------- Backbone --------------
|
| 69 |
+
backbone_out = self.backbone(x, output_attentions=output_attentions)
|
| 70 |
+
x = backbone_out.pooler_output
|
| 71 |
+
x = rearrange(x, '(b d) e -> b d e', b=B)
|
| 72 |
+
|
| 73 |
+
# -------------- Slice Fusion --------------
|
| 74 |
+
if self.slice_fusion_type == 'none':
|
| 75 |
+
return x
|
| 76 |
+
elif self.slice_fusion_type == 'transformer':
|
| 77 |
+
cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=x.device)
|
| 78 |
+
pad = torch.concat([x_pad, cls_pad], dim=1) # [B, D+1]
|
| 79 |
+
x = torch.concat([x, self.cls_token.repeat(B, 1, 1)], dim=1) # [B, 1+D, E]
|
| 80 |
+
if output_attentions:
|
| 81 |
+
x, slice_hiddens = self.slice_fusion(x, mask=~pad, return_hiddens=True) # [B, D+1, E]
|
| 82 |
+
else:
|
| 83 |
+
x = self.slice_fusion(x, mask=~pad) # [B, D+1, L]
|
| 84 |
+
elif self.slice_fusion_type == 'linear':
|
| 85 |
+
x = rearrange(x, 'b d e -> b e d')
|
| 86 |
+
x = self.slice_fusion(x) # -> [B, E, 1]
|
| 87 |
+
x = rearrange(x, 'b e d -> b d e') # -> [B, 1, E]
|
| 88 |
+
elif self.slice_fusion_type == 'average':
|
| 89 |
+
x = x.mean(dim=1, keepdim=True) # [B, D, E] -> [B, 1, E]
|
| 90 |
+
|
| 91 |
+
# -------------- Logits --------------
|
| 92 |
+
x = self.linear(x[:, -1])
|
| 93 |
+
if output_attentions:
|
| 94 |
+
slice_attn_layers = [
|
| 95 |
+
interm.post_softmax_attn
|
| 96 |
+
for interm in getattr(slice_hiddens, 'attn_intermediates', [])
|
| 97 |
+
if interm is not None and getattr(interm, 'post_softmax_attn', None) is not None
|
| 98 |
+
]
|
| 99 |
+
return x, backbone_out.attentions, slice_attn_layers
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
def forward_attention(self, x) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 103 |
+
B, C, D, _, _ = x.shape
|
| 104 |
+
# Disable fast attention
|
| 105 |
+
attn_impl = self.backbone.config._attn_implementation
|
| 106 |
+
self.backbone.set_attn_implementation("eager")
|
| 107 |
+
flash_modules = []
|
| 108 |
+
for module in self.slice_fusion.modules():
|
| 109 |
+
if hasattr(module, 'flash'):
|
| 110 |
+
flash_modules.append((module, module.flash))
|
| 111 |
+
module.flash = False
|
| 112 |
+
|
| 113 |
+
out, backbone_attn, slice_attn_layers = self.forward(x, output_attentions=True)
|
| 114 |
+
|
| 115 |
+
# Restore previous attention implementation
|
| 116 |
+
for module, previous in flash_modules:
|
| 117 |
+
module.flash = previous
|
| 118 |
+
if hasattr(self.backbone, "set_attn_implementation"):
|
| 119 |
+
self.backbone.set_attn_implementation(attn_impl)
|
| 120 |
+
|
| 121 |
+
# Process attentions
|
| 122 |
+
slice_attn = torch.stack(slice_attn_layers)[-1]
|
| 123 |
+
slice_attn = slice_attn.mean(dim=1)
|
| 124 |
+
slice_attn = slice_attn[:, -1, :-1]
|
| 125 |
+
slice_attn = slice_attn.view(B, C, D).mean(dim=1)
|
| 126 |
+
|
| 127 |
+
plane_attn_layers = [att for att in backbone_attn if att is not None]
|
| 128 |
+
plane_attn = torch.stack(plane_attn_layers)[-1]
|
| 129 |
+
plane_attn = plane_attn.mean(dim=1)
|
| 130 |
+
num_reg_tokens = getattr(self.backbone.config, 'num_register_tokens', 0)
|
| 131 |
+
plane_attn = plane_attn[:, 0, 1 + num_reg_tokens:]
|
| 132 |
+
plane_attn = plane_attn.view(B, C * D, -1)
|
| 133 |
+
|
| 134 |
+
# Weight every slice by its slice attention
|
| 135 |
+
plane_attn = plane_attn * slice_attn.unsqueeze(-1)
|
| 136 |
+
|
| 137 |
+
num_patches = plane_attn.shape[-1]
|
| 138 |
+
side = int(math.sqrt(num_patches))
|
| 139 |
+
if side * side != num_patches:
|
| 140 |
+
raise RuntimeError("number of patches is not a perfect square")
|
| 141 |
+
plane_attn = plane_attn.reshape(B, C * D, side, side)
|
| 142 |
+
|
| 143 |
+
return out, plane_attn, slice_attn
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class MSTRegression(nn.Module):
|
| 147 |
+
def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", optimizer_kwargs={'lr':1e-5}, **kwargs):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
return self.mst(x)
|
| 153 |
+
|
| 154 |
+
def forward_attention(self, x):
|
| 155 |
+
return self.mst.forward_attention(x)
|
predict_attention.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchio as tio
|
| 9 |
+
from torchvision.utils import save_image
|
| 10 |
+
from matplotlib.pyplot import get_cmap
|
| 11 |
+
|
| 12 |
+
from models import MSTRegression
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def minmax_norm(x):
|
| 17 |
+
"""Normalizes input to [0, 1] for each batch and channel"""
|
| 18 |
+
return (x - x.min()) / (x.max() - x.min())
|
| 19 |
+
|
| 20 |
+
def tensor2image(tensor, batch=0):
|
| 21 |
+
"""Transform tensor into shape of multiple 2D RGB/gray images. """
|
| 22 |
+
return (tensor if tensor.ndim<5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:,None])
|
| 23 |
+
|
| 24 |
+
def tensor_cam2image(tensor, cam, batch=0, alpha=0.5, color_map=get_cmap('jet')):
|
| 25 |
+
"""Transform a tensor and a (grad) cam into multiple 2D RGB images."""
|
| 26 |
+
img = tensor2image(tensor, batch) # -> [B, C, H, W]
|
| 27 |
+
img = torch.cat([img for _ in range(3)], dim=1) if img.shape[1]!=3 else img # Ensure RGB [B, 3, H, W]
|
| 28 |
+
cam_img = tensor2image(cam, batch) # -> [B, 1, H, W]
|
| 29 |
+
cam_img = cam_img[:,0].cpu().numpy() # -> [B, H, W]
|
| 30 |
+
cam_img = torch.tensor(color_map(cam_img)) # -> [B, H, W, 4], color_map expects input to be [0.0, 1.0]
|
| 31 |
+
cam_img = torch.moveaxis(cam_img, -1, 1)[:, :3] # -> [B, 3, H, W]
|
| 32 |
+
overlay = (1-alpha)*img + alpha*cam_img
|
| 33 |
+
return overlay
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def crop_breast_height(image, margin_top=10) -> tio.Crop:
|
| 38 |
+
"""Crop height to 256 and try to cover breast based on intensity localization"""
|
| 39 |
+
threshold = int(np.quantile(image.data.float(), 0.9))
|
| 40 |
+
foreground = image.data>threshold
|
| 41 |
+
fg_rows = foreground[0].sum(axis=(0, 2))
|
| 42 |
+
top = min(max(512-int(torch.argwhere(fg_rows).max()) - margin_top, 0), 256)
|
| 43 |
+
bottom = 256-top
|
| 44 |
+
return tio.Crop((0,0, bottom, top, 0, 0))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_bilateral_transform(img:tio.ScalarImage, ref_img=None, target_spacing = (0.7, 0.7, 3), target_shape = (512, 512, 32)):
|
| 48 |
+
# -------- Settings --------------
|
| 49 |
+
ref_img = img if ref_img is None else ref_img
|
| 50 |
+
|
| 51 |
+
# Spacing
|
| 52 |
+
ref_img = tio.ToCanonical()(ref_img)
|
| 53 |
+
ref_img = tio.Resample(target_spacing)(ref_img)
|
| 54 |
+
resample = tio.Resample(ref_img)
|
| 55 |
+
|
| 56 |
+
# Crop
|
| 57 |
+
ref_img = tio.CropOrPad(target_shape, padding_mode='minimum')(ref_img)
|
| 58 |
+
crop_height = crop_breast_height(ref_img)
|
| 59 |
+
|
| 60 |
+
# Process input image
|
| 61 |
+
trans = tio.Compose([
|
| 62 |
+
resample,
|
| 63 |
+
tio.CropOrPad(target_shape, padding_mode='minimum'),
|
| 64 |
+
crop_height,
|
| 65 |
+
])
|
| 66 |
+
|
| 67 |
+
trans_inv = tio.Compose([
|
| 68 |
+
crop_height.inverse(),
|
| 69 |
+
tio.CropOrPad(img.spatial_shape, padding_mode='minimum'),
|
| 70 |
+
tio.Resample(img),
|
| 71 |
+
])
|
| 72 |
+
return trans(img), trans_inv
|
| 73 |
+
|
| 74 |
+
def get_unilateral_transform(img: tio.ScalarImage, target_shape=(224, 224, 32)):
|
| 75 |
+
transform = tio.Compose([
|
| 76 |
+
tio.Flip((1,0)),
|
| 77 |
+
tio.CropOrPad(target_shape),
|
| 78 |
+
tio.ZNormalization(masking_method=lambda x:(x>x.min()) & (x<x.max())),
|
| 79 |
+
])
|
| 80 |
+
inv_transform = tio.Compose([
|
| 81 |
+
tio.CropOrPad(img.spatial_shape),
|
| 82 |
+
tio.Flip((1,0)),
|
| 83 |
+
])
|
| 84 |
+
return transform(img), inv_transform
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def run_prediction(img: tio.ScalarImage, model: MSTRegression):
|
| 88 |
+
img_bil, bil_trans_rev = get_bilateral_transform(img)
|
| 89 |
+
split_side = {
|
| 90 |
+
'right': tio.Crop((256, 0, 0, 0, 0, 0)),
|
| 91 |
+
'left': tio.Crop((0, 256, 0, 0, 0, 0)),
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
weights, probs = {}, {}
|
| 95 |
+
for side, crop in split_side.items():
|
| 96 |
+
img_side = crop(img_bil)
|
| 97 |
+
img_side, uni_trans_inv = get_unilateral_transform(img_side)
|
| 98 |
+
img_side = img_side.data.swapaxes(1,-1)
|
| 99 |
+
img_side = img_side.unsqueeze(0) # Add batch dim -> [1, C, H, W, D]
|
| 100 |
+
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
device = next(model.parameters()).device
|
| 103 |
+
logits, weight, weight_slice = model.forward_attention(img_side.to(device))
|
| 104 |
+
|
| 105 |
+
weight = F.interpolate(weight.unsqueeze(1), size=img_side.shape[2:], mode='trilinear', align_corners=False).cpu()
|
| 106 |
+
# pred_prob = model.logits2probabilities(logits).cpu()
|
| 107 |
+
pred_prob = F.softmax(logits, dim=-1).cpu()
|
| 108 |
+
probs[side] = pred_prob.squeeze(0)
|
| 109 |
+
|
| 110 |
+
weight = weight.squeeze(0).swapaxes(1,-1) # ->[C, W, H, D]
|
| 111 |
+
weight = uni_trans_inv(weight)
|
| 112 |
+
weights[side] = weight
|
| 113 |
+
|
| 114 |
+
weight = torch.concat([weights['left'], weights['right']], dim=1) # C, W, H, D
|
| 115 |
+
weight = tio.ScalarImage(tensor=weight, affine=img_bil.affine)
|
| 116 |
+
weight = bil_trans_rev(weight)
|
| 117 |
+
weight.set_data(minmax_norm(weight.data))
|
| 118 |
+
return probs, weight
|
| 119 |
+
|
| 120 |
+
def load_model(repo_id= "ODELIA-AI/MST") -> MSTRegression:
|
| 121 |
+
# Download config and state dict
|
| 122 |
+
config_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="model_config.json")
|
| 123 |
+
with open(config_path, "r", encoding="utf-8") as fp:
|
| 124 |
+
config = json.load(fp)
|
| 125 |
+
|
| 126 |
+
hparams = config.get("hparams", {})
|
| 127 |
+
model = MSTRegression(**hparams)
|
| 128 |
+
|
| 129 |
+
state_dict_path = hf_hub_download(repo_id=repo_id, repo_type="model", filename="state_dict.pt")
|
| 130 |
+
state_dict = torch.load(state_dict_path, map_location="cpu")
|
| 131 |
+
model.load_state_dict(state_dict, strict=True)
|
| 132 |
+
return model
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
#------------ Get Arguments ----------------
|
| 137 |
+
parser = argparse.ArgumentParser()
|
| 138 |
+
parser.add_argument('--path_img', default='/home/homesOnMaster/gfranzes/Documents/datasets/ODELIA/UKA/data/UKA_2/Sub_1.nii.gz', type=str)
|
| 139 |
+
args = parser.parse_args()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
#------------ Settings/Defaults ----------------
|
| 143 |
+
path_out_dir = Path().cwd()/'results/test_attention'
|
| 144 |
+
path_out_dir.mkdir(parents=True, exist_ok=True)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ------------ Load Data ----------------
|
| 148 |
+
path_img = Path(args.path_img)
|
| 149 |
+
img = tio.ScalarImage(path_img)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ------------ Initialize Model ------------
|
| 153 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 154 |
+
model = load_model()
|
| 155 |
+
model.to(device)
|
| 156 |
+
model.eval()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ------------ Predict ----------------
|
| 160 |
+
probs, weight = run_prediction(img, model)
|
| 161 |
+
|
| 162 |
+
img.save(path_out_dir/f"input.nii.gz")
|
| 163 |
+
weight.save(path_out_dir/f"attention.nii.gz")
|
| 164 |
+
weight = weight.data.swapaxes(1,-1).unsqueeze(0) # C, D, H, W
|
| 165 |
+
img = img.data.swapaxes(1,-1).unsqueeze(0) # C, D, H, W
|
| 166 |
+
save_image(tensor_cam2image(minmax_norm(img), minmax_norm(weight), alpha=0.5),
|
| 167 |
+
path_out_dir/f"overlay.png", normalize=False)
|
| 168 |
+
|
| 169 |
+
for side in ['left', 'right']:
|
| 170 |
+
print(f"{side} breast predicted probabilities: {probs[side]}")
|
| 171 |
+
|
state_dict.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c25602fec82d90912ed6f2623639937a2fb44931cfcfb382aecd16d5647c8327
|
| 3 |
+
size 92379550
|