|
|
from einops import rearrange |
|
|
import torch.nn as nn |
|
|
import torch |
|
|
import math |
|
|
from transformers import AutoModel, Dinov2WithRegistersModel, Dinov2WithRegistersConfig, DINOv3ViTConfig, DINOv3ViTModel |
|
|
from x_transformers import Encoder |
|
|
|
|
|
|
|
|
class _MST(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
out_ch=1, |
|
|
backbone_type="dinov3", |
|
|
model_size = "s", |
|
|
slice_fusion_type = "transformer", |
|
|
weights=True, |
|
|
): |
|
|
super().__init__() |
|
|
self.backbone_type = backbone_type |
|
|
self.slice_fusion_type = slice_fusion_type |
|
|
|
|
|
if backbone_type == "dinov2": |
|
|
model_size_key = {'s':'small', 'b':'base', 'l':'large'}.get(model_size) |
|
|
model_name = f"facebook/dinov2-with-registers-{model_size_key}" |
|
|
if weights: |
|
|
self.backbone = AutoModel.from_pretrained(model_name) |
|
|
else: |
|
|
configs = { |
|
|
'small': Dinov2WithRegistersConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6), |
|
|
'base': Dinov2WithRegistersConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12), |
|
|
'large': Dinov2WithRegistersConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16), |
|
|
} |
|
|
config = configs.get(model_size_key) |
|
|
config.image_size = 518 |
|
|
config.patch_size = 14 |
|
|
self.backbone = Dinov2WithRegistersModel(config) |
|
|
emb_ch = self.backbone.config.hidden_size |
|
|
elif backbone_type == "dinov3": |
|
|
model_name = f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m" |
|
|
if weights: |
|
|
self.backbone = AutoModel.from_pretrained(model_name) |
|
|
else: |
|
|
configs = { |
|
|
's': DINOv3ViTConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, intermediate_size=1536, patch_size=16, num_register_tokens=4), |
|
|
'b': DINOv3ViTConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, patch_size=16, num_register_tokens=4), |
|
|
'l': DINOv3ViTConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, patch_size=16, num_register_tokens=4), |
|
|
} |
|
|
config = configs.get(model_size) |
|
|
self.backbone = DINOv3ViTModel(config) |
|
|
else: |
|
|
raise ValueError("Unknown backbone_type") |
|
|
|
|
|
emb_ch = self.backbone.config.hidden_size |
|
|
self.emb_ch = emb_ch |
|
|
if slice_fusion_type == "transformer": |
|
|
self.slice_fusion = Encoder( |
|
|
dim = emb_ch, |
|
|
heads = 12 if emb_ch%12 == 0 else 8, |
|
|
ff_mult = 1, |
|
|
attn_dropout=0.0, |
|
|
pre_norm = True, |
|
|
depth = 1, |
|
|
attn_flash = True, |
|
|
ff_no_bias = True, |
|
|
rotary_pos_emb=True, |
|
|
) |
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_ch)) |
|
|
elif slice_fusion_type == 'average': |
|
|
pass |
|
|
elif slice_fusion_type == "none": |
|
|
pass |
|
|
else: |
|
|
raise ValueError("Unknown slice_fusion_type") |
|
|
|
|
|
self.linear = nn.Linear(emb_ch, out_ch) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, output_attentions=False): |
|
|
B, *_ = x.shape |
|
|
|
|
|
|
|
|
x_pad = torch.isclose(x.mean(dim=(-1,-2)), x[:, :, :, 0, 0]) |
|
|
x_pad = rearrange(x_pad, 'b c d -> b (c d)') |
|
|
|
|
|
x = rearrange(x, 'b c d h w -> (b c d) h w') |
|
|
x = x[:, None] |
|
|
x = x.repeat(1, 3, 1, 1) |
|
|
|
|
|
|
|
|
backbone_out = self.backbone(x, output_attentions=output_attentions) |
|
|
x = backbone_out.pooler_output |
|
|
x = rearrange(x, '(b d) e -> b d e', b=B) |
|
|
|
|
|
|
|
|
if self.slice_fusion_type == 'none': |
|
|
return x |
|
|
elif self.slice_fusion_type == 'transformer': |
|
|
cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=x.device) |
|
|
pad = torch.concat([x_pad, cls_pad], dim=1) |
|
|
x = torch.concat([x, self.cls_token.repeat(B, 1, 1)], dim=1) |
|
|
if output_attentions: |
|
|
x, slice_hiddens = self.slice_fusion(x, mask=~pad, return_hiddens=True) |
|
|
else: |
|
|
x = self.slice_fusion(x, mask=~pad) |
|
|
elif self.slice_fusion_type == 'linear': |
|
|
x = rearrange(x, 'b d e -> b e d') |
|
|
x = self.slice_fusion(x) |
|
|
x = rearrange(x, 'b e d -> b d e') |
|
|
elif self.slice_fusion_type == 'average': |
|
|
x = x.mean(dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
x = self.linear(x[:, -1]) |
|
|
if output_attentions: |
|
|
slice_attn_layers = [ |
|
|
interm.post_softmax_attn |
|
|
for interm in getattr(slice_hiddens, 'attn_intermediates', []) |
|
|
if interm is not None and getattr(interm, 'post_softmax_attn', None) is not None |
|
|
] |
|
|
return x, backbone_out.attentions, slice_attn_layers |
|
|
return x |
|
|
|
|
|
def forward_attention(self, x) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
B, C, D, _, _ = x.shape |
|
|
|
|
|
attn_impl = self.backbone.config._attn_implementation |
|
|
self.backbone.set_attn_implementation("eager") |
|
|
flash_modules = [] |
|
|
for module in self.slice_fusion.modules(): |
|
|
if hasattr(module, 'flash'): |
|
|
flash_modules.append((module, module.flash)) |
|
|
module.flash = False |
|
|
|
|
|
out, backbone_attn, slice_attn_layers = self.forward(x, output_attentions=True) |
|
|
|
|
|
|
|
|
for module, previous in flash_modules: |
|
|
module.flash = previous |
|
|
if hasattr(self.backbone, "set_attn_implementation"): |
|
|
self.backbone.set_attn_implementation(attn_impl) |
|
|
|
|
|
|
|
|
slice_attn = torch.stack(slice_attn_layers)[-1] |
|
|
slice_attn = slice_attn.mean(dim=1) |
|
|
slice_attn = slice_attn[:, -1, :-1] |
|
|
slice_attn = slice_attn.view(B, C, D).mean(dim=1) |
|
|
|
|
|
plane_attn_layers = [att for att in backbone_attn if att is not None] |
|
|
plane_attn = torch.stack(plane_attn_layers)[-1] |
|
|
plane_attn = plane_attn.mean(dim=1) |
|
|
num_reg_tokens = getattr(self.backbone.config, 'num_register_tokens', 0) |
|
|
plane_attn = plane_attn[:, 0, 1 + num_reg_tokens:] |
|
|
plane_attn = plane_attn.view(B, C * D, -1) |
|
|
|
|
|
|
|
|
plane_attn = plane_attn * slice_attn.unsqueeze(-1) |
|
|
|
|
|
num_patches = plane_attn.shape[-1] |
|
|
side = int(math.sqrt(num_patches)) |
|
|
if side * side != num_patches: |
|
|
raise RuntimeError("number of patches is not a perfect square") |
|
|
plane_attn = plane_attn.reshape(B, C * D, side, side) |
|
|
|
|
|
return out, plane_attn, slice_attn |
|
|
|
|
|
|
|
|
class MSTRegression(nn.Module): |
|
|
def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", weights=True, **kwargs): |
|
|
super().__init__() |
|
|
self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type, weights=weights) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.mst(x) |
|
|
|
|
|
def forward_attention(self, x): |
|
|
return self.mst.forward_attention(x) |