kungchuking commited on
Commit
2c76547
·
1 Parent(s): 70bca94

Copied from github repository.

Browse files
README.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [CVPR 2023] DynamicStereo: Consistent Dynamic Depth from Stereo Videos.
2
+
3
+ **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
4
+
5
+ [Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
6
+
7
+ [[`Paper`](https://research.facebook.com/publications/dynamicstereo-consistent-dynamic-depth-from-stereo-videos/)] [[`Project`](https://dynamic-stereo.github.io/)] [[`BibTeX`](#citing-dynamicstereo)]
8
+
9
+ ![nikita-reading](https://user-images.githubusercontent.com/37815420/236242052-e72d5605-1ab2-426c-ae8d-5c8a86d5252c.gif)
10
+
11
+ **DynamicStereo** is a transformer-based architecture for temporally consistent depth estimation from stereo videos. It has been trained on a combination of two datasets: [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) and **Dynamic Replica** that we present below.
12
+
13
+ ## Dataset
14
+
15
+ https://user-images.githubusercontent.com/37815420/236239579-7877623c-716b-4074-a14e-944d095f1419.mp4
16
+
17
+ The dataset consists of 145200 *stereo* frames (524 videos) with humans and animals in motion.
18
+
19
+ We provide annotations for both *left and right* views, see [this notebook](https://github.com/facebookresearch/dynamic_stereo/blob/main/notebooks/Dynamic_Replica_demo.ipynb):
20
+ - camera intrinsics and extrinsics
21
+ - image depth (can be converted to disparity with intrinsics)
22
+ - instance segmentation masks
23
+ - binary foreground / background segmentation masks
24
+ - optical flow (released!)
25
+ - long-range pixel trajectories (released!)
26
+
27
+
28
+ ### Download the Dynamic Replica dataset
29
+ Due to the enormous size of the original dataset, we created the `links_lite.json` file to enable quick testing by downloading just a small portion of the dataset.
30
+
31
+ ```
32
+ python ./scripts/download_dynamic_replica.py --link_list_file links_lite.json --download_folder ./dynamic_replica_data --download_splits test train valid real
33
+ ```
34
+
35
+ To download the full dataset, please visit [the original site](https://github.com/facebookresearch/dynamic_stereo) created by Meta.
36
+
37
+ ## Installation
38
+
39
+ Describes installation of DynamicStereo with the latest PyTorch3D, PyTorch 1.12.1 & cuda 11.3
40
+
41
+ ### Setup the root for all source files:
42
+ ```
43
+ git clone https://github.com/facebookresearch/dynamic_stereo
44
+ cd dynamic_stereo
45
+ export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
46
+ ```
47
+ ### Create a conda env:
48
+ ```
49
+ conda create -n dynamicstereo python=3.8
50
+ conda activate dynamicstereo
51
+ ```
52
+ ### Install requirements
53
+ ```
54
+ conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
55
+ # It will require some time to install PyTorch3D. In the meantime, you may want to take a break and enjoy a cup of coffee.
56
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
57
+ pip install -r requirements.txt
58
+ ```
59
+
60
+ ### (Optional) Install RAFT-Stereo
61
+ ```
62
+ mkdir third_party
63
+ cd third_party
64
+ git clone https://github.com/princeton-vl/RAFT-Stereo
65
+ cd RAFT-Stereo
66
+ bash download_models.sh
67
+ cd ../..
68
+ ```
69
+
70
+
71
+
72
+ ## Evaluation
73
+ To download the checkpoints, you can follow the below instructions:
74
+ ```
75
+ mkdir checkpoints
76
+ cd checkpoints
77
+ wget https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_sf.pth
78
+ wget https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_dr_sf.pth
79
+ cd ..
80
+ ```
81
+ You can also download the checkpoints manually by clicking the links below. Copy the checkpoints to `./dynamic_stereo/checkpoints`.
82
+
83
+ - [DynamicStereo](https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_sf.pth) trained on SceneFlow
84
+ - [DynamicStereo](https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_dr_sf.pth) trained on SceneFlow and *Dynamic Replica*
85
+
86
+ To evaluate DynamicStereo:
87
+ ```
88
+ python ./evaluation/evaluate.py --config-name eval_dynamic_replica_40_frames \
89
+ MODEL.model_name=DynamicStereoModel exp_dir=./outputs/test_dynamic_replica_ds \
90
+ MODEL.DynamicStereoModel.model_weights=./checkpoints/dynamic_stereo_sf.pth
91
+ ```
92
+ Due to the high image resolution, evaluation on *Dynamic Replica* requires a 32GB GPU. If you don't have enough GPU memory, you can decrease `kernel_size` from 20 to 10 by adding `MODEL.DynamicStereoModel.kernel_size=10` to the above python command. Another option is to decrease the dataset resolution.
93
+
94
+ As a result, you should see the numbers from *Table 5* in the [paper](https://arxiv.org/pdf/2305.02296.pdf). (for this, you need `kernel_size=20`)
95
+
96
+ Reconstructions of all the *Dynamic Replica* splits (including *real*) will be visualized and saved to `exp_dir`.
97
+
98
+ If you installed [RAFT-Stereo](https://github.com/princeton-vl/RAFT-Stereo), you can run:
99
+ ```
100
+ python ./evaluation/evaluate.py --config-name eval_dynamic_replica_40_frames \
101
+ MODEL.model_name=RAFTStereoModel exp_dir=./outputs/test_dynamic_replica_raft
102
+ ```
103
+
104
+ Other public datasets we use:
105
+ - [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
106
+ - [Sintel](http://sintel.is.tue.mpg.de/stereo)
107
+ - [Middlebury](https://vision.middlebury.edu/stereo/data/)
108
+ - [ETH3D](https://www.eth3d.net/datasets#low-res-two-view-training-data)
109
+ - [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_stereo.php)
110
+
111
+ ## Training
112
+ Training requires a 32GB GPU. You can decrease `image_size` and / or `sample_len` if you don't have enough GPU memory.
113
+ You need to donwload SceneFlow before training. Alternatively, you can only train on *Dynamic Replica*.
114
+ ```
115
+ python train.py --batch_size 1 \
116
+ --spatial_scale -0.2 0.4 --image_size 384 512 --saturation_range 0 1.4 --num_steps 200000 \
117
+ --ckpt_path dynamicstereo_sf_dr \
118
+ --sample_len 5 --lr 0.0003 --train_iters 10 --valid_iters 20 \
119
+ --num_workers 28 --save_freq 100 --update_block_3d --different_update_blocks \
120
+ --attention_type self_stereo_temporal_update_time_update_space --train_datasets dynamic_replica things monkaa driving
121
+ ```
122
+ If you want to train on SceneFlow only, remove the flag `dynamic_replica` from `train_datasets`.
123
+
124
+
125
+
126
+ ## License
127
+ The majority of dynamic_stereo is licensed under CC-BY-NC, however portions of the project are available under separate license terms: [RAFT-Stereo](https://github.com/princeton-vl/RAFT-Stereo) is licensed under the MIT license, [LoFTR](https://github.com/zju3dv/LoFTR) and [CREStereo](https://github.com/megvii-research/CREStereo) are licensed under the Apache 2.0 license.
128
+
129
+
130
+ ## Citing DynamicStereo
131
+ If you use DynamicStereo or Dynamic Replica in your research, please use the following BibTeX entry.
132
+ ```
133
+ @article{karaev2023dynamicstereo,
134
+ title={DynamicStereo: Consistent Dynamic Depth from Stereo Videos},
135
+ author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
136
+ journal={CVPR},
137
+ year={2023}
138
+ }
139
+ ```
datasets/augmentor.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import random
9
+ from PIL import Image
10
+
11
+ import cv2
12
+
13
+ cv2.setNumThreads(0)
14
+ cv2.ocl.setUseOpenCL(False)
15
+
16
+ from torchvision.transforms import ColorJitter, functional, Compose
17
+
18
+
19
+ class AdjustGamma(object):
20
+ def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
21
+ self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = (
22
+ gamma_min,
23
+ gamma_max,
24
+ gain_min,
25
+ gain_max,
26
+ )
27
+
28
+ def __call__(self, sample):
29
+ gain = random.uniform(self.gain_min, self.gain_max)
30
+ gamma = random.uniform(self.gamma_min, self.gamma_max)
31
+ return functional.adjust_gamma(sample, gamma, gain)
32
+
33
+ def __repr__(self):
34
+ return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})"
35
+
36
+
37
+ class SequenceDispFlowAugmentor:
38
+ def __init__(
39
+ self,
40
+ crop_size,
41
+ min_scale=-0.2,
42
+ max_scale=0.5,
43
+ do_flip=True,
44
+ yjitter=False,
45
+ saturation_range=[0.6, 1.4],
46
+ gamma=[1, 1, 1, 1],
47
+ ):
48
+ # spatial augmentation params
49
+ self.crop_size = crop_size
50
+ self.min_scale = min_scale
51
+ self.max_scale = max_scale
52
+ self.spatial_aug_prob = 1.0
53
+ self.stretch_prob = 0.8
54
+ self.max_stretch = 0.2
55
+
56
+ # flip augmentation params
57
+ self.yjitter = yjitter
58
+ self.do_flip = do_flip
59
+ self.h_flip_prob = 0.5
60
+ self.v_flip_prob = 0.1
61
+
62
+ # photometric augmentation params
63
+ self.photo_aug = Compose(
64
+ [
65
+ ColorJitter(
66
+ brightness=0.4,
67
+ contrast=0.4,
68
+ saturation=saturation_range,
69
+ hue=0.5 / 3.14,
70
+ ),
71
+ AdjustGamma(*gamma),
72
+ ]
73
+ )
74
+ self.asymmetric_color_aug_prob = 0.2
75
+ self.eraser_aug_prob = 0.5
76
+
77
+ def color_transform(self, seq):
78
+ """Photometric augmentation"""
79
+
80
+ # asymmetric
81
+ if np.random.rand() < self.asymmetric_color_aug_prob:
82
+ for i in range(len(seq)):
83
+ for cam in (0, 1):
84
+ seq[i][cam] = np.array(
85
+ self.photo_aug(Image.fromarray(seq[i][cam])), dtype=np.uint8
86
+ )
87
+ # symmetric
88
+ else:
89
+ image_stack = np.concatenate(
90
+ [seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0
91
+ )
92
+ image_stack = np.array(
93
+ self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
94
+ )
95
+ split = np.split(image_stack, len(seq) * 2, axis=0)
96
+ for i in range(len(seq)):
97
+ seq[i][0] = split[2 * i]
98
+ seq[i][1] = split[2 * i + 1]
99
+ return seq
100
+
101
+ def eraser_transform(self, seq, bounds=[50, 100]):
102
+ """Occlusion augmentation"""
103
+ ht, wd = seq[0][0].shape[:2]
104
+ for i in range(len(seq)):
105
+ for cam in (0, 1):
106
+ if np.random.rand() < self.eraser_aug_prob:
107
+ mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0)
108
+ for _ in range(np.random.randint(1, 3)):
109
+ x0 = np.random.randint(0, wd)
110
+ y0 = np.random.randint(0, ht)
111
+ dx = np.random.randint(bounds[0], bounds[1])
112
+ dy = np.random.randint(bounds[0], bounds[1])
113
+ seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
114
+
115
+ return seq
116
+
117
+ def spatial_transform(self, img, disp):
118
+ # randomly sample scale
119
+ ht, wd = img[0][0].shape[:2]
120
+ min_scale = np.maximum(
121
+ (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
122
+ )
123
+
124
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
125
+ scale_x = scale
126
+ scale_y = scale
127
+ if np.random.rand() < self.stretch_prob:
128
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
129
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
130
+
131
+ scale_x = np.clip(scale_x, min_scale, None)
132
+ scale_y = np.clip(scale_y, min_scale, None)
133
+
134
+ if np.random.rand() < self.spatial_aug_prob:
135
+ # rescale the images
136
+ for i in range(len(img)):
137
+ for cam in (0, 1):
138
+ img[i][cam] = cv2.resize(
139
+ img[i][cam],
140
+ None,
141
+ fx=scale_x,
142
+ fy=scale_y,
143
+ interpolation=cv2.INTER_LINEAR,
144
+ )
145
+ if len(disp[i]) > 0:
146
+ disp[i][cam] = cv2.resize(
147
+ disp[i][cam],
148
+ None,
149
+ fx=scale_x,
150
+ fy=scale_y,
151
+ interpolation=cv2.INTER_LINEAR,
152
+ )
153
+ disp[i][cam] = disp[i][cam] * [scale_x, scale_y]
154
+
155
+ if self.yjitter:
156
+ y0 = np.random.randint(2, img[0][0].shape[0] - self.crop_size[0] - 2)
157
+ x0 = np.random.randint(2, img[0][0].shape[1] - self.crop_size[1] - 2)
158
+
159
+ for i in range(len(img)):
160
+ y1 = y0 + np.random.randint(-2, 2 + 1)
161
+ img[i][0] = img[i][0][
162
+ y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
163
+ ]
164
+ img[i][1] = img[i][1][
165
+ y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]
166
+ ]
167
+ if len(disp[i]) > 0:
168
+ disp[i][0] = disp[i][0][
169
+ y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
170
+ ]
171
+ disp[i][1] = disp[i][1][
172
+ y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]
173
+ ]
174
+ else:
175
+ y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0])
176
+ x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1])
177
+ for i in range(len(img)):
178
+ for cam in (0, 1):
179
+ img[i][cam] = img[i][cam][
180
+ y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
181
+ ]
182
+ if len(disp[i]) > 0:
183
+ disp[i][cam] = disp[i][cam][
184
+ y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
185
+ ]
186
+
187
+ return img, disp
188
+
189
+ def __call__(self, img, disp):
190
+ img = self.color_transform(img)
191
+ img = self.eraser_transform(img)
192
+ img, disp = self.spatial_transform(img, disp)
193
+
194
+ for i in range(len(img)):
195
+ for cam in (0, 1):
196
+ img[i][cam] = np.ascontiguousarray(img[i][cam])
197
+ if len(disp[i]) > 0:
198
+ disp[i][cam] = np.ascontiguousarray(disp[i][cam])
199
+
200
+ return img, disp
datasets/dynamic_stereo_datasets.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
7
+
8
+ # -- Added by Chu King on 16th November 2025 for debugging purposes.
9
+ import torch.distributed as dist
10
+ import signal
11
+
12
+ import os
13
+ import copy
14
+ import gzip
15
+ import logging
16
+ import torch
17
+ import numpy as np
18
+ import torch.utils.data as data
19
+ import torch.nn.functional as F
20
+ import os.path as osp
21
+ from glob import glob
22
+
23
+ from collections import defaultdict
24
+ from PIL import Image
25
+ from dataclasses import dataclass
26
+ from typing import List, Optional
27
+ from pytorch3d.renderer.cameras import PerspectiveCameras
28
+ from pytorch3d.implicitron.dataset.types import (
29
+ FrameAnnotation as ImplicitronFrameAnnotation,
30
+ load_dataclass,
31
+ )
32
+
33
+ from datasets import frame_utils
34
+ from evaluation.utils.eval_utils import depth2disparity_scale
35
+ from datasets.augmentor import SequenceDispFlowAugmentor
36
+
37
+
38
+ @dataclass
39
+ class DynamicReplicaFrameAnnotation(ImplicitronFrameAnnotation):
40
+ """A dataclass used to load annotations from json."""
41
+
42
+ camera_name: Optional[str] = None
43
+
44
+
45
+ class StereoSequenceDataset(data.Dataset):
46
+ def __init__(self, aug_params=None, sparse=False, reader=None):
47
+ self.augmentor = None
48
+ self.sparse = sparse
49
+ self.img_pad = (
50
+ aug_params.pop("img_pad", None) if aug_params is not None else None
51
+ )
52
+ if aug_params is not None and "crop_size" in aug_params:
53
+ if sparse:
54
+ raise ValueError("Sparse augmentor is not implemented")
55
+ else:
56
+ self.augmentor = SequenceDispFlowAugmentor(**aug_params)
57
+
58
+ if reader is None:
59
+ self.disparity_reader = frame_utils.read_gen
60
+ else:
61
+ self.disparity_reader = reader
62
+ self.depth_reader = self._load_16big_png_depth
63
+ self.is_test = False
64
+ self.sample_list = []
65
+ self.extra_info = []
66
+ self.depth_eps = 1e-5
67
+
68
+ def _load_16big_png_depth(self, depth_png):
69
+ with Image.open(depth_png) as depth_pil:
70
+ # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
71
+ # we cast it to uint16, then reinterpret as float16, then cast to float32
72
+ depth = (
73
+ np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
74
+ .astype(np.float32)
75
+ .reshape((depth_pil.size[1], depth_pil.size[0]))
76
+ )
77
+ return depth
78
+
79
+ def _get_pytorch3d_camera(
80
+ self, entry_viewpoint, image_size, scale: float
81
+ ) -> PerspectiveCameras:
82
+ assert entry_viewpoint is not None
83
+ # principal point and focal length
84
+ principal_point = torch.tensor(
85
+ entry_viewpoint.principal_point, dtype=torch.float
86
+ )
87
+ focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
88
+
89
+ half_image_size_wh_orig = (
90
+ torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0
91
+ )
92
+
93
+ # first, we convert from the dataset's NDC convention to pixels
94
+ format = entry_viewpoint.intrinsics_format
95
+ if format.lower() == "ndc_norm_image_bounds":
96
+ # this is e.g. currently used in CO3D for storing intrinsics
97
+ rescale = half_image_size_wh_orig
98
+ elif format.lower() == "ndc_isotropic":
99
+ rescale = half_image_size_wh_orig.min()
100
+ else:
101
+ raise ValueError(f"Unknown intrinsics format: {format}")
102
+
103
+ # principal point and focal length in pixels
104
+ principal_point_px = half_image_size_wh_orig - principal_point * rescale
105
+ focal_length_px = focal_length * rescale
106
+
107
+ # now, convert from pixels to PyTorch3D v0.5+ NDC convention
108
+ # if self.image_height is None or self.image_width is None:
109
+ out_size = list(reversed(image_size))
110
+
111
+ half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
112
+ half_min_image_size_output = half_image_size_output.min()
113
+
114
+ # rescaled principal point and focal length in ndc
115
+ principal_point = (
116
+ half_image_size_output - principal_point_px * scale
117
+ ) / half_min_image_size_output
118
+ focal_length = focal_length_px * scale / half_min_image_size_output
119
+
120
+ return PerspectiveCameras(
121
+ focal_length=focal_length[None],
122
+ principal_point=principal_point[None],
123
+ R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
124
+ T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
125
+ )
126
+
127
+ def _get_output_tensor(self, sample):
128
+ output_tensor = defaultdict(list)
129
+ sample_size = len(sample["image"]["left"])
130
+ output_tensor_keys = ["img", "disp", "valid_disp", "mask"]
131
+ add_keys = ["viewpoint", "metadata"]
132
+ for add_key in add_keys:
133
+ if add_key in sample:
134
+ output_tensor_keys.append(add_key)
135
+
136
+ for key in output_tensor_keys:
137
+ output_tensor[key] = [[] for _ in range(sample_size)]
138
+
139
+ if "viewpoint" in sample:
140
+ viewpoint_left = self._get_pytorch3d_camera(
141
+ sample["viewpoint"]["left"][0],
142
+ sample["metadata"]["left"][0][1],
143
+ scale=1.0,
144
+ )
145
+ viewpoint_right = self._get_pytorch3d_camera(
146
+ sample["viewpoint"]["right"][0],
147
+ sample["metadata"]["right"][0][1],
148
+ scale=1.0,
149
+ )
150
+ depth2disp_scale = depth2disparity_scale(
151
+ viewpoint_left,
152
+ viewpoint_right,
153
+ torch.Tensor(sample["metadata"]["left"][0][1])[None],
154
+ )
155
+
156
+ for i in range(sample_size):
157
+ for cam in ["left", "right"]:
158
+ if "mask" in sample and cam in sample["mask"]:
159
+ mask = frame_utils.read_gen(sample["mask"][cam][i])
160
+ mask = np.array(mask) / 255.0
161
+ output_tensor["mask"][i].append(mask)
162
+
163
+ if "viewpoint" in sample and cam in sample["viewpoint"]:
164
+ viewpoint = self._get_pytorch3d_camera(
165
+ sample["viewpoint"][cam][i],
166
+ sample["metadata"][cam][i][1],
167
+ scale=1.0,
168
+ )
169
+ output_tensor["viewpoint"][i].append(viewpoint)
170
+
171
+ if "metadata" in sample and cam in sample["metadata"]:
172
+ metadata = sample["metadata"][cam][i]
173
+ output_tensor["metadata"][i].append(metadata)
174
+
175
+ if cam in sample["image"]:
176
+
177
+ img = frame_utils.read_gen(sample["image"][cam][i])
178
+ img = np.array(img).astype(np.uint8)
179
+
180
+ # grayscale images
181
+ if len(img.shape) == 2:
182
+ img = np.tile(img[..., None], (1, 1, 3))
183
+ else:
184
+ img = img[..., :3]
185
+ output_tensor["img"][i].append(img)
186
+
187
+ if cam in sample["disparity"]:
188
+ disp = self.disparity_reader(sample["disparity"][cam][i])
189
+ if isinstance(disp, tuple):
190
+ disp, valid_disp = disp
191
+ else:
192
+ valid_disp = disp < 512
193
+ disp = np.array(disp).astype(np.float32)
194
+
195
+ disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)
196
+
197
+ output_tensor["disp"][i].append(disp)
198
+ output_tensor["valid_disp"][i].append(valid_disp)
199
+
200
+ elif "depth" in sample and cam in sample["depth"]:
201
+ depth = self.depth_reader(sample["depth"][cam][i])
202
+
203
+ depth_mask = depth < self.depth_eps
204
+ depth[depth_mask] = self.depth_eps
205
+
206
+ disp = depth2disp_scale / depth
207
+ disp[depth_mask] = 0
208
+ valid_disp = (disp < 512) * (1 - depth_mask)
209
+
210
+ disp = np.array(disp).astype(np.float32)
211
+ disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)
212
+ output_tensor["disp"][i].append(disp)
213
+ output_tensor["valid_disp"][i].append(valid_disp)
214
+
215
+ return output_tensor
216
+
217
+ def __getitem__(self, index):
218
+ im_tensor = {"img": None}
219
+ sample = self.sample_list[index]
220
+ if self.is_test:
221
+ sample_size = len(sample["image"]["left"])
222
+ im_tensor["img"] = [[] for _ in range(sample_size)]
223
+ for i in range(sample_size):
224
+ for cam in ["left", "right"]:
225
+ img = frame_utils.read_gen(sample["image"][cam][i])
226
+ img = np.array(img).astype(np.uint8)[..., :3]
227
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
228
+ im_tensor["img"][i].append(img)
229
+ im_tensor["img"] = torch.stack(im_tensor["img"])
230
+ return im_tensor, self.extra_info[index]
231
+
232
+ index = index % len(self.sample_list)
233
+
234
+ try:
235
+ output_tensor = self._get_output_tensor(sample)
236
+ except:
237
+ logging.warning(f"Exception in loading sample {index}!")
238
+ index = np.random.randint(len(self.sample_list))
239
+ logging.info(f"New index is {index}")
240
+ sample = self.sample_list[index]
241
+ output_tensor = self._get_output_tensor(sample)
242
+ sample_size = len(sample["image"]["left"])
243
+
244
+ if self.augmentor is not None:
245
+ output_tensor["img"], output_tensor["disp"] = self.augmentor(
246
+ output_tensor["img"], output_tensor["disp"]
247
+ )
248
+ for i in range(sample_size):
249
+ for cam in (0, 1):
250
+ if cam < len(output_tensor["img"][i]):
251
+ img = (
252
+ torch.from_numpy(output_tensor["img"][i][cam])
253
+ .permute(2, 0, 1)
254
+ .float()
255
+ )
256
+ if self.img_pad is not None:
257
+ padH, padW = self.img_pad
258
+ img = F.pad(img, [padW] * 2 + [padH] * 2)
259
+ output_tensor["img"][i][cam] = img
260
+
261
+ if cam < len(output_tensor["disp"][i]):
262
+ disp = (
263
+ torch.from_numpy(output_tensor["disp"][i][cam])
264
+ .permute(2, 0, 1)
265
+ .float()
266
+ )
267
+
268
+ if self.sparse:
269
+ valid_disp = torch.from_numpy(
270
+ output_tensor["valid_disp"][i][cam]
271
+ )
272
+ else:
273
+ valid_disp = (
274
+ (disp[0].abs() < 512)
275
+ & (disp[1].abs() < 512)
276
+ & (disp[0].abs() != 0)
277
+ )
278
+ disp = disp[:1]
279
+
280
+ output_tensor["disp"][i][cam] = disp
281
+ output_tensor["valid_disp"][i][cam] = valid_disp.float()
282
+
283
+ if "mask" in output_tensor and cam < len(output_tensor["mask"][i]):
284
+ mask = torch.from_numpy(output_tensor["mask"][i][cam]).float()
285
+ output_tensor["mask"][i][cam] = mask
286
+
287
+ if "viewpoint" in output_tensor and cam < len(
288
+ output_tensor["viewpoint"][i]
289
+ ):
290
+ viewpoint = output_tensor["viewpoint"][i][cam]
291
+ output_tensor["viewpoint"][i][cam] = viewpoint
292
+
293
+ res = {}
294
+ if "viewpoint" in output_tensor and self.split != "train":
295
+ res["viewpoint"] = output_tensor["viewpoint"]
296
+ if "metadata" in output_tensor and self.split != "train":
297
+ res["metadata"] = output_tensor["metadata"]
298
+
299
+ for k, v in output_tensor.items():
300
+ if k != "viewpoint" and k != "metadata":
301
+ for i in range(len(v)):
302
+ if len(v[i]) > 0:
303
+ v[i] = torch.stack(v[i])
304
+ if len(v) > 0 and (len(v[0]) > 0):
305
+ res[k] = torch.stack(v)
306
+ return res
307
+
308
+ def __mul__(self, v):
309
+ copy_of_self = copy.deepcopy(self)
310
+ copy_of_self.sample_list = v * copy_of_self.sample_list
311
+ copy_of_self.extra_info = v * copy_of_self.extra_info
312
+ return copy_of_self
313
+
314
+ def __len__(self):
315
+ return len(self.sample_list)
316
+
317
+
318
+ class DynamicReplicaDataset(StereoSequenceDataset):
319
+ def __init__(
320
+ self,
321
+ aug_params=None,
322
+ root="./dynamic_replica_data",
323
+ split="train",
324
+ sample_len=-1,
325
+ only_first_n_samples=-1,
326
+ t_step_validation=1, # -- Added by Chu King on 24th November 2025 to control the separation between consecutive samples in validation
327
+ VERBOSE=False # -- Added by Chu King on 16th November 2025 for debugging purposes
328
+ ):
329
+ super(DynamicReplicaDataset, self).__init__(aug_params)
330
+ self.root = root
331
+ self.sample_len = sample_len
332
+ self.split = split
333
+
334
+ frame_annotations_file = f"frame_annotations_{split}.jgz"
335
+
336
+ with gzip.open(
337
+ osp.join(root, split, frame_annotations_file), "rt", encoding="utf8"
338
+ ) as zipfile:
339
+ frame_annots_list = load_dataclass(
340
+ zipfile, List[DynamicReplicaFrameAnnotation]
341
+ )
342
+ seq_annot = defaultdict(lambda: defaultdict(list))
343
+ for frame_annot in frame_annots_list:
344
+ seq_annot[frame_annot.sequence_name][frame_annot.camera_name].append(
345
+ frame_annot
346
+ )
347
+
348
+ # -- Added by Chu King on 16th November 2025 for debugging purposes
349
+ if VERBOSE:
350
+ rank = dist.get_rank() if dist.is_initialized() else 0
351
+ with open(f"debug_rank_{rank}.txt", "a") as f:
352
+ f.write("[INFO] seq_annot: {}\n".format(seq_annot))
353
+ # -- os.kill(os.getpid(), signal.SIGABRT)
354
+
355
+ for seq_name in seq_annot.keys():
356
+
357
+ # -- Added by Chu King on 16th November 2025 for debugging purposes
358
+ if VERBOSE:
359
+ rank = dist.get_rank() if dist.is_initialized() else 0
360
+ with open(f"debug_rank_{rank}.txt", "a") as f:
361
+ f.write("---- ----\n")
362
+ f.write("[INFO] seq_name: {}\n".format(seq_name))
363
+
364
+ try:
365
+ filenames = defaultdict(lambda: defaultdict(list))
366
+ for cam in ["left", "right"]:
367
+ for framedata in seq_annot[seq_name][cam]:
368
+ im_path = osp.join(root, split, framedata.image.path)
369
+ depth_path = osp.join(root, split, framedata.depth.path)
370
+ mask_path = osp.join(root, split, framedata.mask.path)
371
+
372
+ # -- Added by Chu King on 16th November 2025 for debugging purposes
373
+ if VERBOSE:
374
+ rank = dist.get_rank() if dist.is_initialized() else 0
375
+ with open(f"debug_rank_{rank}.txt", "a") as f:
376
+ f.write("[INFO] cam: {}\n".format(cam))
377
+ f.write("[INFO] framedata: {}\n".format(framedata))
378
+ f.write("[INFO] framedata.viewpoint: {}\n".format(framedata.viewpoint))
379
+ f.write("[INFO] im_path: {}\n".format(im_path))
380
+ f.write("[INFO] depth_path: {}\n".format(depth_path))
381
+ f.write("[INFO] mask_path: {}\n".format(mask_path))
382
+
383
+ # -- Modified by Chu King on 16th November 2025 to clarify the nature of assertion errors.
384
+ assert os.path.isfile(im_path), "[ERROR] Rectified image path {} doesn't exist.".format(im_path)
385
+
386
+ tokens = root.split("/")
387
+ # -- if split != "test" and "real" not in tokens:
388
+ # -- assert os.path.isfile(depth_path), "[ERROR] Depth path {} doesn't exist. ".format(depth_path)
389
+ if not os.path.isfile(depth_path):
390
+ if split != "test" or "real" not in tokens:
391
+ print ("[WARNING] Depth path {} doesn't exist.".format(depth_path))
392
+
393
+ assert os.path.isfile(mask_path), "[ERROR] Mask path {} doesn't exist.".format(mask_path)
394
+
395
+ filenames["image"][cam].append(im_path)
396
+ filenames["mask"][cam].append(mask_path)
397
+ filenames["depth"][cam].append(depth_path)
398
+ filenames["viewpoint"][cam].append(framedata.viewpoint)
399
+ filenames["metadata"][cam].append(
400
+ [framedata.sequence_name, framedata.image.size]
401
+ )
402
+
403
+ for k in filenames.keys():
404
+ assert (
405
+ len(filenames[k][cam])
406
+ == len(filenames["image"][cam])
407
+ > 0
408
+ ), framedata.sequence_name
409
+
410
+ if not os.path.isfile(depth_path):
411
+ del filenames["depth"]
412
+
413
+ seq_len = len(filenames["image"][cam])
414
+
415
+ print("seq_len", seq_name, seq_len)
416
+ if split == "train":
417
+ for ref_idx in range(0, seq_len, 3):
418
+ # -- step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
419
+ # -- Modified by Chu King on 24th November 2025 to handle high-speed motion.
420
+ step = 1 if self.sample_len == 1 else np.random.randint(1, 12)
421
+ if ref_idx + step * self.sample_len < seq_len:
422
+ sample = defaultdict(lambda: defaultdict(list))
423
+ for cam in ["left", "right"]:
424
+ for idx in range(
425
+ ref_idx, ref_idx + step * self.sample_len, step
426
+ ):
427
+ for k in filenames.keys():
428
+ if "mask" not in k:
429
+ sample[k][cam].append(
430
+ filenames[k][cam][idx]
431
+ )
432
+
433
+ self.sample_list.append(sample)
434
+ else:
435
+ step = self.sample_len if self.sample_len > 0 else seq_len
436
+ counter = 0
437
+
438
+ for ref_idx in range(0, seq_len, step):
439
+ sample = defaultdict(lambda: defaultdict(list))
440
+ for cam in ["left", "right"]:
441
+ # -- Modified by Chu King on 24th November 2025 to control the separation between samples during validation.
442
+ # -- for idx in range(ref_idx, ref_idx + step):
443
+ for idx in range(ref_idx, ref_idx + step * t_step_validation, t_step_validation):
444
+ for k in filenames.keys():
445
+ sample[k][cam].append(filenames[k][cam][idx])
446
+
447
+ self.sample_list.append(sample)
448
+ counter += 1
449
+ if only_first_n_samples > 0 and counter >= only_first_n_samples:
450
+ break
451
+ except Exception as e:
452
+ print(e)
453
+ print("Skipping sequence", seq_name)
454
+
455
+ assert len(self.sample_list) > 0, "No samples found"
456
+ print(f"Added {len(self.sample_list)} from Dynamic Replica {split}")
457
+ logging.info(f"Added {len(self.sample_list)} from Dynamic Replica {split}")
458
+
459
+
460
+ class SequenceSceneFlowDataset(StereoSequenceDataset):
461
+ def __init__(
462
+ self,
463
+ aug_params=None,
464
+ root="./datasets",
465
+ dstype="frames_cleanpass",
466
+ sample_len=1,
467
+ things_test=False,
468
+ add_things=True,
469
+ add_monkaa=True,
470
+ add_driving=True,
471
+ ):
472
+ super(SequenceSceneFlowDataset, self).__init__(aug_params)
473
+ self.root = root
474
+ self.dstype = dstype
475
+ self.sample_len = sample_len
476
+ if things_test:
477
+ self._add_things("TEST")
478
+ else:
479
+ if add_things:
480
+ self._add_things("TRAIN")
481
+ if add_monkaa:
482
+ self._add_monkaa()
483
+ if add_driving:
484
+ self._add_driving()
485
+
486
+ def _add_things(self, split="TRAIN"):
487
+ """Add FlyingThings3D data"""
488
+
489
+ original_length = len(self.sample_list)
490
+ root = osp.join(self.root, "FlyingThings3D")
491
+ image_paths = defaultdict(list)
492
+ disparity_paths = defaultdict(list)
493
+
494
+ for cam in ["left", "right"]:
495
+ image_paths[cam] = sorted(
496
+ glob(osp.join(root, self.dstype, split, f"*/*/{cam}/"))
497
+ )
498
+ disparity_paths[cam] = [
499
+ path.replace(self.dstype, "disparity") for path in image_paths[cam]
500
+ ]
501
+
502
+ # Choose a random subset of 400 images for validation
503
+ state = np.random.get_state()
504
+ np.random.seed(1000)
505
+ val_idxs = set(np.random.permutation(len(image_paths["left"]))[:40])
506
+ np.random.set_state(state)
507
+ np.random.seed(0)
508
+ num_seq = len(image_paths["left"])
509
+
510
+ for seq_idx in range(num_seq):
511
+ if (split == "TEST" and seq_idx in val_idxs) or (
512
+ split == "TRAIN" and not seq_idx in val_idxs
513
+ ):
514
+ images, disparities = defaultdict(list), defaultdict(list)
515
+ for cam in ["left", "right"]:
516
+ images[cam] = sorted(
517
+ glob(osp.join(image_paths[cam][seq_idx], "*.png"))
518
+ )
519
+ disparities[cam] = sorted(
520
+ glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
521
+ )
522
+
523
+ self._append_sample(images, disparities)
524
+
525
+ assert len(self.sample_list) > 0, "No samples found"
526
+ print(
527
+ f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}"
528
+ )
529
+ logging.info(
530
+ f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}"
531
+ )
532
+
533
+ def _add_monkaa(self):
534
+ """Add FlyingThings3D data"""
535
+
536
+ original_length = len(self.sample_list)
537
+ root = osp.join(self.root, "Monkaa")
538
+ image_paths = defaultdict(list)
539
+ disparity_paths = defaultdict(list)
540
+
541
+ for cam in ["left", "right"]:
542
+ image_paths[cam] = sorted(glob(osp.join(root, self.dstype, f"*/{cam}/")))
543
+ disparity_paths[cam] = [
544
+ path.replace(self.dstype, "disparity") for path in image_paths[cam]
545
+ ]
546
+
547
+ num_seq = len(image_paths["left"])
548
+
549
+ for seq_idx in range(num_seq):
550
+ images, disparities = defaultdict(list), defaultdict(list)
551
+ for cam in ["left", "right"]:
552
+ images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png")))
553
+ disparities[cam] = sorted(
554
+ glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
555
+ )
556
+
557
+ self._append_sample(images, disparities)
558
+
559
+ assert len(self.sample_list) > 0, "No samples found"
560
+ print(
561
+ f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}"
562
+ )
563
+ logging.info(
564
+ f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}"
565
+ )
566
+
567
+ def _add_driving(self):
568
+ """Add FlyingThings3D data"""
569
+
570
+ original_length = len(self.sample_list)
571
+ root = osp.join(self.root, "Driving")
572
+ image_paths = defaultdict(list)
573
+ disparity_paths = defaultdict(list)
574
+
575
+ for cam in ["left", "right"]:
576
+ image_paths[cam] = sorted(
577
+ glob(osp.join(root, self.dstype, f"*/*/*/{cam}/"))
578
+ )
579
+ disparity_paths[cam] = [
580
+ path.replace(self.dstype, "disparity") for path in image_paths[cam]
581
+ ]
582
+
583
+ num_seq = len(image_paths["left"])
584
+ for seq_idx in range(num_seq):
585
+ images, disparities = defaultdict(list), defaultdict(list)
586
+ for cam in ["left", "right"]:
587
+ images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png")))
588
+ disparities[cam] = sorted(
589
+ glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
590
+ )
591
+
592
+ self._append_sample(images, disparities)
593
+
594
+ assert len(self.sample_list) > 0, "No samples found"
595
+ print(
596
+ f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}"
597
+ )
598
+ logging.info(
599
+ f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}"
600
+ )
601
+
602
+ def _append_sample(self, images, disparities):
603
+ seq_len = len(images["left"])
604
+ for ref_idx in range(0, seq_len - self.sample_len):
605
+ sample = defaultdict(lambda: defaultdict(list))
606
+ for cam in ["left", "right"]:
607
+ for idx in range(ref_idx, ref_idx + self.sample_len):
608
+ sample["image"][cam].append(images[cam][idx])
609
+ sample["disparity"][cam].append(disparities[cam][idx])
610
+ self.sample_list.append(sample)
611
+
612
+ sample = defaultdict(lambda: defaultdict(list))
613
+ for cam in ["left", "right"]:
614
+ for idx in range(ref_idx, ref_idx + self.sample_len):
615
+ sample["image"][cam].append(images[cam][seq_len - idx - 1])
616
+ sample["disparity"][cam].append(disparities[cam][seq_len - idx - 1])
617
+ self.sample_list.append(sample)
618
+
619
+
620
+ class SequenceSintelStereo(StereoSequenceDataset):
621
+ def __init__(
622
+ self,
623
+ dstype="clean",
624
+ aug_params=None,
625
+ root="./datasets",
626
+ ):
627
+ super().__init__(
628
+ aug_params, sparse=True, reader=frame_utils.readDispSintelStereo
629
+ )
630
+ self.dstype = dstype
631
+ original_length = len(self.sample_list)
632
+ image_root = osp.join(root, "sintel_stereo", "training")
633
+
634
+ image_paths = defaultdict(list)
635
+ disparity_paths = defaultdict(list)
636
+
637
+ for cam in ["left", "right"]:
638
+ image_paths[cam] = sorted(
639
+ glob(osp.join(image_root, f"{self.dstype}_{cam}/*"))
640
+ )
641
+
642
+ cam = "left"
643
+ disparity_paths[cam] = [
644
+ path.replace(f"{self.dstype}_{cam}", "disparities")
645
+ for path in image_paths[cam]
646
+ ]
647
+
648
+ num_seq = len(image_paths["left"])
649
+ # for each sequence
650
+ for seq_idx in range(num_seq):
651
+ sample = defaultdict(lambda: defaultdict(list))
652
+ for cam in ["left", "right"]:
653
+ sample["image"][cam] = sorted(
654
+ glob(osp.join(image_paths[cam][seq_idx], "*.png"))
655
+ )
656
+ cam = "left"
657
+ sample["disparity"][cam] = sorted(
658
+ glob(osp.join(disparity_paths[cam][seq_idx], "*.png"))
659
+ )
660
+ for im1, disp in zip(sample["image"][cam], sample["disparity"][cam]):
661
+ assert (
662
+ im1.split("/")[-1].split(".")[0]
663
+ == disp.split("/")[-1].split(".")[0]
664
+ ), (im1.split("/")[-1].split(".")[0], disp.split("/")[-1].split(".")[0])
665
+ self.sample_list.append(sample)
666
+
667
+ logging.info(
668
+ f"Added {len(self.sample_list) - original_length} from SintelStereo {self.dstype}"
669
+ )
670
+
671
+
672
+ def fetch_dataloader(args):
673
+ """Create the data loader for the corresponding trainign set"""
674
+
675
+ aug_params = {
676
+ "crop_size": args.image_size,
677
+ "min_scale": args.spatial_scale[0],
678
+ "max_scale": args.spatial_scale[1],
679
+ "do_flip": False,
680
+ "yjitter": not args.noyjitter,
681
+ }
682
+ if hasattr(args, "saturation_range") and args.saturation_range is not None:
683
+ aug_params["saturation_range"] = args.saturation_range
684
+ if hasattr(args, "img_gamma") and args.img_gamma is not None:
685
+ aug_params["gamma"] = args.img_gamma
686
+ if hasattr(args, "do_flip") and args.do_flip is not None:
687
+ aug_params["do_flip"] = args.do_flip
688
+
689
+ train_dataset = None
690
+
691
+ add_monkaa = "monkaa" in args.train_datasets
692
+ add_driving = "driving" in args.train_datasets
693
+ add_things = "things" in args.train_datasets
694
+ add_dynamic_replica = "dynamic_replica" in args.train_datasets
695
+
696
+ new_dataset = None
697
+
698
+ if add_monkaa or add_driving or add_things:
699
+ clean_dataset = SequenceSceneFlowDataset(
700
+ aug_params,
701
+ dstype="frames_cleanpass",
702
+ sample_len=args.sample_len,
703
+ add_monkaa=add_monkaa,
704
+ add_driving=add_driving,
705
+ add_things=add_things,
706
+ )
707
+
708
+ final_dataset = SequenceSceneFlowDataset(
709
+ aug_params,
710
+ dstype="frames_finalpass",
711
+ sample_len=args.sample_len,
712
+ add_monkaa=add_monkaa,
713
+ add_driving=add_driving,
714
+ add_things=add_things,
715
+ )
716
+
717
+ new_dataset = clean_dataset + final_dataset
718
+
719
+ if add_dynamic_replica:
720
+ dr_dataset = DynamicReplicaDataset(
721
+ aug_params, split="train", sample_len=args.sample_len
722
+ )
723
+ if new_dataset is None:
724
+ new_dataset = dr_dataset
725
+ else:
726
+ new_dataset = new_dataset + dr_dataset
727
+
728
+ logging.info(f"Adding {len(new_dataset)} samples from SceneFlow")
729
+ train_dataset = (
730
+ new_dataset if train_dataset is None else train_dataset + new_dataset
731
+ )
732
+
733
+ train_loader = data.DataLoader(
734
+ train_dataset,
735
+ batch_size=args.batch_size,
736
+ pin_memory=True,
737
+ shuffle=True,
738
+ num_workers=args.num_workers,
739
+ drop_last=True,
740
+ )
741
+
742
+ logging.info("Training with %d image pairs" % len(train_dataset))
743
+ return train_loader
datasets/frame_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from os.path import *
10
+ import re
11
+ import imageio
12
+ import cv2
13
+
14
+ cv2.setNumThreads(0)
15
+ cv2.ocl.setUseOpenCL(False)
16
+
17
+ TAG_CHAR = np.array([202021.25], np.float32)
18
+
19
+
20
+ def readFlow(fn):
21
+ """Read .flo file in Middlebury format"""
22
+ # Code adapted from:
23
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
24
+
25
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
26
+ # print 'fn = %s'%(fn)
27
+ with open(fn, "rb") as f:
28
+ magic = np.fromfile(f, np.float32, count=1)
29
+ if 202021.25 != magic:
30
+ print("Magic number incorrect. Invalid .flo file")
31
+ return None
32
+ else:
33
+ w = np.fromfile(f, np.int32, count=1)
34
+ h = np.fromfile(f, np.int32, count=1)
35
+ # print 'Reading %d x %d flo file\n' % (w, h)
36
+ data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
37
+ # Reshape data into 3D array (columns, rows, bands)
38
+ # The reshape here is for visualization, the original code is (w,h,2)
39
+ return np.resize(data, (int(h), int(w), 2))
40
+
41
+
42
+ def readPFM(file):
43
+ file = open(file, "rb")
44
+
45
+ color = None
46
+ width = None
47
+ height = None
48
+ scale = None
49
+ endian = None
50
+
51
+ header = file.readline().rstrip()
52
+ if header == b"PF":
53
+ color = True
54
+ elif header == b"Pf":
55
+ color = False
56
+ else:
57
+ raise Exception("Not a PFM file.")
58
+
59
+ dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
60
+ if dim_match:
61
+ width, height = map(int, dim_match.groups())
62
+ else:
63
+ raise Exception("Malformed PFM header.")
64
+
65
+ scale = float(file.readline().rstrip())
66
+ if scale < 0: # little-endian
67
+ endian = "<"
68
+ scale = -scale
69
+ else:
70
+ endian = ">" # big-endian
71
+
72
+ data = np.fromfile(file, endian + "f")
73
+ shape = (height, width, 3) if color else (height, width)
74
+
75
+ data = np.reshape(data, shape)
76
+ data = np.flipud(data)
77
+ return data
78
+
79
+
80
+ def readDispSintelStereo(file_name):
81
+ """Return disparity read from filename."""
82
+ f_in = np.array(Image.open(file_name))
83
+ d_r = f_in[:, :, 0].astype("float64")
84
+ d_g = f_in[:, :, 1].astype("float64")
85
+ d_b = f_in[:, :, 2].astype("float64")
86
+
87
+ disp = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14)
88
+ mask = np.array(Image.open(file_name.replace("disparities", "occlusions")))
89
+ valid = (mask == 0) & (disp > 0)
90
+ return disp, valid
91
+
92
+
93
+ def readDispMiddlebury(file_name):
94
+ assert basename(file_name) == "disp0GT.pfm"
95
+ disp = readPFM(file_name).astype(np.float32)
96
+ assert len(disp.shape) == 2
97
+ nocc_pix = file_name.replace("disp0GT.pfm", "mask0nocc.png")
98
+ assert exists(nocc_pix)
99
+ nocc_pix = imageio.imread(nocc_pix) == 255
100
+ assert np.any(nocc_pix)
101
+ return disp, nocc_pix
102
+
103
+
104
+ def read_gen(file_name, pil=False):
105
+ ext = splitext(file_name)[-1]
106
+ if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
107
+ return Image.open(file_name)
108
+ elif ext == ".bin" or ext == ".raw":
109
+ return np.load(file_name)
110
+ elif ext == ".flo":
111
+ return readFlow(file_name).astype(np.float32)
112
+ elif ext == ".pfm":
113
+ flow = readPFM(file_name).astype(np.float32)
114
+ if len(flow.shape) == 2:
115
+ return flow
116
+ else:
117
+ return flow[:, :, :-1]
118
+ return []
evaluation/configs/eval_dynamic_replica_150_frames.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default_config_eval
3
+ visualize_interval: 0
4
+ exp_dir: ./outputs/dynamic_stereo_DR
5
+ sample_len: 150
6
+ MODEL:
7
+ model_name: DynamicStereoModel
8
+
evaluation/configs/eval_dynamic_replica_40_frames.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default_config_eval
3
+ visualize_interval: 0
4
+ exp_dir: ./outputs/dynamic_stereo_DR
5
+ sample_len: 40
6
+ MODEL:
7
+ model_name: DynamicStereoModel
8
+
evaluation/configs/eval_real_data.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default_config_eval
3
+ visualize_interval: 1
4
+ exp_dir: ./outputs/dynamic_stereo_real
5
+ dataset_name: real
6
+ sample_len: 40
7
+ MODEL:
8
+ model_name: DynamicStereoModel
9
+
evaluation/configs/eval_sintel_clean.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default_config_eval
3
+ visualize_interval: -1
4
+ exp_dir: ./outputs/dynamic_stereo_sintel_clean
5
+ sample_len: 30
6
+ dataset_name: sintel
7
+ dstype: clean
8
+ MODEL:
9
+ model_name: DynamicStereoModel
evaluation/configs/eval_sintel_final.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default_config_eval
3
+ visualize_interval: -1
4
+ exp_dir: ./outputs/dynamic_stereo_sintel_final
5
+ sample_len: 30
6
+ dataset_name: sintel
7
+ dstype: final
8
+ MODEL:
9
+ model_name: DynamicStereoModel
evaluation/core/evaluator.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from collections import defaultdict
9
+ import torch.nn.functional as F
10
+ import torch
11
+ from tqdm import tqdm
12
+ from omegaconf import DictConfig
13
+ from pytorch3d.implicitron.tools.config import Configurable
14
+
15
+ from evaluation.utils.eval_utils import depth2disparity_scale, eval_batch
16
+ from evaluation.utils.utils import (
17
+ PerceptionPrediction,
18
+ pretty_print_perception_metrics,
19
+ visualize_batch,
20
+ )
21
+
22
+
23
+ class Evaluator(Configurable):
24
+ """
25
+ A class defining the DynamicStereo evaluator.
26
+
27
+ Args:
28
+ eps: Threshold for converting disparity to depth.
29
+ """
30
+
31
+ eps = 1e-5
32
+
33
+ def setup_visualization(self, cfg: DictConfig) -> None:
34
+ # Visualization
35
+ self.visualize_interval = cfg.visualize_interval
36
+ self.exp_dir = cfg.exp_dir
37
+ if self.visualize_interval > 0:
38
+ self.visualize_dir = os.path.join(cfg.exp_dir, "visualisations")
39
+
40
+ @torch.no_grad()
41
+ def evaluate_sequence(
42
+ self,
43
+ sci_enc_L,
44
+ sci_enc_R,
45
+ model,
46
+ test_dataloader: torch.utils.data.DataLoader,
47
+ is_real_data: bool = False,
48
+ step=None,
49
+ writer=None,
50
+ train_mode=False,
51
+ interp_shape=None,
52
+ resolution=[480, 640]
53
+ ):
54
+ # -- Modified by Chu King on 20th November 2025 for SCI Stereo.
55
+ # -- model.eval()
56
+
57
+ per_batch_eval_results = []
58
+
59
+ if self.visualize_interval > 0:
60
+ os.makedirs(self.visualize_dir, exist_ok=True)
61
+
62
+ for batch_idx, sequence in enumerate(tqdm(test_dataloader)):
63
+ batch_dict = defaultdict(list)
64
+ batch_dict["stereo_video"] = sequence["img"]
65
+ if not is_real_data:
66
+ batch_dict["disparity"] = sequence["disp"][:, 0].abs()
67
+ batch_dict["disparity_mask"] = sequence["valid_disp"][:, :1] # ~ (T, 1, 720, 1280)
68
+
69
+ if "mask" in sequence:
70
+ batch_dict["fg_mask"] = sequence["mask"][:, :1]
71
+ else:
72
+ batch_dict["fg_mask"] = torch.ones_like(
73
+ batch_dict["disparity_mask"]
74
+ )
75
+ elif interp_shape is not None:
76
+ left_video = batch_dict["stereo_video"][:, 0]
77
+ left_video = F.interpolate(
78
+ left_video, tuple(interp_shape), mode="bilinear"
79
+ )
80
+ right_video = batch_dict["stereo_video"][:, 1]
81
+ right_video = F.interpolate(
82
+ right_video, tuple(interp_shape), mode="bilinear"
83
+ )
84
+ batch_dict["stereo_video"] = torch.stack([left_video, right_video], 1)
85
+
86
+ # -- This method is always invoked with train_mode=True.
87
+ if train_mode:
88
+ # -- Modified by Chu King on 20th November 2025.
89
+ # -- predictions = model.forward_batch_test(batch_dict)
90
+ predictions = model.forward_batch_test(batch_dict, sci_enc_L, sci_enc_R)
91
+ else:
92
+ predictions = model(batch_dict)
93
+
94
+ assert "disparity" in predictions
95
+ predictions["disparity"] = predictions["disparity"][:, :1].clone().cpu()
96
+
97
+ # -- print ("[INFO] predictions[\"disparity\"].shape", predictions["disparity"].shape)
98
+ # -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].shape)
99
+ # -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].round().shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round().shape)
100
+
101
+ if not is_real_data:
102
+ predictions["disparity"] = predictions["disparity"] * (
103
+ # -- Modified by Chu King on 22nd November 2025
104
+ # -- batch_dict["disparity_mask"].round()
105
+ batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round()
106
+ )
107
+
108
+ batch_eval_result, seq_length = eval_batch(batch_dict, predictions)
109
+
110
+ per_batch_eval_results.append((batch_eval_result, seq_length))
111
+ pretty_print_perception_metrics(batch_eval_result)
112
+
113
+ if (self.visualize_interval > 0) and (
114
+ batch_idx % self.visualize_interval == 0
115
+ ):
116
+ perception_prediction = PerceptionPrediction()
117
+
118
+ pred_disp = predictions["disparity"]
119
+ pred_disp[pred_disp < self.eps] = self.eps
120
+
121
+ scale = depth2disparity_scale(
122
+ sequence["viewpoint"][0][0],
123
+ sequence["viewpoint"][0][1],
124
+ torch.tensor([pred_disp.shape[2], pred_disp.shape[3]])[None],
125
+ )
126
+
127
+ perception_prediction.depth_map = (scale / pred_disp).cuda()
128
+ perspective_cameras = []
129
+ for cam in sequence["viewpoint"]:
130
+ perspective_cameras.append(cam[0])
131
+
132
+ perception_prediction.perspective_cameras = perspective_cameras
133
+
134
+ # -- Modified by Chu King on 22nd November 2025 to fix image resolution during training.
135
+ if "stereo_original_video" in batch_dict:
136
+ batch_dict["stereo_video"] = batch_dict["stereo_original_video"][..., :resolution[0], :resolution[1]].clone()
137
+
138
+ for k, v in batch_dict.items():
139
+ if isinstance(v, torch.Tensor):
140
+ batch_dict[k] = v.cuda()
141
+
142
+ visualize_batch(
143
+ batch_dict,
144
+ perception_prediction,
145
+ output_dir=self.visualize_dir,
146
+ sequence_name=sequence["metadata"][0][0][0],
147
+ step=step,
148
+ writer=writer,
149
+ # -- Added by Chu King on 22nd November 2025 to fix image resolution during evaluation.
150
+ resolution=resolution
151
+ )
152
+ return per_batch_eval_results
evaluation/evaluate.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, Dict, Optional
11
+
12
+ import hydra
13
+ import numpy as np
14
+
15
+ import torch
16
+ from omegaconf import OmegaConf
17
+
18
+ from dynamic_stereo.evaluation.utils.utils import aggregate_and_print_results
19
+
20
+ import dynamic_stereo.datasets.dynamic_stereo_datasets as datasets
21
+
22
+ from dynamic_stereo.models.core.model_zoo import (
23
+ get_all_model_default_configs,
24
+ model_zoo,
25
+ )
26
+ from pytorch3d.implicitron.tools.config import get_default_args_field
27
+ from dynamic_stereo.evaluation.core.evaluator import Evaluator
28
+
29
+
30
+ @dataclass(eq=False)
31
+ class DefaultConfig:
32
+ exp_dir: str = "./outputs"
33
+
34
+ # one of [sintel, dynamicreplica, things]
35
+ dataset_name: str = "dynamicreplica"
36
+
37
+ sample_len: int = -1
38
+ dstype: Optional[str] = None
39
+ # clean, final
40
+ MODEL: Dict[str, Any] = field(
41
+ default_factory=lambda: get_all_model_default_configs()
42
+ )
43
+ EVALUATOR: Dict[str, Any] = get_default_args_field(Evaluator)
44
+
45
+ seed: int = 42
46
+ gpu_idx: int = 0
47
+
48
+ visualize_interval: int = 0 # Use 0 for no visualization
49
+
50
+ # Override hydra's working directory to current working dir,
51
+ # also disable storing the .hydra logs:
52
+ hydra: dict = field(
53
+ default_factory=lambda: {
54
+ "run": {"dir": "."},
55
+ "output_subdir": None,
56
+ }
57
+ )
58
+
59
+
60
+ def run_eval(cfg: DefaultConfig):
61
+ """
62
+ Evaluates new view synthesis metrics of a specified model
63
+ on a benchmark dataset.
64
+ """
65
+ # make the experiment directory
66
+ os.makedirs(cfg.exp_dir, exist_ok=True)
67
+
68
+ # dump the exp cofig to the exp_dir
69
+ cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
70
+ with open(cfg_file, "w") as f:
71
+ OmegaConf.save(config=cfg, f=f)
72
+
73
+ torch.manual_seed(cfg.seed)
74
+ np.random.seed(cfg.seed)
75
+ evaluator = Evaluator(**cfg.EVALUATOR)
76
+
77
+ model = model_zoo(**cfg.MODEL)
78
+ model.cuda(0)
79
+ evaluator.setup_visualization(cfg)
80
+
81
+ if cfg.dataset_name == "dynamicreplica":
82
+ test_dataloader = datasets.DynamicReplicaDataset(
83
+ split="valid", sample_len=cfg.sample_len, only_first_n_samples=1
84
+ )
85
+ elif cfg.dataset_name == "sintel":
86
+ test_dataloader = datasets.SequenceSintelStereo(dstype=cfg.dstype)
87
+ elif cfg.dataset_name == "things":
88
+ test_dataloader = datasets.SequenceSceneFlowDatasets(
89
+ {},
90
+ dstype=cfg.dstype,
91
+ sample_len=cfg.sample_len,
92
+ add_monkaa=False,
93
+ add_driving=False,
94
+ things_test=True,
95
+ )
96
+ elif cfg.dataset_name == "real":
97
+ for real_sequence_name in ["teddy_static", "ignacio_waving", "nikita_reading"]:
98
+ ds_path = f"./dynamic_replica_data/real/{real_sequence_name}"
99
+ # seq_len_real = 20
100
+ real_dataset = datasets.DynamicReplicaDataset(
101
+ split="test",
102
+ sample_len=cfg.sample_len,
103
+ root=ds_path,
104
+ only_first_n_samples=1,
105
+ )
106
+
107
+ evaluator.evaluate_sequence(
108
+ model=model,
109
+ test_dataloader=real_dataset,
110
+ is_real_data=True,
111
+ train_mode=False,
112
+ )
113
+ return
114
+
115
+ print()
116
+
117
+ evaluate_result = evaluator.evaluate_sequence(
118
+ model,
119
+ test_dataloader,
120
+ )
121
+
122
+ aggreegate_result = aggregate_and_print_results(evaluate_result)
123
+
124
+ result_file = os.path.join(cfg.exp_dir, f"result_eval.json")
125
+
126
+ print(f"Dumping eval results to {result_file}.")
127
+ with open(result_file, "w") as f:
128
+ json.dump(aggreegate_result, f)
129
+
130
+
131
+ cs = hydra.core.config_store.ConfigStore.instance()
132
+ cs.store(name="default_config_eval", node=DefaultConfig)
133
+
134
+
135
+ @hydra.main(config_path="./configs/", config_name="default_config_eval")
136
+ def evaluate(cfg: DefaultConfig) -> None:
137
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
138
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
139
+ run_eval(cfg)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ evaluate()
evaluation/utils/eval_utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Dict, Optional, Union
9
+
10
+ import torch
11
+ from pytorch3d.utils import opencv_from_cameras_projection
12
+
13
+
14
+ @dataclass(eq=True, frozen=True)
15
+ class PerceptionMetric:
16
+ metric: str
17
+ depth_scaling_norm: Optional[str] = None
18
+ suffix: str = ""
19
+ index: str = ""
20
+
21
+ def __str__(self):
22
+ return (
23
+ self.metric
24
+ + self.index
25
+ + (
26
+ ("_norm_" + self.depth_scaling_norm)
27
+ if self.depth_scaling_norm is not None
28
+ else ""
29
+ )
30
+ + self.suffix
31
+ )
32
+
33
+
34
+ def eval_endpoint_error_sequence(
35
+ x: torch.Tensor,
36
+ y: torch.Tensor,
37
+ mask: torch.Tensor,
38
+ crop: int = 0,
39
+ mask_thr: float = 0.5,
40
+ clamp_thr: float = 1e-5,
41
+ ) -> Dict[str, torch.Tensor]:
42
+
43
+ assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
44
+ x.shape,
45
+ y.shape,
46
+ mask.shape,
47
+ )
48
+ assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)
49
+
50
+ # chuck out the border
51
+ if crop > 0:
52
+ if crop > min(y.shape[2:]) - crop:
53
+ raise ValueError("Incorrect crop size.")
54
+ y = y[:, :, crop:-crop, crop:-crop]
55
+ x = x[:, :, crop:-crop, crop:-crop]
56
+ mask = mask[:, :, crop:-crop, crop:-crop]
57
+
58
+ y = y * (mask > mask_thr).float()
59
+ x = x * (mask > mask_thr).float()
60
+ y[torch.isnan(y)] = 0
61
+
62
+ results = {}
63
+ for epe_name in ("epe", "temp_epe"):
64
+ if epe_name == "epe":
65
+ endpoint_error = (mask * (x - y) ** 2).sum(dim=1).sqrt()
66
+ elif epe_name == "temp_epe":
67
+ delta_mask = mask[:-1] * mask[1:]
68
+ endpoint_error = (
69
+ (delta_mask * ((x[:-1] - x[1:]) - (y[:-1] - y[1:])) ** 2)
70
+ .sum(dim=1)
71
+ .sqrt()
72
+ )
73
+
74
+ # epe_nonzero = endpoint_error != 0
75
+ nonzero = torch.count_nonzero(endpoint_error)
76
+
77
+ epe_mean = endpoint_error.sum() / torch.clamp(
78
+ nonzero, clamp_thr
79
+ ) # average error for all the sequence pixels
80
+ epe_inv_accuracy_05px = (endpoint_error > 0.5).sum() / torch.clamp(
81
+ nonzero, clamp_thr
82
+ )
83
+ epe_inv_accuracy_1px = (endpoint_error > 1).sum() / torch.clamp(
84
+ nonzero, clamp_thr
85
+ )
86
+ epe_inv_accuracy_2px = (endpoint_error > 2).sum() / torch.clamp(
87
+ nonzero, clamp_thr
88
+ )
89
+ epe_inv_accuracy_3px = (endpoint_error > 3).sum() / torch.clamp(
90
+ nonzero, clamp_thr
91
+ )
92
+
93
+ results[f"{epe_name}_mean"] = epe_mean[None]
94
+ results[f"{epe_name}_bad_0.5px"] = epe_inv_accuracy_05px[None] * 100
95
+ results[f"{epe_name}_bad_1px"] = epe_inv_accuracy_1px[None] * 100
96
+ results[f"{epe_name}_bad_2px"] = epe_inv_accuracy_2px[None] * 100
97
+ results[f"{epe_name}_bad_3px"] = epe_inv_accuracy_3px[None] * 100
98
+ return results
99
+
100
+
101
+ def depth2disparity_scale(left_camera, right_camera, image_size_tensor):
102
+ # # opencv camera matrices
103
+ (_, T1, K1), (_, T2, _) = [
104
+ opencv_from_cameras_projection(
105
+ f,
106
+ image_size_tensor,
107
+ )
108
+ for f in (left_camera, right_camera)
109
+ ]
110
+ fix_baseline = T1[0][0] - T2[0][0]
111
+ focal_length_px = K1[0][0][0]
112
+ # following this https://github.com/princeton-vl/RAFT-Stereo#converting-disparity-to-depth
113
+ return focal_length_px * fix_baseline
114
+
115
+
116
+ def depth_to_pcd(
117
+ depth_map,
118
+ img,
119
+ focal_length,
120
+ cx,
121
+ cy,
122
+ step: int = None,
123
+ inv_extrinsic=None,
124
+ mask=None,
125
+ filter=False,
126
+ ):
127
+ __, w, __ = img.shape
128
+ if step is None:
129
+ step = int(w / 100)
130
+ Z = depth_map[::step, ::step]
131
+ colors = img[::step, ::step, :]
132
+
133
+ Pixels_Y = torch.arange(Z.shape[0]).to(Z.device) * step
134
+ Pixels_X = torch.arange(Z.shape[1]).to(Z.device) * step
135
+
136
+ X = (Pixels_X[None, :] - cx) * Z / focal_length
137
+ Y = (Pixels_Y[:, None] - cy) * Z / focal_length
138
+
139
+ inds = Z > 0
140
+
141
+ if mask is not None:
142
+ inds = inds * (mask[::step, ::step] > 0)
143
+
144
+ X = X[inds].reshape(-1)
145
+ Y = Y[inds].reshape(-1)
146
+ Z = Z[inds].reshape(-1)
147
+ colors = colors[inds]
148
+ pcd = torch.stack([X, Y, Z]).T
149
+
150
+ if inv_extrinsic is not None:
151
+ pcd_ext = torch.vstack([pcd.T, torch.ones((1, pcd.shape[0])).to(Z.device)])
152
+ pcd = (inv_extrinsic @ pcd_ext)[:3, :].T
153
+
154
+ if filter:
155
+ pcd, filt_inds = filter_outliers(pcd)
156
+ colors = colors[filt_inds]
157
+ return pcd, colors
158
+
159
+
160
+ def filter_outliers(pcd, sigma=3):
161
+ mean = pcd.mean(0)
162
+ std = pcd.std(0)
163
+ inds = ((pcd - mean).abs() < sigma * std)[:, 2]
164
+ pcd = pcd[inds]
165
+ return pcd, inds
166
+
167
+ # -- Modified by Chu King on 22nd November 2025 to fix the resolution during evaluation.
168
+ def eval_batch(batch_dict, predictions, resolution=[480, 640]) -> Dict[str, Union[float, torch.Tensor]]:
169
+ """
170
+ Produce performance metrics for a single batch of perception
171
+ predictions.
172
+ Args:
173
+ frame_data: A PixarFrameData object containing the input to the new view
174
+ synthesis method.
175
+ preds: A PerceptionPrediction object with the predicted data.
176
+ Returns:
177
+ results: A dictionary holding evaluation metrics.
178
+ """
179
+ results = {}
180
+
181
+ assert "disparity" in predictions
182
+ mask_now = torch.ones_like(batch_dict["fg_mask"][..., :resolution[0], :resolution[1]])
183
+
184
+ mask_now = mask_now * batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]]
185
+
186
+ eval_flow_traj_output = eval_endpoint_error_sequence(
187
+ predictions["disparity"], batch_dict["disparity"][..., :resolution[0], :resolution[1]], mask_now
188
+ )
189
+ for epe_name in ("epe", "temp_epe"):
190
+ results[PerceptionMetric(f"disp_{epe_name}_mean")] = eval_flow_traj_output[
191
+ f"{epe_name}_mean"
192
+ ]
193
+
194
+ results[PerceptionMetric(f"disp_{epe_name}_bad_3px")] = eval_flow_traj_output[
195
+ f"{epe_name}_bad_3px"
196
+ ]
197
+
198
+ results[PerceptionMetric(f"disp_{epe_name}_bad_2px")] = eval_flow_traj_output[
199
+ f"{epe_name}_bad_2px"
200
+ ]
201
+
202
+ results[PerceptionMetric(f"disp_{epe_name}_bad_1px")] = eval_flow_traj_output[
203
+ f"{epe_name}_bad_1px"
204
+ ]
205
+
206
+ results[PerceptionMetric(f"disp_{epe_name}_bad_0.5px")] = eval_flow_traj_output[
207
+ f"{epe_name}_bad_0.5px"
208
+ ]
209
+ if "endpoint_error_per_pixel" in eval_flow_traj_output:
210
+ results["disp_endpoint_error_per_pixel"] = eval_flow_traj_output[
211
+ "endpoint_error_per_pixel"
212
+ ]
213
+ return (results, len(predictions["disparity"]))
evaluation/utils/utils.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ import configparser
9
+ import os
10
+ import math
11
+ from typing import Optional, List
12
+ import torch
13
+ import cv2
14
+ import numpy as np
15
+ from dataclasses import dataclass
16
+ from tabulate import tabulate
17
+
18
+
19
+ from pytorch3d.structures import Pointclouds
20
+ from pytorch3d.transforms import RotateAxisAngle
21
+ from pytorch3d.utils import (
22
+ opencv_from_cameras_projection,
23
+ )
24
+ from pytorch3d.renderer import (
25
+ AlphaCompositor,
26
+ PointsRasterizationSettings,
27
+ PointsRasterizer,
28
+ PointsRenderer,
29
+ )
30
+ from evaluation.utils.eval_utils import depth_to_pcd
31
+
32
+
33
+ @dataclass
34
+ class PerceptionPrediction:
35
+ """
36
+ Holds the tensors that describe a result of any perception module.
37
+ """
38
+
39
+ depth_map: Optional[torch.Tensor] = None
40
+ disparity: Optional[torch.Tensor] = None
41
+ image_rgb: Optional[torch.Tensor] = None
42
+ fg_probability: Optional[torch.Tensor] = None
43
+
44
+
45
+ def aggregate_eval_results(per_batch_eval_results, reduction="mean"):
46
+
47
+ total_length = 0
48
+ aggregate_results = defaultdict(list)
49
+ for result in per_batch_eval_results:
50
+ if isinstance(result, tuple):
51
+ reduction = "sum"
52
+ length = result[1]
53
+ total_length += length
54
+ result = result[0]
55
+ for metric, val in result.items():
56
+ if reduction == "sum":
57
+ aggregate_results[metric].append(val * length)
58
+
59
+ if reduction == "mean":
60
+ return {k: torch.cat(v).mean().item() for k, v in aggregate_results.items()}
61
+ elif reduction == "sum":
62
+ return {
63
+ k: torch.cat(v).sum().item() / float(total_length)
64
+ for k, v in aggregate_results.items()
65
+ }
66
+
67
+
68
+ def aggregate_and_print_results(
69
+ per_batch_eval_results: List[dict],
70
+ ):
71
+ print("")
72
+ result = aggregate_eval_results(
73
+ per_batch_eval_results,
74
+ )
75
+ pretty_print_perception_metrics(result)
76
+ result = {str(k): v for k, v in result.items()}
77
+
78
+ print("")
79
+ return result
80
+
81
+
82
+ def pretty_print_perception_metrics(results):
83
+
84
+ metrics = sorted(list(results.keys()), key=lambda x: x.metric)
85
+
86
+ print("===== Perception results =====")
87
+ print(
88
+ tabulate(
89
+ [[metric, results[metric]] for metric in metrics],
90
+ )
91
+ )
92
+
93
+
94
+ def read_calibration(calibration_file, resolution_string):
95
+ # ported from https://github.com/stereolabs/zed-open-capture/
96
+ # blob/dfa0aee51ccd2297782230a05ca59e697df496b2/examples/include/calibration.hpp#L4172
97
+
98
+ zed_resolutions = {
99
+ "2K": (1242, 2208),
100
+ "FHD": (1080, 1920),
101
+ "HD": (720, 1280),
102
+ # "qHD": (540, 960),
103
+ "VGA": (376, 672),
104
+ }
105
+ assert resolution_string in zed_resolutions.keys()
106
+ image_height, image_width = zed_resolutions[resolution_string]
107
+
108
+ # Open camera configuration file
109
+ assert os.path.isfile(calibration_file)
110
+ calib = configparser.ConfigParser()
111
+ calib.read(calibration_file)
112
+
113
+ # Get translations
114
+ T = np.zeros((3, 1))
115
+ T[0] = float(calib["STEREO"]["baseline"])
116
+ T[1] = float(calib["STEREO"]["ty"])
117
+ T[2] = float(calib["STEREO"]["tz"])
118
+
119
+ baseline = T[0]
120
+
121
+ # Get left parameters
122
+ left_cam_cx = float(calib[f"LEFT_CAM_{resolution_string}"]["cx"])
123
+ left_cam_cy = float(calib[f"LEFT_CAM_{resolution_string}"]["cy"])
124
+ left_cam_fx = float(calib[f"LEFT_CAM_{resolution_string}"]["fx"])
125
+ left_cam_fy = float(calib[f"LEFT_CAM_{resolution_string}"]["fy"])
126
+ left_cam_k1 = float(calib[f"LEFT_CAM_{resolution_string}"]["k1"])
127
+ left_cam_k2 = float(calib[f"LEFT_CAM_{resolution_string}"]["k2"])
128
+ left_cam_p1 = float(calib[f"LEFT_CAM_{resolution_string}"]["p1"])
129
+ left_cam_p2 = float(calib[f"LEFT_CAM_{resolution_string}"]["p2"])
130
+ left_cam_k3 = float(calib[f"LEFT_CAM_{resolution_string}"]["k3"])
131
+
132
+ # Get right parameters
133
+ right_cam_cx = float(calib[f"RIGHT_CAM_{resolution_string}"]["cx"])
134
+ right_cam_cy = float(calib[f"RIGHT_CAM_{resolution_string}"]["cy"])
135
+ right_cam_fx = float(calib[f"RIGHT_CAM_{resolution_string}"]["fx"])
136
+ right_cam_fy = float(calib[f"RIGHT_CAM_{resolution_string}"]["fy"])
137
+ right_cam_k1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k1"])
138
+ right_cam_k2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k2"])
139
+ right_cam_p1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p1"])
140
+ right_cam_p2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p2"])
141
+ right_cam_k3 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k3"])
142
+
143
+ # Get rotations
144
+ R_zed = np.zeros(3)
145
+ R_zed[0] = float(calib["STEREO"][f"rx_{resolution_string.lower()}"])
146
+ R_zed[1] = float(calib["STEREO"][f"cv_{resolution_string.lower()}"])
147
+ R_zed[2] = float(calib["STEREO"][f"rz_{resolution_string.lower()}"])
148
+
149
+ R = cv2.Rodrigues(R_zed)[0]
150
+
151
+ # Left
152
+ cameraMatrix_left = np.array(
153
+ [[left_cam_fx, 0, left_cam_cx], [0, left_cam_fy, left_cam_cy], [0, 0, 1]]
154
+ )
155
+ distCoeffs_left = np.array(
156
+ [left_cam_k1, left_cam_k2, left_cam_p1, left_cam_p2, left_cam_k3]
157
+ )
158
+
159
+ # Right
160
+ cameraMatrix_right = np.array(
161
+ [
162
+ [right_cam_fx, 0, right_cam_cx],
163
+ [0, right_cam_fy, right_cam_cy],
164
+ [0, 0, 1],
165
+ ]
166
+ )
167
+ distCoeffs_right = np.array(
168
+ [right_cam_k1, right_cam_k2, right_cam_p1, right_cam_p2, right_cam_k3]
169
+ )
170
+
171
+ # Stereo
172
+ R1, R2, P1, P2, Q = cv2.stereoRectify(
173
+ cameraMatrix1=cameraMatrix_left,
174
+ distCoeffs1=distCoeffs_left,
175
+ cameraMatrix2=cameraMatrix_right,
176
+ distCoeffs2=distCoeffs_right,
177
+ imageSize=(image_width, image_height),
178
+ R=R,
179
+ T=T,
180
+ flags=cv2.CALIB_ZERO_DISPARITY,
181
+ newImageSize=(image_width, image_height),
182
+ alpha=0,
183
+ )[:5]
184
+
185
+ # Precompute maps for cv::remap()
186
+ map_left_x, map_left_y = cv2.initUndistortRectifyMap(
187
+ cameraMatrix_left,
188
+ distCoeffs_left,
189
+ R1,
190
+ P1,
191
+ (image_width, image_height),
192
+ cv2.CV_32FC1,
193
+ )
194
+ map_right_x, map_right_y = cv2.initUndistortRectifyMap(
195
+ cameraMatrix_right,
196
+ distCoeffs_right,
197
+ R2,
198
+ P2,
199
+ (image_width, image_height),
200
+ cv2.CV_32FC1,
201
+ )
202
+
203
+ zed_calib = {
204
+ "map_left_x": map_left_x,
205
+ "map_left_y": map_left_y,
206
+ "map_right_x": map_right_x,
207
+ "map_right_y": map_right_y,
208
+ "pose_left": P1,
209
+ "pose_right": P2,
210
+ "baseline": baseline,
211
+ "image_width": image_width,
212
+ "image_height": image_height,
213
+ }
214
+
215
+ return zed_calib
216
+
217
+
218
+ def visualize_batch(
219
+ batch_dict: dict,
220
+ preds: PerceptionPrediction,
221
+ output_dir: str,
222
+ ref_frame: int = 0,
223
+ only_foreground=False,
224
+ step=0,
225
+ sequence_name=None,
226
+ writer=None,
227
+ # -- Added by Chu King on 22nd November 2025 to fix image resolution during evaluation.
228
+ resolution=[480, 640]
229
+ ):
230
+ os.makedirs(output_dir, exist_ok=True)
231
+
232
+ outputs = {}
233
+
234
+ if preds.depth_map is not None:
235
+ device = preds.depth_map.device
236
+
237
+ pcd_global_seq = []
238
+ # -- H, W = batch_dict["stereo_video"].shape[3:]
239
+ H, W = resolution
240
+
241
+ for i in range(len(batch_dict["stereo_video"])):
242
+ R, T, K = opencv_from_cameras_projection(
243
+ preds.perspective_cameras[i],
244
+ torch.tensor([H, W])[None].to(device),
245
+ )
246
+
247
+ extrinsic_3x4_0 = torch.cat([R[0], T[0, :, None]], dim=1)
248
+
249
+ extr_matrix = torch.cat(
250
+ [
251
+ extrinsic_3x4_0,
252
+ torch.Tensor([[0, 0, 0, 1]]).to(extrinsic_3x4_0.device),
253
+ ],
254
+ dim=0,
255
+ )
256
+
257
+ inv_extr_matrix = extr_matrix.inverse().to(device)
258
+ pcd, colors = depth_to_pcd(
259
+ preds.depth_map[i, 0],
260
+ batch_dict["stereo_video"][..., :resolution[0], : resolution[1]][i][0].permute(1, 2, 0),
261
+ K[0][0][0],
262
+ K[0][0][2],
263
+ K[0][1][2],
264
+ step=1,
265
+ inv_extrinsic=inv_extr_matrix,
266
+ mask=batch_dict["fg_mask"][..., :resolution[0], : resolution[1]][i, 0] if only_foreground else None,
267
+ filter=False,
268
+ )
269
+
270
+ R, T = inv_extr_matrix[None, :3, :3], inv_extr_matrix[None, :3, 3]
271
+ pcd_global_seq.append((pcd, colors, (R, T, preds.perspective_cameras[i])))
272
+
273
+ raster_settings = PointsRasterizationSettings(
274
+ image_size=[H, W], radius=0.003, points_per_pixel=10
275
+ )
276
+ R, T, cam_ = pcd_global_seq[ref_frame][2]
277
+
278
+ median_depth = preds.depth_map.median()
279
+ cam_.cuda()
280
+
281
+ for mode in ["angle_15", "angle_-15", "changing_angle"]:
282
+ res = []
283
+
284
+ for t, (pcd, color, __) in enumerate(pcd_global_seq):
285
+
286
+ if mode == "changing_angle":
287
+ angle = math.cos((math.pi) * (t / 15)) * 15
288
+ elif mode == "angle_15":
289
+ angle = 15
290
+ elif mode == "angle_-15":
291
+ angle = -15
292
+
293
+ delta_x = median_depth * math.sin(math.radians(angle))
294
+ delta_z = median_depth * (1 - math.cos(math.radians(angle)))
295
+
296
+ cam = cam_.clone()
297
+ cam.R = torch.bmm(
298
+ cam.R,
299
+ RotateAxisAngle(angle=angle, axis="Y", device=device).get_matrix()[
300
+ :, :3, :3
301
+ ],
302
+ )
303
+ cam.T[0, 0] = cam.T[0, 0] - delta_x
304
+ cam.T[0, 2] = cam.T[0, 2] - delta_z + median_depth / 2.0
305
+
306
+ rasterizer = PointsRasterizer(
307
+ cameras=cam, raster_settings=raster_settings
308
+ )
309
+ renderer = PointsRenderer(
310
+ rasterizer=rasterizer,
311
+ compositor=AlphaCompositor(background_color=(1, 1, 1)),
312
+ )
313
+ pcd_copy = pcd.clone()
314
+
315
+ point_cloud = Pointclouds(points=[pcd_copy], features=[color / 255.0])
316
+ images = renderer(point_cloud)
317
+ res.append(images[0, ..., :3].cpu())
318
+ res = torch.stack(res)
319
+
320
+ video = (res * 255).numpy().astype(np.uint8)
321
+ save_name = f"{sequence_name}_reconstruction_{step}_mode_{mode}_"
322
+ if writer is None:
323
+ outputs[mode] = video
324
+ if only_foreground:
325
+ save_name += "fg_only"
326
+ else:
327
+ save_name += "full_scene"
328
+ video_out = cv2.VideoWriter(
329
+ os.path.join(
330
+ output_dir,
331
+ f"{save_name}.mp4",
332
+ ),
333
+ cv2.VideoWriter_fourcc(*"mp4v"),
334
+ fps=10,
335
+ frameSize=(res.shape[2], res.shape[1]),
336
+ isColor=True,
337
+ )
338
+
339
+ for i in range(len(video)):
340
+ video_out.write(cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB))
341
+ video_out.release()
342
+
343
+ if writer is not None:
344
+ writer.add_video(
345
+ f"{sequence_name}_reconstruction_mode_{mode}",
346
+ (res * 255).permute(0, 3, 1, 2).to(torch.uint8)[None],
347
+ global_step=step,
348
+ fps=8,
349
+ )
350
+
351
+ return outputs
links_lite.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "real": [
3
+ "https://dl.fbaipublicfiles.com/dynamic_replica_v2/real/real_000.zip"
4
+ ],
5
+ "test": [
6
+ "https://dl.fbaipublicfiles.com/dynamic_replica_v2/test/test_000.zip"
7
+ ],
8
+ "valid": [
9
+ "https://dl.fbaipublicfiles.com/dynamic_replica_v2/valid/valid_000.zip",
10
+ "https://dl.fbaipublicfiles.com/dynamic_replica_v2/valid/valid_001.zip"
11
+ ],
12
+ "train": [
13
+ "https://dl.fbaipublicfiles.com/dynamic_replica_v2/train/train_000.zip"
14
+ ]
15
+ }
models/core/attention.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import copy
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import Module, Dropout
12
+
13
+ """
14
+ Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
15
+ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
16
+ """
17
+
18
+
19
+ def elu_feature_map(x):
20
+ return torch.nn.functional.elu(x) + 1
21
+
22
+
23
+ class PositionEncodingSine(nn.Module):
24
+ """
25
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
26
+ """
27
+
28
+ def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
29
+ """
30
+ Args:
31
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
32
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
33
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
34
+ on the final performance. For now, we keep both impls for backward compatability.
35
+ We will remove the buggy impl after re-training all variants of our released models.
36
+ """
37
+ super().__init__()
38
+
39
+ # -- d_model: embedding dimension
40
+ pe = torch.zeros((d_model, *max_shape))
41
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
42
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
43
+ if temp_bug_fix:
44
+ div_term = torch.exp(
45
+ torch.arange(0, d_model // 2, 2).float()
46
+ * (-math.log(10000.0) / (d_model // 2))
47
+ )
48
+ else: # a buggy implementation (for backward compatability only)
49
+ div_term = torch.exp(
50
+ torch.arange(0, d_model // 2, 2).float()
51
+ * (-math.log(10000.0) / d_model // 2)
52
+ )
53
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
54
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
55
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
56
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
57
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
58
+
59
+ self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
60
+
61
+ def forward(self, x):
62
+ """
63
+ Args:
64
+ x: [N, C, H, W]
65
+ """
66
+ return x + self.pe[:, :, : x.size(2), : x.size(3)].to(x.device)
67
+
68
+
69
+ class LinearAttention(Module):
70
+ def __init__(self, eps=1e-6):
71
+ super().__init__()
72
+ self.feature_map = elu_feature_map
73
+ self.eps = eps
74
+
75
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
76
+ """Multi-Head linear attention proposed in "Transformers are RNNs"
77
+ Args:
78
+ queries: [N, L, H, D]
79
+ keys: [N, S, H, D]
80
+ values: [N, S, H, D]
81
+ q_mask: [N, L]
82
+ kv_mask: [N, S]
83
+ Returns:
84
+ queried_values: (N, L, H, D)
85
+ """
86
+ Q = self.feature_map(queries)
87
+ K = self.feature_map(keys)
88
+
89
+ # set padded position to zero
90
+ if q_mask is not None:
91
+ Q = Q * q_mask[:, :, None, None]
92
+ if kv_mask is not None:
93
+ K = K * kv_mask[:, :, None, None]
94
+ values = values * kv_mask[:, :, None, None]
95
+
96
+ v_length = values.size(1)
97
+ values = values / v_length # prevent fp16 overflow
98
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
99
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
100
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
101
+
102
+ return queried_values.contiguous()
103
+
104
+
105
+ class FullAttention(Module):
106
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
107
+ super().__init__()
108
+ self.use_dropout = use_dropout
109
+ self.dropout = Dropout(attention_dropout)
110
+
111
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
112
+ """Multi-head scaled dot-product attention, a.k.a full attention.
113
+ Args:
114
+ queries: [N, L, H, D]
115
+ keys: [N, S, H, D]
116
+ values: [N, S, H, D]
117
+ q_mask: [N, L]
118
+ kv_mask: [N, S]
119
+ Returns:
120
+ queried_values: (N, L, H, D)
121
+ """
122
+
123
+ # Compute the unnormalized attention and apply the masks
124
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
125
+ if kv_mask is not None:
126
+ QK.masked_fill_(
127
+ ~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf")
128
+ )
129
+
130
+ # Compute the attention and the weighted average
131
+ softmax_temp = 1.0 / queries.size(3) ** 0.5 # sqrt(D)
132
+ A = torch.softmax(softmax_temp * QK, dim=2)
133
+ if self.use_dropout:
134
+ A = self.dropout(A)
135
+
136
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
137
+
138
+ return queried_values.contiguous()
139
+
140
+
141
+ # Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
142
+ class LoFTREncoderLayer(nn.Module):
143
+ def __init__(self, d_model, nhead, attention="linear"):
144
+ super(LoFTREncoderLayer, self).__init__()
145
+
146
+ self.dim = d_model // nhead
147
+ self.nhead = nhead
148
+
149
+ # multi-head attention
150
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
151
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
152
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
153
+
154
+ # -- LoFTR optionally uses linear attention (faster, avoids quadratic cost), otherwise normal softmax attention.
155
+ self.attention = LinearAttention() if attention == "linear" else FullAttention()
156
+ self.merge = nn.Linear(d_model, d_model, bias=False)
157
+
158
+ # feed-forward network
159
+ self.mlp = nn.Sequential(
160
+ nn.Linear(d_model * 2, d_model * 2, bias=False),
161
+ nn.ReLU(),
162
+ nn.Linear(d_model * 2, d_model, bias=False),
163
+ )
164
+
165
+ # norm and dropout
166
+ self.norm1 = nn.LayerNorm(d_model)
167
+ self.norm2 = nn.LayerNorm(d_model)
168
+
169
+ def forward(self, x, source, x_mask=None, source_mask=None):
170
+ """
171
+ Args:
172
+ x (torch.Tensor): [N, L, C]
173
+ source (torch.Tensor): [N, S, C]
174
+ x_mask (torch.Tensor): [N, L] (optional)
175
+ source_mask (torch.Tensor): [N, S] (optional)
176
+ """
177
+ bs = x.size(0)
178
+ query, key, value = x, source, source
179
+
180
+ # multi-head attention
181
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
182
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
183
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
184
+ message = self.attention(
185
+ query, key, value, q_mask=x_mask, kv_mask=source_mask
186
+ ) # [N, L, (H, D)]
187
+ message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
188
+ message = self.norm1(message)
189
+
190
+ # feed-forward network
191
+ message = self.mlp(torch.cat([x, message], dim=2))
192
+ message = self.norm2(message)
193
+
194
+ return x + message
195
+
196
+
197
+ class LocalFeatureTransformer(nn.Module):
198
+ """A Local Feature Transformer (LoFTR) module."""
199
+
200
+ def __init__(self, d_model, nhead, layer_names, attention):
201
+ super(LocalFeatureTransformer, self).__init__()
202
+
203
+ self.d_model = d_model
204
+ self.nhead = nhead
205
+ self.layer_names = layer_names
206
+ encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
207
+ self.layers = nn.ModuleList(
208
+ [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
209
+ )
210
+ self._reset_parameters()
211
+
212
+ def _reset_parameters(self):
213
+ for p in self.parameters():
214
+ if p.dim() > 1:
215
+ nn.init.xavier_uniform_(p)
216
+
217
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
218
+ """
219
+ Args:
220
+ feat0 (torch.Tensor): [N, L, C]
221
+ feat1 (torch.Tensor): [N, S, C]
222
+ mask0 (torch.Tensor): [N, L] (optional)
223
+ mask1 (torch.Tensor): [N, S] (optional)
224
+ """
225
+ assert self.d_model == feat0.size(
226
+ 2
227
+ ), "the feature number of src and transformer must be equal"
228
+
229
+ for layer, name in zip(self.layers, self.layer_names):
230
+
231
+ if name == "self":
232
+ feat0 = layer(feat0, feat0, mask0, mask0)
233
+ feat1 = layer(feat1, feat1, mask1, mask1)
234
+ elif name == "cross":
235
+ feat0 = layer(feat0, feat1, mask0, mask1)
236
+ feat1 = layer(feat1, feat0, mask1, mask0)
237
+ else:
238
+ raise KeyError
239
+
240
+ return feat0, feat1
models/core/corr.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def bilinear_sampler(img, coords, mode="bilinear", mask=False, stereo=True):
12
+ """Wrapper for grid_sample, uses pixel coordinates"""
13
+ H, W = img.shape[-2:]
14
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
15
+ xgrid = 2 * xgrid / (W - 1) - 1
16
+ if not stereo:
17
+ ygrid = 2 * ygrid / (H - 1) - 1
18
+ else:
19
+ assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
20
+ img = img.contiguous()
21
+ grid = torch.cat([xgrid, ygrid], dim=-1).contiguous()
22
+ img = F.grid_sample(img, grid, align_corners=True)
23
+
24
+ if mask:
25
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
26
+ return img, mask.float()
27
+
28
+ return img
29
+
30
+
31
+ def coords_grid(batch, ht, wd, device):
32
+ coords = torch.meshgrid(
33
+ torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij"
34
+ )
35
+ coords = torch.stack(coords[::-1], dim=0).float()
36
+ return coords[None].repeat(batch, 1, 1, 1)
37
+
38
+
39
+ class CorrBlock1D:
40
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
41
+ self.num_levels = num_levels
42
+ self.radius = radius
43
+ self.corr_pyramid = []
44
+ self.coords = coords_grid(
45
+ fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device
46
+ )
47
+ # all pairs correlation
48
+ corr = CorrBlock1D.corr(fmap1, fmap2)
49
+
50
+ batch, h1, w1, dim, w2 = corr.shape
51
+ corr = corr.reshape(batch * h1 * w1, dim, 1, w2)
52
+
53
+ self.corr_pyramid.append(corr)
54
+ for i in range(self.num_levels):
55
+ corr = F.avg_pool2d(corr, [1, 2], stride=[1, 2])
56
+ self.corr_pyramid.append(corr)
57
+
58
+ def __call__(self, flow):
59
+ r = self.radius
60
+ coords = self.coords + flow
61
+ coords = coords[:, :1].permute(0, 2, 3, 1)
62
+ batch, h1, w1, _ = coords.shape
63
+
64
+ out_pyramid = []
65
+ for i in range(self.num_levels):
66
+ corr = self.corr_pyramid[i]
67
+ dx = torch.linspace(-r, r, 2 * r + 1)
68
+ dx = dx.view(1, 1, 2 * r + 1, 1).to(coords.device)
69
+ x0 = dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2 ** i
70
+ y0 = torch.zeros_like(x0)
71
+
72
+ coords_lvl = torch.cat([x0, y0], dim=-1)
73
+ corr = bilinear_sampler(corr, coords_lvl)
74
+ corr = corr.view(batch, h1, w1, -1)
75
+ out_pyramid.append(corr)
76
+
77
+ out = torch.cat(out_pyramid, dim=-1)
78
+ return out.permute(0, 3, 1, 2).contiguous().float()
79
+
80
+ @staticmethod
81
+ def corr(fmap1, fmap2):
82
+ B, D, H, W1 = fmap1.shape
83
+ _, _, _, W2 = fmap2.shape
84
+ fmap1 = fmap1.view(B, D, H, W1)
85
+ fmap2 = fmap2.view(B, D, H, W2)
86
+ corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
87
+ corr = corr.reshape(B, H, W1, 1, W2).contiguous()
88
+ return corr / torch.sqrt(torch.tensor(D).float())
models/core/dynamic_stereo.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # -- Added by Chu King on 16th November 2025 for debugging purposes.
8
+ import os, signal
9
+ import logging
10
+ import torch.distributed as dist
11
+
12
+ from typing import Dict, List
13
+ from einops import rearrange
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from collections import defaultdict
18
+
19
+
20
+ from models.core.update import (
21
+ BasicUpdateBlock,
22
+ SequenceUpdateBlock3D,
23
+ TimeAttnBlock,
24
+ )
25
+
26
+ # -- Added by Chu King on 21st November 2025
27
+ from models.core.sci_codec import sci_decoder
28
+ from models.core.extractor import BasicEncoder
29
+ from models.core.corr import CorrBlock1D
30
+
31
+ from models.core.attention import (
32
+ PositionEncodingSine,
33
+ LocalFeatureTransformer,
34
+ )
35
+ from models.core.utils.utils import InputPadder, interp
36
+
37
+ autocast = torch.cuda.amp.autocast
38
+
39
+
40
+ class DynamicStereo(nn.Module):
41
+ def __init__(
42
+ self,
43
+ max_disp: int = 192,
44
+ mixed_precision: bool = False,
45
+ num_frames: int = 5,
46
+ attention_type: str = None,
47
+ use_3d_update_block: bool = False,
48
+ different_update_blocks: bool = False,
49
+ ):
50
+ super(DynamicStereo, self).__init__()
51
+
52
+ self.max_flow = max_disp
53
+ self.mixed_precision = mixed_precision
54
+
55
+ self.hidden_dim = 128
56
+ self.context_dim = 128
57
+ dim = 256
58
+ self.dim = dim
59
+ self.dropout = 0 # -- dropout probability
60
+
61
+ # -- decide whether to use 3D update blocks (like RAFT3D) or simpler 2D blocks.
62
+ self.use_3d_update_block = use_3d_update_block
63
+
64
+ # -- Modified by Chu King on 21st November 2025
65
+ # -- CNN encoder that extracts features from images.
66
+ # * output_dim: output channels
67
+ # * norm_fn="instance": applies instance normalization.
68
+ # -- self.fnet = BasicEncoder(
69
+ # -- output_dim=dim, norm_fn="instance", dropout=self.dropout
70
+ # -- )
71
+ self.fnet = sci_decoder(
72
+ n_frame=num_frames,
73
+ n_taps=2,
74
+ output_dim=dim,
75
+ norm_fn="instance",
76
+ dropout=self.dropout
77
+ )
78
+
79
+ # -- Boolean flag to decide whether different update blocks are used for different resolutions.
80
+ self.different_update_blocks = different_update_blocks
81
+
82
+ # -- Cost volumne planes (matching costs for disparity computation).
83
+ cor_planes = 4 * 9
84
+ self.depth = 4
85
+ self.attention_type = attention_type
86
+ # attention_type is a combination of the following attention types:
87
+ # self_stereo, temporal, update_time, update_space
88
+ # for example, self_stereo_temporal_update_time_update_space
89
+
90
+ if self.use_3d_update_block:
91
+ # -- Uses 3D convolutions for spatiotemporal processing.
92
+ if self.different_update_blocks:
93
+ # -- self.update_block08, self.update_block16, self.update_block04
94
+ # are update blocks for different resolution levels (i.e. 1/8, 1/16, 1/4)
95
+ self.update_block08 = SequenceUpdateBlock3D(
96
+ hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
97
+ )
98
+ self.update_block16 = SequenceUpdateBlock3D(
99
+ hidden_dim=self.hidden_dim,
100
+ cor_planes=cor_planes,
101
+ mask_size=4,
102
+ attention_type=attention_type,
103
+ )
104
+ self.update_block04 = SequenceUpdateBlock3D(
105
+ hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
106
+ )
107
+ else:
108
+ self.update_block = SequenceUpdateBlock3D(
109
+ hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
110
+ )
111
+ else:
112
+ # -- Uses standard 2D update blocks.
113
+ if self.different_update_blocks:
114
+ self.update_block08 = BasicUpdateBlock(
115
+ hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
116
+ )
117
+ self.update_block16 = BasicUpdateBlock(
118
+ hidden_dim=self.hidden_dim,
119
+ cor_planes=cor_planes,
120
+ mask_size=4,
121
+ attention_type=attention_type,
122
+ )
123
+ self.update_block04 = BasicUpdateBlock(
124
+ hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
125
+ )
126
+ else:
127
+ self.update_block = BasicUpdateBlock(
128
+ hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
129
+ )
130
+
131
+ if attention_type is not None:
132
+ # -- The model incorporates several attention types.
133
+ if ("update_time" in attention_type) or ("temporal" in attention_type):
134
+ # -- This variable learns positional embeddings for different time steps in the sequence.
135
+ self.time_embed = nn.Parameter(torch.zeros(1, num_frames, dim))
136
+
137
+ # -- Temporal attention: processes information across different time frames.
138
+ if "temporal" in attention_type:
139
+ self.time_attn_blocks = nn.ModuleList(
140
+ [TimeAttnBlock(dim=dim, num_heads=8) for _ in range(self.depth)]
141
+ )
142
+
143
+ # -- Stereo attention: includes self-attention and cross attention blocks for processing
144
+ # left-right stereo image pairs.
145
+ if "self_stereo" in attention_type:
146
+ self.self_attn_blocks = nn.ModuleList(
147
+ [
148
+ LocalFeatureTransformer(
149
+ d_model=dim,
150
+ nhead=8,
151
+ layer_names=["self"] * 1,
152
+ attention="linear",
153
+ )
154
+ for _ in range(self.depth)
155
+ ]
156
+ )
157
+
158
+ self.cross_attn_blocks = nn.ModuleList(
159
+ [
160
+ LocalFeatureTransformer(
161
+ d_model=dim,
162
+ nhead=8,
163
+ layer_names=["cross"] * 1,
164
+ attention="linear",
165
+ )
166
+ for _ in range(self.depth)
167
+ ]
168
+ )
169
+
170
+ self.num_frames = num_frames
171
+
172
+ @torch.jit.ignore
173
+ def no_weight_decay(self):
174
+ return {"time_embed"}
175
+
176
+ def freeze_bn(self):
177
+ for m in self.modules():
178
+ if isinstance(m, nn.BatchNorm2d):
179
+ m.eval()
180
+
181
+ def convex_upsample(self, flow: torch.Tensor, mask: torch.Tensor, rate: int = 4):
182
+ """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
183
+ N, _, H, W = flow.shape
184
+ mask = mask.view(N, 1, 9, rate, rate, H, W)
185
+ mask = torch.softmax(mask, dim=2)
186
+
187
+ up_flow = F.unfold(rate * flow, [3, 3], padding=1)
188
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
189
+
190
+ up_flow = torch.sum(mask * up_flow, dim=2)
191
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
192
+ return up_flow.reshape(N, 2, rate * H, rate * W)
193
+
194
+ def zero_init(self, fmap: torch.Tensor):
195
+ N, _, H, W = fmap.shape
196
+ _x = torch.zeros([N, 1, H, W], dtype=torch.float32)
197
+ _y = torch.zeros([N, 1, H, W], dtype=torch.float32)
198
+ zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
199
+ return zero_flow
200
+
201
+ def forward_batch_test(
202
+ self, batch_dict: Dict, sci_enc_L, sci_enc_R, kernel_size: int = 14, iters: int = 20
203
+ ):
204
+ stride = kernel_size // 2
205
+ predictions = defaultdict(list)
206
+
207
+ disp_preds = []
208
+ video = batch_dict["stereo_video"]
209
+ num_ims = len(video)
210
+ print("video", video.shape)
211
+
212
+ # -- Divide a single long sequence to multiple long sequences.
213
+ # -- For SCI stereo, we only test the first sequence.
214
+ # -- for i in range(0, num_ims, stride):
215
+ for i in range(1):
216
+ left_ims = video[i : min(i + kernel_size, num_ims), 0]
217
+ # -- padder = InputPadder(left_ims.shape, divis_by=32)
218
+
219
+ right_ims = video[i : min(i + kernel_size, num_ims), 1]
220
+ # -- left_ims, right_ims = padder.pad(left_ims, right_ims)
221
+
222
+ # -- Modified by Chu King on 20th November 2025
223
+ # 0) Convert to Gray
224
+ def rgb_to_gray(x):
225
+ weights = torch.tensor([0.2989, 0.5870, 0.1140], dtype=x.dtype, device=x.device)
226
+ gray = (x * weights[None, None, :, None, None]).sum(dim=2)
227
+ return gray # -- shape: [B, T, H, W]
228
+
229
+ video_L = rgb_to_gray(left_ims.to(next(sci_enc_L.parameters()).device)) # ~ (b, t, h, w)
230
+ video_R = rgb_to_gray(right_ims.to(next(sci_enc_R.parameters()).device)) # ~ (b, t, h, w)
231
+
232
+ # 1) Extract and normalize input videos.
233
+ # -- min_max_norm = lambda x : 2. * (x / 255.) - 1.
234
+ min_max_norm = lambda x: x / 255.
235
+ video_L = min_max_norm(video_L)
236
+ video_R = min_max_norm(video_R)
237
+
238
+ # 2) If the tensor is non-contiguous and we try .view() later, PyTorch will raise an error:
239
+ video_L = video_L.contiguous()
240
+ video_R = video_R.contiguous()
241
+
242
+ # 3) Coded exposure modeling.
243
+ snapshot_L = sci_enc_L(video_L)
244
+ snapshot_R = sci_enc_L(video_R)
245
+
246
+ with autocast(enabled=self.mixed_precision):
247
+ disparities_forw = self.forward(
248
+ # -- Modified by Chu King on 20th November 2025
249
+ # -- left_ims[None].cuda(),
250
+ # -- right_ims[None].cuda(),
251
+ snapshot_L,
252
+ snapshot_R,
253
+ iters=iters,
254
+ test_mode=True,
255
+ )
256
+
257
+ # -- Padding disabled by Chu King on 20th November 2025
258
+ # -- disparities_forw = padder.unpad(disparities_forw[:, 0])[:, None].cpu()
259
+ disparities_forw = disparities_forw[:, 0][:, None].cpu()
260
+
261
+ # -- We are not doing overlapping chunks in SCI stereo.
262
+ disp_preds.append(disparities_forw)
263
+ # -- if len(disp_preds) > 0 and len(disparities_forw) >= stride:
264
+ # -- if len(disparities_forw) < kernel_size:
265
+ # -- disp_preds.append(disparities_forw[stride // 2 :])
266
+ # -- else:
267
+ # -- disp_preds.append(disparities_forw[stride // 2 : -stride // 2])
268
+ # -- elif len(disp_preds) == 0:
269
+ # -- disp_preds.append(disparities_forw[: -stride // 2])
270
+
271
+ predictions["disparity"] = (torch.cat(disp_preds).squeeze(1).abs())[:, :1]
272
+
273
+ return predictions
274
+
275
+ def forward_sst_block(
276
+ self, fmap1_dw16: torch.Tensor, fmap2_dw16: torch.Tensor, T: int
277
+ ):
278
+ # -- fmap1_dw16 ~ (B*T, C, H, W) -- left-view features
279
+ # -- fmap2_dw16 ~ (B*T, C, H, W) -- right-view features
280
+ *_, h, w = fmap1_dw16.shape
281
+
282
+ # positional encoding and self-attention
283
+ pos_encoding_fn_small = PositionEncodingSine(d_model=self.dim, max_shape=(h, w))
284
+ fmap1_dw16 = pos_encoding_fn_small(fmap1_dw16)
285
+ fmap2_dw16 = pos_encoding_fn_small(fmap2_dw16)
286
+
287
+ if self.attention_type is not None:
288
+ # add time embeddings
289
+ if (
290
+ "temporal" in self.attention_type
291
+ or "update_time" in self.attention_type
292
+ ):
293
+ fmap1_dw16 = rearrange(
294
+ fmap1_dw16, "(b t) m h w -> (b h w) t m", t=T, h=h, w=w
295
+ )
296
+ fmap2_dw16 = rearrange(
297
+ fmap2_dw16, "(b t) m h w -> (b h w) t m", t=T, h=h, w=w
298
+ )
299
+
300
+ # interpolate if video length doesn't match
301
+ if T != self.num_frames:
302
+ time_embed = self.time_embed.transpose(1, 2)
303
+ new_time_embed = F.interpolate(time_embed, size=(T), mode="nearest")
304
+ new_time_embed = new_time_embed.transpose(1, 2).contiguous()
305
+ else:
306
+ new_time_embed = self.time_embed
307
+
308
+ fmap1_dw16 = fmap1_dw16 + new_time_embed
309
+ fmap2_dw16 = fmap2_dw16 + new_time_embed
310
+
311
+ fmap1_dw16 = rearrange(
312
+ fmap1_dw16, "(b h w) t m -> (b t) m h w", t=T, h=h, w=w
313
+ )
314
+ fmap2_dw16 = rearrange(
315
+ fmap2_dw16, "(b h w) t m -> (b t) m h w", t=T, h=h, w=w
316
+ )
317
+
318
+ if ("self_stereo" in self.attention_type) or (
319
+ "temporal" in self.attention_type
320
+ ):
321
+ for att_ind in range(self.depth):
322
+ if "self_stereo" in self.attention_type:
323
+ fmap1_dw16 = rearrange(
324
+ fmap1_dw16, "(b t) m h w -> (b t) (h w) m", t=T, h=h, w=w
325
+ )
326
+ fmap2_dw16 = rearrange(
327
+ fmap2_dw16, "(b t) m h w -> (b t) (h w) m", t=T, h=h, w=w
328
+ )
329
+
330
+ fmap1_dw16, fmap2_dw16 = self.self_attn_blocks[att_ind](
331
+ fmap1_dw16, fmap2_dw16
332
+ )
333
+ fmap1_dw16, fmap2_dw16 = self.cross_attn_blocks[att_ind](
334
+ fmap1_dw16, fmap2_dw16
335
+ )
336
+
337
+ fmap1_dw16 = rearrange(
338
+ fmap1_dw16, "(b t) (h w) m -> (b t) m h w ", t=T, h=h, w=w
339
+ )
340
+ fmap2_dw16 = rearrange(
341
+ fmap2_dw16, "(b t) (h w) m -> (b t) m h w ", t=T, h=h, w=w
342
+ )
343
+
344
+ if "temporal" in self.attention_type:
345
+ fmap1_dw16 = self.time_attn_blocks[att_ind](fmap1_dw16, T=T)
346
+ fmap2_dw16 = self.time_attn_blocks[att_ind](fmap2_dw16, T=T)
347
+ return fmap1_dw16, fmap2_dw16
348
+
349
+ def forward_update_block(
350
+ self,
351
+ update_block: nn.Module,
352
+ corr_fn: CorrBlock1D,
353
+ flow: torch.Tensor,
354
+ net: torch.Tensor,
355
+ inp: torch.Tensor,
356
+ predictions: List,
357
+ iters: int,
358
+ interp_scale: float,
359
+ t: int,
360
+ ):
361
+ for _ in range(iters):
362
+ flow = flow.detach()
363
+ out_corrs = corr_fn(flow)
364
+ with autocast(enabled=self.mixed_precision):
365
+ net, up_mask, delta_flow = update_block(net, inp, out_corrs, flow, t=t)
366
+
367
+ flow = flow + delta_flow
368
+ flow_up = flow_out = self.convex_upsample(flow, up_mask, rate=4)
369
+ if interp_scale > 1:
370
+ flow_up = interp_scale * interp(
371
+ flow_out,
372
+ (
373
+ interp_scale * flow_out.shape[2],
374
+ interp_scale * flow_out.shape[3],
375
+ ),
376
+ )
377
+ flow_up = flow_up[:, :1]
378
+ predictions.append(flow_up)
379
+ return flow_out, net
380
+
381
+ def forward(self, image1, image2, flow_init=None, iters=10, test_mode=False):
382
+ """Estimate optical flow between pair of frames"""
383
+ b, *_ = image1.shape
384
+
385
+ hdim = self.hidden_dim
386
+
387
+ with autocast(enabled=self.mixed_precision):
388
+ fmap1, fmap2 = self.fnet([image1, image2])
389
+
390
+ net, inp = torch.split(fmap1, [hdim, hdim], dim=1)
391
+ net = torch.tanh(net)
392
+ inp = F.relu(inp)
393
+ *_, h, w = fmap1.shape
394
+ # 1/4 -> 1/16
395
+ # feature
396
+ fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
397
+ fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
398
+
399
+ fmap1_dw16, fmap2_dw16 = self.forward_sst_block(fmap1_dw16, fmap2_dw16, T=self.num_frames)
400
+
401
+ net_dw16, inp_dw16 = torch.split(fmap1_dw16, [hdim, hdim], dim=1)
402
+ net_dw16 = torch.tanh(net_dw16)
403
+ inp_dw16 = F.relu(inp_dw16)
404
+
405
+ fmap1_dw8 = (
406
+ F.avg_pool2d(fmap1, 2, stride=2) + interp(fmap1_dw16, (h // 2, w // 2))
407
+ ) / 2.0
408
+ fmap2_dw8 = (
409
+ F.avg_pool2d(fmap2, 2, stride=2) + interp(fmap2_dw16, (h // 2, w // 2))
410
+ ) / 2.0
411
+
412
+ net_dw8, inp_dw8 = torch.split(fmap1_dw8, [hdim, hdim], dim=1)
413
+ net_dw8 = torch.tanh(net_dw8)
414
+ inp_dw8 = F.relu(inp_dw8)
415
+ # Cascaded refinement (1/16 + 1/8 + 1/4)
416
+ predictions = []
417
+ flow = None
418
+ flow_up = None
419
+ if flow_init is not None:
420
+ scale = h / flow_init.shape[2]
421
+ flow = -scale * interp(flow_init, (h, w))
422
+ else:
423
+ # zero initialization
424
+ flow_dw16 = self.zero_init(fmap1_dw16) # -- (N, 2, H, W)
425
+
426
+ # Recurrent Update Module
427
+ # Update 1/16
428
+ update_block = (
429
+ self.update_block16
430
+ if self.different_update_blocks
431
+ else self.update_block
432
+ )
433
+
434
+ corr_fn_att_dw16 = CorrBlock1D(fmap1_dw16, fmap2_dw16)
435
+ flow, net_dw16 = self.forward_update_block(
436
+ update_block=update_block,
437
+ corr_fn=corr_fn_att_dw16,
438
+ flow=flow_dw16,
439
+ net=net_dw16,
440
+ inp=inp_dw16,
441
+ predictions=predictions,
442
+ iters=iters // 2,
443
+ interp_scale=4,
444
+ t=self.num_frames,
445
+ )
446
+
447
+ scale = fmap1_dw8.shape[2] / flow.shape[2]
448
+ flow_dw8 = -scale * interp(flow, (fmap1_dw8.shape[2], fmap1_dw8.shape[3]))
449
+
450
+ net_dw8 = (
451
+ net_dw8
452
+ + interp(net_dw16, (2 * net_dw16.shape[2], 2 * net_dw16.shape[3]))
453
+ ) / 2.0
454
+ # Update 1/8
455
+
456
+ update_block = (
457
+ self.update_block08
458
+ if self.different_update_blocks
459
+ else self.update_block
460
+ )
461
+
462
+ corr_fn_dw8 = CorrBlock1D(fmap1_dw8, fmap2_dw8)
463
+ flow, net_dw8 = self.forward_update_block(
464
+ update_block=update_block,
465
+ corr_fn=corr_fn_dw8,
466
+ flow=flow_dw8,
467
+ net=net_dw8,
468
+ inp=inp_dw8,
469
+ predictions=predictions,
470
+ iters=iters // 2,
471
+ interp_scale=2,
472
+ t=self.num_frames,
473
+ )
474
+
475
+ scale = h / flow.shape[2]
476
+ flow = -scale * interp(flow, (h, w))
477
+
478
+ net = (
479
+ net + interp(net_dw8, (2 * net_dw8.shape[2], 2 * net_dw8.shape[3]))
480
+ ) / 2.0
481
+ # Update 1/4
482
+ update_block = (
483
+ self.update_block04 if self.different_update_blocks else self.update_block
484
+ )
485
+ corr_fn = CorrBlock1D(fmap1, fmap2)
486
+ flow, __ = self.forward_update_block(
487
+ update_block=update_block,
488
+ corr_fn=corr_fn,
489
+ flow=flow,
490
+ net=net,
491
+ inp=inp,
492
+ predictions=predictions,
493
+ iters=iters,
494
+ interp_scale=1,
495
+ t=self.num_frames,
496
+ )
497
+
498
+ predictions = torch.stack(predictions)
499
+
500
+ predictions = rearrange(predictions, "d (b t) c h w -> d t b c h w", b=b, t=self.num_frames)
501
+ flow_up = predictions[-1]
502
+
503
+ if test_mode:
504
+ return flow_up
505
+
506
+ return predictions
models/core/extractor.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ # -- Added by Chu King on 16th November 2025 for debugging purposes.
11
+ import os, signal
12
+ import logging
13
+ import torch.distributed as dist
14
+
15
+ class ResidualBlock(nn.Module):
16
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
17
+ super(ResidualBlock, self).__init__()
18
+
19
+ self.conv1 = nn.Conv2d(
20
+ in_planes, planes, kernel_size=3, padding=1, stride=stride
21
+ )
22
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
23
+ self.relu = nn.ReLU(inplace=True)
24
+
25
+ num_groups = planes // 8
26
+
27
+ if norm_fn == "group":
28
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
29
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
30
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
31
+
32
+ elif norm_fn == "batch":
33
+ self.norm1 = nn.BatchNorm2d(planes)
34
+ self.norm2 = nn.BatchNorm2d(planes)
35
+ self.norm3 = nn.BatchNorm2d(planes)
36
+
37
+ elif norm_fn == "instance":
38
+ self.norm1 = nn.InstanceNorm2d(planes, affine=False)
39
+ self.norm2 = nn.InstanceNorm2d(planes, affine=False)
40
+ self.norm3 = nn.InstanceNorm2d(planes, affine=False)
41
+
42
+ elif norm_fn == "none":
43
+ self.norm1 = nn.Sequential()
44
+ self.norm2 = nn.Sequential()
45
+ self.norm3 = nn.Sequential()
46
+
47
+ self.downsample = nn.Sequential(
48
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
49
+ )
50
+
51
+ def forward(self, x):
52
+ y = x
53
+ y = self.relu(self.norm1(self.conv1(y)))
54
+ y = self.relu(self.norm2(self.conv2(y)))
55
+
56
+ # -- ensures that x is transformed to the correct shape so it can be added to y.
57
+ x = self.downsample(x)
58
+
59
+ return self.relu(x + y)
60
+
61
+
62
+ class BasicEncoder(nn.Module):
63
+ def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
64
+ super(BasicEncoder, self).__init__()
65
+ self.norm_fn = norm_fn
66
+
67
+ if self.norm_fn == "group":
68
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
69
+
70
+ elif self.norm_fn == "batch":
71
+ self.norm1 = nn.BatchNorm2d(64)
72
+
73
+ elif self.norm_fn == "instance":
74
+ self.norm1 = nn.InstanceNorm2d(64, affine=False)
75
+
76
+ elif self.norm_fn == "none":
77
+ self.norm1 = nn.Sequential()
78
+
79
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
80
+ self.relu1 = nn.ReLU(inplace=True)
81
+
82
+ self.in_planes = 64
83
+ self.layer1 = self._make_layer(64, stride=1)
84
+ self.layer2 = self._make_layer(96, stride=2)
85
+ self.layer3 = self._make_layer(128, stride=1)
86
+
87
+ # output convolution
88
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
89
+
90
+ self.dropout = None
91
+ if dropout > 0:
92
+ self.dropout = nn.Dropout2d(p=dropout)
93
+
94
+ # -- self.modules() is a PyTorch utility function that returns all submodules of this nn.Module recursively.
95
+ # -- This means it will looop through every layer: conv1, layer1, layer2, layer3, conv2 and so on.
96
+ for m in self.modules():
97
+ if isinstance(m, nn.Conv2d):
98
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
99
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
100
+ if m.weight is not None:
101
+ nn.init.constant_(m.weight, 1)
102
+ if m.bias is not None:
103
+ nn.init.constant_(m.bias, 0)
104
+
105
+ def _make_layer(self, dim, stride=1):
106
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
107
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
108
+ layers = (layer1, layer2)
109
+
110
+ self.in_planes = dim
111
+ return nn.Sequential(*layers)
112
+
113
+ def forward(self, x):
114
+ # -- x = [L, R]
115
+ # -- L, R ~ (b*t, c, h, w)
116
+
117
+ # if input is list, combine batch dimension
118
+ is_list = isinstance(x, tuple) or isinstance(x, list)
119
+ if is_list:
120
+ batch_dim = x[0].shape[0]
121
+ x = torch.cat(x, dim=0)
122
+
123
+ x = self.conv1(x)
124
+ x = self.norm1(x)
125
+ x = self.relu1(x)
126
+
127
+ x = self.layer1(x)
128
+ x = self.layer2(x)
129
+ x = self.layer3(x)
130
+
131
+ x = self.conv2(x)
132
+
133
+ if self.dropout is not None:
134
+ x = self.dropout(x)
135
+
136
+ if is_list:
137
+ x = torch.split(x, x.shape[0] // 2, dim=0)
138
+
139
+ return x
models/core/model_zoo.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ from dynamic_stereo.models.dynamic_stereo_model import DynamicStereoModel
9
+
10
+ from pytorch3d.implicitron.tools.config import get_default_args
11
+
12
+ try:
13
+ from dynamic_stereo.models.raft_stereo_model import RAFTStereoModel
14
+
15
+ MODELS = [RAFTStereoModel, DynamicStereoModel]
16
+ except:
17
+ MODELS = [DynamicStereoModel]
18
+
19
+ _MODEL_NAME_TO_MODEL = {model_cls.__name__: model_cls for model_cls in MODELS}
20
+ _MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG = {}
21
+ for model_cls in MODELS:
22
+ _MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG[
23
+ model_cls.MODEL_CONFIG_NAME
24
+ ] = get_default_args(model_cls)
25
+ MODEL_NAME_NONE = "NONE"
26
+
27
+
28
+ def model_zoo(model_name: str, **kwargs):
29
+ if model_name.upper() == MODEL_NAME_NONE:
30
+ return None
31
+
32
+ model_cls = _MODEL_NAME_TO_MODEL.get(model_name)
33
+
34
+ if model_cls is None:
35
+ raise ValueError(f"No such model name: {model_name}")
36
+
37
+ model_cls_params = {}
38
+ if "model_zoo" in getattr(model_cls, "__dataclass_fields__", []):
39
+ model_cls_params["model_zoo"] = model_zoo
40
+ print(
41
+ f"{model_cls.MODEL_CONFIG_NAME} model configs:",
42
+ kwargs.get(model_cls.MODEL_CONFIG_NAME),
43
+ )
44
+ return model_cls(**model_cls_params, **kwargs.get(model_cls.MODEL_CONFIG_NAME, {}))
45
+
46
+
47
+ def get_all_model_default_configs():
48
+ return copy.deepcopy(_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG)
models/core/sci_codec.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from models.core.extractor import ResidualBlock
7
+
8
+ autocast = torch.cuda.amp.autocast
9
+
10
+ class ste_fn(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(ctx, x):
13
+ return (x > 0).float()
14
+ @staticmethod
15
+ def backward(ctx, grad):
16
+ return F.hardtanh(grad)
17
+
18
+ class STE(nn.Module):
19
+ def __init__(self):
20
+ super(STE, self).__init__()
21
+ def forward(self, x):
22
+ return ste_fn.apply(x)
23
+
24
+ class sci_encoder(nn.Module):
25
+ def __init__(
26
+ self,
27
+ sigma_range=[0, 1e-9],
28
+ n_frame=8,
29
+ in_channels=1,
30
+ n_taps=2,
31
+ resolution=[480, 640]):
32
+
33
+ super(sci_encoder, self).__init__()
34
+
35
+ assert n_taps in [1, 2], "[ERROR] n_taps should be either 1 or 2."
36
+
37
+ self.sigma_range = sigma_range
38
+ self.n_frame = n_frame
39
+ self.in_channels = in_channels
40
+ self.n_taps = n_taps
41
+ self.resolution = resolution
42
+
43
+ # -- Shutter code; Learnable parameters
44
+ self.ce_weight = nn.Parameter(torch.Tensor(n_frame, in_channels, *resolution))
45
+
46
+ # -- initialize
47
+ nn.init.uniform_(self.ce_weight, a=-1, b=1)
48
+
49
+ self.ste = STE()
50
+
51
+ def forward(self, frames):
52
+
53
+ # -- print ("[INFO] self.ce_weight.device: ", self.ce_weight.device)
54
+ ce_code = self.ste(self.ce_weight)
55
+ # -- print ("[INFO] ce_code.device: ", ce_code.device)
56
+
57
+ frames = frames[..., :self.resolution[0], :self.resolution[1]]
58
+ frames = frames.contiguous()
59
+ frames = torch.unsqueeze(frames, 2)
60
+
61
+ # -- print ("[INFO] ce_code.shape: ", ce_code.shape)
62
+ # -- print ("[INFO] frames.shape: ", frames.shape)
63
+
64
+ # -- repeat by the batch size
65
+ ce_code = ce_code.repeat(frames.shape[0], 1, 1, 1, 1)
66
+ # -- print ("[INFO] ce_code.shape: ", ce_code.shape)
67
+ # -- print ("[INFO] ce_code.squeeze(2).shape: ", ce_code.squeeze(2).shape)
68
+
69
+ ce_blur_img = torch.zeros(frames.shape[0], self.in_channels * self.n_taps, *self.resolution).to(frames.device) # -- (b, c, h, w)
70
+
71
+ # -- print ("[INFO] ce_blur_img.shape: ", ce_blur_img.shape)
72
+ ce_blur_img[:, 0, ...] = torch.sum( ce_code * frames, axis=1) / self.n_frame
73
+ ce_blur_img[:, 1, ...] = torch.sum((1. - ce_code) * frames, axis=1) / self.n_frame
74
+
75
+ # -- add noise
76
+ noise_level = np.random.uniform(*self.sigma_range)
77
+ ce_blur_img_noisy = ce_blur_img + torch.tensor(noise_level).to(frames.device) * torch.randn(ce_blur_img.shape).to(frames.device)
78
+
79
+ # -- concat snapshots and mask patterns
80
+ out = torch.zeros(frames.shape[0], self.n_taps + self.n_frame, *self.resolution).to(frames.device)
81
+
82
+ # -- print ("[INFO] out.shape: ", out.shape)
83
+ out[:, :self.n_taps, :, :] = ce_blur_img_noisy
84
+ out[:, self.n_taps:, :, :] = ce_code.squeeze(2)
85
+
86
+ return out
87
+
88
+ class sci_decoder(nn.Module):
89
+ def __init__(self,
90
+ n_frame=8,
91
+ n_taps=2,
92
+ output_dim=128,
93
+ norm_fn="batch",
94
+ dropout=.0):
95
+
96
+ super(sci_decoder, self).__init__()
97
+
98
+ self.norm_fn = norm_fn
99
+ if norm_fn == "group":
100
+ self.norm1 = nn.GroupNorm(num_groups=4, num_channels=4*n_frame)
101
+ elif norm_fn == "batch":
102
+ self.norm1 = nn.BatchNorm2d(4*n_frame)
103
+ elif norm_fn == "instance":
104
+ self.norm1 = nn.InstanceNorm2d(4*n_frame, affine=True)
105
+ elif norm_fn == "none":
106
+ self.norm1 = nn.Sequential()
107
+
108
+ # -- Input Convoultion
109
+ # -- Assuming n_frame=8; n_ich=10; n_och=32
110
+ self.conv1 = nn.Conv2d(n_taps+n_frame, 4*n_frame, kernel_size=7, stride=2, padding=3)
111
+ self.relu1 = nn.ReLU(inplace=True)
112
+
113
+ # -- Residual Blocks
114
+ self.layer1 = self._make_layer( 4*n_frame, 4*n_frame, stride=1)
115
+ self.layer2 = self._make_layer( 4*n_frame, 16*n_frame, stride=2)
116
+ self.layer3 = self._make_layer(16*n_frame, 64*n_frame, stride=1)
117
+
118
+ # -- Output Convolution
119
+ self.conv2 = nn.Conv2d(64*n_frame, output_dim*n_frame, kernel_size=1)
120
+
121
+ if dropout > 0.:
122
+ self.dropout = nn.Dropout2d(p=dropout)
123
+ else:
124
+ self.dropout = None
125
+
126
+ # -- self.modules() is a PyTorch utility function that returns all submodules of this nn.Module recursively.
127
+ # -- This means it will looop through every layer: conv1, layer1, layer2, layer3, conv2 and so on.
128
+ for m in self.modules():
129
+ if isinstance(m, nn.Conv2d):
130
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
131
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
132
+ if m.weight is not None:
133
+ nn.init.constant_(m.weight, 1)
134
+ if m.bias is not None:
135
+ nn.init.constant_(m.bias, 0)
136
+
137
+ # -- Private function to make residual blocks
138
+ def _make_layer(self, n_ich, n_och, stride=1):
139
+ layer1 = ResidualBlock(n_ich, n_och, self.norm_fn, stride=stride)
140
+ layer2 = ResidualBlock(n_och, n_och, self.norm_fn, stride=1)
141
+ layers = (layer1, layer2)
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ # -- x = [L, R]
147
+ # -- L, R ~ (b, c, h, w); c=n_taps+n_frame
148
+
149
+ # -- if input is list, combine batch dimension
150
+ is_list = isinstance(x, tuple) or isinstance(x, list)
151
+ if is_list:
152
+ batch_dim = x[0].shape[0]
153
+ x = torch.cat(x, dim=0)
154
+
155
+ # -- print ("[INFO] x.shape: ", x.shape)
156
+
157
+ x = self.conv1(x)
158
+ x = self.norm1(x)
159
+ x = self.relu1(x)
160
+
161
+ x = self.layer1(x)
162
+ x = self.layer2(x)
163
+ x = self.layer3(x)
164
+
165
+ x = self.conv2(x)
166
+
167
+ # -- expand the temporal dimension
168
+ # -- (b, c, h, w) -> (b*t, c//t, h, w)
169
+ x = x.contiguous()
170
+ x = x.view(x.shape[0]*8, x.shape[1]//8, x.shape[-2], x.shape[-1])
171
+
172
+ if self.dropout is not None:
173
+ x = self.dropout(x)
174
+
175
+ # -- if input is list, split the first dimension
176
+ if is_list:
177
+ x = torch.split(x, x.shape[0] // 2, dim=0)
178
+
179
+ return x
180
+
models/core/update.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from einops import rearrange
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from models.core.attention import LoFTREncoderLayer
13
+
14
+ # -- Added by Chu King on 16th November 2025 for debugging purposes.
15
+ import os, signal
16
+ import logging
17
+ import torch.distributed as dist
18
+
19
+ # Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py
20
+ class FlowHead(nn.Module):
21
+ def __init__(self, input_dim=128, hidden_dim=256):
22
+ super(FlowHead, self).__init__()
23
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
24
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
25
+ self.relu = nn.ReLU(inplace=True)
26
+
27
+ def forward(self, x):
28
+ return self.conv2(self.relu(self.conv1(x)))
29
+
30
+
31
+ class SepConvGRU(nn.Module):
32
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
33
+ super(SepConvGRU, self).__init__()
34
+ self.convz1 = nn.Conv2d(
35
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
36
+ )
37
+ self.convr1 = nn.Conv2d(
38
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
39
+ )
40
+ self.convq1 = nn.Conv2d(
41
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
42
+ )
43
+
44
+ self.convz2 = nn.Conv2d(
45
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
46
+ )
47
+ self.convr2 = nn.Conv2d(
48
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
49
+ )
50
+ self.convq2 = nn.Conv2d(
51
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
52
+ )
53
+
54
+ def forward(self, h, x):
55
+ # horizontal
56
+ hx = torch.cat([h, x], dim=1)
57
+ z = torch.sigmoid(self.convz1(hx))
58
+ r = torch.sigmoid(self.convr1(hx))
59
+ q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
60
+ h = (1 - z) * h + z * q
61
+
62
+ # vertical
63
+ hx = torch.cat([h, x], dim=1)
64
+ z = torch.sigmoid(self.convz2(hx))
65
+ r = torch.sigmoid(self.convr2(hx))
66
+ q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
67
+ h = (1 - z) * h + z * q
68
+
69
+ return h
70
+
71
+
72
+ class ConvGRU(nn.Module):
73
+ def __init__(self, hidden_dim, input_dim, kernel_size=3):
74
+ super(ConvGRU, self).__init__()
75
+ self.convz = nn.Conv2d(
76
+ hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2
77
+ )
78
+ self.convr = nn.Conv2d(
79
+ hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2
80
+ )
81
+ self.convq = nn.Conv2d(
82
+ hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2
83
+ )
84
+
85
+ def forward(self, h, x):
86
+ hx = torch.cat([h, x], dim=1)
87
+
88
+ z = torch.sigmoid(self.convz(hx))
89
+ r = torch.sigmoid(self.convr(hx))
90
+ q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
91
+
92
+ h = (1 - z) * h + z * q
93
+ return h
94
+
95
+
96
+ class SepConvGRU3D(nn.Module):
97
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
98
+ super(SepConvGRU3D, self).__init__()
99
+ self.convz1 = nn.Conv3d(
100
+ hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
101
+ )
102
+ self.convr1 = nn.Conv3d(
103
+ hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
104
+ )
105
+ self.convq1 = nn.Conv3d(
106
+ hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
107
+ )
108
+
109
+ self.convz2 = nn.Conv3d(
110
+ hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
111
+ )
112
+ self.convr2 = nn.Conv3d(
113
+ hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
114
+ )
115
+ self.convq2 = nn.Conv3d(
116
+ hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
117
+ )
118
+
119
+ self.convz3 = nn.Conv3d(
120
+ hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
121
+ )
122
+ self.convr3 = nn.Conv3d(
123
+ hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
124
+ )
125
+ self.convq3 = nn.Conv3d(
126
+ hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
127
+ )
128
+
129
+ def forward(self, h, x):
130
+ hx = torch.cat([h, x], dim=1)
131
+ z = torch.sigmoid(self.convz1(hx))
132
+ r = torch.sigmoid(self.convr1(hx))
133
+ q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
134
+ h = (1 - z) * h + z * q
135
+
136
+ # vertical
137
+ hx = torch.cat([h, x], dim=1)
138
+ z = torch.sigmoid(self.convz2(hx))
139
+ r = torch.sigmoid(self.convr2(hx))
140
+ q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
141
+ h = (1 - z) * h + z * q
142
+
143
+ # time
144
+ hx = torch.cat([h, x], dim=1)
145
+ z = torch.sigmoid(self.convz3(hx))
146
+ r = torch.sigmoid(self.convr3(hx))
147
+ q = torch.tanh(self.convq3(torch.cat([r * h, x], dim=1)))
148
+ h = (1 - z) * h + z * q
149
+
150
+ return h
151
+
152
+
153
+ class BasicMotionEncoder(nn.Module):
154
+ def __init__(self, cor_planes):
155
+ super(BasicMotionEncoder, self).__init__()
156
+
157
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
158
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
159
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
160
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
161
+ self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
162
+
163
+ def forward(self, flow, corr):
164
+ cor = F.relu(self.convc1(corr))
165
+ cor = F.relu(self.convc2(cor))
166
+ flo = F.relu(self.convf1(flow))
167
+ flo = F.relu(self.convf2(flo))
168
+
169
+ cor_flo = torch.cat([cor, flo], dim=1)
170
+ out = F.relu(self.conv(cor_flo))
171
+ return torch.cat([out, flow], dim=1)
172
+
173
+
174
+ class Attention(nn.Module):
175
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):
176
+ super().__init__()
177
+ self.num_heads = num_heads
178
+ head_dim = dim // num_heads
179
+ self.scale = qk_scale or head_dim ** -0.5
180
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
181
+ self.proj = nn.Linear(dim, dim)
182
+
183
+ def forward(self, x):
184
+ B, N, C = x.shape
185
+ # -- Bug fixed by Chu King on 22nd November 2025
186
+ qkv = self.qkv(x)
187
+ # -- qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
188
+ qkv = qkv.view(B, N, 3, self.num_heads, C // self.num_heads)
189
+ qkv = qkv.permute(0, 3, 1, 2, 4) # -- (B, H, N, 3, -1)
190
+ # -- q, k, v = qkv, qkv, qkv
191
+ q, k, v = qkv.unbind(dim=3)
192
+
193
+ attn = (q @ k.transpose(-2, -1)) * self.scale
194
+
195
+ attn = attn.softmax(dim=-1)
196
+
197
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C).contiguous()
198
+ x = self.proj(x)
199
+ return x
200
+
201
+
202
+ class Mlp(nn.Module):
203
+ def __init__(
204
+ self,
205
+ in_features,
206
+ hidden_features=None,
207
+ out_features=None,
208
+ act_layer=nn.GELU,
209
+ drop=0.0,
210
+ ):
211
+ super().__init__()
212
+ out_features = out_features or in_features
213
+ hidden_features = hidden_features or in_features
214
+ self.fc1 = nn.Linear(in_features, hidden_features)
215
+ self.act = act_layer()
216
+ self.fc2 = nn.Linear(hidden_features, out_features)
217
+ self.drop = nn.Dropout(drop)
218
+
219
+ def forward(self, x):
220
+ x = self.fc1(x)
221
+ x = self.act(x)
222
+ x = self.drop(x)
223
+ x = self.fc2(x)
224
+ x = self.drop(x)
225
+ return x
226
+
227
+
228
+ class TimeAttnBlock(nn.Module):
229
+ def __init__(self, dim=256, num_heads=8):
230
+ super(TimeAttnBlock, self).__init__()
231
+ self.temporal_attn = Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None)
232
+ self.temporal_fc = nn.Linear(dim, dim)
233
+ self.temporal_norm1 = nn.LayerNorm(dim)
234
+
235
+ nn.init.constant_(self.temporal_fc.weight, 0)
236
+ nn.init.constant_(self.temporal_fc.bias, 0)
237
+
238
+ def forward(self, x, T=1):
239
+ _, _, h, w = x.shape
240
+
241
+ x = rearrange(x, "(b t) m h w -> (b h w) t m", h=h, w=w, t=T)
242
+ res_temporal1 = self.temporal_attn(self.temporal_norm1(x))
243
+ res_temporal1 = rearrange(
244
+ res_temporal1, "(b h w) t m -> b (h w t) m", h=h, w=w, t=T
245
+ )
246
+ res_temporal1 = self.temporal_fc(res_temporal1)
247
+ res_temporal1 = rearrange(
248
+ res_temporal1, " b (h w t) m -> b t m h w", h=h, w=w, t=T
249
+ )
250
+ x = rearrange(x, "(b h w) t m -> b t m h w", h=h, w=w, t=T)
251
+ x = x + res_temporal1
252
+ x = rearrange(x, "b t m h w -> (b t) m h w", h=h, w=w, t=T)
253
+ return x
254
+
255
+
256
+ class SpaceAttnBlock(nn.Module):
257
+ def __init__(self, dim=256, num_heads=8):
258
+ super(SpaceAttnBlock, self).__init__()
259
+ self.encoder_layer = LoFTREncoderLayer(dim, nhead=num_heads, attention="linear")
260
+
261
+ def forward(self, x, T=1):
262
+ _, _, h, w = x.shape
263
+ x = rearrange(x, "(b t) m h w -> (b t) (h w) m", h=h, w=w, t=T)
264
+ x = self.encoder_layer(x, x)
265
+ x = rearrange(x, "(b t) (h w) m -> (b t) m h w", h=h, w=w, t=T)
266
+ return x
267
+
268
+
269
+ class BasicUpdateBlock(nn.Module):
270
+ def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type=None):
271
+ super(BasicUpdateBlock, self).__init__()
272
+ self.attention_type = attention_type
273
+ if attention_type is not None:
274
+ if "update_time" in attention_type:
275
+ self.time_attn = TimeAttnBlock(dim=256, num_heads=8)
276
+
277
+ if "update_space" in attention_type:
278
+ self.space_attn = SpaceAttnBlock(dim=256, num_heads=8)
279
+
280
+ self.encoder = BasicMotionEncoder(cor_planes)
281
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
282
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
283
+
284
+ self.mask = nn.Sequential(
285
+ nn.Conv2d(128, 256, 3, padding=1),
286
+ nn.ReLU(inplace=True),
287
+ nn.Conv2d(256, mask_size ** 2 * 9, 1, padding=0),
288
+ )
289
+
290
+ def forward(self, net, inp, corr, flow, upsample=True, t=1):
291
+ motion_features = self.encoder(flow, corr)
292
+ inp = torch.cat((inp, motion_features), dim=1)
293
+
294
+ if self.attention_type is not None:
295
+ if "update_time" in self.attention_type:
296
+ inp = self.time_attn(inp, T=t)
297
+
298
+ if "update_space" in self.attention_type:
299
+ inp = self.space_attn(inp, T=t)
300
+
301
+ net = self.gru(net, inp)
302
+ delta_flow = self.flow_head(net)
303
+
304
+ # scale mask to balence gradients
305
+ mask = 0.25 * self.mask(net)
306
+ return net, mask, delta_flow
307
+
308
+
309
+ class FlowHead3D(nn.Module):
310
+ def __init__(self, input_dim=128, hidden_dim=256):
311
+ super(FlowHead3D, self).__init__()
312
+ self.conv1 = nn.Conv3d(input_dim, hidden_dim, 3, padding=1)
313
+ self.conv2 = nn.Conv3d(hidden_dim, 2, 3, padding=1)
314
+ self.relu = nn.ReLU(inplace=True)
315
+
316
+ def forward(self, x):
317
+ return self.conv2(self.relu(self.conv1(x)))
318
+
319
+
320
+ class SequenceUpdateBlock3D(nn.Module):
321
+ def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type=None):
322
+ super(SequenceUpdateBlock3D, self).__init__()
323
+
324
+ # -- Extracts motion-related features from:
325
+ # * current flow estimate
326
+ # * correlation volume
327
+ self.encoder = BasicMotionEncoder(cor_planes)
328
+
329
+ # -- 3D separable convolution GRU enables temporal reasoning with 3D convolutions.
330
+ self.gru = SepConvGRU3D(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
331
+
332
+ self.flow_head = FlowHead3D(hidden_dim, hidden_dim=256)
333
+ self.mask = nn.Sequential(
334
+ nn.Conv2d(hidden_dim, hidden_dim + 128, 3, padding=1),
335
+ nn.ReLU(inplace=True),
336
+ nn.Conv2d(hidden_dim + 128, (mask_size ** 2) * 9, 1, padding=0),
337
+ )
338
+ self.attention_type = attention_type
339
+ if attention_type is not None:
340
+ if "update_time" in attention_type:
341
+ self.time_attn = TimeAttnBlock(dim=256, num_heads=8)
342
+ if "update_space" in attention_type:
343
+ self.space_attn = SpaceAttnBlock(dim=256, num_heads=8)
344
+
345
+ def forward(self, net, inp, corrs, flows, t, upsample=True):
346
+ inp_tensor = []
347
+
348
+ motion_features = self.encoder(flows, corrs)
349
+ inp_tensor = torch.cat([inp, motion_features], dim=1)
350
+
351
+ if self.attention_type is not None:
352
+ if "update_time" in self.attention_type:
353
+ inp_tensor = self.time_attn(inp_tensor, T=t)
354
+
355
+ if "update_space" in self.attention_type:
356
+ inp_tensor = self.space_attn(inp_tensor, T=t)
357
+
358
+ net = rearrange(net, "(b t) c h w -> b c t h w", t=t)
359
+ inp_tensor = rearrange(inp_tensor, "(b t) c h w -> b c t h w", t=t)
360
+
361
+ net = self.gru(net, inp_tensor)
362
+
363
+ delta_flow = self.flow_head(net)
364
+
365
+ # scale mask to balance gradients
366
+ net = rearrange(net, " b c t h w -> (b t) c h w")
367
+ mask = 0.25 * self.mask(net)
368
+
369
+ delta_flow = rearrange(delta_flow, " b c t h w -> (b t) c h w")
370
+ return net, mask, delta_flow
models/core/utils/config.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import dataclasses
8
+ import inspect
9
+ import itertools
10
+ import sys
11
+ import warnings
12
+ from collections import Counter, defaultdict
13
+ from enum import Enum
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
15
+
16
+ from omegaconf import DictConfig, OmegaConf, open_dict
17
+ from pytorch3d.common.datatypes import get_args, get_origin
18
+
19
+
20
+ """
21
+ This functionality allows a configurable system to be determined in a dataclass-type
22
+ way. It is a generalization of omegaconf's "structured", in the dataclass case.
23
+ Core functionality:
24
+
25
+ - Configurable -- A base class used to label a class as being one which uses this
26
+ system. Uses class members and __post_init__ like a dataclass.
27
+
28
+ - expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.
29
+
30
+ - get_default_args -- gets an omegaconf.DictConfig for initializing a given class.
31
+
32
+ - run_auto_creation -- Initialises nested members. To be called in __post_init__.
33
+
34
+
35
+ In addition, a Configurable may contain members whose type is decided at runtime.
36
+
37
+ - ReplaceableBase -- As a base instead of Configurable, labels a class to say that
38
+ any child class can be used instead.
39
+
40
+ - registry -- A global store of named child classes of ReplaceableBase classes.
41
+ Used as `@registry.register` decorator on class definition.
42
+
43
+
44
+ Additional utility functions:
45
+
46
+ - remove_unused_components -- used for simplifying a DictConfig instance.
47
+ - get_default_args_field -- default for DictConfig member of another configurable.
48
+ - enable_get_default_args -- Allows get_default_args on a function or plain class.
49
+
50
+
51
+ 1. The simplest usage of this functionality is as follows. First a schema is defined
52
+ in dataclass style.
53
+
54
+ class A(Configurable):
55
+ n: int = 9
56
+
57
+ class B(Configurable):
58
+ a: A
59
+
60
+ def __post_init__(self):
61
+ run_auto_creation(self)
62
+
63
+ Then it can be used like
64
+
65
+ b_args = get_default_args(B)
66
+ b = B(**b_args)
67
+
68
+ In this case, get_default_args(B) returns an omegaconf.DictConfig with the right
69
+ members {"a_args": {"n": 9}}. It also modifies the definitions of the classes to
70
+ something like the following. (The modification itself is done by the function
71
+ `expand_args_fields`, which is called inside `get_default_args`.)
72
+
73
+ @dataclasses.dataclass
74
+ class A:
75
+ n: int = 9
76
+
77
+ @dataclasses.dataclass
78
+ class B:
79
+ a_args: DictConfig = dataclasses.field(default_factory=lambda: DictConfig({"n": 9}))
80
+
81
+ def __post_init__(self):
82
+ self.a = A(**self.a_args)
83
+
84
+ 2. Pluggability. Instead of a dataclass-style member being given a concrete class,
85
+ it can be given a base class and the implementation will be looked up by name in the
86
+ global `registry` in this module. E.g.
87
+
88
+ class A(ReplaceableBase):
89
+ k: int = 1
90
+
91
+ @registry.register
92
+ class A1(A):
93
+ m: int = 3
94
+
95
+ @registry.register
96
+ class A2(A):
97
+ n: str = "2"
98
+
99
+ class B(Configurable):
100
+ a: A
101
+ a_class_type: str = "A2"
102
+ b: Optional[A]
103
+ b_class_type: Optional[str] = "A2"
104
+
105
+ def __post_init__(self):
106
+ run_auto_creation(self)
107
+
108
+ will expand to
109
+
110
+ @dataclasses.dataclass
111
+ class A:
112
+ k: int = 1
113
+
114
+ @dataclasses.dataclass
115
+ class A1(A):
116
+ m: int = 3
117
+
118
+ @dataclasses.dataclass
119
+ class A2(A):
120
+ n: str = "2"
121
+
122
+ @dataclasses.dataclass
123
+ class B:
124
+ a_class_type: str = "A2"
125
+ a_A1_args: DictConfig = dataclasses.field(
126
+ default_factory=lambda: DictConfig({"k": 1, "m": 3}
127
+ )
128
+ a_A2_args: DictConfig = dataclasses.field(
129
+ default_factory=lambda: DictConfig({"k": 1, "n": 2}
130
+ )
131
+ b_class_type: Optional[str] = "A2"
132
+ b_A1_args: DictConfig = dataclasses.field(
133
+ default_factory=lambda: DictConfig({"k": 1, "m": 3}
134
+ )
135
+ b_A2_args: DictConfig = dataclasses.field(
136
+ default_factory=lambda: DictConfig({"k": 1, "n": 2}
137
+ )
138
+
139
+ def __post_init__(self):
140
+ if self.a_class_type == "A1":
141
+ self.a = A1(**self.a_A1_args)
142
+ elif self.a_class_type == "A2":
143
+ self.a = A2(**self.a_A2_args)
144
+ else:
145
+ raise ValueError(...)
146
+
147
+ if self.b_class_type is None:
148
+ self.b = None
149
+ elif self.b_class_type == "A1":
150
+ self.b = A1(**self.b_A1_args)
151
+ elif self.b_class_type == "A2":
152
+ self.b = A2(**self.b_A2_args)
153
+ else:
154
+ raise ValueError(...)
155
+
156
+ 3. Aside from these classes, the members of these classes should be things
157
+ which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
158
+ can be built from them with `DictConfig`s and lists of them.
159
+
160
+ In addition, you can call `get_default_args` on a function or class to get
161
+ the `DictConfig` of its defaulted arguments, assuming those are all things
162
+ which `DictConfig` is happy with, so long as you add a call to
163
+ `enable_get_default_args` after its definition. If you want to use such a
164
+ thing as the default for a member of another configured class,
165
+ `get_default_args_field` is a helper.
166
+ """
167
+
168
+
169
+ _unprocessed_warning: str = (
170
+ " must be processed before it can be used."
171
+ + " This is done by calling expand_args_fields "
172
+ + "or get_default_args on it."
173
+ )
174
+
175
+ TYPE_SUFFIX: str = "_class_type"
176
+ ARGS_SUFFIX: str = "_args"
177
+ ENABLED_SUFFIX: str = "_enabled"
178
+
179
+
180
+ class ReplaceableBase:
181
+ """
182
+ Base class for dataclass-style classes which
183
+ can be stored in the registry.
184
+ """
185
+
186
+ def __new__(cls, *args, **kwargs):
187
+ """
188
+ This function only exists to raise a
189
+ warning if class construction is attempted
190
+ without processing.
191
+ """
192
+ obj = super().__new__(cls)
193
+ if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
194
+ warnings.warn(cls.__name__ + _unprocessed_warning)
195
+ return obj
196
+
197
+
198
+ class Configurable:
199
+ """
200
+ This indicates a class which is not ReplaceableBase
201
+ but still needs to be
202
+ expanded into a dataclass with expand_args_fields.
203
+ This expansion is delayed.
204
+ """
205
+
206
+ def __new__(cls, *args, **kwargs):
207
+ """
208
+ This function only exists to raise a
209
+ warning if class construction is attempted
210
+ without processing.
211
+ """
212
+ obj = super().__new__(cls)
213
+ if cls is not Configurable and not _is_actually_dataclass(cls):
214
+ warnings.warn(cls.__name__ + _unprocessed_warning)
215
+ return obj
216
+
217
+
218
+ _X = TypeVar("X", bound=ReplaceableBase)
219
+
220
+
221
+ class _Registry:
222
+ """
223
+ Register from names to classes. In particular, we say that direct subclasses of
224
+ ReplaceableBase are "base classes" and we register subclasses of each base class
225
+ in a separate namespace.
226
+ """
227
+
228
+ def __init__(self) -> None:
229
+ self._mapping: Dict[
230
+ Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]
231
+ ] = defaultdict(dict)
232
+
233
+ def register(self, some_class: Type[_X]) -> Type[_X]:
234
+ """
235
+ A class decorator, to register a class in self.
236
+ """
237
+ name = some_class.__name__
238
+ self._register(some_class, name=name)
239
+ return some_class
240
+
241
+ def _register(
242
+ self,
243
+ some_class: Type[ReplaceableBase],
244
+ *,
245
+ base_class: Optional[Type[ReplaceableBase]] = None,
246
+ name: str,
247
+ ) -> None:
248
+ """
249
+ Register a new member.
250
+
251
+ Args:
252
+ cls: the new member
253
+ base_class: (optional) what the new member is a type for
254
+ name: name for the new member
255
+ """
256
+ if base_class is None:
257
+ base_class = self._base_class_from_class(some_class)
258
+ if base_class is None:
259
+ raise ValueError(
260
+ f"Cannot register {some_class}. Cannot tell what it is."
261
+ )
262
+ if some_class is base_class:
263
+ raise ValueError(f"Attempted to register the base class {some_class}")
264
+ self._mapping[base_class][name] = some_class
265
+
266
+ def get(
267
+ self, base_class_wanted: Type[ReplaceableBase], name: str
268
+ ) -> Type[ReplaceableBase]:
269
+ """
270
+ Retrieve a class from the registry by name
271
+
272
+ Args:
273
+ base_class_wanted: parent type of type we are looking for.
274
+ It determines the namespace.
275
+ This will typically be a direct subclass of ReplaceableBase.
276
+ name: what to look for
277
+
278
+ Returns:
279
+ class type
280
+ """
281
+ if self._is_base_class(base_class_wanted):
282
+ base_class = base_class_wanted
283
+ else:
284
+ base_class = self._base_class_from_class(base_class_wanted)
285
+ if base_class is None:
286
+ raise ValueError(
287
+ f"Cannot look up {base_class_wanted}. Cannot tell what it is."
288
+ )
289
+ result = self._mapping[base_class].get(name)
290
+ if result is None:
291
+ raise ValueError(f"{name} has not been registered.")
292
+ if not issubclass(result, base_class_wanted):
293
+ raise ValueError(
294
+ f"{name} resolves to {result} which does not subclass {base_class_wanted}"
295
+ )
296
+ return result
297
+
298
+ def get_all(
299
+ self, base_class_wanted: Type[ReplaceableBase]
300
+ ) -> List[Type[ReplaceableBase]]:
301
+ """
302
+ Retrieve all registered implementations from the registry
303
+
304
+ Args:
305
+ base_class_wanted: parent type of type we are looking for.
306
+ It determines the namespace.
307
+ This will typically be a direct subclass of ReplaceableBase.
308
+ Returns:
309
+ list of class types
310
+ """
311
+ if self._is_base_class(base_class_wanted):
312
+ return list(self._mapping[base_class_wanted].values())
313
+
314
+ base_class = self._base_class_from_class(base_class_wanted)
315
+ if base_class is None:
316
+ raise ValueError(
317
+ f"Cannot look up {base_class_wanted}. Cannot tell what it is."
318
+ )
319
+ return [
320
+ class_
321
+ for class_ in self._mapping[base_class].values()
322
+ if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
323
+ ]
324
+
325
+ @staticmethod
326
+ def _is_base_class(some_class: Type[ReplaceableBase]) -> bool:
327
+ """
328
+ Return whether the given type is a direct subclass of ReplaceableBase
329
+ and so gets used as a namespace.
330
+ """
331
+ return ReplaceableBase in some_class.__bases__
332
+
333
+ @staticmethod
334
+ def _base_class_from_class(
335
+ some_class: Type[ReplaceableBase],
336
+ ) -> Optional[Type[ReplaceableBase]]:
337
+ """
338
+ Find the parent class of some_class which inherits ReplaceableBase, or None
339
+ """
340
+ for base in some_class.mro()[-3::-1]:
341
+ if base is not ReplaceableBase and issubclass(base, ReplaceableBase):
342
+ return base
343
+ return None
344
+
345
+
346
+ # Global instance of the registry
347
+ registry = _Registry()
348
+
349
+
350
+ class _ProcessType(Enum):
351
+ """
352
+ Type of member which gets rewritten by expand_args_fields.
353
+ """
354
+
355
+ CONFIGURABLE = 1
356
+ REPLACEABLE = 2
357
+ OPTIONAL_CONFIGURABLE = 3
358
+ OPTIONAL_REPLACEABLE = 4
359
+
360
+
361
+ def _default_create(
362
+ name: str, type_: Type, process_type: _ProcessType
363
+ ) -> Callable[[Any], None]:
364
+ """
365
+ Return the default creation function for a member. This is a function which
366
+ could be called in __post_init__ to initialise the member, and will be called
367
+ from run_auto_creation.
368
+
369
+ Args:
370
+ name: name of the member
371
+ type_: type of the member (with any Optional removed)
372
+ process_type: Shows whether member's declared type inherits ReplaceableBase,
373
+ in which case the actual type to be created is decided at
374
+ runtime.
375
+
376
+ Returns:
377
+ Function taking one argument, the object whose member should be
378
+ initialized.
379
+ """
380
+
381
+ def inner(self):
382
+ expand_args_fields(type_)
383
+ args = getattr(self, name + ARGS_SUFFIX)
384
+ setattr(self, name, type_(**args))
385
+
386
+ def inner_optional(self):
387
+ expand_args_fields(type_)
388
+ enabled = getattr(self, name + ENABLED_SUFFIX)
389
+ if enabled:
390
+ args = getattr(self, name + ARGS_SUFFIX)
391
+ setattr(self, name, type_(**args))
392
+ else:
393
+ setattr(self, name, None)
394
+
395
+ def inner_pluggable(self):
396
+ type_name = getattr(self, name + TYPE_SUFFIX)
397
+ if type_name is None:
398
+ setattr(self, name, None)
399
+ return
400
+
401
+ chosen_class = registry.get(type_, type_name)
402
+ if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
403
+ # If this warning is raised, it means that a new definition of
404
+ # the chosen class has been registered since our class was processed
405
+ # (i.e. expanded). A DictConfig which comes from our get_default_args
406
+ # (which might have triggered the processing) will contain the old default
407
+ # values for the members of the chosen class. Changes to those defaults which
408
+ # were made in the redefinition will not be reflected here.
409
+ warnings.warn(f"New implementation of {type_name} is being chosen.")
410
+ expand_args_fields(chosen_class)
411
+ args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
412
+ setattr(self, name, chosen_class(**args))
413
+
414
+ if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
415
+ return inner_optional
416
+ return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
417
+
418
+
419
+ def run_auto_creation(self: Any) -> None:
420
+ """
421
+ Run all the functions named in self._creation_functions.
422
+ """
423
+ for create_function in self._creation_functions:
424
+ getattr(self, create_function)()
425
+
426
+
427
+ def _is_configurable_class(C) -> bool:
428
+ return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
429
+
430
+
431
+ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
432
+ """
433
+ Get the DictConfig corresponding to the defaults in a dataclass or
434
+ configurable. Normal use is to provide a dataclass can be provided as C.
435
+ If enable_get_default_args has been called on a function or plain class,
436
+ then that function or class can be provided as C.
437
+
438
+ If C is a subclass of Configurable or ReplaceableBase, we make sure
439
+ it has been processed with expand_args_fields.
440
+
441
+ Args:
442
+ C: the class or function to be processed
443
+ _do_not_process: (internal use) When this function is called from
444
+ expand_args_fields, we specify any class currently being
445
+ processed, to make sure we don't try to process a class
446
+ while it is already being processed.
447
+
448
+ Returns:
449
+ new DictConfig object, which is typed.
450
+ """
451
+ if C is None:
452
+ return DictConfig({})
453
+
454
+ if _is_configurable_class(C):
455
+ if C in _do_not_process:
456
+ raise ValueError(
457
+ f"Internal recursion error. Need processed {C},"
458
+ f" but cannot get it. _do_not_process={_do_not_process}"
459
+ )
460
+ # This is safe to run multiple times. It will return
461
+ # straight away if C has already been processed.
462
+ expand_args_fields(C, _do_not_process=_do_not_process)
463
+
464
+ if dataclasses.is_dataclass(C):
465
+ # Note that if get_default_args_field is used somewhere in C,
466
+ # this call is recursive. No special care is needed,
467
+ # because in practice get_default_args_field is used for
468
+ # separate types than the outer type.
469
+
470
+ out: DictConfig = OmegaConf.structured(C)
471
+ exclude = getattr(C, "_processed_members", ())
472
+ with open_dict(out):
473
+ for field in exclude:
474
+ out.pop(field, None)
475
+ return out
476
+
477
+ if _is_configurable_class(C):
478
+ raise ValueError(f"Failed to process {C}")
479
+
480
+ if not inspect.isfunction(C) and not inspect.isclass(C):
481
+ raise ValueError(f"Unexpected {C}")
482
+
483
+ dataclass_name = _dataclass_name_for_function(C)
484
+ dataclass = getattr(sys.modules[C.__module__], dataclass_name, None)
485
+ if dataclass is None:
486
+ raise ValueError(
487
+ f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
488
+ )
489
+
490
+ return OmegaConf.structured(dataclass)
491
+
492
+
493
+ def _dataclass_name_for_function(C: Any) -> str:
494
+ """
495
+ Returns the name of the dataclass which enable_get_default_args(C)
496
+ creates.
497
+ """
498
+ name = f"_{C.__name__}_default_args_"
499
+ return name
500
+
501
+
502
+ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
503
+ """
504
+ If C is a function or a plain class with an __init__ function,
505
+ and you want get_default_args(C) to work, then add
506
+ `enable_get_default_args(C)` straight after the definition of C.
507
+ This makes a dataclass corresponding to the default arguments of C
508
+ and stores it in the same module as C.
509
+
510
+ Args:
511
+ C: a function, or a class with an __init__ function. Must
512
+ have types for all its defaulted args.
513
+ overwrite: whether to allow calling this a second time on
514
+ the same function.
515
+ """
516
+ if not inspect.isfunction(C) and not inspect.isclass(C):
517
+ raise ValueError(f"Unexpected {C}")
518
+
519
+ field_annotations = []
520
+ for pname, defval in _params_iter(C):
521
+ default = defval.default
522
+ if default == inspect.Parameter.empty:
523
+ # we do not have a default value for the parameter
524
+ continue
525
+
526
+ if defval.annotation == inspect._empty:
527
+ raise ValueError(
528
+ "All arguments of the input callable have to be typed."
529
+ + f" Argument '{pname}' does not have a type annotation."
530
+ )
531
+
532
+ _, annotation = _resolve_optional(defval.annotation)
533
+
534
+ if isinstance(default, set): # force OmegaConf to convert it to ListConfig
535
+ default = tuple(default)
536
+
537
+ if isinstance(default, (list, dict)):
538
+ # OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
539
+ field_ = dataclasses.field(default_factory=lambda default=default: default)
540
+ elif not _is_immutable_type(annotation, default):
541
+ continue
542
+ else:
543
+ # we can use a simple default argument for dataclass.field
544
+ field_ = dataclasses.field(default=default)
545
+ field_annotations.append((pname, defval.annotation, field_))
546
+
547
+ name = _dataclass_name_for_function(C)
548
+ module = sys.modules[C.__module__]
549
+ if hasattr(module, name):
550
+ if overwrite:
551
+ warnings.warn(f"Overwriting {name} in {C.__module__}.")
552
+ else:
553
+ raise ValueError(f"Cannot overwrite {name} in {C.__module__}.")
554
+ dc = dataclasses.make_dataclass(name, field_annotations)
555
+ dc.__module__ = C.__module__
556
+ setattr(module, name, dc)
557
+
558
+
559
+ def _params_iter(C):
560
+ """Returns dict of keyword args of a class or function C."""
561
+ if inspect.isclass(C):
562
+ return itertools.islice( # exclude `self`
563
+ inspect.signature(C.__init__).parameters.items(), 1, None
564
+ )
565
+
566
+ return inspect.signature(C).parameters.items()
567
+
568
+
569
+ def _is_immutable_type(type_: Type, val: Any) -> bool:
570
+ PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
571
+ # sometimes type can be too relaxed (e.g. Any), so we also check values
572
+ if isinstance(val, PRIMITIVE_TYPES):
573
+ return True
574
+
575
+ return type_ in PRIMITIVE_TYPES or (
576
+ inspect.isclass(type_) and issubclass(type_, Enum)
577
+ )
578
+
579
+
580
+ # copied from OmegaConf
581
+ def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
582
+ """Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
583
+ if get_origin(type_) is Union:
584
+ args = get_args(type_)
585
+ if len(args) == 2 and args[1] == type(None): # noqa E721
586
+ return True, args[0]
587
+ if type_ is Any:
588
+ return True, Any
589
+
590
+ return False, type_
591
+
592
+
593
+ def _is_actually_dataclass(some_class) -> bool:
594
+ # Return whether the class some_class has been processed with
595
+ # the dataclass annotation. This is more specific than
596
+ # dataclasses.is_dataclass which returns True on anything
597
+ # deriving from a dataclass.
598
+
599
+ # Checking for __init__ would also work for our purpose.
600
+ return "__dataclass_fields__" in some_class.__dict__
601
+
602
+
603
+ def expand_args_fields(
604
+ some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
605
+ ) -> Type[_X]:
606
+ """
607
+ This expands a class which inherits Configurable or ReplaceableBase classes,
608
+ including dataclass processing. some_class is modified in place by this function.
609
+ For classes of type ReplaceableBase, you can add some_class to the registry before
610
+ or after calling this function. But potential inner classes need to be registered
611
+ before this function is run on the outer class.
612
+
613
+ The transformations this function makes, before the concluding
614
+ dataclasses.dataclass, are as follows. if X is a base class with registered
615
+ subclasses Y and Z, replace a class member
616
+
617
+ x: X
618
+
619
+ and optionally
620
+
621
+ x_class_type: str = "Y"
622
+ def create_x(self):...
623
+
624
+ with
625
+
626
+ x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
627
+ x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
628
+ def create_x(self):
629
+ self.x = registry.get(X, self.x_class_type)(
630
+ **self.getattr(f"x_{self.x_class_type}_args)
631
+ )
632
+ x_class_type: str = "UNDEFAULTED"
633
+
634
+ without adding the optional attributes if they are already there.
635
+
636
+ Similarly, replace
637
+
638
+ x: Optional[X]
639
+
640
+ and optionally
641
+
642
+ x_class_type: Optional[str] = "Y"
643
+ def create_x(self):...
644
+
645
+ with
646
+
647
+ x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
648
+ x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
649
+ def create_x(self):
650
+ if self.x_class_type is None:
651
+ self.x = None
652
+ return
653
+
654
+ self.x = registry.get(X, self.x_class_type)(
655
+ **self.getattr(f"x_{self.x_class_type}_args)
656
+ )
657
+ x_class_type: Optional[str] = "UNDEFAULTED"
658
+
659
+ without adding the optional attributes if they are already there.
660
+
661
+ Similarly, if X is a subclass of Configurable,
662
+
663
+ x: X
664
+
665
+ and optionally
666
+
667
+ def create_x(self):...
668
+
669
+ will be replaced with
670
+
671
+ x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
672
+ def create_x(self):
673
+ self.x = X(self.x_args)
674
+
675
+ Similarly, replace,
676
+
677
+ x: Optional[X]
678
+
679
+ and optionally
680
+
681
+ def create_x(self):...
682
+ x_enabled: bool = ...
683
+
684
+ with
685
+
686
+ x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
687
+ x_enabled: bool = False
688
+ def create_x(self):
689
+ if self.x_enabled:
690
+ self.x = X(self.x_args)
691
+ else:
692
+ self.x = None
693
+
694
+
695
+ Also adds the following class members, unannotated so that dataclass
696
+ ignores them.
697
+ - _creation_functions: Tuple[str] of all the create_ functions,
698
+ including those from base classes.
699
+ - _known_implementations: Dict[str, Type] containing the classes which
700
+ have been found from the registry.
701
+ (used only to raise a warning if it one has been overwritten)
702
+ - _processed_members: a Dict[str, Any] of all the members which have been
703
+ transformed, with values giving the types they were declared to have.
704
+ (E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
705
+
706
+ Args:
707
+ some_class: the class to be processed
708
+ _do_not_process: Internal use for get_default_args: Because get_default_args calls
709
+ and is called by this function, we let it specify any class currently
710
+ being processed, to make sure we don't try to process a class while
711
+ it is already being processed.
712
+
713
+
714
+ Returns:
715
+ some_class itself, which has been modified in place. This
716
+ allows this function to be used as a class decorator.
717
+ """
718
+ if _is_actually_dataclass(some_class):
719
+ return some_class
720
+
721
+ # The functions this class's run_auto_creation will run.
722
+ creation_functions: List[str] = []
723
+ # The classes which this type knows about from the registry
724
+ # We could use a weakref.WeakValueDictionary here which would mean
725
+ # that we don't warn if the class we should have expected is elsewhere
726
+ # unused.
727
+ known_implementations: Dict[str, Type] = {}
728
+ # Names of members which have been processed.
729
+ processed_members: Dict[str, Any] = {}
730
+
731
+ # For all bases except ReplaceableBase and Configurable and object,
732
+ # we need to process them before our own processing. This is
733
+ # because dataclasses expect to inherit dataclasses and not unprocessed
734
+ # dataclasses.
735
+ for base in some_class.mro()[-3:0:-1]:
736
+ if base is ReplaceableBase:
737
+ continue
738
+ if base is Configurable:
739
+ continue
740
+ if not issubclass(base, (Configurable, ReplaceableBase)):
741
+ continue
742
+ expand_args_fields(base, _do_not_process=_do_not_process)
743
+ if "_creation_functions" in base.__dict__:
744
+ creation_functions.extend(base._creation_functions)
745
+ if "_known_implementations" in base.__dict__:
746
+ known_implementations.update(base._known_implementations)
747
+ if "_processed_members" in base.__dict__:
748
+ processed_members.update(base._processed_members)
749
+
750
+ to_process: List[Tuple[str, Type, _ProcessType]] = []
751
+ if "__annotations__" in some_class.__dict__:
752
+ for name, type_ in some_class.__annotations__.items():
753
+ underlying_and_process_type = _get_type_to_process(type_)
754
+ if underlying_and_process_type is None:
755
+ continue
756
+ underlying_type, process_type = underlying_and_process_type
757
+ to_process.append((name, underlying_type, process_type))
758
+
759
+ for name, underlying_type, process_type in to_process:
760
+ processed_members[name] = some_class.__annotations__[name]
761
+ _process_member(
762
+ name=name,
763
+ type_=underlying_type,
764
+ process_type=process_type,
765
+ some_class=some_class,
766
+ creation_functions=creation_functions,
767
+ _do_not_process=_do_not_process,
768
+ known_implementations=known_implementations,
769
+ )
770
+
771
+ for key, count in Counter(creation_functions).items():
772
+ if count > 1:
773
+ warnings.warn(f"Clash with {key} in a base class.")
774
+ some_class._creation_functions = tuple(creation_functions)
775
+ some_class._processed_members = processed_members
776
+ some_class._known_implementations = known_implementations
777
+
778
+ dataclasses.dataclass(eq=False)(some_class)
779
+ return some_class
780
+
781
+
782
+ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
783
+ """
784
+ Get a dataclass field which defaults to get_default_args(...)
785
+
786
+ Args:
787
+ As for get_default_args.
788
+
789
+ Returns:
790
+ function to return new DictConfig object
791
+ """
792
+
793
+ def create():
794
+ return get_default_args(C, _do_not_process=_do_not_process)
795
+
796
+ return dataclasses.field(default_factory=create)
797
+
798
+
799
+ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
800
+ """
801
+ If a member is annotated as `type_`, and that should expanded in
802
+ expand_args_fields, return how it should be expanded.
803
+ """
804
+ if get_origin(type_) == Union:
805
+ # We look for Optional[X] which is a Union of X with None.
806
+ args = get_args(type_)
807
+ if len(args) != 2 or all(a is not type(None) for a in args): # noqa: E721
808
+ return
809
+ underlying = args[0] if args[1] is type(None) else args[1] # noqa: E721
810
+ if (
811
+ isinstance(underlying, type)
812
+ and issubclass(underlying, ReplaceableBase)
813
+ and ReplaceableBase in underlying.__bases__
814
+ ):
815
+ return underlying, _ProcessType.OPTIONAL_REPLACEABLE
816
+
817
+ if isinstance(underlying, type) and issubclass(underlying, Configurable):
818
+ return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
819
+
820
+ if not isinstance(type_, type):
821
+ # e.g. any other Union or Tuple
822
+ return
823
+
824
+ if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
825
+ return type_, _ProcessType.REPLACEABLE
826
+
827
+ if issubclass(type_, Configurable):
828
+ return type_, _ProcessType.CONFIGURABLE
829
+
830
+
831
+ def _process_member(
832
+ *,
833
+ name: str,
834
+ type_: Type,
835
+ process_type: _ProcessType,
836
+ some_class: Type,
837
+ creation_functions: List[str],
838
+ _do_not_process: Tuple[type, ...],
839
+ known_implementations: Dict[str, Type],
840
+ ) -> None:
841
+ """
842
+ Make the modification (of expand_args_fields) to some_class for a single member.
843
+
844
+ Args:
845
+ name: member name
846
+ type_: member type (with Optional removed if needed)
847
+ process_type: whether member has dynamic type
848
+ some_class: (MODIFIED IN PLACE) the class being processed
849
+ creation_functions: (MODIFIED IN PLACE) the names of the create functions
850
+ _do_not_process: as for expand_args_fields.
851
+ known_implementations: (MODIFIED IN PLACE) known types from the registry
852
+ """
853
+ # Because we are adding defaultable members, make
854
+ # sure they go at the end of __annotations__ in case
855
+ # there are non-defaulted standard class members.
856
+ del some_class.__annotations__[name]
857
+
858
+ if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
859
+ type_name = name + TYPE_SUFFIX
860
+ if type_name not in some_class.__annotations__:
861
+ if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
862
+ some_class.__annotations__[type_name] = Optional[str]
863
+ else:
864
+ some_class.__annotations__[type_name] = str
865
+ setattr(some_class, type_name, "UNDEFAULTED")
866
+
867
+ for derived_type in registry.get_all(type_):
868
+ if derived_type in _do_not_process:
869
+ continue
870
+ if issubclass(derived_type, some_class):
871
+ # When derived_type is some_class we have a simple
872
+ # recursion to avoid. When it's a strict subclass the
873
+ # situation is even worse.
874
+ continue
875
+ known_implementations[derived_type.__name__] = derived_type
876
+ args_name = f"{name}_{derived_type.__name__}{ARGS_SUFFIX}"
877
+ if args_name in some_class.__annotations__:
878
+ raise ValueError(
879
+ f"Cannot generate {args_name} because it is already present."
880
+ )
881
+ some_class.__annotations__[args_name] = DictConfig
882
+ setattr(
883
+ some_class,
884
+ args_name,
885
+ get_default_args_field(
886
+ derived_type, _do_not_process=_do_not_process + (some_class,)
887
+ ),
888
+ )
889
+ else:
890
+ args_name = name + ARGS_SUFFIX
891
+ if args_name in some_class.__annotations__:
892
+ raise ValueError(
893
+ f"Cannot generate {args_name} because it is already present."
894
+ )
895
+ if issubclass(type_, some_class) or type_ in _do_not_process:
896
+ raise ValueError(f"Cannot process {type_} inside {some_class}")
897
+
898
+ some_class.__annotations__[args_name] = DictConfig
899
+ setattr(
900
+ some_class,
901
+ args_name,
902
+ get_default_args_field(
903
+ type_,
904
+ _do_not_process=_do_not_process + (some_class,),
905
+ ),
906
+ )
907
+ if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
908
+ enabled_name = name + ENABLED_SUFFIX
909
+ if enabled_name not in some_class.__annotations__:
910
+ some_class.__annotations__[enabled_name] = bool
911
+ setattr(some_class, enabled_name, False)
912
+
913
+ creation_function_name = f"create_{name}"
914
+ if not hasattr(some_class, creation_function_name):
915
+ setattr(
916
+ some_class,
917
+ creation_function_name,
918
+ _default_create(name, type_, process_type),
919
+ )
920
+ creation_functions.append(creation_function_name)
921
+
922
+
923
+ def remove_unused_components(dict_: DictConfig) -> None:
924
+ """
925
+ Assuming dict_ represents the state of a configurable,
926
+ modify it to remove all the portions corresponding to
927
+ pluggable parts which are not in use.
928
+ For example, if renderer_class_type is SignedDistanceFunctionRenderer,
929
+ the renderer_MultiPassEmissionAbsorptionRenderer_args will be
930
+ removed. Also, if chocolate_enabled is False, then chocolate_args will
931
+ be removed.
932
+
933
+ Args:
934
+ dict_: (MODIFIED IN PLACE) a DictConfig instance
935
+ """
936
+ keys = [key for key in dict_ if isinstance(key, str)]
937
+ suffix_length = len(TYPE_SUFFIX)
938
+ replaceables = [key[:-suffix_length] for key in keys if key.endswith(TYPE_SUFFIX)]
939
+ args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
940
+ for replaceable in replaceables:
941
+ selected_type = dict_[replaceable + TYPE_SUFFIX]
942
+ if selected_type is None:
943
+ expect = ""
944
+ else:
945
+ expect = replaceable + "_" + selected_type + ARGS_SUFFIX
946
+ with open_dict(dict_):
947
+ for key in args_keys:
948
+ if key.startswith(replaceable + "_") and key != expect:
949
+ del dict_[key]
950
+
951
+ suffix_length = len(ENABLED_SUFFIX)
952
+ enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)]
953
+ for enableable in enableables:
954
+ enabled = dict_[enableable + ENABLED_SUFFIX]
955
+ if not enabled:
956
+ with open_dict(dict_):
957
+ dict_.pop(enableable + ARGS_SUFFIX, None)
958
+
959
+ for key in dict_:
960
+ if isinstance(dict_.get(key), DictConfig):
961
+ remove_unused_components(dict_[key])
models/core/utils/utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def interp(tensor, size):
11
+ return F.interpolate(
12
+ tensor,
13
+ size=size,
14
+ mode="bilinear",
15
+ align_corners=True,
16
+ )
17
+
18
+
19
+ class InputPadder:
20
+ """Pads images such that dimensions are divisible by 8"""
21
+
22
+ def __init__(self, dims, mode="sintel", divis_by=8):
23
+ self.ht, self.wd = dims[-2:]
24
+ pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
25
+ pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
26
+ if mode == "sintel":
27
+ self._pad = [
28
+ pad_wd // 2,
29
+ pad_wd - pad_wd // 2,
30
+ pad_ht // 2,
31
+ pad_ht - pad_ht // 2,
32
+ ]
33
+ else:
34
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
35
+
36
+ def pad(self, *inputs):
37
+ assert all((x.ndim == 4) for x in inputs)
38
+ return [F.pad(x, self._pad, mode="replicate") for x in inputs]
39
+
40
+ def unpad(self, x):
41
+ assert x.ndim == 4
42
+ ht, wd = x.shape[-2:]
43
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
44
+ return x[..., c[0] : c[1], c[2] : c[3]]
models/dynamic_stereo_model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import ClassVar
8
+
9
+ import torch
10
+ from pytorch3d.implicitron.tools.config import Configurable
11
+
12
+ from dynamic_stereo.models.core.dynamic_stereo import DynamicStereo
13
+
14
+
15
+ class DynamicStereoModel(Configurable, torch.nn.Module):
16
+
17
+ MODEL_CONFIG_NAME: ClassVar[str] = "DynamicStereoModel"
18
+
19
+ # model_weights: str = "./checkpoints/dynamic_stereo_sf.pth"
20
+ model_weights: str = "./checkpoints/dynamic_stereo_dr_sf.pth"
21
+ kernel_size: int = 20
22
+
23
+ def __post_init__(self):
24
+ super().__init__()
25
+
26
+ self.mixed_precision = False
27
+ model = DynamicStereo(
28
+ mixed_precision=self.mixed_precision,
29
+ num_frames=5,
30
+ attention_type="self_stereo_temporal_update_time_update_space",
31
+ use_3d_update_block=True,
32
+ different_update_blocks=True,
33
+ )
34
+
35
+ state_dict = torch.load(self.model_weights, map_location="cpu")
36
+ if "model" in state_dict:
37
+ state_dict = state_dict["model"]
38
+ if "state_dict" in state_dict:
39
+ state_dict = state_dict["state_dict"]
40
+ state_dict = {"module." + k: v for k, v in state_dict.items()}
41
+ model.load_state_dict(state_dict, strict=False)
42
+
43
+ self.model = model
44
+ self.model.to("cuda")
45
+ self.model.eval()
46
+
47
+ def forward(self, batch_dict, iters=20):
48
+ return self.model.forward_batch_test(
49
+ batch_dict, kernel_size=self.kernel_size, iters=iters
50
+ )
models/raft_stereo_model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from types import SimpleNamespace
9
+ from typing import ClassVar
10
+
11
+ import torch
12
+ from pytorch3d.implicitron.tools.config import Configurable
13
+
14
+ import importlib
15
+ import sys
16
+
17
+ sys.path.append("third_party/RAFT-Stereo")
18
+ raft_stereo = importlib.import_module(
19
+ "dynamic_stereo.third_party.RAFT-Stereo.core.raft_stereo"
20
+ )
21
+ raft_stereo_utils = importlib.import_module(
22
+ "dynamic_stereo.third_party.RAFT-Stereo.core.utils.utils"
23
+ )
24
+ autocast = torch.cuda.amp.autocast
25
+
26
+
27
+ class RAFTStereoModel(Configurable, torch.nn.Module):
28
+ MODEL_CONFIG_NAME: ClassVar[str] = "RAFTStereoModel"
29
+ model_weights: str = "./third_party/RAFT-Stereo/models/raftstereo-middlebury.pth"
30
+
31
+ def __post_init__(self):
32
+ super().__init__()
33
+
34
+ model_args = SimpleNamespace(
35
+ hidden_dims=[128] * 3,
36
+ corr_implementation="reg",
37
+ shared_backbone=False,
38
+ corr_levels=4,
39
+ corr_radius=4,
40
+ n_downsample=2,
41
+ slow_fast_gru=False,
42
+ n_gru_layers=3,
43
+ mixed_precision=False,
44
+ context_norm="batch",
45
+ )
46
+ self.args = model_args
47
+ model = torch.nn.DataParallel(
48
+ raft_stereo.RAFTStereo(model_args), device_ids=[0]
49
+ )
50
+
51
+ state_dict = torch.load(self.model_weights, map_location="cpu")
52
+ if "state_dict" in state_dict:
53
+ state_dict = state_dict["state_dict"]
54
+ state_dict = {"module." + k: v for k, v in state_dict.items()}
55
+ model.load_state_dict(state_dict)
56
+
57
+ self.model = model.module
58
+ self.model.to("cuda")
59
+ self.model.eval()
60
+
61
+ def forward(self, batch_dict, iters=32):
62
+ predictions = defaultdict(list)
63
+ for stereo_pair in batch_dict["stereo_video"]:
64
+ left_image_rgb = stereo_pair[None, 0].cuda()
65
+ right_image_rgb = stereo_pair[None, 1].cuda()
66
+
67
+ padder = raft_stereo_utils.InputPadder(left_image_rgb.shape, divis_by=32)
68
+ left_image_rgb, right_image_rgb = padder.pad(
69
+ left_image_rgb, right_image_rgb
70
+ )
71
+
72
+ with autocast(enabled=self.args.mixed_precision):
73
+ _, flow_up = self.model.forward(
74
+ left_image_rgb,
75
+ right_image_rgb,
76
+ iters=iters,
77
+ test_mode=True,
78
+ )
79
+ flow_up = padder.unpad(flow_up)
80
+ predictions["disparity"].append(flow_up)
81
+ predictions["disparity"] = (
82
+ torch.stack(predictions["disparity"]).squeeze(1).abs()
83
+ )
84
+ return predictions
notebooks/Dynamic_Replica_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/evaluate.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra-core==1.1
2
+ einops==0.4.1
3
+ flow_vis==0.1
4
+ imageio==2.21.1
5
+ matplotlib==3.5.3
6
+ munch==2.5.0
7
+ numpy==1.23.5
8
+ omegaconf==2.1.0
9
+ opencv_python==4.6.0.66
10
+ opt_einsum==3.3.0
11
+ Pillow==9.5.0
12
+ pytorch_lightning==1.6.0
13
+ requests
14
+ scikit_image==0.19.2
15
+ scipy==1.10.0
16
+ setuptools==65.6.3
17
+ tabulate==0.8.10
18
+ tqdm==4.64.1
19
+ moviepy
20
+ jupyter
scripts/checksum_check.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import glob
9
+ import argparse
10
+ import hashlib
11
+ import json
12
+
13
+ from typing import Optional
14
+ from multiprocessing import Pool
15
+ from tqdm import tqdm
16
+
17
+
18
+ DEFAULT_SHA256S_FILE = os.path.join(__file__.rsplit(os.sep, 2)[0], "dr_sha256.json")
19
+ BLOCKSIZE = 65536
20
+
21
+
22
+ def main(
23
+ download_folder: str,
24
+ sha256s_file: str,
25
+ dump: bool = False,
26
+ n_sha256_workers: int = 4
27
+ ):
28
+ if not os.path.isfile(sha256s_file):
29
+ raise ValueError(f"The SHA256 file does not exist ({sha256s_file}).")
30
+
31
+ expected_sha256s = get_expected_sha256s(
32
+ sha256s_file=sha256s_file
33
+ )
34
+
35
+ zipfiles = sorted(glob.glob(os.path.join(download_folder, "*.zip")))
36
+ print(f"Extracting SHA256 hashes for {len(zipfiles)} files in {download_folder}.")
37
+ extracted_sha256s_list = []
38
+ with Pool(processes=n_sha256_workers) as sha_pool:
39
+ for extracted_hash in tqdm(
40
+ sha_pool.imap(_sha256_file_and_print, zipfiles),
41
+ total=len(zipfiles),
42
+ ):
43
+ extracted_sha256s_list.append(extracted_hash)
44
+ pass
45
+
46
+ extracted_sha256s = dict(
47
+ zip([os.path.split(z)[-1] for z in zipfiles], extracted_sha256s_list)
48
+ )
49
+
50
+ if dump:
51
+ print(extracted_sha256s)
52
+ with open(sha256s_file, "w") as f:
53
+ json.dump(extracted_sha256s, f, indent=2)
54
+
55
+
56
+ missing_keys, invalid_keys = [], []
57
+ for k in expected_sha256s.keys():
58
+ if k not in extracted_sha256s:
59
+ print(f"{k} missing!")
60
+ missing_keys.append(k)
61
+ elif expected_sha256s[k] != extracted_sha256s[k]:
62
+ print(
63
+ f"'{k}' does not match!"
64
+ + f" ({expected_sha256s[k]} != {extracted_sha256s[k]})"
65
+ )
66
+ invalid_keys.append(k)
67
+ if len(invalid_keys) + len(missing_keys) > 0:
68
+ raise ValueError(
69
+ f"Checksum checker failed!"
70
+ + f" Non-matching checksums: {str(invalid_keys)};"
71
+ + f" missing files: {str(missing_keys)}."
72
+ )
73
+
74
+
75
+ def get_expected_sha256s(
76
+ sha256s_file: str
77
+ ):
78
+ with open(sha256s_file, "r") as f:
79
+ expected_sha256s = json.load(f)
80
+ return expected_sha256s
81
+
82
+
83
+ def check_dr_sha256(
84
+ path: str,
85
+ sha256s_file: str,
86
+ expected_sha256s: Optional[dict] = None,
87
+ do_assertion: bool = True,
88
+ ):
89
+ zipname = os.path.split(path)[-1]
90
+ if expected_sha256s is None:
91
+ expected_sha256s = get_expected_sha256s(
92
+ sha256s_file=sha256s_file,
93
+ )
94
+ extracted_hash = sha256_file(path)
95
+ if do_assertion:
96
+ assert (
97
+ extracted_hash == expected_sha256s[zipname]
98
+ ), f"{zipname}: ({extracted_hash} != {expected_sha256s[zipname]})"
99
+ else:
100
+ return extracted_hash == expected_sha256s[zipname]
101
+
102
+
103
+ def sha256_file(path: str):
104
+ sha256_hash = hashlib.sha256()
105
+ with open(path, "rb") as f:
106
+ file_buffer = f.read(BLOCKSIZE)
107
+ while len(file_buffer) > 0:
108
+ sha256_hash.update(file_buffer)
109
+ file_buffer = f.read(BLOCKSIZE)
110
+ digest_ = sha256_hash.hexdigest()
111
+ return digest_
112
+
113
+
114
+ def _sha256_file_and_print(path: str):
115
+ digest_ = sha256_file(path)
116
+ print(f"{path}: {digest_}")
117
+ return digest_
118
+
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser(
123
+ description="Check SHA256 hashes of the Dynamic Replica dataset."
124
+ )
125
+ parser.add_argument(
126
+ "--download_folder",
127
+ type=str,
128
+ help="A local target folder for downloading the the dataset files.",
129
+ )
130
+ parser.add_argument(
131
+ "--sha256s_file",
132
+ type=str,
133
+ help="A local target folder for downloading the the dataset files.",
134
+ default=DEFAULT_SHA256S_FILE,
135
+ )
136
+ parser.add_argument(
137
+ "--num_workers",
138
+ type=int,
139
+ default=4,
140
+ help="The number of sha256 extraction workers.",
141
+ )
142
+ parser.add_argument(
143
+ "--dump_sha256s",
144
+ action="store_true",
145
+ help="Store sha256s hashes.",
146
+ )
147
+
148
+ args = parser.parse_args()
149
+ main(
150
+ str(args.download_folder),
151
+ dump=bool(args.dump_sha256s),
152
+ n_sha256_workers=int(args.num_workers),
153
+ sha256s_file=str(args.sha256s_file),
154
+ )
scripts/download_dynamic_replica.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import sys
9
+
10
+ sys.path.append("./scripts/")
11
+ from download_utils import build_arg_parser, download_dataset
12
+
13
+
14
+ DEFAULT_LINK_LIST_FILE = os.path.join(os.path.dirname(__file__), "links.json")
15
+ DEFAULT_SHA256S_FILE = os.path.join(os.path.dirname(__file__), "dr_sha256.json")
16
+
17
+
18
+ if __name__ == "__main__":
19
+ parser = build_arg_parser(
20
+ "dynamic_replica", DEFAULT_LINK_LIST_FILE, DEFAULT_SHA256S_FILE
21
+ )
22
+
23
+ args = parser.parse_args()
24
+ os.makedirs(args.download_folder, exist_ok=True)
25
+ download_dataset(
26
+ str(args.link_list_file),
27
+ str(args.download_folder),
28
+ n_download_workers=int(args.n_download_workers),
29
+ n_extract_workers=int(args.n_extract_workers),
30
+ download_splits=args.download_splits,
31
+ checksum_check=bool(args.checksum_check),
32
+ clear_archives_after_unpacking=bool(args.clear_archives_after_unpacking),
33
+ sha256s_file=str(args.sha256_file),
34
+ skip_downloaded_archives=not bool(args.redownload_existing_archives),
35
+ )
scripts/download_utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import shutil
9
+ import requests
10
+ import functools
11
+ import json
12
+ import warnings
13
+
14
+ from argparse import ArgumentParser
15
+ from typing import List, Optional
16
+ from multiprocessing import Pool
17
+ from tqdm import tqdm
18
+
19
+ import sys
20
+
21
+ sys.path.append("./scripts/")
22
+
23
+ from checksum_check import check_dr_sha256
24
+
25
+
26
+ def download_dataset(
27
+ link_list_file: str,
28
+ download_folder: str,
29
+ n_download_workers: int = 4,
30
+ n_extract_workers: int = 4,
31
+ download_splits: List[str] = ['real', 'valid', 'test', 'train'],
32
+ checksum_check: bool = False,
33
+ clear_archives_after_unpacking: bool = False,
34
+ skip_downloaded_archives: bool = True,
35
+ sha256s_file: Optional[str] = None,
36
+ ):
37
+ """
38
+ Downloads and unpacks the dataset in CO3D format.
39
+ Note: The script will make a folder `<download_folder>/_in_progress`, which
40
+ stores files whose download is in progress. The folder can be safely deleted
41
+ the download is finished.
42
+ Args:
43
+ link_list_file: A text file with the list of zip file download links.
44
+ download_folder: A local target folder for downloading the
45
+ the dataset files.
46
+ n_download_workers: The number of parallel workers
47
+ for downloading the dataset files.
48
+ n_extract_workers: The number of parallel workers
49
+ for extracting the dataset files.
50
+ download_splits: A list of data splits to download.
51
+ Must be in ['real', 'valid', 'test', 'train'].
52
+ checksum_check: Enable validation of the downloaded file's checksum before
53
+ extraction.
54
+ clear_archives_after_unpacking: Delete the unnecessary downloaded archive files
55
+ after unpacking.
56
+ skip_downloaded_archives: Skip re-downloading already downloaded archives.
57
+ """
58
+
59
+ if checksum_check and not sha256s_file:
60
+ raise ValueError(
61
+ "checksum_check is requested but ground-truth SHA256 file not provided!"
62
+ )
63
+
64
+ if not os.path.isfile(link_list_file):
65
+ raise ValueError(
66
+ "Please specify `link_list_file` with a valid path to a json"
67
+ " with zip file download links."
68
+ # " The file is stored in the DynamicStereo github:"
69
+ # " https://github.com/facebookresearch/dynamic_stereo/blob/main/dynamic_stereo/links.json"
70
+ )
71
+
72
+ if not os.path.isdir(download_folder):
73
+ raise ValueError(
74
+ "Please specify `download_folder` with a valid path to a target folder"
75
+ + " for downloading the dataset."
76
+ + f" {download_folder} does not exist."
77
+ )
78
+
79
+ # read the link file
80
+ with open(link_list_file, "r") as f:
81
+ links = json.load(f)
82
+
83
+ for split in download_splits:
84
+ if split not in ['real', 'valid', 'test', 'train']:
85
+ raise ValueError(
86
+ f"Download split {str(split)} is not valid"
87
+ )
88
+
89
+ data_links = []
90
+ for split_name, urls in links.items():
91
+ if split_name in download_splits:
92
+ for url in urls:
93
+ link_name = os.path.split(url)[-1]
94
+ data_links.append((split_name, link_name, url))
95
+
96
+
97
+ with Pool(processes=n_download_workers) as download_pool:
98
+ download_ok = {}
99
+ for link_name, ok in tqdm(
100
+ download_pool.imap(
101
+ functools.partial(
102
+ _download_split_file,
103
+ download_folder,
104
+ checksum_check,
105
+ sha256s_file,
106
+ skip_downloaded_archives,
107
+ ),
108
+ data_links,
109
+ ),
110
+ total=len(data_links),
111
+ ):
112
+ download_ok[link_name] = ok
113
+
114
+ with Pool(processes=n_extract_workers) as extract_pool:
115
+ for _ in tqdm(
116
+ extract_pool.imap(
117
+ functools.partial(
118
+ _unpack_split_file,
119
+ download_folder,
120
+ clear_archives_after_unpacking,
121
+ ),
122
+ data_links,
123
+ ),
124
+ total=len(data_links),
125
+ ):
126
+ pass
127
+ print("Done")
128
+
129
+
130
+
131
+ def build_arg_parser(
132
+ dataset_name: str,
133
+ default_link_list_file: str,
134
+ default_sha256_file: str,
135
+ ) -> ArgumentParser:
136
+ parser = ArgumentParser(description=f"Download the {dataset_name} dataset.")
137
+ parser.add_argument(
138
+ "--download_folder",
139
+ type=str,
140
+ required=True,
141
+ help="A local target folder for downloading the the dataset files.",
142
+ )
143
+ parser.add_argument(
144
+ "--n_download_workers",
145
+ type=int,
146
+ default=4,
147
+ help="The number of parallel workers for downloading the dataset files.",
148
+ )
149
+ parser.add_argument(
150
+ "--n_extract_workers",
151
+ type=int,
152
+ default=4,
153
+ help="The number of parallel workers for extracting the dataset files.",
154
+ )
155
+ parser.add_argument(
156
+ "--download_splits",
157
+ default=['real', 'valid', 'test', 'train'],
158
+ nargs='+',
159
+ help=f"A comma-separated list of {dataset_name} splits to download.",
160
+ )
161
+ parser.add_argument(
162
+ "--link_list_file",
163
+ type=str,
164
+ default=default_link_list_file,
165
+ help=(
166
+ f"The file with html links to the {dataset_name} dataset files."
167
+ + " In most cases the default local file `links.json` should be used."
168
+ ),
169
+ )
170
+ parser.add_argument(
171
+ "--sha256_file",
172
+ type=str,
173
+ default=default_sha256_file,
174
+ help=(
175
+ f"The file with SHA256 hashes of {dataset_name} dataset files."
176
+ + " In most cases the default local file `dr_sha256.json` should be used."
177
+ ),
178
+ )
179
+ parser.add_argument(
180
+ "--checksum_check",
181
+ action="store_true",
182
+ default=True,
183
+ help="Check the SHA256 checksum of each downloaded file before extraction.",
184
+ )
185
+ parser.add_argument(
186
+ "--no_checksum_check",
187
+ action="store_false",
188
+ dest="checksum_check",
189
+ default=False,
190
+ help="Does not check the SHA256 checksum of each downloaded file before extraction.",
191
+ )
192
+ parser.set_defaults(checksum_check=True)
193
+ parser.add_argument(
194
+ "--clear_archives_after_unpacking",
195
+ action="store_true",
196
+ default=False,
197
+ help="Delete the unnecessary downloaded archive files after unpacking.",
198
+ )
199
+ parser.add_argument(
200
+ "--redownload_existing_archives",
201
+ action="store_true",
202
+ default=False,
203
+ help="Redownload the already-downloaded archives.",
204
+ )
205
+
206
+ return parser
207
+
208
+ def _unpack_split_file(
209
+ download_folder: str,
210
+ clear_archive: bool,
211
+ link: str,
212
+ ):
213
+ split, link_name, url = link
214
+ local_fl = os.path.join(download_folder, link_name)
215
+ print(f"Unpacking dataset file {local_fl} ({link_name}) to {download_folder}.")
216
+
217
+ download_folder_split = os.path.join(download_folder, split)
218
+ # os.makedirs(download_folder_split, exist_ok=True)
219
+ shutil.unpack_archive(local_fl, download_folder_split)
220
+ if clear_archive:
221
+ os.remove(local_fl)
222
+
223
+ def _download_split_file(
224
+ download_folder: str,
225
+ checksum_check: bool,
226
+ sha256s_file: Optional[str],
227
+ skip_downloaded_files: bool,
228
+ link: str,
229
+ ):
230
+ __, link_name, url = link
231
+ local_fl_final = os.path.join(download_folder, link_name)
232
+
233
+ if skip_downloaded_files and os.path.isfile(local_fl_final):
234
+ print(f"Skipping {local_fl_final}, already downloaded!")
235
+ return link_name, True
236
+
237
+ in_progress_folder = os.path.join(download_folder, "_in_progress")
238
+ os.makedirs(in_progress_folder, exist_ok=True)
239
+ local_fl = os.path.join(in_progress_folder, link_name)
240
+
241
+ print(f"Downloading dataset file {link_name} ({url}) to {local_fl}.")
242
+ _download_with_progress_bar(url, local_fl, link_name)
243
+ if checksum_check:
244
+ print(f"Checking SHA256 for {local_fl}.")
245
+ try:
246
+ check_dr_sha256(
247
+ local_fl,
248
+ sha256s_file=sha256s_file,
249
+ )
250
+ except AssertionError:
251
+ warnings.warn(
252
+ f"Checksums for {local_fl} did not match!"
253
+ + " This is likely due to a network failure,"
254
+ + " please restart the download script."
255
+ )
256
+ return link_name, False
257
+
258
+ os.rename(local_fl, local_fl_final)
259
+ return link_name, True
260
+
261
+
262
+ def _download_with_progress_bar(url: str, fname: str, filename: str):
263
+
264
+ # taken from https://stackoverflow.com/a/62113293/986477
265
+ resp = requests.get(url, stream=True)
266
+ print(url)
267
+ total = int(resp.headers.get("content-length", 0))
268
+ with open(fname, "wb") as file, tqdm(
269
+ desc=fname,
270
+ total=total,
271
+ unit="iB",
272
+ unit_scale=True,
273
+ unit_divisor=1024,
274
+ ) as bar:
275
+ for datai, data in enumerate(resp.iter_content(chunk_size=1024)):
276
+ size = file.write(data)
277
+ bar.update(size)
278
+ if datai % max((max(total // 1024, 1) // 20), 1) == 0:
279
+ print(f"{filename}: Downloaded {100.0*(float(bar.n)/max(total, 1)):3.1f}%.")
280
+ print(bar)
scripts/dr_sha256.json ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "real_000.zip": "e5c2aac04146d783c64f76d0ef7a9e8d49d80ffac99d2a795563517f15943a6f",
3
+ "valid_000.zip": "0f35bee47030ae1a30289beb92ba69c5336491e0f07aab0a05cb5505173d1faf",
4
+ "valid_001.zip": "cb37d3b1f643118ae22840b4212b00c00a8fe137099d3730a07796a5fefab24a",
5
+ "valid_002.zip": "5535f2a98e06c68cf97e3259e962e3d44465a1820369e4425c4ef2a719b01ad0",
6
+ "valid_003.zip": "e19db94514d22829743aa363698f407ecfd98d8f08eab037289a420939ef5143",
7
+ "valid_004.zip": "953328f24ba0c3e8709df3829cce238305a8998bf7ae938c80069fab6f513862",
8
+ "valid_005.zip": "27ce4c7424292dcf3e8e0b370fbbc848bd6d73ae28ea5832fddfa8e9c17d6011",
9
+ "test_000.zip": "a56fa676a7a3dc52b33f1571d41fb0221e289735acccb7b9ad42dfb13fdac68c",
10
+ "test_001.zip": "43580e89331826182f41d2ce9f06f62da46617fea9e612a16b2610de8ffdc10b",
11
+ "test_002.zip": "33551fb68979d3d2f20e1976d9169a84ad58658c459aba4d7a2671c8d66904b9",
12
+ "test_003.zip": "45ad28d7555e3579225d26dfcb8244b65de0d1ee749560cc6dd84f121b4b40de",
13
+ "test_004.zip": "d736b56fe15410525deda1c16c0b8da4497383480a4328da92bc0ddb64a62d52",
14
+ "test_005.zip": "3ae331047019a39c6306a17407c72e40dc15b5113f6f9ef72aba2da0b859ea7d",
15
+ "test_006.zip": "94341c8ac8ed1d7f11816ad121e6c5821a751fdc3d3122a653c86f7b5845ca80",
16
+ "test_007.zip": "4e18facbd507e16fc41d90d5c2ce1b44c511d3e2986e1ccdf5d264748d5d7e15",
17
+ "test_008.zip": "e4d5aa0c25eb01863bbced477e17fddd9d8217d23d238bb06b0b688a0f6ed8e3",
18
+ "test_009.zip": "5a413411cfc376078ed0357708084f71949159c13119aabb5c9ae1ffde33b6b7",
19
+ "test_010.zip": "82ea42c7544385aa2d41271e63399534a398dbbef8a06cb990c8bb34296928c8",
20
+ "train_000.zip": "e9fd9af579b0d08d538551c0ab6f7231a1fd162139667803e672cc0dc8b98b03",
21
+ "train_001.zip": "65cb438c7a48567f85db8e54109db6c25d2a621fcbd3267c542a8a640e1dad56",
22
+ "train_002.zip": "c3d9a76a955dd9feb0275837a17133a1d7ee76c963f1c6fa7630deb0ca8209b2",
23
+ "train_003.zip": "13e108f78c7da1f1c1469dd87fab55a6e4ec79f1fcb6d7d16cc9006a933979f4",
24
+ "train_004.zip": "171b92a62b46a68f1d89c2326ba67b6433faf087bc1eecc7a947c19d0f90d3e6",
25
+ "train_005.zip": "75461ffe13cfbd87b4f0f9ffc83002b8381f5a0a212ece18b8961012f865a46e",
26
+ "train_006.zip": "7546f94817814031a738082e6b30858d0057710af052a88fa505a961b6699886",
27
+ "train_007.zip": "371dd100b215bcd41129def1c8fd07f974af11a9b3d3b9966ce5d9700b9929ad",
28
+ "train_008.zip": "313f5c2089c6afc1691edf054e8b2af9eb8b2d91f791153763758c8d91abee48",
29
+ "train_009.zip": "9cbb9f44bb6b7dcc74f00a51db4d2a8797c95a0d880d63ef1612d3883b16b995",
30
+ "train_010.zip": "eb158fccc23a4b41358ec94be203f49a677f86626af7a88f0e649454c409c706",
31
+ "train_011.zip": "f8b3f8c738cdcdbbdf346a4dd78b99883b5d4ab74c11b64ec7b4f8ccd3b68ffc",
32
+ "train_012.zip": "b364ba9d35d7e55019d3554cf65b295d2358859c222b3b847b0f2cced948cfce",
33
+ "train_013.zip": "c8a50efbd93e6e422eabf1846dac2d75e81dfcfcd4d785fe18b01526af9695f6",
34
+ "train_014.zip": "52a768ce76310861cf1fc990ebb8d16f0c27fceff02c12b11638d36ca1c3a927",
35
+ "train_015.zip": "67bf0ba775948997f5ab3cc810b6d0e8149758334210ace6f5cdfc529fe7d26e",
36
+ "train_016.zip": "d5b9a26736421d8f330fd5e531d26071531501a88609d29d580b9d56b6bc17a3",
37
+ "train_017.zip": "5f2d2c93e7944baf1e6d3dee671b12abb7476a75cbd6f572af86fe5c22472fa6",
38
+ "train_018.zip": "77aa801b6b0359b970466329e4a05b937df94b650228cf4797a2a029606b8e5b",
39
+ "train_019.zip": "30934c91cc0ae69acef6a89e4a5180686bd04080e2384a8bde5877cbaaadc575",
40
+ "train_020.zip": "901d5c08705a70053a3e865354a4e7149c35f026b6ed166fee029d829d88c124",
41
+ "train_021.zip": "f27019ff58e54a004ed2cf2106ed459a31b010ed82d32028b0e196dd365b8b0e",
42
+ "train_022.zip": "0600346a2ce162f7e9824e90c553b69a656d4731c86d903e300d932ec8ba7600",
43
+ "train_023.zip": "660d768e4b1bfe742a42ae6ee84f5e91c930789488a7c7f118e5d0edd1f1a010",
44
+ "train_024.zip": "1f8792002baceaba8f93f93be1bee7c83a48c677e4b2d025b6f0047a796e94cd",
45
+ "train_025.zip": "0b92b3f41c18fded8fcb7aba44e7d8738750b8155c907924200fdf4dc1718794",
46
+ "train_026.zip": "4dc401639317527231abfef07221b8d7db2d0950008828104cd1f72092325d05",
47
+ "train_027.zip": "e8313eaa21163f9dd2ff4558d16b1c9cf4962c2e4c0403d6a315955660a98b14",
48
+ "train_028.zip": "d73edf1c500b4311795aaae0a03b3bc04a2c266e2a20b27ba9b6e72fb27fd277",
49
+ "train_029.zip": "c5e4d302c62e693626445aba19638711108049235b0075558e7949b189050c56",
50
+ "train_030.zip": "506b9ba7a740b0bf84159546f797437a48a24e468cb949f2189e51cf404c6170",
51
+ "train_031.zip": "f36bb4b77fdb255dae2050884cf59cd3f8e46e77ea2984b4b219b799c4aac089",
52
+ "train_032.zip": "fddca4efc40ed8d05adf9d519e4fb5b486ac77e8fa08c98d5c4be15867fda8a0",
53
+ "train_033.zip": "c24d2b5c04f3e90b265fd0762e7ae19fb01a7c1948a4c09451383a9eec9f640f",
54
+ "train_034.zip": "5828fbd615c4476f6107fe844cbf81632eff2f9c75194cb84d749630d9359e14",
55
+ "train_035.zip": "7b60fe125fd1a9ba7991e2accd0f2b212968983b4631d43eccff9836a0c35ba8",
56
+ "train_036.zip": "0f4eaf464a2afc62447a802159b3844487b80e9d1c9b0a7d324b0d5914514d60",
57
+ "train_037.zip": "ba85a6692d86e48c4c787b334d4384c08b914e4cee7f3d2692dcae1bbac55878",
58
+ "train_038.zip": "c67b0f5305560d8089bdc2f6212c05256c044e50a715d59b864fbef705bc6b5c",
59
+ "train_039.zip": "f4b66c9e1360a8d6d8337c94eefb1132d865c2735c6b78ba726a590073174aad",
60
+ "train_040.zip": "2c64b76d028fcc153f267925b79a24cf3bb0e42cc7716773df2139f5cec5e319",
61
+ "train_041.zip": "22b1c0ab99a7f8bd0d36c2d2511d3d469cc390776c38132d1e8f1ad7aae5d4ff",
62
+ "train_042.zip": "8f2afaecb9f90947c9071111fde9c015acfceb432ae0bf94deff3ecd581b26c8",
63
+ "train_043.zip": "adf7ea7c356339b10b797c49163252704b4e6b0cebcc741d3374f8c9467f6b43",
64
+ "train_044.zip": "3d0fe4a85fd22ff9c8ed468ca8173d93406a72fadf800d9e6bbf209348cf8965",
65
+ "train_045.zip": "70874eca6bce66cb7681092755d066968e9c8fc32a266d7c0d2f29c01b2b2669",
66
+ "train_046.zip": "01adcdbba0a25383e2281ce02a946f6bc824e1b8e16cf88e85a4ad275203884c",
67
+ "train_047.zip": "50ed632ae330acf60c1b2e22b28fbfab5ccf0e8f23320b2911dcc2d43db048b6",
68
+ "train_048.zip": "f302984f486df60d7a281e2b0a9b6d32456fc6042eb596cb5ef54ee919ccd7bb",
69
+ "train_049.zip": "8e8e0a426796f76dfb2d29cb855894fd01cc954b017aa1d06ae1a121fb310088",
70
+ "train_050.zip": "051f0dd8e612e7073dd20585c42681daeff853a6ee0de6f2e8ff4581cdf4f83b",
71
+ "train_051.zip": "3f39b3732c32b960aef4bf3f152b1a72195dc4ab4bbc10116a05875ca8d40417",
72
+ "train_052.zip": "361b9bcd3364c63c8f2814dfacf91489b79c9cedf03ffcb03b3dacfb77cee3a1",
73
+ "train_053.zip": "f6afe23b3005b1889f76ea9c10ac42f7c4f07cefbe737781229640b834f8ede2",
74
+ "train_054.zip": "ef993bd657104770df8e07a9d7c8ac1d1c3ac57b91f66796bea97f03e5a01df2",
75
+ "train_055.zip": "ec0dea8199e1db7bd8e19f85b0d1a9ab9e8fc2be2c5da5b3455f96e074ad7f22",
76
+ "train_056.zip": "44259829f6832c3dc14b893d5f5b7b6f784a09570f26e9cc9749807a1b05b21e",
77
+ "train_057.zip": "263b712fe2ded353cb248324305f831d8b14aa0858f005067bb27e88decd7f32",
78
+ "train_058.zip": "c44fb44365bc4cd8c4c9bb13d70fa9bb290708b7d3fe44fd79c6eed42702ed70",
79
+ "train_059.zip": "43dd65609afb3992273f914b4d0108187f85eaf1f252f85556f10e40816d5e6c",
80
+ "train_060.zip": "97b2abe90259f4629d7c1c1cec2427f155252403f5dcfea563e2d1338ae63150",
81
+ "train_061.zip": "9d8c790d1806659617ddd6dd99ae56388b5eb9f311c47a079ac8fa5df8f44f57",
82
+ "train_062.zip": "5b4398d6a8709ddf1b050b03b19dfe8aacf3378a4879402f457f12bd97ab99df",
83
+ "train_063.zip": "05024f1b0671cb3026db0b9e801c9aab000b828784839f970a8ad0bc23125435",
84
+ "train_064.zip": "b9bba3999971745ea2cdce69c00c49b109ba02c9f3169614d1d229e468bebc68",
85
+ "train_065.zip": "ff4084dd7c017478b872fd7c9152df5271a7088489d3b86cc21968db272356ef",
86
+ "train_066.zip": "9d8158fd6691065c1cb76ac36c3be90b065e8848856a66b10475b11e1261dd4d",
87
+ "train_067.zip": "3e4b9ebef2bdecab5774a72037d9f1f7c40359e6a2d00851c0c40bdd686373c5",
88
+ "train_068.zip": "a89d53ce7c79af32a659a2a59138568ada1395c56c6063f4f49c1d4e052cf9cd",
89
+ "train_069.zip": "3f66206486af3f0bfa04ce8f664b6af6aa7fd2ad8ebadd5c75039de8c5ffea91",
90
+ "train_070.zip": "e8a95aad5f81e7185a7dacb9031a5c27010ec17302e2e35f7f1de3dc88e02a7b",
91
+ "train_071.zip": "677bf42f8d576c79189cd5af2abf420990368d9c7d768a21a10fc0939dde121f",
92
+ "train_072.zip": "f8d5ea223dc13663bbaae6c5bbd732db15f1c249e7fe2da44b5a6ba5b7dbf505",
93
+ "train_073.zip": "3057bda88ebd5bffb0da030d1126e1fb4fed4b5fbfc547dc0be669ece39979c1",
94
+ "train_074.zip": "f3a01d19e6fedd44679d76ee93051b91b616a55b6b22861db126b8d2bfdba7ce",
95
+ "train_075.zip": "0faa29f3f712f744e003da29b249896cc770fb9b357e8a4c447eeb6ad2798ce2",
96
+ "train_076.zip": "d9943f9b72be89dd8f1273bd02133ab24b81e3c3f794e13362a96b0826518696",
97
+ "train_077.zip": "cfab28d27c1532a91980b65baa4d40c8e13144788b9ae7a4c36ce8b909e51e55",
98
+ "train_078.zip": "b06277baadbe60b2019d0f7b6ed637b23957b6320797bf4b6b9099dc4df0cc7e",
99
+ "train_079.zip": "2163ef05752f7a8813fa9cd5661547bc280239fd3bd903b94a8aef37182e9645",
100
+ "train_080.zip": "13ae6b86afe4aa00ce19f4f7a8df24d11742340c5775fca02f6e1f70cd9a3be7",
101
+ "train_081.zip": "a2512084c16220e0acd207f5e330dd319a30c3445b5034f2c14f9a65111628a3",
102
+ "train_082.zip": "d9615ac989465bc85cf990167ce176af55b8affeebb58d5021c215c1f7235c8a",
103
+ "train_083.zip": "539710fcc33b043dd24499d3987852a35c8a1c5fb75f7530a9caebf57fd5f324",
104
+ "train_084.zip": "33232eb1d68e493a25126f22e31326b7c1195ea511c332a1413e83a0245bdae6",
105
+ "train_085.zip": "13e575f24a77278b7de25e3d186f6201692b3e45ed4701b071d5a770c0e1d590"
106
+ }
setup.csh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/csh
2
+
3
+ python -m virtualenv venv
4
+
5
+ # -- CUDA 12.6
6
+ pip install torch==2.1.0+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
7
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
8
+ python -m pip install pip==24.0
9
+ pip install -r requirements.txt
train.csh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/csh
2
+
3
+ set arg_count = $#argv
4
+ if ( $arg_count >= 1 ) then
5
+ if ( "$argv[1]" == "-clean" || "$argv[1]" == "-clean_only" ) then
6
+ echo "[INFO] Killing alll other GPU processes to free up resources."
7
+
8
+ sh -c 'ps | grep python | sed "s/ pts.\+$//g" > .tmp.csh'
9
+ chmod +x .tmp.csh
10
+ sed -i "s/^/kill -9 /g" .tmp.csh
11
+ source .tmp.csh
12
+ rm -rf .tmp.csh
13
+ rm -rf debug_rank_*
14
+ rm -rf dynamicstereo_sf_dr
15
+ endif
16
+
17
+ if ( "$argv[1]" == "-clean_only" ) then
18
+ exit 0
19
+ endif
20
+ endif
21
+
22
+ setenv PYTORCH_CUDA_ALLOC_CONF "max_split_size_mb:32,garbage_collection_threshold:0.5,expandable_segments:False"
23
+ setenv CUDA_LAUNCH_BLOCKING 1
24
+ setenv PYTORCH_NO_CUDA_MEMORY_CACHING 1
25
+ setenv CUBLAS_WORKSPACE_CONFIG ":16:8"
26
+ setenv CUDA_VISIBLE_DEVICES 3
27
+
28
+ # -- GPU OOM Error when trained with sample_len=8 on kilby.
29
+ python train.py --batch_size 1 \
30
+ --image_size 480 640 --saturation_range 0 1.4 --num_steps 200000 \
31
+ --ckpt_path dynamicstereo_sf_dr \
32
+ --sample_len 8 --lr 0.0003 --train_iters 8 --valid_iters 8 \
33
+ --num_workers 28 --save_freq 100 --update_block_3d --different_update_blocks \
34
+ --attention_type self_stereo_temporal_update_time_update_space --train_datasets dynamic_replica
train.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+ import os
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+
16
+ from munch import DefaultMunch
17
+ import json
18
+ from pytorch_lightning.lite import LightningLite
19
+ from torch.cuda.amp import GradScaler
20
+
21
+ from train_utils.utils import (
22
+ run_test_eval,
23
+ save_ims_to_tb,
24
+ count_parameters,
25
+ )
26
+ from train_utils.logger import Logger
27
+ from models.core.dynamic_stereo import DynamicStereo
28
+ from models.core.sci_codec import sci_encoder
29
+ from evaluation.core.evaluator import Evaluator
30
+ from train_utils.losses import sequence_loss
31
+ import datasets.dynamic_stereo_datasets as datasets
32
+
33
+ class wrapper(nn.Module):
34
+ def __init__(
35
+ self,
36
+ sigma_range=[0, 1e-9],
37
+ num_frames=8,
38
+ in_channels=1,
39
+ n_taps=2,
40
+ resolution=[480, 640],
41
+ mixed_precision=True,
42
+ attention_type="self_stereo_temporal_update_time_update_space",
43
+ update_block_3d=True,
44
+ different_update_blocks=True,
45
+ train_iters=16):
46
+
47
+ super(wrapper, self).__init__()
48
+
49
+ self.train_iters = train_iters
50
+
51
+ self.sci_enc_L = sci_encoder(sigma_range=sigma_range,
52
+ n_frame=num_frames,
53
+ in_channels=in_channels,
54
+ n_taps=n_taps,
55
+ resolution=resolution)
56
+ self.sci_enc_R = sci_encoder(sigma_range=sigma_range,
57
+ n_frame=num_frames,
58
+ in_channels=in_channels,
59
+ n_taps=n_taps,
60
+ resolution=resolution)
61
+
62
+ self.stereo = DynamicStereo(max_disp=256,
63
+ mixed_precision=mixed_precision,
64
+ num_frames=num_frames,
65
+ attention_type=attention_type,
66
+ use_3d_update_block=update_block_3d,
67
+ different_update_blocks=different_update_blocks)
68
+
69
+ def forward(self, batch):
70
+ # ---- ---- FORWARD PASS ---- ----
71
+ # -- Modified by Chu King on 20th November 2025
72
+
73
+ # -- print ("[INFO] batch[\"img\"].device: ", batch["img"].device)
74
+
75
+ # 0) Convert to Gray
76
+ def rgb_to_gray(x):
77
+ weights = torch.tensor([0.2989, 0.5870, 0.1140], dtype=x.dtype, device=x.device)
78
+ gray = (x * weights[None, None, :, None, None]).sum(dim=2)
79
+ return gray # -- shape: [B, T, H, W]
80
+
81
+ video_L = rgb_to_gray(batch["img"][:, :, 0]).cuda() # ~ (b, t, h, w)
82
+ video_R = rgb_to_gray(batch["img"][:, :, 1]).cuda() # ~ (b, t, h, w)
83
+
84
+ # -- print ("[INFO] video_L.device: ", video_L.device)
85
+
86
+ # 1) Extract and normalize input videos.
87
+ # -- min_max_norm = lambda x : 2. * (x / 255.) - 1.
88
+ min_max_norm = lambda x: x / 255.
89
+ video_L = min_max_norm(video_L) # ~ (b, t, h, w)
90
+ video_R = min_max_norm(video_R) # ~ (b, t, h, w)
91
+ # -- print ("[INFO] video_L.device: ", video_L.device)
92
+
93
+ # 2) If the tensor is non-contiguous and we try .view() later, PyTorch will raise an error:
94
+ video_L = video_L.contiguous()
95
+ video_R = video_R.contiguous()
96
+
97
+ # -- print ("[INFO] video_L.device: ", video_L.device)
98
+
99
+ # 3) Coded exposure modeling.
100
+ snapshot_L = self.sci_enc_L(video_L) # ~ (b, c, h, w) -- c=2 for 2 taps
101
+ snapshot_R = self.sci_enc_R(video_R) # ~ (b, c, h, w) -- c=2 for 2 taps
102
+
103
+ # -- print ("[INFO] self.sci_enc_L.device: ", next(self.sci_enc_R.parameters()).device)
104
+ # -- print ("[INFO] snapshot_L.device: ", snapshot_L.device)
105
+
106
+ # 4) Dynamic Stereo
107
+ output = {}
108
+
109
+ disparities = self.stereo(
110
+ snapshot_L,
111
+ snapshot_R,
112
+ iters=self.train_iters,
113
+ test_mode=False
114
+ )
115
+
116
+ n_views = len(batch["disp"][0]) # -- sample_len
117
+ for i in range(n_views):
118
+ seq_loss, metrics = sequence_loss(
119
+ disparities[:, i], batch["disp"][:, i, 0], batch["valid_disp"][:, i, 0]
120
+ )
121
+
122
+ output[f"disp_{i}"] = {"loss": seq_loss / n_views, "metrics": metrics}
123
+ output["disparity"] = {
124
+ "predictions": torch.cat(
125
+ [disparities[-1, i, 0] for i in range(n_views)], dim=1
126
+ ).detach(),
127
+ }
128
+ return output
129
+
130
+ def fetch_optimizer(args, model):
131
+ """Create the optimizer and learning rate scheduler"""
132
+ optimizer = optim.AdamW(
133
+ model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
134
+ )
135
+ scheduler = optim.lr_scheduler.OneCycleLR(
136
+ optimizer,
137
+ args.lr,
138
+ args.num_steps + 100,
139
+ pct_start=0.01,
140
+ cycle_momentum=False,
141
+ anneal_strategy="linear",
142
+ )
143
+ return optimizer, scheduler
144
+
145
+
146
+ # -- Modified by Chu King on 20th November 2025
147
+ # -- Take snapshots instead of videos as input.
148
+ # -- def forward_batch(batch, model, args):
149
+ def forward_batch(snapshot_L, snapshot_R, model, args):
150
+ output = {}
151
+
152
+ disparities = model(
153
+ # -- batch["img"][:, :, 0],
154
+ # -- batch["img"][:, :, 1],
155
+ snapshot_L,
156
+ snapshot_R,
157
+ iters=args.train_iters,
158
+ test_mode=False,
159
+ )
160
+ num_traj = len(batch["disp"][0])
161
+ for i in range(num_traj):
162
+ seq_loss, metrics = sequence_loss(
163
+ disparities[:, i], batch["disp"][:, i, 0], batch["valid_disp"][:, i, 0]
164
+ )
165
+
166
+ output[f"disp_{i}"] = {"loss": seq_loss / num_traj, "metrics": metrics}
167
+ output["disparity"] = {
168
+ "predictions": torch.cat(
169
+ [disparities[-1, i, 0] for i in range(num_traj)], dim=1
170
+ ).detach(),
171
+ }
172
+ return output
173
+
174
+
175
+ class Lite(LightningLite):
176
+ def run(self, args):
177
+ self.seed_everything(0)
178
+
179
+ # ----------------------------------------- Loading Dataset -----------------------------------------------
180
+ # -- Modified by Chu King on 15th November 2025 to allow quick testing with only 1 training video on the workstation.
181
+ # -- The number of subframes should be fixed for SCI stereo.
182
+ eval_dataloader_dr = datasets.DynamicReplicaDataset(
183
+ # -- split="valid", sample_len=40, only_first_n_samples=1, VERBOSE=False
184
+ split="valid", sample_len=args.sample_len, only_first_n_samples=1, VERBOSE=False
185
+ )
186
+
187
+ eval_dataloader_sintel_clean = datasets.SequenceSintelStereo(dstype="clean")
188
+ eval_dataloader_sintel_final = datasets.SequenceSintelStereo(dstype="final")
189
+
190
+ eval_dataloaders = [
191
+ ("sintel_clean", eval_dataloader_sintel_clean),
192
+ ("sintel_final", eval_dataloader_sintel_final),
193
+ ("dynamic_replica", eval_dataloader_dr),
194
+ ]
195
+
196
+ evaluator = Evaluator()
197
+
198
+ eval_vis_cfg = {
199
+ "visualize_interval": 1, # Use 0 for no visualization
200
+ "exp_dir": args.ckpt_path,
201
+ }
202
+ eval_vis_cfg = DefaultMunch.fromDict(eval_vis_cfg, object())
203
+ evaluator.setup_visualization(eval_vis_cfg)
204
+
205
+ # ----------------------------------------- Model Instantiation -----------------------------------------------
206
+ # -- Added by Chu King on 20th November 2025
207
+ # -- Instantiate the model
208
+ model = wrapper(sigma_range=[0, 1e-9],
209
+ num_frames=args.sample_len,
210
+ in_channels=1,
211
+ n_taps=2,
212
+ resolution=args.image_size,
213
+ mixed_precision=args.mixed_precision,
214
+ attention_type=args.attention_type,
215
+ update_block_3d=args.update_block_3d,
216
+ different_update_blocks=args.different_update_blocks,
217
+ train_iters=args.train_iters)
218
+
219
+ with open(args.ckpt_path + "/meta.json", "w") as file:
220
+ json.dump(vars(args), file, sort_keys=True, indent=4)
221
+
222
+ model.cuda()
223
+
224
+ logging.info("count_parameters(model): {}".format(count_parameters(model)))
225
+
226
+ train_loader = datasets.fetch_dataloader(args)
227
+ train_loader = self.setup_dataloaders(train_loader, move_to_device=False)
228
+
229
+ logging.info(f"Train loader size: {len(train_loader)}")
230
+
231
+ optimizer, scheduler = fetch_optimizer(args, model)
232
+
233
+ total_steps = 0
234
+ logger = Logger(model, scheduler, args.ckpt_path)
235
+
236
+ # ----------------------------------------- Loading Checkpoint -----------------------------------------------
237
+ folder_ckpts = [
238
+ f
239
+ for f in os.listdir(args.ckpt_path)
240
+ if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f
241
+ ]
242
+ if len(folder_ckpts) > 0:
243
+ ckpt_path = sorted(folder_ckpts)[-1]
244
+ ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path))
245
+ logging.info(f"Loading checkpoint {ckpt_path}")
246
+ if "model" in ckpt:
247
+ model.load_state_dict(ckpt["model"])
248
+ else:
249
+ model.load_state_dict(ckpt)
250
+ if "optimizer" in ckpt:
251
+ logging.info("Load optimizer")
252
+ optimizer.load_state_dict(ckpt["optimizer"])
253
+ if "scheduler" in ckpt:
254
+ logging.info("Load scheduler")
255
+ scheduler.load_state_dict(ckpt["scheduler"])
256
+ if "total_steps" in ckpt:
257
+ total_steps = ckpt["total_steps"]
258
+ logging.info(f"Load total_steps {total_steps}")
259
+
260
+ elif args.restore_ckpt is not None:
261
+ assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
262
+ ".pt"
263
+ )
264
+ logging.info("Loading checkpoint...")
265
+ strict = True
266
+
267
+ state_dict = self.load(args.restore_ckpt)
268
+ if "model" in state_dict:
269
+ state_dict = state_dict["model"]
270
+ # -- Since we wrapped the model in torch.nn.DataParallel or torch.nn.parallel.DistributedDataParallel,
271
+ # PyTorch automatically prefixes all parameter names with "module.":
272
+ # state_dict = {
273
+ # 'module.conv1.weight': tensor(...),
274
+ # 'module.conv1.bias': tensor(...),
275
+ # 'module.fc.weight': tensor(...),
276
+ # 'module.fc.bias': tensor(...),
277
+ # }
278
+ # -- So we need to strip the "module." prefix:
279
+ if list(state_dict.keys())[0].startswith("module."):
280
+ state_dict = {
281
+ k.replace("module.", ""): v for k, v in state_dict.items()
282
+ }
283
+ model.load_state_dict(state_dict, strict=strict)
284
+
285
+ logging.info(f"Done loading checkpoint")
286
+ # ----------------------------------------- Optimzer, Scheduler -----------------------------------------------
287
+
288
+ model, optimizer = self.setup(model, optimizer, move_to_device=False)
289
+ model.cuda()
290
+ model.train()
291
+ model.module.module.stereo.freeze_bn() # -- We keep BatchNorm frozen
292
+
293
+ save_freq = args.save_freq
294
+ scaler = GradScaler(enabled=args.mixed_precision)
295
+
296
+ # ----------------------------------------- Training Loop -----------------------------------------------
297
+ should_keep_training = True
298
+ global_batch_num = 0
299
+ epoch = -1
300
+ while should_keep_training:
301
+ epoch += 1
302
+
303
+ for i_batch, batch in enumerate(tqdm(train_loader)):
304
+ optimizer.zero_grad()
305
+ if batch is None:
306
+ print("batch is None")
307
+ continue
308
+
309
+ for k, v in batch.items():
310
+ batch[k] = v.cuda()
311
+
312
+ assert model.training
313
+
314
+ # ---- ---- FORWARD PASS ---- ----
315
+ # -- Modified by Chu King on 20th November 2025
316
+ output = model(batch)
317
+
318
+ loss = 0
319
+ logger.update()
320
+ for k, v in output.items():
321
+ if "loss" in v:
322
+ loss += v["loss"]
323
+ logger.writer.add_scalar(
324
+ f"live_{k}_loss", v["loss"].item(), total_steps
325
+ )
326
+ if "metrics" in v:
327
+ logger.push(v["metrics"], k)
328
+
329
+ if self.global_rank == 0:
330
+ if total_steps % save_freq == save_freq - 1:
331
+ save_ims_to_tb(logger.writer, batch, output, total_steps)
332
+ if len(output) > 1:
333
+ logger.writer.add_scalar(
334
+ f"live_total_loss", loss.item(), total_steps
335
+ )
336
+ logger.writer.add_scalar(
337
+ f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
338
+ )
339
+ global_batch_num += 1
340
+ self.barrier()
341
+
342
+ # ---- ---- BACKWARD PASS ---- ----
343
+ self.backward(scaler.scale(loss))
344
+ scaler.unscale_(optimizer)
345
+
346
+ # -- Prevent exploding gradients in RNNs or very deep networks
347
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
348
+
349
+ scaler.step(optimizer)
350
+ scheduler.step()
351
+ scaler.update()
352
+ total_steps += 1
353
+
354
+ if self.global_rank == 0:
355
+
356
+ if (i_batch >= len(train_loader) - 1) or (
357
+ total_steps == 1 and args.validate_at_start
358
+ ):
359
+ ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps)
360
+ save_path = Path(
361
+ f"{args.ckpt_path}/model_{args.name}_{ckpt_iter}.pth"
362
+ )
363
+
364
+ save_dict = {
365
+ "model": model.module.module.state_dict(),
366
+ "optimizer": optimizer.state_dict(),
367
+ "scheduler": scheduler.state_dict(),
368
+ "total_steps": total_steps,
369
+ }
370
+
371
+ logging.info(f"Saving file {save_path}")
372
+ self.save(save_dict, save_path)
373
+
374
+ # ---- ---- EVALUATION ---- ----
375
+ if epoch % args.evaluate_every_n_epoch == 0:
376
+ # -- Added by Chu King on 21st November 2025
377
+ model.eval()
378
+
379
+ logging.info(f"Evaluation at epoch {epoch}")
380
+ run_test_eval(
381
+ args.ckpt_path,
382
+ "valid",
383
+ evaluator,
384
+ model.module.module.sci_enc_L,
385
+ model.module.module.sci_enc_R,
386
+ model.module.module.stereo,
387
+ eval_dataloaders,
388
+ logger.writer,
389
+ total_steps,
390
+ resolution=args.image_size
391
+ )
392
+
393
+ # -- Added by Chu King on 20th November 2025 for SCI stereo
394
+ model.train()
395
+
396
+ model.module.module.stereo.freeze_bn()
397
+
398
+ self.barrier()
399
+ if total_steps > args.num_steps:
400
+ should_keep_training = False
401
+ break
402
+
403
+ logger.close()
404
+ # ----------------------------------------- Save models after training -----------------------------------------------
405
+ # -- Modified by Chu King on 20th November 2025 to save SCI encoders' models.
406
+ # -- PATH = f"{args.ckpt_path}/{args.name}_final.pth"
407
+ PATH = f"{args.ckpt_path}/{args.name}_model_final.pth"
408
+ torch.save(model.module.module.state_dict(), PATH)
409
+
410
+ # ----------------------------------------- Testing -----------------------------------------------
411
+ # -- Modified by Chu King on 20th November 2025
412
+ test_dataloader_dr = datasets.DynamicStereoDataset(
413
+ # -- The number of subframes should be fixed for SCI stereo
414
+ # -- split="test", sample_len=150, only_first_n_samples=1
415
+ split="test", sample_len=args.sample_len, only_first_n_samples=1
416
+ )
417
+ test_dataloaders = [
418
+ ("sintel_clean", eval_dataloader_sintel_clean),
419
+ ("sintel_final", eval_dataloader_sintel_final),
420
+ ("dynamic_replica", test_dataloader_dr),
421
+ ]
422
+
423
+ # -- Modifed by Chu King on 21st November 2025
424
+ model.eval()
425
+ run_test_eval(
426
+ args.ckpt_path,
427
+ "test",
428
+ evaluator,
429
+ model.module.module.sci_enc_L,
430
+ model.module.module.sci_enc_R,
431
+ model.module.module.stereo,
432
+ test_dataloaders,
433
+ logger.writer,
434
+ total_steps,
435
+ resolution=args.image_size
436
+ )
437
+
438
+
439
+ if __name__ == "__main__":
440
+ parser = argparse.ArgumentParser()
441
+ parser.add_argument("--name", default="dynamic-stereo", help="name your experiment")
442
+ parser.add_argument("--restore_ckpt", help="restore checkpoint")
443
+ parser.add_argument("--ckpt_path", help="path to save checkpoints")
444
+ parser.add_argument(
445
+ "--mixed_precision", action="store_true", help="use mixed precision"
446
+ )
447
+
448
+ # Training parameters
449
+ parser.add_argument(
450
+ "--batch_size", type=int, default=6, help="batch size used during training."
451
+ )
452
+ parser.add_argument(
453
+ "--train_datasets",
454
+ nargs="+",
455
+ default=["things", "monkaa", "driving"],
456
+ help="training datasets.",
457
+ )
458
+ parser.add_argument("--lr", type=float, default=0.0002, help="max learning rate.")
459
+
460
+ parser.add_argument(
461
+ "--num_steps", type=int, default=100000, help="length of training schedule."
462
+ )
463
+ parser.add_argument(
464
+ "--image_size",
465
+ type=int,
466
+ nargs="+",
467
+ default=[320, 720],
468
+ help="size of the random image crops used during training.",
469
+ )
470
+ parser.add_argument(
471
+ "--train_iters",
472
+ type=int,
473
+ default=16,
474
+ help="number of updates to the disparity field in each forward pass.",
475
+ )
476
+ parser.add_argument(
477
+ "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
478
+ )
479
+
480
+ parser.add_argument(
481
+ "--sample_len", type=int, default=2, help="length of training video samples"
482
+ )
483
+ parser.add_argument(
484
+ "--validate_at_start", action="store_true", help="validate the model at start"
485
+ )
486
+ parser.add_argument("--save_freq", type=int, default=100, help="save frequency")
487
+
488
+ parser.add_argument(
489
+ "--evaluate_every_n_epoch",
490
+ type=int,
491
+ default=1,
492
+ help="evaluate every n epoch",
493
+ )
494
+
495
+ parser.add_argument(
496
+ "--num_workers", type=int, default=6, help="number of dataloader workers."
497
+ )
498
+ # Validation parameters
499
+ parser.add_argument(
500
+ "--valid_iters",
501
+ type=int,
502
+ default=32,
503
+ help="number of updates to the disparity field in each forward pass during validation.",
504
+ )
505
+ # Architecure choices
506
+ parser.add_argument(
507
+ "--different_update_blocks",
508
+ action="store_true",
509
+ help="use different update blocks for each resolution",
510
+ )
511
+ parser.add_argument(
512
+ "--attention_type",
513
+ type=str,
514
+ help="attention type of the SST and update blocks. \
515
+ Any combination of 'self_stereo', 'temporal', 'update_time', 'update_space' connected by an underscore.",
516
+ )
517
+ parser.add_argument(
518
+ "--update_block_3d", action="store_true", help="use Conv3D update block"
519
+ )
520
+ # Data augmentation
521
+ parser.add_argument(
522
+ "--img_gamma", type=float, nargs="+", default=None, help="gamma range"
523
+ )
524
+ parser.add_argument(
525
+ "--saturation_range",
526
+ type=float,
527
+ nargs="+",
528
+ default=None,
529
+ help="color saturation",
530
+ )
531
+ parser.add_argument(
532
+ "--do_flip",
533
+ default=False,
534
+ choices=["h", "v"],
535
+ help="flip the images horizontally or vertically",
536
+ )
537
+ parser.add_argument(
538
+ "--spatial_scale",
539
+ type=float,
540
+ nargs="+",
541
+ default=[0, 0],
542
+ help="re-scale the images randomly",
543
+ )
544
+ parser.add_argument(
545
+ "--noyjitter",
546
+ action="store_true",
547
+ help="don't simulate imperfect rectification",
548
+ )
549
+ args = parser.parse_args()
550
+
551
+ logging.basicConfig(
552
+ level=logging.INFO,
553
+ format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
554
+ )
555
+
556
+ Path(args.ckpt_path).mkdir(exist_ok=True, parents=True)
557
+ from pytorch_lightning.strategies import DDPStrategy
558
+
559
+ Lite(
560
+ # -- strategy=DDPStrategy(find_unused_parameters=True),
561
+ strategy=DDPStrategy(find_unused_parameters=False),
562
+ devices="auto",
563
+ accelerator="gpu",
564
+ precision=32,
565
+ ).run(args)
train_utils/logger.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+
10
+ from torch.utils.tensorboard import SummaryWriter
11
+
12
+
13
+ class Logger:
14
+
15
+ SUM_FREQ = 100
16
+
17
+ def __init__(self, model, scheduler, ckpt_path):
18
+ self.model = model
19
+ self.scheduler = scheduler
20
+ self.total_steps = 0
21
+ self.running_loss = {}
22
+ self.ckpt_path = ckpt_path
23
+ self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
24
+
25
+ def _print_training_status(self):
26
+ metrics_data = [
27
+ self.running_loss[k] / Logger.SUM_FREQ
28
+ for k in sorted(self.running_loss.keys())
29
+ ]
30
+ training_str = "[{:6d}] ".format(self.total_steps + 1)
31
+ metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
32
+
33
+ # print the training status
34
+ logging.info(
35
+ f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
36
+ )
37
+
38
+ if self.writer is None:
39
+ self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
40
+ for k in self.running_loss:
41
+ self.writer.add_scalar(
42
+ k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
43
+ )
44
+ self.running_loss[k] = 0.0
45
+
46
+ def push(self, metrics, task):
47
+ for key in metrics:
48
+ task_key = str(key) + "_" + task
49
+ if task_key not in self.running_loss:
50
+ self.running_loss[task_key] = 0.0
51
+ self.running_loss[task_key] += metrics[key]
52
+
53
+ def update(self):
54
+ self.total_steps += 1
55
+ if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
56
+ self._print_training_status()
57
+ self.running_loss = {}
58
+
59
+ def write_dict(self, results):
60
+ if self.writer is None:
61
+ self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
62
+
63
+ for key in results:
64
+ self.writer.add_scalar(key, results[key], self.total_steps)
65
+
66
+ def close(self):
67
+ self.writer.close()
train_utils/losses.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ # -- Added by Chu King on 23rd November 2025 to check for NaNs
10
+ import math
11
+
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+
15
+ def flow_to_rgb(flow):
16
+ # flow: [2, H, W]
17
+ u = flow[0]
18
+ v = flow[1]
19
+ rad = np.sqrt(u ** 2 + v ** 2)
20
+ ang = np.arctan2(v, u)
21
+
22
+ hsv = np.zeros((flow.shape[1], flow.shape[2], 3), dtype=np.float32)
23
+ hsv[..., 0] = (ang + np.pi) / (2 * np.pi)
24
+ hsv[..., 1] = 1.0
25
+ hsv[..., 2] = np.clip(rad / np.percentile(rad, 99), 0, 1)
26
+
27
+ rgb = plt.cm.hsv(hsv)
28
+ return rgb[..., :3]
29
+
30
+ def visualize_flow_debug(flow_pred, flow_gt, epe, step=0, save_path="debug"):
31
+ flow_pred_np = flow_pred.detach().cpu().numpy()
32
+ flow_gt_np = flow_gt.detach().cpu().numpy()
33
+ epe_np = epe
34
+
35
+ flow_pred0 = flow_pred_np[0, 0, :, :]
36
+ flow_gt0 = flow_gt_np[0, 0, :, :]
37
+ epe0 = epe_np
38
+
39
+ fig, axs = plt.subplots(1, 2, figsize=(15, 5))
40
+
41
+ axs[0].imshow(flow_to_rgb(flow_pred0))
42
+ axs[0].set_title("Predicted Flow")
43
+ axs[0].axis("off")
44
+
45
+ axs[1].imshow(flow_to_rgb(flow_gt0))
46
+ axs[1].set_title("Ground Truth Flow")
47
+ axs[1].axis("off")
48
+
49
+ # -- axs[2].imshow(epe0, cmap="inferno")
50
+ # -- axs[2].set_title("EPE heatmap")
51
+ # -- axs[2].axis("off")
52
+
53
+ fig.suptitle(f"STEP = {step}")
54
+
55
+ plt.tight_layout()
56
+ plt.savefig(f"{save_path}/flow_debug_{step}.png")
57
+ plt.close()
58
+
59
+ def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9, max_flow=700):
60
+ """Loss function defined over sequence of flow predictions"""
61
+ n_predictions = len(flow_preds)
62
+ assert n_predictions >= 1
63
+ flow_loss = 0.0
64
+ # exlude invalid pixels and extremely large diplacements
65
+ mag = torch.sum(flow_gt ** 2, dim=1).sqrt().unsqueeze(1)
66
+
67
+ if len(valid.shape) != len(flow_gt.shape):
68
+ valid = valid.unsqueeze(1)
69
+
70
+ valid = (valid >= 0.5) & (mag < max_flow)
71
+
72
+ if valid.shape != flow_gt.shape:
73
+ valid = torch.cat([valid, valid], dim=1)
74
+ assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]
75
+ assert not torch.isinf(flow_gt[valid.bool()]).any()
76
+
77
+ for i in range(n_predictions):
78
+ assert (
79
+ not torch.isnan(flow_preds[i]).any()
80
+ and not torch.isinf(flow_preds[i]).any()
81
+ )
82
+
83
+ if n_predictions == 1:
84
+ i_weight = 1
85
+ else:
86
+ # We adjust the loss_gamma so it is consistent for any number of iterations
87
+ adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))
88
+ i_weight = adjusted_loss_gamma ** (n_predictions - i - 1)
89
+
90
+ flow_pred = flow_preds[i].clone()
91
+ if valid.shape[1] == 1 and flow_preds[i].shape[1] == 2:
92
+ flow_pred = flow_pred[:, :1]
93
+
94
+ i_loss = (flow_pred - flow_gt).abs()
95
+
96
+ assert i_loss.shape == valid.shape, [
97
+ i_loss.shape,
98
+ valid.shape,
99
+ flow_gt.shape,
100
+ flow_pred.shape,
101
+ ]
102
+ flow_loss += i_weight * i_loss[valid.bool()].mean()
103
+
104
+ epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
105
+
106
+ valid = valid[:, 0]
107
+ epe = epe.view(-1)
108
+ epe = epe[valid.reshape(epe.shape)]
109
+
110
+ # -- Added by Chu King to deal with the case when there is no valid disparity.
111
+ if valid.sum().item() == 0:
112
+ metrics = {"epe": 0.0, "1px": 0.0, "3px": 0.0, "5px": 0.0}
113
+ else:
114
+ metrics = {
115
+ "epe": epe.mean().item(),
116
+ "1px": (epe < 1).float().mean().item(),
117
+ "3px": (epe < 3).float().mean().item(),
118
+ "5px": (epe < 5).float().mean().item(),
119
+ }
120
+
121
+ for k, v in metrics.items():
122
+ if math.isnan(v):
123
+ print ("[ERROR] Nan detected for k: ", k)
124
+ if torch.isnan(flow_preds[-1]).any(): print("[WARNING] NaN in flow_preds")
125
+ if torch.isinf(flow_preds[-1]).any(): print("[WARNING] Inf in flow_preds")
126
+ if torch.isnan(flow_gt).any(): print("[WARNING] NaN in flow_gt")
127
+ if torch.isinf(flow_gt).any(): print("[WARNING] Inf in flow_gt")
128
+
129
+ raw_diff = flow_preds[-1] - flow_gt
130
+ if torch.isnan(raw_diff).any(): print("[WARNING] NaN in flow_diff")
131
+
132
+ sq = (raw_diff ** 2)
133
+ if torch.isnan(sq).any(): print("[WARNING] NaN in square")
134
+
135
+ sum_sq = torch.sum(sq, dim=1)
136
+ if torch.isnan(sum_sq).any(): print("[WARNING] NaN in sum")
137
+
138
+ epe = sum_sq.sqrt()
139
+ if torch.isnan(epe).any(): print("[WARNING] NaN in sqrt")
140
+ if torch.isinf(epe).any(): print("[WARNING] Inf in sqrt")
141
+
142
+ num_valid = valid.sum().item()
143
+ print("[INFO] Valid pixels:", num_valid)
144
+ if num_valid == 0:
145
+ print("[WARNING]: No valid pixels metrics will be NaN.")
146
+
147
+ if (epe > 1e6).any():
148
+ print("[INFP] Large EPE values detected:", epe.max().item())
149
+
150
+ print ("[INFO] Flow pred sample:", flow_preds[-1].view(-1)[:10])
151
+ print ("[INFO] Flow gt sample:", flow_gt.view(-1)[:10])
152
+ print ("[INFO] EPE sample:", epe.view(-1)[:10])
153
+ print ("[INFO] Valid sample:", valid.view(-1)[:10])
154
+
155
+ visualize_flow_debug(flow_preds[-1], flow_gt, v, step=0, save_path="debug")
156
+ raise SystemExit("Nan detected.")
157
+
158
+ return flow_loss, metrics
train_utils/utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import os
9
+ import torch
10
+
11
+ import json
12
+ import flow_vis
13
+ import matplotlib.pyplot as plt
14
+
15
+ import datasets.dynamic_stereo_datasets as datasets
16
+ from evaluation.utils.utils import aggregate_and_print_results
17
+
18
+
19
+ def count_parameters(model):
20
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
21
+
22
+
23
+ def run_test_eval(ckpt_path, eval_type, evaluator, sci_enc_L, sci_enc_R, model, dataloaders, writer, step, resolution=[480, 640]):
24
+
25
+ # -- Evalution of real scenes disabled by Chu King on 16th November 2025 as depth data
26
+ # are not available.
27
+ # -- for real_sequence_name in ["teddy_static", "ignacio_waving", "nikita_reading"]:
28
+ # -- seq_len_real = 50
29
+ # -- ds_path = f"./dynamic_replica_data/real/{real_sequence_name}"
30
+ # -- real_dataset = datasets.DynamicReplicaDataset(
31
+ # -- split="test", root=ds_path, sample_len=seq_len_real, only_first_n_samples=1,
32
+ # -- VERBOSE=False # -- Added by Chu King on 16th November 2025 for debugging purposes
33
+ # -- )
34
+
35
+ # -- evaluator.evaluate_sequence(
36
+ # -- model=model.module.module,
37
+ # -- test_dataloader=real_dataset,
38
+ # -- writer=writer,
39
+ # -- step=step,
40
+ # -- train_mode=True,
41
+ # -- )
42
+
43
+ for ds_name, dataloader in dataloaders:
44
+ evaluator.visualize_interval = 1 if not "sintel" in ds_name else 0
45
+
46
+ evaluate_result = evaluator.evaluate_sequence(
47
+ sci_enc_L=sci_enc_L,
48
+ sci_enc_R=sci_enc_R,
49
+ model=model,
50
+ test_dataloader=dataloader,
51
+ writer=writer if not "sintel" in ds_name else None,
52
+ step=step,
53
+ train_mode=True,
54
+ resolution=resolution
55
+ )
56
+
57
+ aggregate_result = aggregate_and_print_results(
58
+ evaluate_result,
59
+ )
60
+
61
+ save_metrics = [
62
+ "flow_mean_accuracy_5px",
63
+ "flow_mean_accuracy_3px",
64
+ "flow_mean_accuracy_1px",
65
+ "flow_epe_traj_mean",
66
+ ]
67
+ for epe_name in ("epe", "temp_epe", "temp_epe_r"):
68
+ for m in [
69
+ f"disp_{epe_name}_bad_0.5px",
70
+ f"disp_{epe_name}_bad_1px",
71
+ f"disp_{epe_name}_bad_2px",
72
+ f"disp_{epe_name}_bad_3px",
73
+ f"disp_{epe_name}_mean",
74
+ ]:
75
+ save_metrics.append(m)
76
+
77
+ for k, v in aggregate_result.items():
78
+ if k in save_metrics:
79
+ writer.add_scalars(
80
+ f"{ds_name}_{k.rsplit('_', 1)[0]}",
81
+ {f"{ds_name}_{k}": v},
82
+ step,
83
+ )
84
+
85
+ result_file = os.path.join(
86
+ ckpt_path,
87
+ f"result_{ds_name}_{eval_type}_{step}_mimo.json",
88
+ )
89
+ print(f"Dumping {eval_type} results to {result_file}.")
90
+ with open(result_file, "w") as f:
91
+ json.dump(aggregate_result, f)
92
+
93
+
94
+ def fig2data(fig):
95
+ """
96
+ fig = plt.figure()
97
+ image = fig2data(fig)
98
+ @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
99
+ @param fig a matplotlib figure
100
+ @return a numpy 3D array of RGBA values
101
+ """
102
+ import PIL.Image as Image
103
+
104
+ # draw the renderer
105
+ fig.canvas.draw()
106
+
107
+ # Get the RGBA buffer from the figure
108
+ w, h = fig.canvas.get_width_height()
109
+ buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
110
+ buf.shape = (w, h, 3)
111
+
112
+ image = Image.frombytes("RGB", (w, h), buf.tobytes())
113
+ image = np.asarray(image)
114
+ return image
115
+
116
+
117
+ def save_ims_to_tb(writer, batch, output, total_steps):
118
+ writer.add_image(
119
+ "train_im",
120
+ torch.cat([torch.cat([im[0], im[1]], dim=-1) for im in batch["img"][0]], dim=-2)
121
+ / 255.0,
122
+ total_steps,
123
+ dataformats="CHW",
124
+ )
125
+ if "disp" in batch and len(batch["disp"]) > 0:
126
+ disp_im = [
127
+ (torch.cat([im[0], im[1]], dim=-1) * torch.cat([val[0], val[1]], dim=-1))
128
+ for im, val in zip(batch["disp"][0], batch["valid_disp"][0])
129
+ ]
130
+
131
+ disp_im = torch.cat(disp_im, dim=1)
132
+
133
+ figure = plt.figure()
134
+ plt.imshow(disp_im.cpu()[0])
135
+ disp_im = fig2data(figure).copy()
136
+
137
+ writer.add_image(
138
+ "train_disp",
139
+ disp_im,
140
+ total_steps,
141
+ dataformats="HWC",
142
+ )
143
+
144
+ for k, v in output.items():
145
+ if "predictions" in v:
146
+ pred = v["predictions"]
147
+ if k == "disparity":
148
+ figure = plt.figure()
149
+ plt.imshow(pred.cpu()[0])
150
+ pred = fig2data(figure).copy()
151
+ dataformat = "HWC"
152
+ else:
153
+ pred = torch.tensor(
154
+ flow_vis.flow_to_color(
155
+ pred.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False
156
+ )
157
+ / 255.0
158
+ )
159
+ dataformat = "HWC"
160
+ writer.add_image(
161
+ f"pred_{k}",
162
+ pred,
163
+ total_steps,
164
+ dataformats=dataformat,
165
+ )
166
+ if "gt" in v:
167
+ gt = v["gt"]
168
+ gt = torch.tensor(
169
+ flow_vis.flow_to_color(
170
+ gt.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False
171
+ )
172
+ / 255.0
173
+ )
174
+ dataformat = "HWC"
175
+ writer.add_image(
176
+ f"gt_{k}",
177
+ gt,
178
+ total_steps,
179
+ dataformats=dataformat,
180
+ )