Image Classification
English
breast
cancer
odelia
MST / models.py
mueller-franzes's picture
Upload epoch=17-step=1836.ckpt
fb68040 verified
from einops import rearrange
import torch.nn as nn
import torch
import math
from transformers import AutoModel, Dinov2WithRegistersModel, Dinov2WithRegistersConfig, DINOv3ViTConfig, DINOv3ViTModel
from x_transformers import Encoder
class _MST(nn.Module):
def __init__(
self,
out_ch=1,
backbone_type="dinov3",
model_size = "s", # 34, 50, ... or 's', 'b', 'l'
slice_fusion_type = "transformer", # transformer, linear, average, none
weights=True,
):
super().__init__()
self.backbone_type = backbone_type
self.slice_fusion_type = slice_fusion_type
if backbone_type == "dinov2":
model_size_key = {'s':'small', 'b':'base', 'l':'large'}.get(model_size)
model_name = f"facebook/dinov2-with-registers-{model_size_key}"
if weights:
self.backbone = AutoModel.from_pretrained(model_name)
else:
configs = {
'small': Dinov2WithRegistersConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6),
'base': Dinov2WithRegistersConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12),
'large': Dinov2WithRegistersConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16),
}
config = configs.get(model_size_key)
config.image_size = 518
config.patch_size = 14
self.backbone = Dinov2WithRegistersModel(config)
emb_ch = self.backbone.config.hidden_size
elif backbone_type == "dinov3":
model_name = f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m"
if weights:
self.backbone = AutoModel.from_pretrained(model_name)
else:
configs = {
's': DINOv3ViTConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, intermediate_size=1536, patch_size=16, num_register_tokens=4),
'b': DINOv3ViTConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, patch_size=16, num_register_tokens=4),
'l': DINOv3ViTConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, patch_size=16, num_register_tokens=4),
}
config = configs.get(model_size)
self.backbone = DINOv3ViTModel(config)
else:
raise ValueError("Unknown backbone_type")
emb_ch = self.backbone.config.hidden_size
self.emb_ch = emb_ch
if slice_fusion_type == "transformer":
self.slice_fusion = Encoder(
dim = emb_ch,
heads = 12 if emb_ch%12 == 0 else 8,
ff_mult = 1,
attn_dropout=0.0,
pre_norm = True,
depth = 1,
attn_flash = True,
ff_no_bias = True,
rotary_pos_emb=True,
)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_ch))
elif slice_fusion_type == 'average':
pass
elif slice_fusion_type == "none":
pass
else:
raise ValueError("Unknown slice_fusion_type")
self.linear = nn.Linear(emb_ch, out_ch)
def forward(self, x, output_attentions=False):
B, *_ = x.shape
# Mask (Slices with constant padded values)
x_pad = torch.isclose(x.mean(dim=(-1,-2)), x[:, :, :, 0, 0]) # [B, C, D]
x_pad = rearrange(x_pad, 'b c d -> b (c d)')
x = rearrange(x, 'b c d h w -> (b c d) h w')
x = x[:, None]
x = x.repeat(1, 3, 1, 1) # Gray to RGB
# -------------- Backbone --------------
backbone_out = self.backbone(x, output_attentions=output_attentions)
x = backbone_out.pooler_output
x = rearrange(x, '(b d) e -> b d e', b=B)
# -------------- Slice Fusion --------------
if self.slice_fusion_type == 'none':
return x
elif self.slice_fusion_type == 'transformer':
cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=x.device)
pad = torch.concat([x_pad, cls_pad], dim=1) # [B, D+1]
x = torch.concat([x, self.cls_token.repeat(B, 1, 1)], dim=1) # [B, 1+D, E]
if output_attentions:
x, slice_hiddens = self.slice_fusion(x, mask=~pad, return_hiddens=True) # [B, D+1, E]
else:
x = self.slice_fusion(x, mask=~pad) # [B, D+1, L]
elif self.slice_fusion_type == 'linear':
x = rearrange(x, 'b d e -> b e d')
x = self.slice_fusion(x) # -> [B, E, 1]
x = rearrange(x, 'b e d -> b d e') # -> [B, 1, E]
elif self.slice_fusion_type == 'average':
x = x.mean(dim=1, keepdim=True) # [B, D, E] -> [B, 1, E]
# -------------- Logits --------------
x = self.linear(x[:, -1])
if output_attentions:
slice_attn_layers = [
interm.post_softmax_attn
for interm in getattr(slice_hiddens, 'attn_intermediates', [])
if interm is not None and getattr(interm, 'post_softmax_attn', None) is not None
]
return x, backbone_out.attentions, slice_attn_layers
return x
def forward_attention(self, x) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, C, D, _, _ = x.shape
# Disable fast attention
attn_impl = self.backbone.config._attn_implementation
self.backbone.set_attn_implementation("eager")
flash_modules = []
for module in self.slice_fusion.modules():
if hasattr(module, 'flash'):
flash_modules.append((module, module.flash))
module.flash = False
out, backbone_attn, slice_attn_layers = self.forward(x, output_attentions=True)
# Restore previous attention implementation
for module, previous in flash_modules:
module.flash = previous
if hasattr(self.backbone, "set_attn_implementation"):
self.backbone.set_attn_implementation(attn_impl)
# Process attentions
slice_attn = torch.stack(slice_attn_layers)[-1]
slice_attn = slice_attn.mean(dim=1)
slice_attn = slice_attn[:, -1, :-1]
slice_attn = slice_attn.view(B, C, D).mean(dim=1)
plane_attn_layers = [att for att in backbone_attn if att is not None]
plane_attn = torch.stack(plane_attn_layers)[-1]
plane_attn = plane_attn.mean(dim=1)
num_reg_tokens = getattr(self.backbone.config, 'num_register_tokens', 0)
plane_attn = plane_attn[:, 0, 1 + num_reg_tokens:]
plane_attn = plane_attn.view(B, C * D, -1)
# Weight every slice by its slice attention
plane_attn = plane_attn * slice_attn.unsqueeze(-1)
num_patches = plane_attn.shape[-1]
side = int(math.sqrt(num_patches))
if side * side != num_patches:
raise RuntimeError("number of patches is not a perfect square")
plane_attn = plane_attn.reshape(B, C * D, side, side)
return out, plane_attn, slice_attn
class MSTRegression(nn.Module):
def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", weights=True, **kwargs):
super().__init__()
self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type, weights=weights)
def forward(self, x):
return self.mst(x)
def forward_attention(self, x):
return self.mst.forward_attention(x)