""" Argus: multi-task perception on a single EUPE-ViT-B backbone. from transformers import AutoModel model = AutoModel.from_pretrained("phanerozoic/argus", trust_remote_code=True) result = model.perceive(image) The EUPE-ViT-B backbone architecture, all supporting layers, and the Argus task heads are inlined below. The backbone code is reproduced from facebookresearch/EUPE (Meta FAIR) under the FAIR Research License. """ import math import time from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F import torch.nn.init from PIL import Image from torch import Tensor, nn from torchvision.ops import nms from torchvision.transforms import v2 from transformers import PretrainedConfig, PreTrainedModel # =========================================================================== # EUPE backbone — vendored verbatim from facebookresearch/EUPE # =========================================================================== # ---------- utility helpers (from eupe/utils/utils.py) --------------------- def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]: shapes = [x.shape for x in x_list] num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list] flattened = torch.cat([x.flatten(0, -2) for x in x_list]) return flattened, shapes, num_tokens def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]: outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0) shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes] outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)] return outputs_reshaped def named_apply( fn: Callable, module: nn.Module, name: str = "", depth_first: bool = True, include_root: bool = False, ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module # ---------- RMSNorm (from eupe/layers/rms_norm.py) ------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def reset_parameters(self) -> None: nn.init.constant_(self.weight, 1) def _norm(self, x: Tensor) -> Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: Tensor) -> Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight # ---------- LayerScale (from eupe/layers/layer_scale.py) ------------------- class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, device=None, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(torch.empty(dim, device=device)) self.init_values = init_values def reset_parameters(self): nn.init.constant_(self.gamma, self.init_values) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma # ---------- PatchEmbed (from eupe/layers/patch_embed.py) ------------------- def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape x = self.proj(x) H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) return x def reset_parameters(self): k = 1 / (self.in_chans * (self.patch_size[0] ** 2)) nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k)) if self.proj.bias is not None: nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k)) # ---------- RoPE (from eupe/layers/rope_position_encoding.py) -------------- class RopePositionEmbedding(nn.Module): def __init__( self, embed_dim: int, *, num_heads: int, base: Optional[float] = 100.0, min_period: Optional[float] = None, max_period: Optional[float] = None, normalize_coords: Literal["min", "max", "separate"] = "separate", shift_coords: Optional[float] = None, jitter_coords: Optional[float] = None, rescale_coords: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): super().__init__() assert embed_dim % (4 * num_heads) == 0 both_periods = min_period is not None and max_period is not None if (base is None and not both_periods) or (base is not None and both_periods): raise ValueError("Either `base` or `min_period`+`max_period` must be provided.") D_head = embed_dim // num_heads self.base = base self.min_period = min_period self.max_period = max_period self.D_head = D_head self.normalize_coords = normalize_coords self.shift_coords = shift_coords self.jitter_coords = jitter_coords self.rescale_coords = rescale_coords self.dtype = dtype self.register_buffer( "periods", torch.empty(D_head // 4, device=device, dtype=dtype), persistent=True, ) self._init_weights() def forward(self, *, H: int, W: int) -> Tuple[Tensor, Tensor]: device = self.periods.device dtype = self.dtype dd = {"device": device, "dtype": dtype} if self.normalize_coords == "max": max_HW = max(H, W) coords_h = torch.arange(0.5, H, **dd) / max_HW coords_w = torch.arange(0.5, W, **dd) / max_HW elif self.normalize_coords == "min": min_HW = min(H, W) coords_h = torch.arange(0.5, H, **dd) / min_HW coords_w = torch.arange(0.5, W, **dd) / min_HW elif self.normalize_coords == "separate": coords_h = torch.arange(0.5, H, **dd) / H coords_w = torch.arange(0.5, W, **dd) / W else: raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) coords = coords.flatten(0, 1) coords = 2.0 * coords - 1.0 if self.training and self.shift_coords is not None: shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords) coords += shift_hw[None, :] if self.training and self.jitter_coords is not None: jitter_max = np.log(self.jitter_coords) jitter_min = -jitter_max jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() coords *= jitter_hw[None, :] if self.training and self.rescale_coords is not None: rescale_max = np.log(self.rescale_coords) rescale_min = -rescale_max rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() coords *= rescale_hw angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] angles = angles.flatten(1, 2) angles = angles.tile(2) cos = torch.cos(angles) sin = torch.sin(angles) return (sin, cos) def _init_weights(self): device = self.periods.device dtype = self.dtype if self.base is not None: periods = self.base ** ( 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) ) else: base = self.max_period / self.min_period exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) periods = base ** exponents periods = periods / base periods = periods * self.max_period self.periods.data = periods # ---------- FFN layers (from eupe/layers/ffn_layers.py) -------------------- class ListForwardMixin(object): def forward(self, x: Tensor): raise NotImplementedError def forward_list(self, x_list: List[Tensor]) -> List[Tensor]: x_flat, shapes, num_tokens = cat_keep_shapes(x_list) x_flat = self.forward(x_flat) return uncat_with_shapes(x_flat, shapes, num_tokens) class Mlp(nn.Module, ListForwardMixin): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, device=None, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class SwiGLUFFN(nn.Module, ListForwardMixin): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Optional[Callable[..., nn.Module]] = None, drop: float = 0.0, bias: bool = True, align_to: int = 8, device=None, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features d = int(hidden_features * 2 / 3) swiglu_hidden_features = d + (-d % align_to) self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) def forward(self, x: Tensor) -> Tensor: x1 = self.w1(x) x2 = self.w2(x) hidden = F.silu(x1) * x2 return self.w3(hidden) # ---------- Attention (from eupe/layers/attention.py) ---------------------- def rope_rotate_half(x: Tensor) -> Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: return (x * cos) + (rope_rotate_half(x) * sin) class LinearKMaskedBias(nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) o = self.out_features assert o % 3 == 0 if self.bias is not None: self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan)) def forward(self, input: Tensor) -> Tensor: masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None return F.linear(input, self.weight, masked_bias) class SelfAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, mask_k_bias: bool = False, device=None, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device) self.proj_drop = nn.Dropout(proj_drop) def apply_rope(self, q: Tensor, k: Tensor, rope) -> Tuple[Tensor, Tensor]: q_dtype = q.dtype k_dtype = k.dtype sin, cos = rope rope_dtype = sin.dtype q = q.to(dtype=rope_dtype) k = k.to(dtype=rope_dtype) N = q.shape[-2] prefix = N - sin.shape[-2] assert prefix >= 0 q_prefix = q[:, :, :prefix, :] q = rope_apply(q[:, :, prefix:, :], sin, cos) q = torch.cat((q_prefix, q), dim=-2) k_prefix = k[:, :, :prefix, :] k = rope_apply(k[:, :, prefix:, :], sin, cos) k = torch.cat((k_prefix, k), dim=-2) q = q.to(dtype=q_dtype) k = k.to(dtype=k_dtype) return q, k def forward(self, x: Tensor, attn_bias=None, rope=None) -> Tensor: qkv = self.qkv(x) attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope) x = self.proj(attn_v) x = self.proj_drop(x) return x def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]: assert len(x_list) == len(rope_list) x_flat, shapes, num_tokens = cat_keep_shapes(x_list) qkv_flat = self.qkv(x_flat) qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens) att_out = [] for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)): att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope)) x_flat, shapes, num_tokens = cat_keep_shapes(att_out) x_flat = self.proj(x_flat) return uncat_with_shapes(x_flat, shapes, num_tokens) def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor: assert attn_bias is None B, N, _ = qkv.shape C = self.qkv.in_features qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = torch.unbind(qkv, 2) q, k, v = [t.transpose(1, 2) for t in [q, k, v]] if rope is not None: q, k = self.apply_rope(q, k, rope) x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = x.transpose(1, 2) return x.reshape([B, N, C]) # ---------- Block (from eupe/layers/block.py) ------------------------------ class SelfAttentionBlock(nn.Module): def __init__( self, dim: int, num_heads: int, ffn_ratio: float = 4.0, qkv_bias: bool = False, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = SelfAttention, ffn_layer: Callable[..., nn.Module] = Mlp, mask_k_bias: bool = False, device=None, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, mask_k_bias=mask_k_bias, device=device, ) self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * ffn_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, device=device, ) self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() self.sample_drop_ratio = drop_path @staticmethod def _maybe_index_rope(rope, indices: Tensor): if rope is None: return None sin, cos = rope assert sin.ndim == cos.ndim if sin.ndim == 4: return sin[indices], cos[indices] return sin, cos def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]: b_list = [x.shape[0] for x in x_list] sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list] if self.training and self.sample_drop_ratio > 0.0: residual_scale_factors = [b / s for b, s in zip(b_list, sample_subset_sizes)] indices_1_list = [ torch.randperm(b, device=x.device)[:s] for x, b, s in zip(x_list, b_list, sample_subset_sizes) ] x_subset_1_list = [x[i] for x, i in zip(x_list, indices_1_list)] if rope_list is not None: rope_subset_list = [ self._maybe_index_rope(r, i) for r, i in zip(rope_list, indices_1_list) ] else: rope_subset_list = rope_list flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list) norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens) residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list) x_attn_list = [ torch.index_add(x, dim=0, source=self.ls1(r1), index=i1, alpha=rsf) for x, r1, i1, rsf in zip(x_list, residual_1_list, indices_1_list, residual_scale_factors) ] indices_2_list = [ torch.randperm(b, device=x.device)[:s] for x, b, s in zip(x_list, b_list, sample_subset_sizes) ] x_subset_2_list = [x[i] for x, i in zip(x_attn_list, indices_2_list)] flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list) norm2_list = uncat_with_shapes(self.norm2(flattened), shapes, num_tokens) residual_2_list = self.mlp.forward_list(norm2_list) x_ffn = [ torch.index_add(xa, dim=0, source=self.ls2(r2), index=i2, alpha=rsf) for xa, r2, i2, rsf in zip(x_attn_list, residual_2_list, indices_2_list, residual_scale_factors) ] else: x_out = [] for x, rope in zip(x_list, rope_list): x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) x_out.append(x_ffn) x_ffn = x_out return x_ffn def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]: if isinstance(x_or_x_list, Tensor): return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0] elif isinstance(x_or_x_list, list): if rope_or_rope_list is None: rope_or_rope_list = [None for _ in x_or_x_list] return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list) raise AssertionError # ---------- DinoVisionTransformer (from eupe/models/vision_transformer.py) ffn_layer_dict = { "mlp": Mlp, "swiglu": SwiGLUFFN, "swiglu32": partial(SwiGLUFFN, align_to=32), "swiglu64": partial(SwiGLUFFN, align_to=64), "swiglu128": partial(SwiGLUFFN, align_to=128), } norm_layer_dict = { "layernorm": partial(nn.LayerNorm, eps=1e-6), "layernormbf16": partial(nn.LayerNorm, eps=1e-5), "rmsnorm": RMSNorm, } dtype_dict = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, } def init_weights_vit(module: nn.Module, name: str = ""): if isinstance(module, nn.Linear): torch.nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) if hasattr(module, "bias_mask") and module.bias_mask is not None: o = module.out_features module.bias_mask.fill_(1) module.bias_mask[o // 3 : 2 * o // 3].fill_(0) if isinstance(module, nn.LayerNorm): module.reset_parameters() if isinstance(module, LayerScale): module.reset_parameters() if isinstance(module, PatchEmbed): module.reset_parameters() if isinstance(module, RMSNorm): module.reset_parameters() class DinoVisionTransformer(nn.Module): def __init__( self, *, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, pos_embed_rope_base: float = 100.0, pos_embed_rope_min_period: Optional[float] = None, pos_embed_rope_max_period: Optional[float] = None, pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", pos_embed_rope_shift_coords: Optional[float] = None, pos_embed_rope_jitter_coords: Optional[float] = None, pos_embed_rope_rescale_coords: Optional[float] = None, pos_embed_rope_dtype: str = "bf16", embed_dim: int = 768, depth: int = 12, num_heads: int = 12, ffn_ratio: float = 4.0, qkv_bias: bool = True, drop_path_rate: float = 0.0, layerscale_init: Optional[float] = None, norm_layer: str = "layernorm", ffn_layer: str = "mlp", ffn_bias: bool = True, proj_bias: bool = True, n_storage_tokens: int = 0, mask_k_bias: bool = False, untie_cls_and_patch_norms: bool = False, untie_global_and_local_cls_norm: bool = False, device: Any = None, **ignored_kwargs, ): super().__init__() del ignored_kwargs norm_layer_cls = norm_layer_dict[norm_layer] self.num_features = self.embed_dim = embed_dim self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, flatten_embedding=False, ) self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device)) self.n_storage_tokens = n_storage_tokens if self.n_storage_tokens > 0: self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device)) self.rope_embed = RopePositionEmbedding( embed_dim=embed_dim, num_heads=num_heads, base=pos_embed_rope_base, min_period=pos_embed_rope_min_period, max_period=pos_embed_rope_max_period, normalize_coords=pos_embed_rope_normalize_coords, shift_coords=pos_embed_rope_shift_coords, jitter_coords=pos_embed_rope_jitter_coords, rescale_coords=pos_embed_rope_rescale_coords, dtype=dtype_dict[pos_embed_rope_dtype], device=device, ) ffn_layer_cls = ffn_layer_dict[ffn_layer] ffn_ratio_sequence = [ffn_ratio] * depth blocks_list = [ SelfAttentionBlock( dim=embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio_sequence[i], qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=drop_path_rate, norm_layer=norm_layer_cls, act_layer=nn.GELU, ffn_layer=ffn_layer_cls, init_values=layerscale_init, mask_k_bias=mask_k_bias, device=device, ) for i in range(depth) ] self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer_cls(embed_dim) self.untie_cls_and_patch_norms = untie_cls_and_patch_norms self.cls_norm = norm_layer_cls(embed_dim) if untie_cls_and_patch_norms else None self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm self.local_cls_norm = norm_layer_cls(embed_dim) if untie_global_and_local_cls_norm else None self.head = nn.Identity() self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device)) def init_weights(self): self.rope_embed._init_weights() nn.init.normal_(self.cls_token, std=0.02) if self.n_storage_tokens > 0: nn.init.normal_(self.storage_tokens, std=0.02) nn.init.zeros_(self.mask_token) named_apply(init_weights_vit, self) def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int, int]]: x = self.patch_embed(x) B, H, W, _ = x.shape x = x.flatten(1, 2) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) cls_token = self.cls_token else: cls_token = self.cls_token + 0 * self.mask_token if self.n_storage_tokens > 0: storage_tokens = self.storage_tokens else: storage_tokens = torch.empty( 1, 0, cls_token.shape[-1], dtype=cls_token.dtype, device=cls_token.device, ) x = torch.cat( [cls_token.expand(B, -1, -1), storage_tokens.expand(B, -1, -1), x], dim=1, ) return x, (H, W) def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]: x = [] rope = [] for t_x, t_masks in zip(x_list, masks_list): t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks) x.append(t2_x) rope.append(hw_tuple) for blk in self.blocks: if self.rope_embed is not None: rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope] else: rope_sincos = [None for _ in rope] x = blk(x, rope_sincos) all_x = x output = [] for idx, (x, masks) in enumerate(zip(all_x, masks_list)): if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm: if self.untie_global_and_local_cls_norm and self.training and idx == 1: x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1]) elif self.untie_cls_and_patch_norms: x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1]) else: x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1]) x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :]) else: x_norm = self.norm(x) x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1] x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :] output.append({ "x_norm_clstoken": x_norm_cls_reg[:, 0], "x_storage_tokens": x_norm_cls_reg[:, 1:], "x_norm_patchtokens": x_norm_patch, "x_prenorm": x, "masks": masks, }) return output def forward_features(self, x, masks: Optional[Tensor] = None): if isinstance(x, torch.Tensor): return self.forward_features_list([x], [masks])[0] return self.forward_features_list(x, masks) def forward(self, *args, is_training: bool = False, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret return self.head(ret["x_norm_clstoken"]) def build_eupe_vitb16() -> DinoVisionTransformer: # qkv_bias=False, mask_k_bias=False: the upstream EUPE-ViT-B release shipped # with `qkv.bias_mask` filled with zeros, which makes the effective qkv bias # zero at every block (masked_bias = bias * 0 = 0). We drop the bias parameter # entirely here — the computation is bitwise-equivalent in fp32, bf16 output # drift is sub-ULP and absorbed by every head except DPT depth (where it # appears as ~2cm noise against a 39cm RMSE, i.e. below the head's own floor). return DinoVisionTransformer( img_size=224, patch_size=16, in_chans=3, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=768, depth=12, num_heads=12, ffn_ratio=4, qkv_bias=False, drop_path_rate=0.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", ffn_layer="mlp", ffn_bias=True, proj_bias=True, n_storage_tokens=4, mask_k_bias=False, ) # =========================================================================== # Argus task heads # =========================================================================== def make_eupe_transform(resize_size: int): return v2.Compose([ v2.ToImage(), v2.Resize((resize_size, resize_size), antialias=True), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) def _normalize_image_input(image_or_images) -> Tuple[bool, list]: """Returns (was_single, [images]). Accepts a PIL.Image or an iterable of them.""" if isinstance(image_or_images, Image.Image): return True, [image_or_images] images = list(image_or_images) if not images: raise ValueError("empty image list") for i, img in enumerate(images): if not isinstance(img, Image.Image): raise TypeError(f"images[{i}] is {type(img).__name__}, expected PIL.Image") return False, images class _BackboneExportWrapper(nn.Module): """ONNX-friendly wrapper: returns (cls, spatial) instead of a dict.""" def __init__(self, backbone: nn.Module): super().__init__() self.backbone = backbone def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: out = self.backbone.forward_features(x) cls = out["x_norm_clstoken"] patches = out["x_norm_patchtokens"] B, N, D = patches.shape h = w = int(N ** 0.5) spatial = patches.permute(0, 2, 1).reshape(B, D, h, w) return cls, spatial class _SegHeadExportWrapper(nn.Module): """ONNX-friendly wrapper: seg head + bilinear upsample to input resolution. The bare seg head emits stride-16 logits (e.g. [B, 150, 40, 40] at 640px input). model.segment() upsamples those to the input resolution before argmax. This wrapper folds the upsample into the graph so the ONNX seg output is already at input resolution — consumers argmax directly without a separate interpolation step. """ def __init__(self, seg_head: nn.Module, resolution: int): super().__init__() self.seg_head = seg_head self.resolution = resolution def forward(self, spatial_features: Tensor) -> Tensor: logits = self.seg_head(spatial_features) return F.interpolate(logits, size=(self.resolution, self.resolution), mode="bilinear", align_corners=False) class _DepthHeadExportWrapper(nn.Module): """ONNX-friendly wrapper for the DPT depth head. DPTDepthDecoder.forward takes (intermediates: List[Tensor], H: int, W: int), which torch.onnx.export cannot trace cleanly because the List contains four tensors and H/W are Python ints. The wrapper accepts the four intermediate ViT-block activations as separate positional tensor inputs and forwards them to the underlying decoder with the captured H and W. """ def __init__(self, depth_head: nn.Module, H: int, W: int): super().__init__() self.depth_head = depth_head self.H = H self.W = W def forward(self, inter0: Tensor, inter1: Tensor, inter2: Tensor, inter3: Tensor) -> Tensor: return self.depth_head([inter0, inter1, inter2, inter3], self.H, self.W) class _ClassifierExportWrapper(nn.Module): """ONNX-friendly wrapper for the ImageNet linear-softmax classifier. Takes the backbone's CLS token, L2-normalizes, applies the stored Linear(embed_dim, 1000) weight + bias, and returns a softmax distribution over the 1000 ImageNet classes. The weight and bias are captured as buffers so the graph is self-contained — no separate weight file needed for classification inference. """ def __init__(self, class_weight: Tensor, class_bias: Tensor): super().__init__() self.register_buffer("weight", class_weight.float().clone()) self.register_buffer("bias", class_bias.float().clone()) def forward(self, cls_token: Tensor) -> Tensor: x = F.normalize(cls_token, dim=-1) logits = F.linear(x, self.weight, self.bias) return F.softmax(logits, dim=-1) class _ONNXBatchedNMS(torch.autograd.Function): """Autograd wrapper that exports to ONNX NonMaxSuppression (opset >= 10). ONNX's NonMaxSuppression handles batched multi-class NMS natively: boxes [B, N, 4] in [y1, x1, y2, x2] order (center_point_box=0) scores [B, C, N] -> selected_indices [M, 3] where each row is [batch, class, box] The eager forward path reproduces this via torchvision.ops.nms so PyTorch tracing and verify=True both work without calling into ORT for the reference. """ @staticmethod def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): return g.op( "NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, center_point_box_i=0, ) @staticmethod def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): from torchvision.ops import nms as tv_nms B, N, _ = boxes.shape _, C, _ = scores.shape max_out = int(max_output_boxes_per_class.item()) iou_thr = float(iou_threshold.item()) score_thr = float(score_threshold.item()) results: List[List[int]] = [] for b in range(B): for c in range(C): sc = scores[b, c] mask = sc > score_thr if not mask.any(): continue idx = mask.nonzero(as_tuple=True)[0] # tv_nms expects [x1, y1, x2, y2]; our boxes are [y1, x1, y2, x2]. bx_xyxy = boxes[b, idx][:, [1, 0, 3, 2]] keep = tv_nms(bx_xyxy, sc[idx], iou_thr)[:max_out] for k in keep.tolist(): results.append([b, c, int(idx[k].item())]) if not results: return torch.zeros((0, 3), dtype=torch.long, device=boxes.device) return torch.tensor(results, dtype=torch.long, device=boxes.device) class _DetectionHeadExportWrapper(nn.Module): """ONNX-friendly wrapper for the detection head (simple FPN + FCOS). Takes backbone stride-16 spatial features and returns decoded per-location predictions concatenated across all five FPN levels. Without NMS (default): - boxes [B, N_total, 4] xyxy in input-resolution pixels, decoded as (location - exp(reg)) / (location + exp(reg)) and clamped. - scores [B, N_total, num_classes] sigmoid(cls_logits) * sigmoid(centerness). With NMS (include_nms=True): - boxes [M, 4] xyxy in input-resolution pixels - scores [M] - class_labels [M] int64 class index - batch_indices[M] int64 batch index N_total = sum(H_i * W_i) across strides [8, 16, 32, 64, 128]. At 640px input: 6400 + 1600 + 400 + 100 + 25 = 8525 locations/image. The NMS variant folds ONNX's NonMaxSuppression (opset >= 10) into the graph using the configured iou / score / max_detections parameters, producing a flat list of surviving detections across all batches and classes. Useful for single-shot TensorRT / mobile inference. Without NMS the consumer runs their own — hard vs soft, per-class vs global, threshold tuning — without re-exporting. """ def __init__(self, detection_head: nn.Module, resolution: int, include_nms: bool = False, nms_iou_threshold: float = 0.5, nms_score_threshold: float = 0.05, nms_max_detections: int = 100): super().__init__() self.detection_head = detection_head self.resolution = resolution self.num_classes = detection_head.num_classes self.include_nms = include_nms self.nms_iou_threshold = nms_iou_threshold self.nms_score_threshold = nms_score_threshold self.nms_max_detections = nms_max_detections # Compute per-level spatial sizes from the SimpleFeaturePyramid's actual # output shapes, not from resolution // stride. The pyramid starts at # stride-16 backbone features (H = resolution // 16) and produces: # P3 = 2*H via ConvTranspose2d(stride=2) # P4 = H via 1x1 + 3x3 convs (no stride) # P5 = (H+1)//2 via Conv2d(3x3, stride=2, padding=1) # P6 = (P5+1)//2 via Conv2d on P5 # P7 = (P6+1)//2 via Conv2d on P6 # When resolution is a multiple of 128, these match resolution // stride # exactly; at other resolutions the stride-2 convs round up via the # padding=1 kernel=3 formula, so P6/P7 are slightly larger than # nominal stride division suggests. Feature-pyramid-level locations # still use the nominal FPN_STRIDES for FCOS box decoding because # that's what eager `model.detect` does. H = resolution // 16 p3 = 2 * H p4 = H p5 = (H + 1) // 2 p6 = (p5 + 1) // 2 p7 = (p6 + 1) // 2 feat_sizes = [(p3, p3), (p4, p4), (p5, p5), (p6, p6), (p7, p7)] locs_per_level = [] for (h, w), s in zip(feat_sizes, FPN_STRIDES): ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s gy, gx = torch.meshgrid(ys, xs, indexing="ij") locs_per_level.append(torch.stack([gx.flatten(), gy.flatten()], -1)) all_locs = torch.cat(locs_per_level, 0) self.register_buffer("all_locs", all_locs) def forward(self, spatial_features: Tensor): cls_logits, box_regs, centernesses = self.detection_head(spatial_features) B = spatial_features.shape[0] flat_cls = torch.cat( [c.permute(0, 2, 3, 1).reshape(B, -1, self.num_classes) for c in cls_logits], dim=1) flat_reg = torch.cat( [r.permute(0, 2, 3, 1).reshape(B, -1, 4) for r in box_regs], dim=1) flat_ctr = torch.cat( [c.permute(0, 2, 3, 1).reshape(B, -1, 1) for c in centernesses], dim=1) scores = torch.sigmoid(flat_cls) * torch.sigmoid(flat_ctr) locs = self.all_locs.unsqueeze(0).expand(B, -1, -1) x1 = (locs[..., 0:1] - flat_reg[..., 0:1]).clamp(0, self.resolution) y1 = (locs[..., 1:2] - flat_reg[..., 1:2]).clamp(0, self.resolution) x2 = (locs[..., 0:1] + flat_reg[..., 2:3]).clamp(0, self.resolution) y2 = (locs[..., 1:2] + flat_reg[..., 3:4]).clamp(0, self.resolution) boxes = torch.cat([x1, y1, x2, y2], dim=-1) if not self.include_nms: return boxes, scores # ONNX NMS expects boxes in [y1, x1, y2, x2] (center_point_box=0) and # scores with the class dim in the middle: [B, C, N]. boxes_yxyx = torch.cat([y1, x1, y2, x2], dim=-1) scores_bcn = scores.permute(0, 2, 1).contiguous() max_out = torch.tensor(self.nms_max_detections, dtype=torch.long, device=boxes.device) iou_thr = torch.tensor(self.nms_iou_threshold, dtype=torch.float32, device=boxes.device) score_thr = torch.tensor(self.nms_score_threshold, dtype=torch.float32, device=boxes.device) selected = _ONNXBatchedNMS.apply( boxes_yxyx, scores_bcn, max_out, iou_thr, score_thr, ) batch_idx = selected[:, 0].long() class_idx = selected[:, 1].long() box_idx = selected[:, 2].long() sel_boxes = boxes[batch_idx, box_idx] # [M, 4] xyxy sel_scores = scores[batch_idx, box_idx, class_idx] # [M] return sel_boxes, sel_scores, class_idx, batch_idx class SegmentationHead(nn.Module): def __init__(self, in_dim: int = 768, num_classes: int = 150): super().__init__() self.batchnorm_layer = nn.BatchNorm2d(in_dim) self.conv = nn.Conv2d(in_dim, num_classes, kernel_size=1) def forward(self, x: Tensor) -> Tensor: return self.conv(self.batchnorm_layer(x)) class DepthHead(nn.Module): def __init__(self, in_dim: int = 768, n_bins: int = 256, min_depth: float = 0.001, max_depth: float = 10.0): super().__init__() self.batchnorm_layer = nn.BatchNorm2d(in_dim) self.conv_depth = nn.Conv2d(in_dim, n_bins, kernel_size=1) self.min_depth = min_depth self.max_depth = max_depth self.n_bins = n_bins def forward(self, x: Tensor) -> Tensor: logits = self.conv_depth(self.batchnorm_layer(x)) logit = torch.relu(logits) + 0.1 logit = logit / logit.sum(dim=1, keepdim=True) bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device) return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1) # =========================================================================== # Detection (FCOS with ViTDet-style simple feature pyramid) # =========================================================================== FPN_STRIDES = [8, 16, 32, 64, 128] COCO_CLASSES = [ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", ] def cofiber_decompose(f: Tensor, n_scales: int) -> List[Tensor]: """Iterated multi-scale decomposition. Each step subtracts the downsampled-then-upsampled component of the current residual and recurses on the remainder. Zero learned parameters. The final entry is the lowest-frequency remainder.""" cofibers: List[Tensor] = [] residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega) residual = omega cofibers.append(residual) return cofibers def make_sin_pos_emb(H: int, W: int, dim: int, device) -> Tensor: """2D sinusoidal positional encoding over an H x W grid. Concatenated to the backbone patch features before the head stem.""" assert dim % 4 == 0, "pos emb dim must be divisible by 4" d = dim // 4 ys = torch.arange(H, device=device, dtype=torch.float32) xs = torch.arange(W, device=device, dtype=torch.float32) omega = torch.exp(torch.arange(d, device=device, dtype=torch.float32) * -(math.log(10000.0) / d)) pe_y = torch.zeros(H, d * 2, device=device) pe_y[:, 0::2] = torch.sin(ys[:, None] * omega[None, :]) pe_y[:, 1::2] = torch.cos(ys[:, None] * omega[None, :]) pe_x = torch.zeros(W, d * 2, device=device) pe_x[:, 0::2] = torch.sin(xs[:, None] * omega[None, :]) pe_x[:, 1::2] = torch.cos(xs[:, None] * omega[None, :]) pos = torch.zeros(dim, H, W, device=device) pos[:d * 2] = pe_y.permute(1, 0)[:, :, None].expand(-1, H, W) pos[d * 2:] = pe_x.permute(1, 0)[None, :, :].expand(H, -1, W).permute(1, 0, 2) return pos.unsqueeze(0) class ConvGNBlock(nn.Module): def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, padding=1) self.norm = nn.GroupNorm(min(32, channels), channels) self.act = nn.GELU() def forward(self, x: Tensor) -> Tensor: return self.act(self.norm(self.conv(x))) class DWResBlock(nn.Module): def __init__(self, channels: int): super().__init__() self.pw = nn.Conv2d(channels, channels, 1) self.act = nn.GELU() self.dw = nn.Conv2d(channels, channels, 3, padding=1, groups=channels) self.norm = nn.GroupNorm(min(32, channels), channels) def forward(self, x: Tensor) -> Tensor: return x + self.norm(self.dw(self.act(self.pw(x)))) def make_tower(hidden: int, n_std: int, n_dw: int) -> nn.Sequential: layers: List[nn.Module] = [ConvGNBlock(hidden) for _ in range(n_std)] layers += [DWResBlock(hidden) for _ in range(n_dw)] return nn.Sequential(*layers) class SplitTowerHead(nn.Module): """Detection head operating on a cofiber decomposition of the frozen backbone features. Five prediction levels (strides 8, 16, 32, 64, 128): a stride-8 level synthesized by a transposed convolution from the stride-16 band and four cofiber bands at strides 16, 32, 64, 128. Separate classification and regression towers of depth (n_std_layers + n_dw_layers) with weights shared across levels. Classification via cosine similarity against frozen CLIP text-encoder embeddings of the COCO class names; regression via exponentiated LTRB distances with a learned per-level scale; centerness via a single 1x1 convolution. Inference-only within Argus: no DFL, no IoU-aware branch, no per-scale bias. The text_embed buffer is populated by from_pretrained's state_dict load.""" def __init__(self, feat_dim: int = 768, hidden: int = 160, n_std_layers: int = 5, n_dw_layers: int = 4, n_scales: int = 4, pos_emb_dim: int = 64, num_classes: int = 80, text_embed_dim: int = 768): super().__init__() self.n_scales = n_scales self.pos_emb_dim = pos_emb_dim self.num_classes = num_classes self.text_embed_dim = text_embed_dim n_total = n_scales + 1 input_dim = feat_dim + pos_emb_dim self.scale_norms = nn.ModuleList([nn.GroupNorm(1, input_dim) for _ in range(n_scales)]) self.stem = nn.Conv2d(input_dim, hidden, 1) self.stem_act = nn.GELU() self.p3_upsample = nn.ConvTranspose2d(hidden, hidden, 2, stride=2) self.p3_norm = nn.GroupNorm(min(32, hidden), hidden) self.lateral_convs = nn.ModuleList([nn.Conv2d(hidden, hidden, 1) for _ in range(n_scales - 1)]) self.lateral_norms = nn.ModuleList( [nn.GroupNorm(min(32, hidden), hidden) for _ in range(n_scales - 1)]) self.cls_tower = make_tower(hidden, n_std_layers, n_dw_layers) self.reg_tower = make_tower(hidden, n_std_layers, n_dw_layers) # CLIP text-aligned classifier. The text_embed buffer is filled from # the state dict at from_pretrained; the zero placeholder here only # exists so the module can be constructed before weights arrive. self.register_buffer("text_embed", torch.zeros(num_classes, text_embed_dim)) self.cls_project = nn.Linear(hidden, text_embed_dim, bias=False) self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / 0.07))) self.cls_bias = nn.Parameter(torch.full((num_classes,), -math.log(99))) self.reg_pred = nn.Conv2d(hidden, 4, 1) self.ctr_pred = nn.Conv2d(hidden, 1, 1) self.scale_params = nn.Parameter(torch.ones(n_total)) def forward(self, spatial: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: B, C, H_, W_ = spatial.shape pos = make_sin_pos_emb(H_, W_, self.pos_emb_dim, spatial.device).expand(B, -1, -1, -1) spatial = torch.cat([spatial, pos], dim=1) cofibers = cofiber_decompose(spatial, self.n_scales) scale_features: List[Tensor] = [] for i, cof in enumerate(cofibers): x = self.stem_act(self.stem(self.scale_norms[i](cof))) scale_features.append(x) # Top-down lateral fusion from coarser to finer scales. for i in range(len(scale_features) - 2, -1, -1): coarse_up = F.interpolate(scale_features[i + 1], size=scale_features[i].shape[2:], mode="bilinear", align_corners=False) scale_features[i] = self.lateral_norms[i]( scale_features[i] + self.lateral_convs[i](coarse_up)) p3 = self.p3_norm(self.p3_upsample(scale_features[0])) all_features = [p3] + scale_features cls_l, reg_l, ctr_l = [], [], [] for i, x in enumerate(all_features): cls_feat = self.cls_tower(x) reg_feat = self.reg_tower(x) B_, _, Hi, Wi = cls_feat.shape f = cls_feat.permute(0, 2, 3, 1).reshape(-1, cls_feat.shape[1]) f_proj = self.cls_project(f) f_norm = F.normalize(f_proj, p=2, dim=-1) logits = f_norm @ self.text_embed.t() cls = (logits * self.logit_scale.exp() + self.cls_bias).reshape( B_, Hi, Wi, self.num_classes).permute(0, 3, 1, 2) reg_raw = (self.reg_pred(reg_feat) * self.scale_params[i]).clamp(-10, 10) reg = reg_raw.exp() ctr = self.ctr_pred(reg_feat) cls_l.append(cls) reg_l.append(reg) ctr_l.append(ctr) return cls_l, reg_l, ctr_l def _make_locations(feature_sizes: List[Tuple[int, int]], strides: List[int], device) -> List[Tensor]: """Per-level center coordinates of feature-map locations in image space.""" all_locs = [] for (h, w), s in zip(feature_sizes, strides): ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") locs = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) all_locs.append(locs) return all_locs @torch.inference_mode() def _decode_detections( cls_logits_per_level: List[Tensor], box_regs_per_level: List[Tensor], centernesses_per_level: List[Tensor], locations_per_level: List[Tensor], image_sizes: List[Tuple[int, int]], score_thresh: float = 0.05, nms_thresh: float = 0.5, max_per_level: int = 1000, max_per_image: int = 100, ) -> List[Dict[str, Tensor]]: """Convert per-level logits/regs/centerness into per-image detections (xyxy boxes).""" B = cls_logits_per_level[0].shape[0] num_classes = cls_logits_per_level[0].shape[1] device = cls_logits_per_level[0].device per_image_results = [] for image_idx in range(B): all_boxes, all_scores, all_labels = [], [], [] for cls_l, reg_l, ctr_l, locs_l in zip( cls_logits_per_level, box_regs_per_level, centernesses_per_level, locations_per_level ): cls = cls_l[image_idx].permute(1, 2, 0).reshape(-1, num_classes) reg = reg_l[image_idx].permute(1, 2, 0).reshape(-1, 4) ctr = ctr_l[image_idx].permute(1, 2, 0).reshape(-1) cls_prob = torch.sigmoid(cls) ctr_prob = torch.sigmoid(ctr) scores = cls_prob * ctr_prob[:, None] mask = scores > score_thresh if not mask.any(): continue cand_loc, cand_cls = mask.nonzero(as_tuple=True) cand_scores = scores[cand_loc, cand_cls] if cand_scores.numel() > max_per_level: top = cand_scores.topk(max_per_level) cand_scores = top.values idx = top.indices cand_loc = cand_loc[idx] cand_cls = cand_cls[idx] cand_locs_xy = locs_l[cand_loc] cand_reg = reg[cand_loc] boxes = torch.stack([ cand_locs_xy[:, 0] - cand_reg[:, 0], cand_locs_xy[:, 1] - cand_reg[:, 1], cand_locs_xy[:, 0] + cand_reg[:, 2], cand_locs_xy[:, 1] + cand_reg[:, 3], ], dim=-1) all_boxes.append(boxes) all_scores.append(cand_scores) all_labels.append(cand_cls) if all_boxes: boxes = torch.cat(all_boxes, dim=0) scores = torch.cat(all_scores, dim=0) labels = torch.cat(all_labels, dim=0) H, W = image_sizes[image_idx] boxes[:, 0::2] = boxes[:, 0::2].clamp(0, W) boxes[:, 1::2] = boxes[:, 1::2].clamp(0, H) keep_all = [] for c in labels.unique(): cm = labels == c keep = nms(boxes[cm], scores[cm], nms_thresh) keep_idx = cm.nonzero(as_tuple=True)[0][keep] keep_all.append(keep_idx) keep_all = torch.cat(keep_all, dim=0) boxes = boxes[keep_all] scores = scores[keep_all] labels = labels[keep_all] if scores.numel() > max_per_image: top = scores.topk(max_per_image) boxes = boxes[top.indices] scores = top.values labels = labels[top.indices] else: boxes = torch.zeros((0, 4), device=device) scores = torch.zeros((0,), device=device) labels = torch.zeros((0,), dtype=torch.long, device=device) per_image_results.append({"boxes": boxes, "scores": scores, "labels": labels}) return per_image_results def _letterbox_to_square(image: Image.Image, resolution: int) -> Tuple[Image.Image, float, Tuple[int, int]]: """Resize preserving aspect ratio and pad bottom/right with black. Matches the training transform.""" W0, H0 = image.size scale = resolution / max(H0, W0) new_w = int(round(W0 * scale)) new_h = int(round(H0 * scale)) resized = image.resize((new_w, new_h), Image.BILINEAR) canvas = Image.new("RGB", (resolution, resolution), (0, 0, 0)) canvas.paste(resized, (0, 0)) return canvas, scale, (W0, H0) # =========================================================================== # DPT depth decoder (multi-scale, hooks into ViT blocks [2, 5, 8, 11]) # =========================================================================== HOOK_BLOCK_INDICES = [2, 5, 8, 11] N_PREFIX_TOKENS = 5 # 1 CLS + 4 register/storage tokens class _ResidualConvUnit(nn.Module): """Two 3x3 conv + BatchNorm blocks with a residual connection. Padding mode is configurable: the Argus-B DPT depth head trains with reflect padding to avoid edge artifacts; Argus-Lite ships weights that were trained with zero padding (the PyTorch default), and switching pad modes at inference would create a small distribution shift in the edge regions. Variants pass `padding_mode` to keep their inference aligned with their training.""" def __init__(self, dim: int, padding_mode: str = "reflect"): super().__init__() self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode=padding_mode, bias=False) self.bn1 = nn.BatchNorm2d(dim) self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode=padding_mode, bias=False) self.bn2 = nn.BatchNorm2d(dim) self.act = nn.GELU() def forward(self, x: Tensor) -> Tensor: return x + self.bn2(self.conv2(self.act(self.bn1(self.conv1(x))))) class _FeatureFusionBlock(nn.Module): def __init__(self, dim: int, has_skip: bool = True, padding_mode: str = "reflect"): super().__init__() self.rcu1 = _ResidualConvUnit(dim, padding_mode=padding_mode) self.rcu2 = _ResidualConvUnit(dim, padding_mode=padding_mode) self.skip_proj = nn.Conv2d(dim, dim, 1) if has_skip else None def forward(self, x: Tensor, skip: Optional[Tensor] = None) -> Tensor: if skip is not None and self.skip_proj is not None: x = x + self.skip_proj(skip) x = self.rcu1(x) x = self.rcu2(x) return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) class _DPTReassemble(nn.Module): def __init__(self, in_dim: int = 768, out_dim: int = 256): super().__init__() self.projects = nn.ModuleList([ nn.Sequential(nn.LayerNorm(in_dim), nn.Linear(in_dim, out_dim)) for _ in range(4) ]) self.refine = nn.ModuleList([ nn.Sequential( nn.Conv2d(out_dim, out_dim, 3, padding=1, padding_mode="reflect", bias=False), nn.BatchNorm2d(out_dim), nn.GELU(), ) for _ in range(4) ]) def forward(self, intermediates: List[Tensor], H: int, W: int) -> List[Tensor]: out = [] for feat, proj, refine in zip(intermediates, self.projects, self.refine): patches = feat[:, N_PREFIX_TOKENS:, :] patches = proj(patches) B, N, D = patches.shape spatial = patches.permute(0, 2, 1).reshape(B, D, H, W) out.append(refine(spatial)) level_4 = F.interpolate(out[0], scale_factor=4, mode="bilinear", align_corners=False) level_8 = F.interpolate(out[1], scale_factor=2, mode="bilinear", align_corners=False) level_16 = out[2] level_32 = F.interpolate(out[3], scale_factor=0.5, mode="bilinear", align_corners=False) return [level_4, level_8, level_16, level_32] class DPTDepthDecoder(nn.Module): def __init__(self, in_dim: int = 768, decoder_dim: int = 256, n_bins: int = 256, min_depth: float = 0.001, max_depth: float = 10.0): super().__init__() self.n_bins = n_bins self.min_depth = min_depth self.max_depth = max_depth self.reassemble = _DPTReassemble(in_dim=in_dim, out_dim=decoder_dim) self.fusion_blocks = nn.ModuleList([ _FeatureFusionBlock(decoder_dim, has_skip=True), _FeatureFusionBlock(decoder_dim, has_skip=True), _FeatureFusionBlock(decoder_dim, has_skip=True), _FeatureFusionBlock(decoder_dim, has_skip=False), ]) self.head = nn.Sequential( nn.Conv2d(decoder_dim, decoder_dim, 3, padding=1, padding_mode="reflect", bias=False), nn.BatchNorm2d(decoder_dim), nn.GELU(), nn.Conv2d(decoder_dim, n_bins, 1), ) def forward(self, intermediates: List[Tensor], H: int, W: int, return_distribution: bool = False): levels = self.reassemble(intermediates, H, W) x = self.fusion_blocks[3](levels[3]) x = self.fusion_blocks[2](x, skip=levels[2]) x = self.fusion_blocks[1](x, skip=levels[1]) x = self.fusion_blocks[0](x, skip=levels[0]) logits = self.head(x) distribution = torch.relu(logits) + 0.1 distribution = distribution / distribution.sum(dim=1, keepdim=True) bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device) depth = torch.einsum("bkhw,k->bhw", distribution, bins).unsqueeze(1) if return_distribution: return depth, distribution, bins return depth # =========================================================================== # Argus model (transformers-compatible) # =========================================================================== class ArgusConfig(PretrainedConfig): model_type = "argus" def __init__( self, embed_dim: int = 768, patch_size: int = 16, num_seg_classes: int = 150, depth_n_bins: int = 256, depth_min_depth: float = 0.001, depth_max_depth: float = 10.0, num_imagenet_classes: int = 1000, class_ids: Optional[list] = None, class_names: Optional[list] = None, detection_num_classes: int = 80, detection_hidden: int = 160, detection_n_std_layers: int = 5, detection_n_dw_layers: int = 4, detection_n_scales: int = 4, detection_pos_emb_dim: int = 64, detection_text_embed_dim: int = 768, detection_class_names: Optional[list] = None, **kwargs, ): super().__init__(**kwargs) self.embed_dim = embed_dim self.patch_size = patch_size self.num_seg_classes = num_seg_classes self.depth_n_bins = depth_n_bins self.depth_min_depth = depth_min_depth self.depth_max_depth = depth_max_depth self.num_imagenet_classes = num_imagenet_classes self.class_ids = class_ids or [] self.class_names = class_names or [] self.detection_num_classes = detection_num_classes self.detection_hidden = detection_hidden self.detection_n_std_layers = detection_n_std_layers self.detection_n_dw_layers = detection_n_dw_layers self.detection_n_scales = detection_n_scales self.detection_pos_emb_dim = detection_pos_emb_dim self.detection_text_embed_dim = detection_text_embed_dim self.detection_class_names = detection_class_names or list(COCO_CLASSES) class Argus(PreTrainedModel): config_class = ArgusConfig base_model_prefix = "argus" supports_gradient_checkpointing = False _tied_weights_keys: list = [] all_tied_weights_keys: dict = {} def __init__(self, config: ArgusConfig): super().__init__(config) self.backbone = build_eupe_vitb16() self.seg_head = SegmentationHead(config.embed_dim, config.num_seg_classes) self.depth_head = DPTDepthDecoder( in_dim=config.embed_dim, decoder_dim=256, n_bins=config.depth_n_bins, min_depth=config.depth_min_depth, max_depth=config.depth_max_depth, ) self.register_buffer( "class_logit_weight", torch.zeros(config.num_imagenet_classes, config.embed_dim), persistent=True, ) self.register_buffer( "class_logit_bias", torch.zeros(config.num_imagenet_classes), persistent=True, ) self.detection_head = SplitTowerHead( feat_dim=config.embed_dim, hidden=config.detection_hidden, n_std_layers=config.detection_n_std_layers, n_dw_layers=config.detection_n_dw_layers, n_scales=config.detection_n_scales, pos_emb_dim=config.detection_pos_emb_dim, num_classes=config.detection_num_classes, text_embed_dim=config.detection_text_embed_dim, ) for p in self.backbone.parameters(): p.requires_grad = False self.backbone.eval() self.seg_head.eval() self.depth_head.eval() self.detection_head.eval() def _init_weights(self, module): # HF reallocates missing buffers and parameters with torch.empty() # (uninitialized memory) on from_pretrained. Populate sensible defaults # for the standard layer types used by the detection head, and zero any # Argus-level buffer that came back NaN. if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.GroupNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) if module is self: for name in ("class_logit_weight", "class_logit_bias"): if hasattr(self, name): buf = getattr(self, name) if torch.isnan(buf).any() or torch.isinf(buf).any(): buf.data.zero_() def _load_imagenet_classes(self): if getattr(self, "_imagenet_classes_loaded", False): return self._imagenet_classes_loaded = True import json import os as _os candidates = [] here = _os.path.dirname(_os.path.abspath(__file__)) candidates.append(_os.path.join(here, "imagenet_classes.json")) name_or_path = getattr(self.config, "_name_or_path", None) if name_or_path and _os.path.isdir(name_or_path): candidates.append(_os.path.join(name_or_path, "imagenet_classes.json")) for path in candidates: if _os.path.isfile(path): with open(path) as f: data = json.load(f) self.config.class_ids = data.get("class_ids", []) self.config.class_names = data.get("class_names", []) return if name_or_path and not _os.path.isdir(name_or_path): try: from huggingface_hub import hf_hub_download path = hf_hub_download(name_or_path, "imagenet_classes.json") with open(path) as f: data = json.load(f) self.config.class_ids = data.get("class_ids", []) self.config.class_names = data.get("class_names", []) except Exception: pass @property def class_ids(self): if not self.config.class_ids: self._load_imagenet_classes() return self.config.class_ids @property def class_names(self): if not self.config.class_names: self._load_imagenet_classes() return self.config.class_names def quantize_int8(self): """Apply INT8 weight-only quantization via torchao. Reduces VRAM by ~11% with negligible accuracy loss (<0.05 m depth drift, 100% classification agreement). Requires torchao: pip install torchao.""" try: from torchao.quantization import quantize_, Int8WeightOnlyConfig except ImportError as e: raise ImportError("torchao is required for INT8 quantization: pip install torchao") from e quantize_(self, Int8WeightOnlyConfig()) return self @torch.inference_mode() def _extract(self, image_tensor: Tensor) -> Tuple[Tensor, Tensor]: with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): out = self.backbone.forward_features(image_tensor) cls = out["x_norm_clstoken"].float() patches = out["x_norm_patchtokens"].float() B, N, D = patches.shape h = w = int(N ** 0.5) spatial = patches.permute(0, 2, 1).reshape(B, D, h, w) return cls, spatial @torch.inference_mode() def classify(self, image_or_images, top_k: int = 5): single, images = _normalize_image_input(image_or_images) transform = make_eupe_transform(224) batch = torch.stack([transform(img) for img in images]).to(self.device) cls, _ = self._extract(batch) cls = F.normalize(cls, dim=-1) w = self.class_logit_weight.to(cls.dtype) b = self.class_logit_bias.to(cls.dtype) logits = F.linear(cls, w, b) scores_full = F.softmax(logits, dim=-1) topk = scores_full.topk(top_k, dim=-1) top2 = scores_full.topk(2, dim=-1) margins = (top2.values[:, 0] - top2.values[:, 1]).tolist() results = [] for b in range(len(images)): entries = [] for score, idx in zip(topk.values[b].tolist(), topk.indices[b].tolist()): entries.append({ "class_id": self.class_ids[idx], "class_name": self.class_names[idx], "score": float(score), }) entries[0]["margin"] = float(margins[b]) results.append(entries) return results[0] if single else results @torch.inference_mode() def segment(self, image_or_images, resolution: int = 512, return_confidence: bool = False): single, images = _normalize_image_input(image_or_images) transform = make_eupe_transform(resolution) batch = torch.stack([transform(img) for img in images]).to(self.device) _, spatial = self._extract(batch) with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): logits = self.seg_head(spatial) logits = F.interpolate(logits, size=(resolution, resolution), mode="bilinear", align_corners=False) seg_maps = logits.argmax(dim=1) # [B, H, W] if return_confidence: probs = F.softmax(logits.float(), dim=1) conf_maps = probs.max(dim=1).values # [B, H, W] in [0, 1] if single: return seg_maps[0], conf_maps[0] return [(seg_maps[i], conf_maps[i]) for i in range(len(images))] if single: return seg_maps[0] return [seg_maps[i] for i in range(len(images))] @torch.inference_mode() def depth(self, image_or_images, resolution: int = 416, return_confidence: bool = False, crop_border: bool = False): """Run the DPT depth decoder. Returns metric depth in meters at the input resolution. ``crop_border=True`` strips a small border (``max(4, H/13)`` pixels per side) from the raw decoder output before bilinear-upsampling to the input resolution. Useful when this model is loaded with a backbone whose DPT decoder was trained with zero padding (the unshipped dev-fork behaviour), which leaves a systematic edge artifact. The canonical checkpoint uses reflect padding inside every DPT conv and does not need this crop, so the option defaults to ``False``.""" single, images = _normalize_image_input(image_or_images) transform = make_eupe_transform(resolution) batch = torch.stack([transform(img) for img in images]).to(self.device) # Hook into intermediate ViT blocks for multi-scale features intermediates = {} hooks = [] for idx in HOOK_BLOCK_INDICES: def _make_hook(block_idx): def _hook(module, inp, out): intermediates[block_idx] = out[0] if isinstance(out, list) else out return _hook hooks.append(self.backbone.blocks[idx].register_forward_hook(_make_hook(idx))) with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): self.backbone.forward_features(batch) for h in hooks: h.remove() inter_list = [intermediates[idx].float() for idx in HOOK_BLOCK_INDICES] H = W = resolution // 16 if return_confidence: depth_b, distribution, bins = self.depth_head( inter_list, H, W, return_distribution=True) # Std of the 256-bin depth distribution: var = E[X^2] - E[X]^2. mean_sq = torch.einsum("bkhw,k->bhw", distribution, bins ** 2) variance = (mean_sq - depth_b.squeeze(1) ** 2).clamp(min=0) std_b = torch.sqrt(variance).unsqueeze(1) else: depth_b = self.depth_head(inter_list, H, W) std_b = None if crop_border: crop = max(4, depth_b.shape[2] // 13) depth_b = depth_b[:, :, crop:-crop, crop:-crop] if std_b is not None: std_b = std_b[:, :, crop:-crop, crop:-crop] depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False) if std_b is not None: std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False) depth_squeezed = depth_b[:, 0].float() if return_confidence: std_squeezed = std_b[:, 0].float() if single: return depth_squeezed[0], std_squeezed[0] return [(depth_squeezed[i], std_squeezed[i]) for i in range(len(images))] if single: return depth_squeezed[0] return [depth_squeezed[i] for i in range(len(images))] @torch.inference_mode() def correspond( self, src_image, tgt_image, resolution: int = 512, ): """Dense patch correspondence between two images. Single-pair form: pass two `PIL.Image` instances. Returns a dict with keys `matches` (numpy array of length grid*grid mapping each source patch to its argmax target patch), `scores` (cosine similarity at the match), and `grid` (the patch-grid side length). Batched form: pass two equally-sized lists/iterables of images. Returns a list of per-pair dicts in the same shape that a single call would produce. Both lists are forwarded through the backbone in two contiguous batches, so cross-pair throughput on GPU is much higher than calling `correspond` in a loop. """ single = isinstance(src_image, Image.Image) and isinstance(tgt_image, Image.Image) if single: srcs = [src_image] tgts = [tgt_image] else: srcs = list(src_image) tgts = list(tgt_image) if len(srcs) != len(tgts): raise ValueError( f"src_image and tgt_image must have the same length; " f"got {len(srcs)} and {len(tgts)}") if not srcs: raise ValueError("empty image list") for i, (a, b) in enumerate(zip(srcs, tgts)): if not isinstance(a, Image.Image) or not isinstance(b, Image.Image): raise TypeError(f"pair {i} must contain two PIL.Image instances") transform = make_eupe_transform(resolution) src_batch = torch.stack([transform(img) for img in srcs]).to(self.device) tgt_batch = torch.stack([transform(img) for img in tgts]).to(self.device) with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): oa = self.backbone.forward_features(src_batch) ob = self.backbone.forward_features(tgt_batch) pa_batch = F.normalize(oa['x_norm_patchtokens'].float(), dim=-1) pb_batch = F.normalize(ob['x_norm_patchtokens'].float(), dim=-1) results = [] for pa, pb in zip(pa_batch, pb_batch): sim = pa @ pb.t() m = sim.argmax(dim=-1) s = sim.max(dim=-1).values grid = int(np.sqrt(pa.shape[0])) results.append({ "matches": m.cpu().numpy(), "scores": s.cpu().numpy(), "grid": grid, }) return results[0] if single else results @torch.inference_mode() def detect( self, image_or_images, resolution: int = 768, score_thresh: float = 0.05, nms_thresh: float = 0.5, max_per_image: int = 100, ): single, images = _normalize_image_input(image_or_images) # Letterbox each image to match the training transform (resize long side # to `resolution`, pad bottom/right with black). Box coordinates are # recovered after decoding by unscaling. canvases, scales, orig_sizes = [], [], [] for img in images: canvas, scale, orig = _letterbox_to_square(img, resolution) canvases.append(canvas) scales.append(scale) orig_sizes.append(orig) det_normalize = v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) batch = torch.stack([det_normalize(c) for c in canvases]).to(self.device) _, spatial = self._extract(batch) with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): cls_logits, box_regs, centernesses = self.detection_head(spatial) cls_logits = [c.float() for c in cls_logits] box_regs = [b.float() for b in box_regs] centernesses = [c.float() for c in centernesses] feature_sizes = [(cl.shape[2], cl.shape[3]) for cl in cls_logits] locations = _make_locations(feature_sizes, FPN_STRIDES, spatial.device) image_sizes = [(resolution, resolution)] * len(images) results = _decode_detections( cls_logits, box_regs, centernesses, locations, image_sizes=image_sizes, score_thresh=score_thresh, nms_thresh=nms_thresh, max_per_image=max_per_image, ) class_names = self.config.detection_class_names formatted = [] for i, r in enumerate(results): scale = scales[i] orig_w, orig_h = orig_sizes[i] boxes = r["boxes"].cpu().numpy() / scale boxes[:, 0::2] = boxes[:, 0::2].clip(0, orig_w) boxes[:, 1::2] = boxes[:, 1::2].clip(0, orig_h) detections = [] for box, score, label in zip( boxes, r["scores"].cpu().numpy(), r["labels"].cpu().numpy() ): detections.append({ "box": [float(v) for v in box.tolist()], "score": float(score), "label": int(label), "class_name": class_names[int(label)] if int(label) < len(class_names) else f"class_{int(label)}", }) formatted.append(detections) return formatted[0] if single else formatted def perceive(self, image_or_images, return_confidence: bool = False): single, images = _normalize_image_input(image_or_images) t0 = time.time() classif = self.classify(images, top_k=5) t1 = time.time() seg_out = self.segment(images, resolution=512, return_confidence=return_confidence) t2 = time.time() depth_out = self.depth(images, resolution=416, return_confidence=return_confidence) t3 = time.time() if return_confidence: seg_maps = [s for s, _ in seg_out] seg_confs = [c for _, c in seg_out] depth_maps = [d for d, _ in depth_out] depth_uncerts = [u for _, u in depth_out] else: seg_maps = seg_out depth_maps = depth_out seg_confs = depth_uncerts = None timings = { "classify": (t1 - t0) * 1000, "segment": (t2 - t1) * 1000, "depth": (t3 - t2) * 1000, "total": (t3 - t0) * 1000, } results = [] for i in range(len(images)): entry = { "classification": classif[i], "segmentation": seg_maps[i].cpu().numpy(), "depth": depth_maps[i].cpu().numpy(), "timings_ms": timings, } if return_confidence: entry["segmentation_confidence"] = seg_confs[i].cpu().numpy() entry["depth_uncertainty"] = depth_uncerts[i].cpu().numpy() results.append(entry) return results[0] if single else results def export_onnx( self, out_dir: str, backbone_resolution: int = 224, dynamic_batch: bool = True, verify: bool = True, tolerance: Union[float, Dict[str, float]] = 5e-2, opset_version: int = 17, include_nms: bool = False, nms_iou_threshold: float = 0.5, nms_score_threshold: float = 0.05, nms_max_detections: int = 100, ) -> dict: """Export backbone, classifier, seg head, depth head, and detection head to ONNX. Produces five graphs: - argus_backbone.onnx image[B,3,H,W] -> cls[B,D], spatial[B,D,H/16,W/16] - argus_classifier.onnx cls_token[B,D] -> probs[B,1000] - argus_seg_head.onnx spatial_features[B,D,h,w] -> seg_logits[B,150,H,W] - argus_depth_head.onnx intermediate_{0..3}[B,N+5,D] -> depth_map[B,1,~8h,~8w] - argus_detection_head.onnx spatial_features[B,D,h,w] -> boxes, scores (+ labels, batch_indices if include_nms) The seg graph folds bilinear upsample to input resolution into the graph, so consumers argmax directly without a separate interpolation step. Correspondence has no learned parameters — it runs as cosine-max on the backbone's spatial output and needs no graph. ``include_nms=True`` bakes an ONNX NonMaxSuppression (opset >= 10) op into the detection head. The detection graph then emits four post-NMS tensors (boxes [M,4], scores [M], class_labels [M], batch_indices [M]) instead of the raw (boxes, scores) pair. Useful for single-shot TensorRT / mobile inference. The default ``include_nms=False`` leaves NMS to the consumer so they can choose hard vs soft, per-class vs global, and tune thresholds without re-exporting. ``tolerance`` can be a float (applied uniformly to every ``*_max_diff`` check) or a dict keyed by verification output name (e.g. ``{"detection_boxes_max_diff": 3.2, "default": 5e-2}``). The ``"default"`` key covers outputs not otherwise listed. If a float is passed, detection box coordinates get a resolution-scaled tolerance (``max(tolerance, backbone_resolution * 5e-3)``) because exp() in the FCOS regression path amplifies FP kernel-dispatch differences to pixel-scale absolute diffs. """ import os os.makedirs(out_dir, exist_ok=True) if backbone_resolution % self.config.patch_size != 0: raise ValueError( f"backbone_resolution ({backbone_resolution}) must be a multiple of patch_size ({self.config.patch_size})" ) spatial_resolution = backbone_resolution // self.config.patch_size if backbone_resolution < 320: import warnings warnings.warn( f"backbone_resolution={backbone_resolution} is below 320; the detection " f"head's coarsest FPN level (stride 128) collapses to <=2 locations per " f"side and the detection graph, while it exports and runs, cannot produce " f"useful detections at this resolution. Classifier, seg, and depth graphs " f"are unaffected. FCOS convention is 640-800px input; export at " f">= 512 for detection.", stacklevel=2, ) wrapper = _BackboneExportWrapper(self.backbone).to(self.device).eval() dummy_image = torch.randn( 1, 3, backbone_resolution, backbone_resolution, device=self.device, dtype=torch.float32, ) dummy_spatial = torch.randn( 1, self.config.embed_dim, spatial_resolution, spatial_resolution, device=self.device, dtype=torch.float32, ) backbone_path = os.path.join(out_dir, "argus_backbone.onnx") classifier_path = os.path.join(out_dir, "argus_classifier.onnx") seg_path = os.path.join(out_dir, "argus_seg_head.onnx") depth_path = os.path.join(out_dir, "argus_depth_head.onnx") detection_path = os.path.join(out_dir, "argus_detection_head.onnx") backbone_axes = None head_axes = None if dynamic_batch: backbone_axes = { "image": {0: "batch"}, "cls_token": {0: "batch"}, "spatial_features": {0: "batch"}, } head_axes = { "spatial_features": {0: "batch"}, "seg_logits": {0: "batch"}, "depth_map": {0: "batch"}, } # dynamo path crashes on EUPE's list-based forward; use legacy. with torch.inference_mode(): torch.onnx.export( wrapper, dummy_image, backbone_path, input_names=["image"], output_names=["cls_token", "spatial_features"], dynamic_axes=backbone_axes, opset_version=opset_version, do_constant_folding=True, dynamo=False, ) seg_wrapper = _SegHeadExportWrapper(self.seg_head, backbone_resolution).to(self.device).eval() torch.onnx.export( seg_wrapper, dummy_spatial, seg_path, input_names=["spatial_features"], output_names=["seg_logits"], dynamic_axes={"spatial_features": head_axes["spatial_features"], "seg_logits": head_axes["seg_logits"]} if head_axes else None, opset_version=opset_version, do_constant_folding=True, dynamo=False, ) depth_wrapper = _DepthHeadExportWrapper( self.depth_head, spatial_resolution, spatial_resolution ).to(self.device).eval() num_patch_tokens = spatial_resolution * spatial_resolution + N_PREFIX_TOKENS dummy_inter = tuple( torch.randn(1, num_patch_tokens, self.config.embed_dim, device=self.device, dtype=torch.float32) for _ in range(len(HOOK_BLOCK_INDICES)) ) depth_input_names = [f"intermediate_{i}" for i in range(len(HOOK_BLOCK_INDICES))] if dynamic_batch: depth_axes = {name: {0: "batch"} for name in depth_input_names} depth_axes["depth_map"] = {0: "batch"} else: depth_axes = None torch.onnx.export( depth_wrapper, dummy_inter, depth_path, input_names=depth_input_names, output_names=["depth_map"], dynamic_axes=depth_axes, opset_version=opset_version, do_constant_folding=True, dynamo=False, ) classifier_wrapper = _ClassifierExportWrapper( self.class_logit_weight, self.class_logit_bias ).to(self.device).eval() dummy_cls = torch.randn( 1, self.config.embed_dim, device=self.device, dtype=torch.float32, ) if dynamic_batch: classifier_axes = {"cls_token": {0: "batch"}, "class_probs": {0: "batch"}} else: classifier_axes = None torch.onnx.export( classifier_wrapper, dummy_cls, classifier_path, input_names=["cls_token"], output_names=["class_probs"], dynamic_axes=classifier_axes, opset_version=opset_version, do_constant_folding=True, dynamo=False, ) detection_wrapper = _DetectionHeadExportWrapper( self.detection_head, backbone_resolution, include_nms=include_nms, nms_iou_threshold=nms_iou_threshold, nms_score_threshold=nms_score_threshold, nms_max_detections=nms_max_detections, ).to(self.device).eval() if include_nms: detection_output_names = ["boxes", "scores", "class_labels", "batch_indices"] # Post-NMS outputs are flat [M, ...]; no fixed batch axis to mark. # Spatial features input still has a dynamic batch dim so the graph # supports multi-image inference even with fused NMS. detection_axes = {"spatial_features": {0: "batch"}} if dynamic_batch else None else: detection_output_names = ["boxes", "scores"] if dynamic_batch: detection_axes = { "spatial_features": {0: "batch"}, "boxes": {0: "batch"}, "scores": {0: "batch"}, } else: detection_axes = None torch.onnx.export( detection_wrapper, dummy_spatial, detection_path, input_names=["spatial_features"], output_names=detection_output_names, dynamic_axes=detection_axes, opset_version=opset_version, do_constant_folding=True, dynamo=False, ) result = { "backbone": backbone_path, "classifier": classifier_path, "seg_head": seg_path, "depth_head": depth_path, "detection_head": detection_path, } if verify: try: import onnxruntime as ort except ImportError as e: raise ImportError("onnxruntime not installed; pip install onnxruntime") from e providers = ["CPUExecutionProvider"] verify_image = torch.randn(2, 3, backbone_resolution, backbone_resolution, dtype=torch.float32) verify_spatial = torch.randn(2, self.config.embed_dim, spatial_resolution, spatial_resolution, dtype=torch.float32) verify_cls = torch.randn(2, self.config.embed_dim, dtype=torch.float32) verify_inter = [ torch.randn(2, num_patch_tokens, self.config.embed_dim, dtype=torch.float32) for _ in range(len(HOOK_BLOCK_INDICES)) ] with torch.inference_mode(): ref_cls, ref_spatial = wrapper(verify_image.to(self.device)) ref_seg = seg_wrapper(verify_spatial.to(self.device)) ref_depth = depth_wrapper(*[v.to(self.device) for v in verify_inter]) ref_probs = classifier_wrapper(verify_cls.to(self.device)) ref_det = detection_wrapper(verify_spatial.to(self.device)) sess = ort.InferenceSession(backbone_path, providers=providers) ort_cls, ort_spatial = sess.run(None, {"image": verify_image.numpy()}) cls_diff = float(np.abs(ort_cls - ref_cls.cpu().numpy()).max()) spatial_diff = float(np.abs(ort_spatial - ref_spatial.cpu().numpy()).max()) sess = ort.InferenceSession(seg_path, providers=providers) ort_seg = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0] seg_diff = float(np.abs(ort_seg - ref_seg.cpu().numpy()).max()) sess = ort.InferenceSession(depth_path, providers=providers) ort_depth = sess.run(None, {f"intermediate_{i}": verify_inter[i].numpy() for i in range(len(HOOK_BLOCK_INDICES))})[0] depth_diff = float(np.abs(ort_depth - ref_depth.cpu().numpy()).max()) sess = ort.InferenceSession(classifier_path, providers=providers) ort_probs = sess.run(None, {"cls_token": verify_cls.numpy()})[0] classifier_diff = float(np.abs(ort_probs - ref_probs.cpu().numpy()).max()) sess = ort.InferenceSession(detection_path, providers=providers) ort_det = sess.run(None, {"spatial_features": verify_spatial.numpy()}) verification = { "backbone_cls_max_diff": cls_diff, "backbone_spatial_max_diff": spatial_diff, "classifier_max_diff": classifier_diff, "seg_head_max_diff": seg_diff, "depth_head_max_diff": depth_diff, "verified_batch_size": 2, } if include_nms: # NMS is inherently implementation-dependent: ONNX's # NonMaxSuppression and the torchvision eager fallback differ # on tie-breaking when multiple detections share a score or # when near-threshold boxes are right at the score cutoff. # Element-wise comparison of post-NMS outputs is the wrong # metric. The structural checks below verify the graph runs, # returns reasonable shapes, and agrees on the top detection. pt_boxes, pt_scores, pt_labels, _ = ref_det ort_boxes, ort_scores, ort_labels, _ = ort_det pt_n = int(pt_scores.shape[0]) ort_n = int(ort_scores.shape[0]) verification["detection_nms_ref_count"] = pt_n verification["detection_nms_ort_count"] = ort_n if pt_n > 0 and ort_n > 0: pt_top = int(pt_scores.cpu().numpy().argmax()) ort_top = int(ort_scores.argmax()) pt_top_box = pt_boxes[pt_top].cpu().numpy() ort_top_box = ort_boxes[ort_top] # IoU of the two top boxes x1 = max(pt_top_box[0], ort_top_box[0]) y1 = max(pt_top_box[1], ort_top_box[1]) x2 = min(pt_top_box[2], ort_top_box[2]) y2 = min(pt_top_box[3], ort_top_box[3]) inter = max(0.0, x2 - x1) * max(0.0, y2 - y1) pt_area = max(0.0, pt_top_box[2] - pt_top_box[0]) * max(0.0, pt_top_box[3] - pt_top_box[1]) ort_area = max(0.0, ort_top_box[2] - ort_top_box[0]) * max(0.0, ort_top_box[3] - ort_top_box[1]) union = max(1e-6, pt_area + ort_area - inter) verification["detection_nms_top_iou"] = float(inter / union) verification["detection_nms_top_class_match"] = bool( int(pt_labels[pt_top].cpu()) == int(ort_labels[ort_top]) ) verification["detection_nms_top_score_diff"] = float(abs( float(pt_scores[pt_top].cpu()) - float(ort_scores[ort_top]) )) else: verification["detection_nms_top_iou"] = None verification["detection_nms_top_class_match"] = None verification["detection_nms_top_score_diff"] = None else: ort_boxes, ort_scores = ort_det ref_boxes, ref_scores = ref_det verification["detection_boxes_max_diff"] = float( np.abs(ort_boxes - ref_boxes.cpu().numpy()).max()) verification["detection_scores_max_diff"] = float( np.abs(ort_scores - ref_scores.cpu().numpy()).max()) # Tolerance resolution: either a float applied uniformly, or a dict # keyed by verification output name (with optional "default" key). # Detection boxes get a resolution-scaled tolerance when only a # float is supplied — exp() in the FCOS regression path amplifies # FP kernel-dispatch differences to pixel-scale absolute diffs. if isinstance(tolerance, dict): default_tol = float(tolerance.get("default", 5e-2)) def _tol_for(key): return float(tolerance.get(key, default_tol)) verification["tolerance"] = dict(tolerance) else: base = float(tolerance) box_tol = max(base, backbone_resolution * 5e-3) def _tol_for(key): return box_tol if key == "detection_boxes_max_diff" else base verification["tolerance"] = base verification["detection_boxes_tolerance"] = box_tol for key, val in list(verification.items()): if not key.endswith("_max_diff"): continue t = _tol_for(key) if val > t: raise RuntimeError( f"ONNX/PyTorch divergence in {key}: {val:.2e} > tolerance {t:.2e}" ) result["verification"] = verification return result