File size: 7,743 Bytes
255fb0d fb68040 255fb0d fb68040 255fb0d fb68040 255fb0d fb68040 255fb0d fb68040 255fb0d fb68040 255fb0d fb68040 255fb0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from einops import rearrange
import torch.nn as nn
import torch
import math
from transformers import AutoModel, Dinov2WithRegistersModel, Dinov2WithRegistersConfig, DINOv3ViTConfig, DINOv3ViTModel
from x_transformers import Encoder
class _MST(nn.Module):
def __init__(
self,
out_ch=1,
backbone_type="dinov3",
model_size = "s", # 34, 50, ... or 's', 'b', 'l'
slice_fusion_type = "transformer", # transformer, linear, average, none
weights=True,
):
super().__init__()
self.backbone_type = backbone_type
self.slice_fusion_type = slice_fusion_type
if backbone_type == "dinov2":
model_size_key = {'s':'small', 'b':'base', 'l':'large'}.get(model_size)
model_name = f"facebook/dinov2-with-registers-{model_size_key}"
if weights:
self.backbone = AutoModel.from_pretrained(model_name)
else:
configs = {
'small': Dinov2WithRegistersConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6),
'base': Dinov2WithRegistersConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12),
'large': Dinov2WithRegistersConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16),
}
config = configs.get(model_size_key)
config.image_size = 518
config.patch_size = 14
self.backbone = Dinov2WithRegistersModel(config)
emb_ch = self.backbone.config.hidden_size
elif backbone_type == "dinov3":
model_name = f"facebook/dinov3-vit{model_size}16-pretrain-lvd1689m"
if weights:
self.backbone = AutoModel.from_pretrained(model_name)
else:
configs = {
's': DINOv3ViTConfig(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, intermediate_size=1536, patch_size=16, num_register_tokens=4),
'b': DINOv3ViTConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, patch_size=16, num_register_tokens=4),
'l': DINOv3ViTConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, patch_size=16, num_register_tokens=4),
}
config = configs.get(model_size)
self.backbone = DINOv3ViTModel(config)
else:
raise ValueError("Unknown backbone_type")
emb_ch = self.backbone.config.hidden_size
self.emb_ch = emb_ch
if slice_fusion_type == "transformer":
self.slice_fusion = Encoder(
dim = emb_ch,
heads = 12 if emb_ch%12 == 0 else 8,
ff_mult = 1,
attn_dropout=0.0,
pre_norm = True,
depth = 1,
attn_flash = True,
ff_no_bias = True,
rotary_pos_emb=True,
)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_ch))
elif slice_fusion_type == 'average':
pass
elif slice_fusion_type == "none":
pass
else:
raise ValueError("Unknown slice_fusion_type")
self.linear = nn.Linear(emb_ch, out_ch)
def forward(self, x, output_attentions=False):
B, *_ = x.shape
# Mask (Slices with constant padded values)
x_pad = torch.isclose(x.mean(dim=(-1,-2)), x[:, :, :, 0, 0]) # [B, C, D]
x_pad = rearrange(x_pad, 'b c d -> b (c d)')
x = rearrange(x, 'b c d h w -> (b c d) h w')
x = x[:, None]
x = x.repeat(1, 3, 1, 1) # Gray to RGB
# -------------- Backbone --------------
backbone_out = self.backbone(x, output_attentions=output_attentions)
x = backbone_out.pooler_output
x = rearrange(x, '(b d) e -> b d e', b=B)
# -------------- Slice Fusion --------------
if self.slice_fusion_type == 'none':
return x
elif self.slice_fusion_type == 'transformer':
cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=x.device)
pad = torch.concat([x_pad, cls_pad], dim=1) # [B, D+1]
x = torch.concat([x, self.cls_token.repeat(B, 1, 1)], dim=1) # [B, 1+D, E]
if output_attentions:
x, slice_hiddens = self.slice_fusion(x, mask=~pad, return_hiddens=True) # [B, D+1, E]
else:
x = self.slice_fusion(x, mask=~pad) # [B, D+1, L]
elif self.slice_fusion_type == 'linear':
x = rearrange(x, 'b d e -> b e d')
x = self.slice_fusion(x) # -> [B, E, 1]
x = rearrange(x, 'b e d -> b d e') # -> [B, 1, E]
elif self.slice_fusion_type == 'average':
x = x.mean(dim=1, keepdim=True) # [B, D, E] -> [B, 1, E]
# -------------- Logits --------------
x = self.linear(x[:, -1])
if output_attentions:
slice_attn_layers = [
interm.post_softmax_attn
for interm in getattr(slice_hiddens, 'attn_intermediates', [])
if interm is not None and getattr(interm, 'post_softmax_attn', None) is not None
]
return x, backbone_out.attentions, slice_attn_layers
return x
def forward_attention(self, x) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, C, D, _, _ = x.shape
# Disable fast attention
attn_impl = self.backbone.config._attn_implementation
self.backbone.set_attn_implementation("eager")
flash_modules = []
for module in self.slice_fusion.modules():
if hasattr(module, 'flash'):
flash_modules.append((module, module.flash))
module.flash = False
out, backbone_attn, slice_attn_layers = self.forward(x, output_attentions=True)
# Restore previous attention implementation
for module, previous in flash_modules:
module.flash = previous
if hasattr(self.backbone, "set_attn_implementation"):
self.backbone.set_attn_implementation(attn_impl)
# Process attentions
slice_attn = torch.stack(slice_attn_layers)[-1]
slice_attn = slice_attn.mean(dim=1)
slice_attn = slice_attn[:, -1, :-1]
slice_attn = slice_attn.view(B, C, D).mean(dim=1)
plane_attn_layers = [att for att in backbone_attn if att is not None]
plane_attn = torch.stack(plane_attn_layers)[-1]
plane_attn = plane_attn.mean(dim=1)
num_reg_tokens = getattr(self.backbone.config, 'num_register_tokens', 0)
plane_attn = plane_attn[:, 0, 1 + num_reg_tokens:]
plane_attn = plane_attn.view(B, C * D, -1)
# Weight every slice by its slice attention
plane_attn = plane_attn * slice_attn.unsqueeze(-1)
num_patches = plane_attn.shape[-1]
side = int(math.sqrt(num_patches))
if side * side != num_patches:
raise RuntimeError("number of patches is not a perfect square")
plane_attn = plane_attn.reshape(B, C * D, side, side)
return out, plane_attn, slice_attn
class MSTRegression(nn.Module):
def __init__(self, in_ch=1, out_ch=1, spatial_dims=3, backbone_type="dinov3", model_size="s", slice_fusion_type="transformer", weights=True, **kwargs):
super().__init__()
self.mst = _MST(out_ch=out_ch, backbone_type=backbone_type, model_size=model_size, slice_fusion_type=slice_fusion_type, weights=weights)
def forward(self, x):
return self.mst(x)
def forward_attention(self, x):
return self.mst.forward_attention(x) |