Commit
·
2c76547
1
Parent(s):
70bca94
Copied from github repository.
Browse files- README.md +139 -0
- datasets/augmentor.py +200 -0
- datasets/dynamic_stereo_datasets.py +743 -0
- datasets/frame_utils.py +118 -0
- evaluation/configs/eval_dynamic_replica_150_frames.yaml +8 -0
- evaluation/configs/eval_dynamic_replica_40_frames.yaml +8 -0
- evaluation/configs/eval_real_data.yaml +9 -0
- evaluation/configs/eval_sintel_clean.yaml +9 -0
- evaluation/configs/eval_sintel_final.yaml +9 -0
- evaluation/core/evaluator.py +152 -0
- evaluation/evaluate.py +143 -0
- evaluation/utils/eval_utils.py +213 -0
- evaluation/utils/utils.py +351 -0
- links_lite.json +15 -0
- models/core/attention.py +240 -0
- models/core/corr.py +88 -0
- models/core/dynamic_stereo.py +506 -0
- models/core/extractor.py +139 -0
- models/core/model_zoo.py +48 -0
- models/core/sci_codec.py +180 -0
- models/core/update.py +370 -0
- models/core/utils/config.py +961 -0
- models/core/utils/utils.py +44 -0
- models/dynamic_stereo_model.py +50 -0
- models/raft_stereo_model.py +84 -0
- notebooks/Dynamic_Replica_demo.ipynb +0 -0
- notebooks/evaluate.ipynb +0 -0
- requirements.txt +20 -0
- scripts/checksum_check.py +154 -0
- scripts/download_dynamic_replica.py +35 -0
- scripts/download_utils.py +280 -0
- scripts/dr_sha256.json +106 -0
- setup.csh +9 -0
- train.csh +34 -0
- train.py +565 -0
- train_utils/logger.py +67 -0
- train_utils/losses.py +158 -0
- train_utils/utils.py +180 -0
README.md
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [CVPR 2023] DynamicStereo: Consistent Dynamic Depth from Stereo Videos.
|
| 2 |
+
|
| 3 |
+
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
|
| 4 |
+
|
| 5 |
+
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
|
| 6 |
+
|
| 7 |
+
[[`Paper`](https://research.facebook.com/publications/dynamicstereo-consistent-dynamic-depth-from-stereo-videos/)] [[`Project`](https://dynamic-stereo.github.io/)] [[`BibTeX`](#citing-dynamicstereo)]
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
**DynamicStereo** is a transformer-based architecture for temporally consistent depth estimation from stereo videos. It has been trained on a combination of two datasets: [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) and **Dynamic Replica** that we present below.
|
| 12 |
+
|
| 13 |
+
## Dataset
|
| 14 |
+
|
| 15 |
+
https://user-images.githubusercontent.com/37815420/236239579-7877623c-716b-4074-a14e-944d095f1419.mp4
|
| 16 |
+
|
| 17 |
+
The dataset consists of 145200 *stereo* frames (524 videos) with humans and animals in motion.
|
| 18 |
+
|
| 19 |
+
We provide annotations for both *left and right* views, see [this notebook](https://github.com/facebookresearch/dynamic_stereo/blob/main/notebooks/Dynamic_Replica_demo.ipynb):
|
| 20 |
+
- camera intrinsics and extrinsics
|
| 21 |
+
- image depth (can be converted to disparity with intrinsics)
|
| 22 |
+
- instance segmentation masks
|
| 23 |
+
- binary foreground / background segmentation masks
|
| 24 |
+
- optical flow (released!)
|
| 25 |
+
- long-range pixel trajectories (released!)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
### Download the Dynamic Replica dataset
|
| 29 |
+
Due to the enormous size of the original dataset, we created the `links_lite.json` file to enable quick testing by downloading just a small portion of the dataset.
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
python ./scripts/download_dynamic_replica.py --link_list_file links_lite.json --download_folder ./dynamic_replica_data --download_splits test train valid real
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
To download the full dataset, please visit [the original site](https://github.com/facebookresearch/dynamic_stereo) created by Meta.
|
| 36 |
+
|
| 37 |
+
## Installation
|
| 38 |
+
|
| 39 |
+
Describes installation of DynamicStereo with the latest PyTorch3D, PyTorch 1.12.1 & cuda 11.3
|
| 40 |
+
|
| 41 |
+
### Setup the root for all source files:
|
| 42 |
+
```
|
| 43 |
+
git clone https://github.com/facebookresearch/dynamic_stereo
|
| 44 |
+
cd dynamic_stereo
|
| 45 |
+
export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
|
| 46 |
+
```
|
| 47 |
+
### Create a conda env:
|
| 48 |
+
```
|
| 49 |
+
conda create -n dynamicstereo python=3.8
|
| 50 |
+
conda activate dynamicstereo
|
| 51 |
+
```
|
| 52 |
+
### Install requirements
|
| 53 |
+
```
|
| 54 |
+
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
| 55 |
+
# It will require some time to install PyTorch3D. In the meantime, you may want to take a break and enjoy a cup of coffee.
|
| 56 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
| 57 |
+
pip install -r requirements.txt
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### (Optional) Install RAFT-Stereo
|
| 61 |
+
```
|
| 62 |
+
mkdir third_party
|
| 63 |
+
cd third_party
|
| 64 |
+
git clone https://github.com/princeton-vl/RAFT-Stereo
|
| 65 |
+
cd RAFT-Stereo
|
| 66 |
+
bash download_models.sh
|
| 67 |
+
cd ../..
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
## Evaluation
|
| 73 |
+
To download the checkpoints, you can follow the below instructions:
|
| 74 |
+
```
|
| 75 |
+
mkdir checkpoints
|
| 76 |
+
cd checkpoints
|
| 77 |
+
wget https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_sf.pth
|
| 78 |
+
wget https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_dr_sf.pth
|
| 79 |
+
cd ..
|
| 80 |
+
```
|
| 81 |
+
You can also download the checkpoints manually by clicking the links below. Copy the checkpoints to `./dynamic_stereo/checkpoints`.
|
| 82 |
+
|
| 83 |
+
- [DynamicStereo](https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_sf.pth) trained on SceneFlow
|
| 84 |
+
- [DynamicStereo](https://dl.fbaipublicfiles.com/dynamic_replica_v1/dynamic_stereo_dr_sf.pth) trained on SceneFlow and *Dynamic Replica*
|
| 85 |
+
|
| 86 |
+
To evaluate DynamicStereo:
|
| 87 |
+
```
|
| 88 |
+
python ./evaluation/evaluate.py --config-name eval_dynamic_replica_40_frames \
|
| 89 |
+
MODEL.model_name=DynamicStereoModel exp_dir=./outputs/test_dynamic_replica_ds \
|
| 90 |
+
MODEL.DynamicStereoModel.model_weights=./checkpoints/dynamic_stereo_sf.pth
|
| 91 |
+
```
|
| 92 |
+
Due to the high image resolution, evaluation on *Dynamic Replica* requires a 32GB GPU. If you don't have enough GPU memory, you can decrease `kernel_size` from 20 to 10 by adding `MODEL.DynamicStereoModel.kernel_size=10` to the above python command. Another option is to decrease the dataset resolution.
|
| 93 |
+
|
| 94 |
+
As a result, you should see the numbers from *Table 5* in the [paper](https://arxiv.org/pdf/2305.02296.pdf). (for this, you need `kernel_size=20`)
|
| 95 |
+
|
| 96 |
+
Reconstructions of all the *Dynamic Replica* splits (including *real*) will be visualized and saved to `exp_dir`.
|
| 97 |
+
|
| 98 |
+
If you installed [RAFT-Stereo](https://github.com/princeton-vl/RAFT-Stereo), you can run:
|
| 99 |
+
```
|
| 100 |
+
python ./evaluation/evaluate.py --config-name eval_dynamic_replica_40_frames \
|
| 101 |
+
MODEL.model_name=RAFTStereoModel exp_dir=./outputs/test_dynamic_replica_raft
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Other public datasets we use:
|
| 105 |
+
- [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
|
| 106 |
+
- [Sintel](http://sintel.is.tue.mpg.de/stereo)
|
| 107 |
+
- [Middlebury](https://vision.middlebury.edu/stereo/data/)
|
| 108 |
+
- [ETH3D](https://www.eth3d.net/datasets#low-res-two-view-training-data)
|
| 109 |
+
- [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_stereo.php)
|
| 110 |
+
|
| 111 |
+
## Training
|
| 112 |
+
Training requires a 32GB GPU. You can decrease `image_size` and / or `sample_len` if you don't have enough GPU memory.
|
| 113 |
+
You need to donwload SceneFlow before training. Alternatively, you can only train on *Dynamic Replica*.
|
| 114 |
+
```
|
| 115 |
+
python train.py --batch_size 1 \
|
| 116 |
+
--spatial_scale -0.2 0.4 --image_size 384 512 --saturation_range 0 1.4 --num_steps 200000 \
|
| 117 |
+
--ckpt_path dynamicstereo_sf_dr \
|
| 118 |
+
--sample_len 5 --lr 0.0003 --train_iters 10 --valid_iters 20 \
|
| 119 |
+
--num_workers 28 --save_freq 100 --update_block_3d --different_update_blocks \
|
| 120 |
+
--attention_type self_stereo_temporal_update_time_update_space --train_datasets dynamic_replica things monkaa driving
|
| 121 |
+
```
|
| 122 |
+
If you want to train on SceneFlow only, remove the flag `dynamic_replica` from `train_datasets`.
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## License
|
| 127 |
+
The majority of dynamic_stereo is licensed under CC-BY-NC, however portions of the project are available under separate license terms: [RAFT-Stereo](https://github.com/princeton-vl/RAFT-Stereo) is licensed under the MIT license, [LoFTR](https://github.com/zju3dv/LoFTR) and [CREStereo](https://github.com/megvii-research/CREStereo) are licensed under the Apache 2.0 license.
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
## Citing DynamicStereo
|
| 131 |
+
If you use DynamicStereo or Dynamic Replica in your research, please use the following BibTeX entry.
|
| 132 |
+
```
|
| 133 |
+
@article{karaev2023dynamicstereo,
|
| 134 |
+
title={DynamicStereo: Consistent Dynamic Depth from Stereo Videos},
|
| 135 |
+
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
|
| 136 |
+
journal={CVPR},
|
| 137 |
+
year={2023}
|
| 138 |
+
}
|
| 139 |
+
```
|
datasets/augmentor.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
cv2.setNumThreads(0)
|
| 14 |
+
cv2.ocl.setUseOpenCL(False)
|
| 15 |
+
|
| 16 |
+
from torchvision.transforms import ColorJitter, functional, Compose
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AdjustGamma(object):
|
| 20 |
+
def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
|
| 21 |
+
self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = (
|
| 22 |
+
gamma_min,
|
| 23 |
+
gamma_max,
|
| 24 |
+
gain_min,
|
| 25 |
+
gain_max,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def __call__(self, sample):
|
| 29 |
+
gain = random.uniform(self.gain_min, self.gain_max)
|
| 30 |
+
gamma = random.uniform(self.gamma_min, self.gamma_max)
|
| 31 |
+
return functional.adjust_gamma(sample, gamma, gain)
|
| 32 |
+
|
| 33 |
+
def __repr__(self):
|
| 34 |
+
return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SequenceDispFlowAugmentor:
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
crop_size,
|
| 41 |
+
min_scale=-0.2,
|
| 42 |
+
max_scale=0.5,
|
| 43 |
+
do_flip=True,
|
| 44 |
+
yjitter=False,
|
| 45 |
+
saturation_range=[0.6, 1.4],
|
| 46 |
+
gamma=[1, 1, 1, 1],
|
| 47 |
+
):
|
| 48 |
+
# spatial augmentation params
|
| 49 |
+
self.crop_size = crop_size
|
| 50 |
+
self.min_scale = min_scale
|
| 51 |
+
self.max_scale = max_scale
|
| 52 |
+
self.spatial_aug_prob = 1.0
|
| 53 |
+
self.stretch_prob = 0.8
|
| 54 |
+
self.max_stretch = 0.2
|
| 55 |
+
|
| 56 |
+
# flip augmentation params
|
| 57 |
+
self.yjitter = yjitter
|
| 58 |
+
self.do_flip = do_flip
|
| 59 |
+
self.h_flip_prob = 0.5
|
| 60 |
+
self.v_flip_prob = 0.1
|
| 61 |
+
|
| 62 |
+
# photometric augmentation params
|
| 63 |
+
self.photo_aug = Compose(
|
| 64 |
+
[
|
| 65 |
+
ColorJitter(
|
| 66 |
+
brightness=0.4,
|
| 67 |
+
contrast=0.4,
|
| 68 |
+
saturation=saturation_range,
|
| 69 |
+
hue=0.5 / 3.14,
|
| 70 |
+
),
|
| 71 |
+
AdjustGamma(*gamma),
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
self.asymmetric_color_aug_prob = 0.2
|
| 75 |
+
self.eraser_aug_prob = 0.5
|
| 76 |
+
|
| 77 |
+
def color_transform(self, seq):
|
| 78 |
+
"""Photometric augmentation"""
|
| 79 |
+
|
| 80 |
+
# asymmetric
|
| 81 |
+
if np.random.rand() < self.asymmetric_color_aug_prob:
|
| 82 |
+
for i in range(len(seq)):
|
| 83 |
+
for cam in (0, 1):
|
| 84 |
+
seq[i][cam] = np.array(
|
| 85 |
+
self.photo_aug(Image.fromarray(seq[i][cam])), dtype=np.uint8
|
| 86 |
+
)
|
| 87 |
+
# symmetric
|
| 88 |
+
else:
|
| 89 |
+
image_stack = np.concatenate(
|
| 90 |
+
[seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0
|
| 91 |
+
)
|
| 92 |
+
image_stack = np.array(
|
| 93 |
+
self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
|
| 94 |
+
)
|
| 95 |
+
split = np.split(image_stack, len(seq) * 2, axis=0)
|
| 96 |
+
for i in range(len(seq)):
|
| 97 |
+
seq[i][0] = split[2 * i]
|
| 98 |
+
seq[i][1] = split[2 * i + 1]
|
| 99 |
+
return seq
|
| 100 |
+
|
| 101 |
+
def eraser_transform(self, seq, bounds=[50, 100]):
|
| 102 |
+
"""Occlusion augmentation"""
|
| 103 |
+
ht, wd = seq[0][0].shape[:2]
|
| 104 |
+
for i in range(len(seq)):
|
| 105 |
+
for cam in (0, 1):
|
| 106 |
+
if np.random.rand() < self.eraser_aug_prob:
|
| 107 |
+
mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0)
|
| 108 |
+
for _ in range(np.random.randint(1, 3)):
|
| 109 |
+
x0 = np.random.randint(0, wd)
|
| 110 |
+
y0 = np.random.randint(0, ht)
|
| 111 |
+
dx = np.random.randint(bounds[0], bounds[1])
|
| 112 |
+
dy = np.random.randint(bounds[0], bounds[1])
|
| 113 |
+
seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
|
| 114 |
+
|
| 115 |
+
return seq
|
| 116 |
+
|
| 117 |
+
def spatial_transform(self, img, disp):
|
| 118 |
+
# randomly sample scale
|
| 119 |
+
ht, wd = img[0][0].shape[:2]
|
| 120 |
+
min_scale = np.maximum(
|
| 121 |
+
(self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
| 125 |
+
scale_x = scale
|
| 126 |
+
scale_y = scale
|
| 127 |
+
if np.random.rand() < self.stretch_prob:
|
| 128 |
+
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
| 129 |
+
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
| 130 |
+
|
| 131 |
+
scale_x = np.clip(scale_x, min_scale, None)
|
| 132 |
+
scale_y = np.clip(scale_y, min_scale, None)
|
| 133 |
+
|
| 134 |
+
if np.random.rand() < self.spatial_aug_prob:
|
| 135 |
+
# rescale the images
|
| 136 |
+
for i in range(len(img)):
|
| 137 |
+
for cam in (0, 1):
|
| 138 |
+
img[i][cam] = cv2.resize(
|
| 139 |
+
img[i][cam],
|
| 140 |
+
None,
|
| 141 |
+
fx=scale_x,
|
| 142 |
+
fy=scale_y,
|
| 143 |
+
interpolation=cv2.INTER_LINEAR,
|
| 144 |
+
)
|
| 145 |
+
if len(disp[i]) > 0:
|
| 146 |
+
disp[i][cam] = cv2.resize(
|
| 147 |
+
disp[i][cam],
|
| 148 |
+
None,
|
| 149 |
+
fx=scale_x,
|
| 150 |
+
fy=scale_y,
|
| 151 |
+
interpolation=cv2.INTER_LINEAR,
|
| 152 |
+
)
|
| 153 |
+
disp[i][cam] = disp[i][cam] * [scale_x, scale_y]
|
| 154 |
+
|
| 155 |
+
if self.yjitter:
|
| 156 |
+
y0 = np.random.randint(2, img[0][0].shape[0] - self.crop_size[0] - 2)
|
| 157 |
+
x0 = np.random.randint(2, img[0][0].shape[1] - self.crop_size[1] - 2)
|
| 158 |
+
|
| 159 |
+
for i in range(len(img)):
|
| 160 |
+
y1 = y0 + np.random.randint(-2, 2 + 1)
|
| 161 |
+
img[i][0] = img[i][0][
|
| 162 |
+
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
|
| 163 |
+
]
|
| 164 |
+
img[i][1] = img[i][1][
|
| 165 |
+
y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]
|
| 166 |
+
]
|
| 167 |
+
if len(disp[i]) > 0:
|
| 168 |
+
disp[i][0] = disp[i][0][
|
| 169 |
+
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
|
| 170 |
+
]
|
| 171 |
+
disp[i][1] = disp[i][1][
|
| 172 |
+
y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]
|
| 173 |
+
]
|
| 174 |
+
else:
|
| 175 |
+
y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0])
|
| 176 |
+
x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1])
|
| 177 |
+
for i in range(len(img)):
|
| 178 |
+
for cam in (0, 1):
|
| 179 |
+
img[i][cam] = img[i][cam][
|
| 180 |
+
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
|
| 181 |
+
]
|
| 182 |
+
if len(disp[i]) > 0:
|
| 183 |
+
disp[i][cam] = disp[i][cam][
|
| 184 |
+
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
return img, disp
|
| 188 |
+
|
| 189 |
+
def __call__(self, img, disp):
|
| 190 |
+
img = self.color_transform(img)
|
| 191 |
+
img = self.eraser_transform(img)
|
| 192 |
+
img, disp = self.spatial_transform(img, disp)
|
| 193 |
+
|
| 194 |
+
for i in range(len(img)):
|
| 195 |
+
for cam in (0, 1):
|
| 196 |
+
img[i][cam] = np.ascontiguousarray(img[i][cam])
|
| 197 |
+
if len(disp[i]) > 0:
|
| 198 |
+
disp[i][cam] = np.ascontiguousarray(disp[i][cam])
|
| 199 |
+
|
| 200 |
+
return img, disp
|
datasets/dynamic_stereo_datasets.py
ADDED
|
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
|
| 7 |
+
|
| 8 |
+
# -- Added by Chu King on 16th November 2025 for debugging purposes.
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
import signal
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import copy
|
| 14 |
+
import gzip
|
| 15 |
+
import logging
|
| 16 |
+
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch.utils.data as data
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import os.path as osp
|
| 21 |
+
from glob import glob
|
| 22 |
+
|
| 23 |
+
from collections import defaultdict
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
+
from typing import List, Optional
|
| 27 |
+
from pytorch3d.renderer.cameras import PerspectiveCameras
|
| 28 |
+
from pytorch3d.implicitron.dataset.types import (
|
| 29 |
+
FrameAnnotation as ImplicitronFrameAnnotation,
|
| 30 |
+
load_dataclass,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from datasets import frame_utils
|
| 34 |
+
from evaluation.utils.eval_utils import depth2disparity_scale
|
| 35 |
+
from datasets.augmentor import SequenceDispFlowAugmentor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class DynamicReplicaFrameAnnotation(ImplicitronFrameAnnotation):
|
| 40 |
+
"""A dataclass used to load annotations from json."""
|
| 41 |
+
|
| 42 |
+
camera_name: Optional[str] = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class StereoSequenceDataset(data.Dataset):
|
| 46 |
+
def __init__(self, aug_params=None, sparse=False, reader=None):
|
| 47 |
+
self.augmentor = None
|
| 48 |
+
self.sparse = sparse
|
| 49 |
+
self.img_pad = (
|
| 50 |
+
aug_params.pop("img_pad", None) if aug_params is not None else None
|
| 51 |
+
)
|
| 52 |
+
if aug_params is not None and "crop_size" in aug_params:
|
| 53 |
+
if sparse:
|
| 54 |
+
raise ValueError("Sparse augmentor is not implemented")
|
| 55 |
+
else:
|
| 56 |
+
self.augmentor = SequenceDispFlowAugmentor(**aug_params)
|
| 57 |
+
|
| 58 |
+
if reader is None:
|
| 59 |
+
self.disparity_reader = frame_utils.read_gen
|
| 60 |
+
else:
|
| 61 |
+
self.disparity_reader = reader
|
| 62 |
+
self.depth_reader = self._load_16big_png_depth
|
| 63 |
+
self.is_test = False
|
| 64 |
+
self.sample_list = []
|
| 65 |
+
self.extra_info = []
|
| 66 |
+
self.depth_eps = 1e-5
|
| 67 |
+
|
| 68 |
+
def _load_16big_png_depth(self, depth_png):
|
| 69 |
+
with Image.open(depth_png) as depth_pil:
|
| 70 |
+
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
|
| 71 |
+
# we cast it to uint16, then reinterpret as float16, then cast to float32
|
| 72 |
+
depth = (
|
| 73 |
+
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
|
| 74 |
+
.astype(np.float32)
|
| 75 |
+
.reshape((depth_pil.size[1], depth_pil.size[0]))
|
| 76 |
+
)
|
| 77 |
+
return depth
|
| 78 |
+
|
| 79 |
+
def _get_pytorch3d_camera(
|
| 80 |
+
self, entry_viewpoint, image_size, scale: float
|
| 81 |
+
) -> PerspectiveCameras:
|
| 82 |
+
assert entry_viewpoint is not None
|
| 83 |
+
# principal point and focal length
|
| 84 |
+
principal_point = torch.tensor(
|
| 85 |
+
entry_viewpoint.principal_point, dtype=torch.float
|
| 86 |
+
)
|
| 87 |
+
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
|
| 88 |
+
|
| 89 |
+
half_image_size_wh_orig = (
|
| 90 |
+
torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# first, we convert from the dataset's NDC convention to pixels
|
| 94 |
+
format = entry_viewpoint.intrinsics_format
|
| 95 |
+
if format.lower() == "ndc_norm_image_bounds":
|
| 96 |
+
# this is e.g. currently used in CO3D for storing intrinsics
|
| 97 |
+
rescale = half_image_size_wh_orig
|
| 98 |
+
elif format.lower() == "ndc_isotropic":
|
| 99 |
+
rescale = half_image_size_wh_orig.min()
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unknown intrinsics format: {format}")
|
| 102 |
+
|
| 103 |
+
# principal point and focal length in pixels
|
| 104 |
+
principal_point_px = half_image_size_wh_orig - principal_point * rescale
|
| 105 |
+
focal_length_px = focal_length * rescale
|
| 106 |
+
|
| 107 |
+
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
|
| 108 |
+
# if self.image_height is None or self.image_width is None:
|
| 109 |
+
out_size = list(reversed(image_size))
|
| 110 |
+
|
| 111 |
+
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
|
| 112 |
+
half_min_image_size_output = half_image_size_output.min()
|
| 113 |
+
|
| 114 |
+
# rescaled principal point and focal length in ndc
|
| 115 |
+
principal_point = (
|
| 116 |
+
half_image_size_output - principal_point_px * scale
|
| 117 |
+
) / half_min_image_size_output
|
| 118 |
+
focal_length = focal_length_px * scale / half_min_image_size_output
|
| 119 |
+
|
| 120 |
+
return PerspectiveCameras(
|
| 121 |
+
focal_length=focal_length[None],
|
| 122 |
+
principal_point=principal_point[None],
|
| 123 |
+
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
|
| 124 |
+
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def _get_output_tensor(self, sample):
|
| 128 |
+
output_tensor = defaultdict(list)
|
| 129 |
+
sample_size = len(sample["image"]["left"])
|
| 130 |
+
output_tensor_keys = ["img", "disp", "valid_disp", "mask"]
|
| 131 |
+
add_keys = ["viewpoint", "metadata"]
|
| 132 |
+
for add_key in add_keys:
|
| 133 |
+
if add_key in sample:
|
| 134 |
+
output_tensor_keys.append(add_key)
|
| 135 |
+
|
| 136 |
+
for key in output_tensor_keys:
|
| 137 |
+
output_tensor[key] = [[] for _ in range(sample_size)]
|
| 138 |
+
|
| 139 |
+
if "viewpoint" in sample:
|
| 140 |
+
viewpoint_left = self._get_pytorch3d_camera(
|
| 141 |
+
sample["viewpoint"]["left"][0],
|
| 142 |
+
sample["metadata"]["left"][0][1],
|
| 143 |
+
scale=1.0,
|
| 144 |
+
)
|
| 145 |
+
viewpoint_right = self._get_pytorch3d_camera(
|
| 146 |
+
sample["viewpoint"]["right"][0],
|
| 147 |
+
sample["metadata"]["right"][0][1],
|
| 148 |
+
scale=1.0,
|
| 149 |
+
)
|
| 150 |
+
depth2disp_scale = depth2disparity_scale(
|
| 151 |
+
viewpoint_left,
|
| 152 |
+
viewpoint_right,
|
| 153 |
+
torch.Tensor(sample["metadata"]["left"][0][1])[None],
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
for i in range(sample_size):
|
| 157 |
+
for cam in ["left", "right"]:
|
| 158 |
+
if "mask" in sample and cam in sample["mask"]:
|
| 159 |
+
mask = frame_utils.read_gen(sample["mask"][cam][i])
|
| 160 |
+
mask = np.array(mask) / 255.0
|
| 161 |
+
output_tensor["mask"][i].append(mask)
|
| 162 |
+
|
| 163 |
+
if "viewpoint" in sample and cam in sample["viewpoint"]:
|
| 164 |
+
viewpoint = self._get_pytorch3d_camera(
|
| 165 |
+
sample["viewpoint"][cam][i],
|
| 166 |
+
sample["metadata"][cam][i][1],
|
| 167 |
+
scale=1.0,
|
| 168 |
+
)
|
| 169 |
+
output_tensor["viewpoint"][i].append(viewpoint)
|
| 170 |
+
|
| 171 |
+
if "metadata" in sample and cam in sample["metadata"]:
|
| 172 |
+
metadata = sample["metadata"][cam][i]
|
| 173 |
+
output_tensor["metadata"][i].append(metadata)
|
| 174 |
+
|
| 175 |
+
if cam in sample["image"]:
|
| 176 |
+
|
| 177 |
+
img = frame_utils.read_gen(sample["image"][cam][i])
|
| 178 |
+
img = np.array(img).astype(np.uint8)
|
| 179 |
+
|
| 180 |
+
# grayscale images
|
| 181 |
+
if len(img.shape) == 2:
|
| 182 |
+
img = np.tile(img[..., None], (1, 1, 3))
|
| 183 |
+
else:
|
| 184 |
+
img = img[..., :3]
|
| 185 |
+
output_tensor["img"][i].append(img)
|
| 186 |
+
|
| 187 |
+
if cam in sample["disparity"]:
|
| 188 |
+
disp = self.disparity_reader(sample["disparity"][cam][i])
|
| 189 |
+
if isinstance(disp, tuple):
|
| 190 |
+
disp, valid_disp = disp
|
| 191 |
+
else:
|
| 192 |
+
valid_disp = disp < 512
|
| 193 |
+
disp = np.array(disp).astype(np.float32)
|
| 194 |
+
|
| 195 |
+
disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)
|
| 196 |
+
|
| 197 |
+
output_tensor["disp"][i].append(disp)
|
| 198 |
+
output_tensor["valid_disp"][i].append(valid_disp)
|
| 199 |
+
|
| 200 |
+
elif "depth" in sample and cam in sample["depth"]:
|
| 201 |
+
depth = self.depth_reader(sample["depth"][cam][i])
|
| 202 |
+
|
| 203 |
+
depth_mask = depth < self.depth_eps
|
| 204 |
+
depth[depth_mask] = self.depth_eps
|
| 205 |
+
|
| 206 |
+
disp = depth2disp_scale / depth
|
| 207 |
+
disp[depth_mask] = 0
|
| 208 |
+
valid_disp = (disp < 512) * (1 - depth_mask)
|
| 209 |
+
|
| 210 |
+
disp = np.array(disp).astype(np.float32)
|
| 211 |
+
disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)
|
| 212 |
+
output_tensor["disp"][i].append(disp)
|
| 213 |
+
output_tensor["valid_disp"][i].append(valid_disp)
|
| 214 |
+
|
| 215 |
+
return output_tensor
|
| 216 |
+
|
| 217 |
+
def __getitem__(self, index):
|
| 218 |
+
im_tensor = {"img": None}
|
| 219 |
+
sample = self.sample_list[index]
|
| 220 |
+
if self.is_test:
|
| 221 |
+
sample_size = len(sample["image"]["left"])
|
| 222 |
+
im_tensor["img"] = [[] for _ in range(sample_size)]
|
| 223 |
+
for i in range(sample_size):
|
| 224 |
+
for cam in ["left", "right"]:
|
| 225 |
+
img = frame_utils.read_gen(sample["image"][cam][i])
|
| 226 |
+
img = np.array(img).astype(np.uint8)[..., :3]
|
| 227 |
+
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
| 228 |
+
im_tensor["img"][i].append(img)
|
| 229 |
+
im_tensor["img"] = torch.stack(im_tensor["img"])
|
| 230 |
+
return im_tensor, self.extra_info[index]
|
| 231 |
+
|
| 232 |
+
index = index % len(self.sample_list)
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
output_tensor = self._get_output_tensor(sample)
|
| 236 |
+
except:
|
| 237 |
+
logging.warning(f"Exception in loading sample {index}!")
|
| 238 |
+
index = np.random.randint(len(self.sample_list))
|
| 239 |
+
logging.info(f"New index is {index}")
|
| 240 |
+
sample = self.sample_list[index]
|
| 241 |
+
output_tensor = self._get_output_tensor(sample)
|
| 242 |
+
sample_size = len(sample["image"]["left"])
|
| 243 |
+
|
| 244 |
+
if self.augmentor is not None:
|
| 245 |
+
output_tensor["img"], output_tensor["disp"] = self.augmentor(
|
| 246 |
+
output_tensor["img"], output_tensor["disp"]
|
| 247 |
+
)
|
| 248 |
+
for i in range(sample_size):
|
| 249 |
+
for cam in (0, 1):
|
| 250 |
+
if cam < len(output_tensor["img"][i]):
|
| 251 |
+
img = (
|
| 252 |
+
torch.from_numpy(output_tensor["img"][i][cam])
|
| 253 |
+
.permute(2, 0, 1)
|
| 254 |
+
.float()
|
| 255 |
+
)
|
| 256 |
+
if self.img_pad is not None:
|
| 257 |
+
padH, padW = self.img_pad
|
| 258 |
+
img = F.pad(img, [padW] * 2 + [padH] * 2)
|
| 259 |
+
output_tensor["img"][i][cam] = img
|
| 260 |
+
|
| 261 |
+
if cam < len(output_tensor["disp"][i]):
|
| 262 |
+
disp = (
|
| 263 |
+
torch.from_numpy(output_tensor["disp"][i][cam])
|
| 264 |
+
.permute(2, 0, 1)
|
| 265 |
+
.float()
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if self.sparse:
|
| 269 |
+
valid_disp = torch.from_numpy(
|
| 270 |
+
output_tensor["valid_disp"][i][cam]
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
valid_disp = (
|
| 274 |
+
(disp[0].abs() < 512)
|
| 275 |
+
& (disp[1].abs() < 512)
|
| 276 |
+
& (disp[0].abs() != 0)
|
| 277 |
+
)
|
| 278 |
+
disp = disp[:1]
|
| 279 |
+
|
| 280 |
+
output_tensor["disp"][i][cam] = disp
|
| 281 |
+
output_tensor["valid_disp"][i][cam] = valid_disp.float()
|
| 282 |
+
|
| 283 |
+
if "mask" in output_tensor and cam < len(output_tensor["mask"][i]):
|
| 284 |
+
mask = torch.from_numpy(output_tensor["mask"][i][cam]).float()
|
| 285 |
+
output_tensor["mask"][i][cam] = mask
|
| 286 |
+
|
| 287 |
+
if "viewpoint" in output_tensor and cam < len(
|
| 288 |
+
output_tensor["viewpoint"][i]
|
| 289 |
+
):
|
| 290 |
+
viewpoint = output_tensor["viewpoint"][i][cam]
|
| 291 |
+
output_tensor["viewpoint"][i][cam] = viewpoint
|
| 292 |
+
|
| 293 |
+
res = {}
|
| 294 |
+
if "viewpoint" in output_tensor and self.split != "train":
|
| 295 |
+
res["viewpoint"] = output_tensor["viewpoint"]
|
| 296 |
+
if "metadata" in output_tensor and self.split != "train":
|
| 297 |
+
res["metadata"] = output_tensor["metadata"]
|
| 298 |
+
|
| 299 |
+
for k, v in output_tensor.items():
|
| 300 |
+
if k != "viewpoint" and k != "metadata":
|
| 301 |
+
for i in range(len(v)):
|
| 302 |
+
if len(v[i]) > 0:
|
| 303 |
+
v[i] = torch.stack(v[i])
|
| 304 |
+
if len(v) > 0 and (len(v[0]) > 0):
|
| 305 |
+
res[k] = torch.stack(v)
|
| 306 |
+
return res
|
| 307 |
+
|
| 308 |
+
def __mul__(self, v):
|
| 309 |
+
copy_of_self = copy.deepcopy(self)
|
| 310 |
+
copy_of_self.sample_list = v * copy_of_self.sample_list
|
| 311 |
+
copy_of_self.extra_info = v * copy_of_self.extra_info
|
| 312 |
+
return copy_of_self
|
| 313 |
+
|
| 314 |
+
def __len__(self):
|
| 315 |
+
return len(self.sample_list)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class DynamicReplicaDataset(StereoSequenceDataset):
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
aug_params=None,
|
| 322 |
+
root="./dynamic_replica_data",
|
| 323 |
+
split="train",
|
| 324 |
+
sample_len=-1,
|
| 325 |
+
only_first_n_samples=-1,
|
| 326 |
+
t_step_validation=1, # -- Added by Chu King on 24th November 2025 to control the separation between consecutive samples in validation
|
| 327 |
+
VERBOSE=False # -- Added by Chu King on 16th November 2025 for debugging purposes
|
| 328 |
+
):
|
| 329 |
+
super(DynamicReplicaDataset, self).__init__(aug_params)
|
| 330 |
+
self.root = root
|
| 331 |
+
self.sample_len = sample_len
|
| 332 |
+
self.split = split
|
| 333 |
+
|
| 334 |
+
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
| 335 |
+
|
| 336 |
+
with gzip.open(
|
| 337 |
+
osp.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
| 338 |
+
) as zipfile:
|
| 339 |
+
frame_annots_list = load_dataclass(
|
| 340 |
+
zipfile, List[DynamicReplicaFrameAnnotation]
|
| 341 |
+
)
|
| 342 |
+
seq_annot = defaultdict(lambda: defaultdict(list))
|
| 343 |
+
for frame_annot in frame_annots_list:
|
| 344 |
+
seq_annot[frame_annot.sequence_name][frame_annot.camera_name].append(
|
| 345 |
+
frame_annot
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# -- Added by Chu King on 16th November 2025 for debugging purposes
|
| 349 |
+
if VERBOSE:
|
| 350 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 351 |
+
with open(f"debug_rank_{rank}.txt", "a") as f:
|
| 352 |
+
f.write("[INFO] seq_annot: {}\n".format(seq_annot))
|
| 353 |
+
# -- os.kill(os.getpid(), signal.SIGABRT)
|
| 354 |
+
|
| 355 |
+
for seq_name in seq_annot.keys():
|
| 356 |
+
|
| 357 |
+
# -- Added by Chu King on 16th November 2025 for debugging purposes
|
| 358 |
+
if VERBOSE:
|
| 359 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 360 |
+
with open(f"debug_rank_{rank}.txt", "a") as f:
|
| 361 |
+
f.write("---- ----\n")
|
| 362 |
+
f.write("[INFO] seq_name: {}\n".format(seq_name))
|
| 363 |
+
|
| 364 |
+
try:
|
| 365 |
+
filenames = defaultdict(lambda: defaultdict(list))
|
| 366 |
+
for cam in ["left", "right"]:
|
| 367 |
+
for framedata in seq_annot[seq_name][cam]:
|
| 368 |
+
im_path = osp.join(root, split, framedata.image.path)
|
| 369 |
+
depth_path = osp.join(root, split, framedata.depth.path)
|
| 370 |
+
mask_path = osp.join(root, split, framedata.mask.path)
|
| 371 |
+
|
| 372 |
+
# -- Added by Chu King on 16th November 2025 for debugging purposes
|
| 373 |
+
if VERBOSE:
|
| 374 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 375 |
+
with open(f"debug_rank_{rank}.txt", "a") as f:
|
| 376 |
+
f.write("[INFO] cam: {}\n".format(cam))
|
| 377 |
+
f.write("[INFO] framedata: {}\n".format(framedata))
|
| 378 |
+
f.write("[INFO] framedata.viewpoint: {}\n".format(framedata.viewpoint))
|
| 379 |
+
f.write("[INFO] im_path: {}\n".format(im_path))
|
| 380 |
+
f.write("[INFO] depth_path: {}\n".format(depth_path))
|
| 381 |
+
f.write("[INFO] mask_path: {}\n".format(mask_path))
|
| 382 |
+
|
| 383 |
+
# -- Modified by Chu King on 16th November 2025 to clarify the nature of assertion errors.
|
| 384 |
+
assert os.path.isfile(im_path), "[ERROR] Rectified image path {} doesn't exist.".format(im_path)
|
| 385 |
+
|
| 386 |
+
tokens = root.split("/")
|
| 387 |
+
# -- if split != "test" and "real" not in tokens:
|
| 388 |
+
# -- assert os.path.isfile(depth_path), "[ERROR] Depth path {} doesn't exist. ".format(depth_path)
|
| 389 |
+
if not os.path.isfile(depth_path):
|
| 390 |
+
if split != "test" or "real" not in tokens:
|
| 391 |
+
print ("[WARNING] Depth path {} doesn't exist.".format(depth_path))
|
| 392 |
+
|
| 393 |
+
assert os.path.isfile(mask_path), "[ERROR] Mask path {} doesn't exist.".format(mask_path)
|
| 394 |
+
|
| 395 |
+
filenames["image"][cam].append(im_path)
|
| 396 |
+
filenames["mask"][cam].append(mask_path)
|
| 397 |
+
filenames["depth"][cam].append(depth_path)
|
| 398 |
+
filenames["viewpoint"][cam].append(framedata.viewpoint)
|
| 399 |
+
filenames["metadata"][cam].append(
|
| 400 |
+
[framedata.sequence_name, framedata.image.size]
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
for k in filenames.keys():
|
| 404 |
+
assert (
|
| 405 |
+
len(filenames[k][cam])
|
| 406 |
+
== len(filenames["image"][cam])
|
| 407 |
+
> 0
|
| 408 |
+
), framedata.sequence_name
|
| 409 |
+
|
| 410 |
+
if not os.path.isfile(depth_path):
|
| 411 |
+
del filenames["depth"]
|
| 412 |
+
|
| 413 |
+
seq_len = len(filenames["image"][cam])
|
| 414 |
+
|
| 415 |
+
print("seq_len", seq_name, seq_len)
|
| 416 |
+
if split == "train":
|
| 417 |
+
for ref_idx in range(0, seq_len, 3):
|
| 418 |
+
# -- step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
|
| 419 |
+
# -- Modified by Chu King on 24th November 2025 to handle high-speed motion.
|
| 420 |
+
step = 1 if self.sample_len == 1 else np.random.randint(1, 12)
|
| 421 |
+
if ref_idx + step * self.sample_len < seq_len:
|
| 422 |
+
sample = defaultdict(lambda: defaultdict(list))
|
| 423 |
+
for cam in ["left", "right"]:
|
| 424 |
+
for idx in range(
|
| 425 |
+
ref_idx, ref_idx + step * self.sample_len, step
|
| 426 |
+
):
|
| 427 |
+
for k in filenames.keys():
|
| 428 |
+
if "mask" not in k:
|
| 429 |
+
sample[k][cam].append(
|
| 430 |
+
filenames[k][cam][idx]
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
self.sample_list.append(sample)
|
| 434 |
+
else:
|
| 435 |
+
step = self.sample_len if self.sample_len > 0 else seq_len
|
| 436 |
+
counter = 0
|
| 437 |
+
|
| 438 |
+
for ref_idx in range(0, seq_len, step):
|
| 439 |
+
sample = defaultdict(lambda: defaultdict(list))
|
| 440 |
+
for cam in ["left", "right"]:
|
| 441 |
+
# -- Modified by Chu King on 24th November 2025 to control the separation between samples during validation.
|
| 442 |
+
# -- for idx in range(ref_idx, ref_idx + step):
|
| 443 |
+
for idx in range(ref_idx, ref_idx + step * t_step_validation, t_step_validation):
|
| 444 |
+
for k in filenames.keys():
|
| 445 |
+
sample[k][cam].append(filenames[k][cam][idx])
|
| 446 |
+
|
| 447 |
+
self.sample_list.append(sample)
|
| 448 |
+
counter += 1
|
| 449 |
+
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
| 450 |
+
break
|
| 451 |
+
except Exception as e:
|
| 452 |
+
print(e)
|
| 453 |
+
print("Skipping sequence", seq_name)
|
| 454 |
+
|
| 455 |
+
assert len(self.sample_list) > 0, "No samples found"
|
| 456 |
+
print(f"Added {len(self.sample_list)} from Dynamic Replica {split}")
|
| 457 |
+
logging.info(f"Added {len(self.sample_list)} from Dynamic Replica {split}")
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class SequenceSceneFlowDataset(StereoSequenceDataset):
|
| 461 |
+
def __init__(
|
| 462 |
+
self,
|
| 463 |
+
aug_params=None,
|
| 464 |
+
root="./datasets",
|
| 465 |
+
dstype="frames_cleanpass",
|
| 466 |
+
sample_len=1,
|
| 467 |
+
things_test=False,
|
| 468 |
+
add_things=True,
|
| 469 |
+
add_monkaa=True,
|
| 470 |
+
add_driving=True,
|
| 471 |
+
):
|
| 472 |
+
super(SequenceSceneFlowDataset, self).__init__(aug_params)
|
| 473 |
+
self.root = root
|
| 474 |
+
self.dstype = dstype
|
| 475 |
+
self.sample_len = sample_len
|
| 476 |
+
if things_test:
|
| 477 |
+
self._add_things("TEST")
|
| 478 |
+
else:
|
| 479 |
+
if add_things:
|
| 480 |
+
self._add_things("TRAIN")
|
| 481 |
+
if add_monkaa:
|
| 482 |
+
self._add_monkaa()
|
| 483 |
+
if add_driving:
|
| 484 |
+
self._add_driving()
|
| 485 |
+
|
| 486 |
+
def _add_things(self, split="TRAIN"):
|
| 487 |
+
"""Add FlyingThings3D data"""
|
| 488 |
+
|
| 489 |
+
original_length = len(self.sample_list)
|
| 490 |
+
root = osp.join(self.root, "FlyingThings3D")
|
| 491 |
+
image_paths = defaultdict(list)
|
| 492 |
+
disparity_paths = defaultdict(list)
|
| 493 |
+
|
| 494 |
+
for cam in ["left", "right"]:
|
| 495 |
+
image_paths[cam] = sorted(
|
| 496 |
+
glob(osp.join(root, self.dstype, split, f"*/*/{cam}/"))
|
| 497 |
+
)
|
| 498 |
+
disparity_paths[cam] = [
|
| 499 |
+
path.replace(self.dstype, "disparity") for path in image_paths[cam]
|
| 500 |
+
]
|
| 501 |
+
|
| 502 |
+
# Choose a random subset of 400 images for validation
|
| 503 |
+
state = np.random.get_state()
|
| 504 |
+
np.random.seed(1000)
|
| 505 |
+
val_idxs = set(np.random.permutation(len(image_paths["left"]))[:40])
|
| 506 |
+
np.random.set_state(state)
|
| 507 |
+
np.random.seed(0)
|
| 508 |
+
num_seq = len(image_paths["left"])
|
| 509 |
+
|
| 510 |
+
for seq_idx in range(num_seq):
|
| 511 |
+
if (split == "TEST" and seq_idx in val_idxs) or (
|
| 512 |
+
split == "TRAIN" and not seq_idx in val_idxs
|
| 513 |
+
):
|
| 514 |
+
images, disparities = defaultdict(list), defaultdict(list)
|
| 515 |
+
for cam in ["left", "right"]:
|
| 516 |
+
images[cam] = sorted(
|
| 517 |
+
glob(osp.join(image_paths[cam][seq_idx], "*.png"))
|
| 518 |
+
)
|
| 519 |
+
disparities[cam] = sorted(
|
| 520 |
+
glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
self._append_sample(images, disparities)
|
| 524 |
+
|
| 525 |
+
assert len(self.sample_list) > 0, "No samples found"
|
| 526 |
+
print(
|
| 527 |
+
f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}"
|
| 528 |
+
)
|
| 529 |
+
logging.info(
|
| 530 |
+
f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}"
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
def _add_monkaa(self):
|
| 534 |
+
"""Add FlyingThings3D data"""
|
| 535 |
+
|
| 536 |
+
original_length = len(self.sample_list)
|
| 537 |
+
root = osp.join(self.root, "Monkaa")
|
| 538 |
+
image_paths = defaultdict(list)
|
| 539 |
+
disparity_paths = defaultdict(list)
|
| 540 |
+
|
| 541 |
+
for cam in ["left", "right"]:
|
| 542 |
+
image_paths[cam] = sorted(glob(osp.join(root, self.dstype, f"*/{cam}/")))
|
| 543 |
+
disparity_paths[cam] = [
|
| 544 |
+
path.replace(self.dstype, "disparity") for path in image_paths[cam]
|
| 545 |
+
]
|
| 546 |
+
|
| 547 |
+
num_seq = len(image_paths["left"])
|
| 548 |
+
|
| 549 |
+
for seq_idx in range(num_seq):
|
| 550 |
+
images, disparities = defaultdict(list), defaultdict(list)
|
| 551 |
+
for cam in ["left", "right"]:
|
| 552 |
+
images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png")))
|
| 553 |
+
disparities[cam] = sorted(
|
| 554 |
+
glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
self._append_sample(images, disparities)
|
| 558 |
+
|
| 559 |
+
assert len(self.sample_list) > 0, "No samples found"
|
| 560 |
+
print(
|
| 561 |
+
f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}"
|
| 562 |
+
)
|
| 563 |
+
logging.info(
|
| 564 |
+
f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}"
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
def _add_driving(self):
|
| 568 |
+
"""Add FlyingThings3D data"""
|
| 569 |
+
|
| 570 |
+
original_length = len(self.sample_list)
|
| 571 |
+
root = osp.join(self.root, "Driving")
|
| 572 |
+
image_paths = defaultdict(list)
|
| 573 |
+
disparity_paths = defaultdict(list)
|
| 574 |
+
|
| 575 |
+
for cam in ["left", "right"]:
|
| 576 |
+
image_paths[cam] = sorted(
|
| 577 |
+
glob(osp.join(root, self.dstype, f"*/*/*/{cam}/"))
|
| 578 |
+
)
|
| 579 |
+
disparity_paths[cam] = [
|
| 580 |
+
path.replace(self.dstype, "disparity") for path in image_paths[cam]
|
| 581 |
+
]
|
| 582 |
+
|
| 583 |
+
num_seq = len(image_paths["left"])
|
| 584 |
+
for seq_idx in range(num_seq):
|
| 585 |
+
images, disparities = defaultdict(list), defaultdict(list)
|
| 586 |
+
for cam in ["left", "right"]:
|
| 587 |
+
images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png")))
|
| 588 |
+
disparities[cam] = sorted(
|
| 589 |
+
glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
self._append_sample(images, disparities)
|
| 593 |
+
|
| 594 |
+
assert len(self.sample_list) > 0, "No samples found"
|
| 595 |
+
print(
|
| 596 |
+
f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}"
|
| 597 |
+
)
|
| 598 |
+
logging.info(
|
| 599 |
+
f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
def _append_sample(self, images, disparities):
|
| 603 |
+
seq_len = len(images["left"])
|
| 604 |
+
for ref_idx in range(0, seq_len - self.sample_len):
|
| 605 |
+
sample = defaultdict(lambda: defaultdict(list))
|
| 606 |
+
for cam in ["left", "right"]:
|
| 607 |
+
for idx in range(ref_idx, ref_idx + self.sample_len):
|
| 608 |
+
sample["image"][cam].append(images[cam][idx])
|
| 609 |
+
sample["disparity"][cam].append(disparities[cam][idx])
|
| 610 |
+
self.sample_list.append(sample)
|
| 611 |
+
|
| 612 |
+
sample = defaultdict(lambda: defaultdict(list))
|
| 613 |
+
for cam in ["left", "right"]:
|
| 614 |
+
for idx in range(ref_idx, ref_idx + self.sample_len):
|
| 615 |
+
sample["image"][cam].append(images[cam][seq_len - idx - 1])
|
| 616 |
+
sample["disparity"][cam].append(disparities[cam][seq_len - idx - 1])
|
| 617 |
+
self.sample_list.append(sample)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class SequenceSintelStereo(StereoSequenceDataset):
|
| 621 |
+
def __init__(
|
| 622 |
+
self,
|
| 623 |
+
dstype="clean",
|
| 624 |
+
aug_params=None,
|
| 625 |
+
root="./datasets",
|
| 626 |
+
):
|
| 627 |
+
super().__init__(
|
| 628 |
+
aug_params, sparse=True, reader=frame_utils.readDispSintelStereo
|
| 629 |
+
)
|
| 630 |
+
self.dstype = dstype
|
| 631 |
+
original_length = len(self.sample_list)
|
| 632 |
+
image_root = osp.join(root, "sintel_stereo", "training")
|
| 633 |
+
|
| 634 |
+
image_paths = defaultdict(list)
|
| 635 |
+
disparity_paths = defaultdict(list)
|
| 636 |
+
|
| 637 |
+
for cam in ["left", "right"]:
|
| 638 |
+
image_paths[cam] = sorted(
|
| 639 |
+
glob(osp.join(image_root, f"{self.dstype}_{cam}/*"))
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
cam = "left"
|
| 643 |
+
disparity_paths[cam] = [
|
| 644 |
+
path.replace(f"{self.dstype}_{cam}", "disparities")
|
| 645 |
+
for path in image_paths[cam]
|
| 646 |
+
]
|
| 647 |
+
|
| 648 |
+
num_seq = len(image_paths["left"])
|
| 649 |
+
# for each sequence
|
| 650 |
+
for seq_idx in range(num_seq):
|
| 651 |
+
sample = defaultdict(lambda: defaultdict(list))
|
| 652 |
+
for cam in ["left", "right"]:
|
| 653 |
+
sample["image"][cam] = sorted(
|
| 654 |
+
glob(osp.join(image_paths[cam][seq_idx], "*.png"))
|
| 655 |
+
)
|
| 656 |
+
cam = "left"
|
| 657 |
+
sample["disparity"][cam] = sorted(
|
| 658 |
+
glob(osp.join(disparity_paths[cam][seq_idx], "*.png"))
|
| 659 |
+
)
|
| 660 |
+
for im1, disp in zip(sample["image"][cam], sample["disparity"][cam]):
|
| 661 |
+
assert (
|
| 662 |
+
im1.split("/")[-1].split(".")[0]
|
| 663 |
+
== disp.split("/")[-1].split(".")[0]
|
| 664 |
+
), (im1.split("/")[-1].split(".")[0], disp.split("/")[-1].split(".")[0])
|
| 665 |
+
self.sample_list.append(sample)
|
| 666 |
+
|
| 667 |
+
logging.info(
|
| 668 |
+
f"Added {len(self.sample_list) - original_length} from SintelStereo {self.dstype}"
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def fetch_dataloader(args):
|
| 673 |
+
"""Create the data loader for the corresponding trainign set"""
|
| 674 |
+
|
| 675 |
+
aug_params = {
|
| 676 |
+
"crop_size": args.image_size,
|
| 677 |
+
"min_scale": args.spatial_scale[0],
|
| 678 |
+
"max_scale": args.spatial_scale[1],
|
| 679 |
+
"do_flip": False,
|
| 680 |
+
"yjitter": not args.noyjitter,
|
| 681 |
+
}
|
| 682 |
+
if hasattr(args, "saturation_range") and args.saturation_range is not None:
|
| 683 |
+
aug_params["saturation_range"] = args.saturation_range
|
| 684 |
+
if hasattr(args, "img_gamma") and args.img_gamma is not None:
|
| 685 |
+
aug_params["gamma"] = args.img_gamma
|
| 686 |
+
if hasattr(args, "do_flip") and args.do_flip is not None:
|
| 687 |
+
aug_params["do_flip"] = args.do_flip
|
| 688 |
+
|
| 689 |
+
train_dataset = None
|
| 690 |
+
|
| 691 |
+
add_monkaa = "monkaa" in args.train_datasets
|
| 692 |
+
add_driving = "driving" in args.train_datasets
|
| 693 |
+
add_things = "things" in args.train_datasets
|
| 694 |
+
add_dynamic_replica = "dynamic_replica" in args.train_datasets
|
| 695 |
+
|
| 696 |
+
new_dataset = None
|
| 697 |
+
|
| 698 |
+
if add_monkaa or add_driving or add_things:
|
| 699 |
+
clean_dataset = SequenceSceneFlowDataset(
|
| 700 |
+
aug_params,
|
| 701 |
+
dstype="frames_cleanpass",
|
| 702 |
+
sample_len=args.sample_len,
|
| 703 |
+
add_monkaa=add_monkaa,
|
| 704 |
+
add_driving=add_driving,
|
| 705 |
+
add_things=add_things,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
final_dataset = SequenceSceneFlowDataset(
|
| 709 |
+
aug_params,
|
| 710 |
+
dstype="frames_finalpass",
|
| 711 |
+
sample_len=args.sample_len,
|
| 712 |
+
add_monkaa=add_monkaa,
|
| 713 |
+
add_driving=add_driving,
|
| 714 |
+
add_things=add_things,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
new_dataset = clean_dataset + final_dataset
|
| 718 |
+
|
| 719 |
+
if add_dynamic_replica:
|
| 720 |
+
dr_dataset = DynamicReplicaDataset(
|
| 721 |
+
aug_params, split="train", sample_len=args.sample_len
|
| 722 |
+
)
|
| 723 |
+
if new_dataset is None:
|
| 724 |
+
new_dataset = dr_dataset
|
| 725 |
+
else:
|
| 726 |
+
new_dataset = new_dataset + dr_dataset
|
| 727 |
+
|
| 728 |
+
logging.info(f"Adding {len(new_dataset)} samples from SceneFlow")
|
| 729 |
+
train_dataset = (
|
| 730 |
+
new_dataset if train_dataset is None else train_dataset + new_dataset
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
train_loader = data.DataLoader(
|
| 734 |
+
train_dataset,
|
| 735 |
+
batch_size=args.batch_size,
|
| 736 |
+
pin_memory=True,
|
| 737 |
+
shuffle=True,
|
| 738 |
+
num_workers=args.num_workers,
|
| 739 |
+
drop_last=True,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
logging.info("Training with %d image pairs" % len(train_dataset))
|
| 743 |
+
return train_loader
|
datasets/frame_utils.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from os.path import *
|
| 10 |
+
import re
|
| 11 |
+
import imageio
|
| 12 |
+
import cv2
|
| 13 |
+
|
| 14 |
+
cv2.setNumThreads(0)
|
| 15 |
+
cv2.ocl.setUseOpenCL(False)
|
| 16 |
+
|
| 17 |
+
TAG_CHAR = np.array([202021.25], np.float32)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def readFlow(fn):
|
| 21 |
+
"""Read .flo file in Middlebury format"""
|
| 22 |
+
# Code adapted from:
|
| 23 |
+
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
| 24 |
+
|
| 25 |
+
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
| 26 |
+
# print 'fn = %s'%(fn)
|
| 27 |
+
with open(fn, "rb") as f:
|
| 28 |
+
magic = np.fromfile(f, np.float32, count=1)
|
| 29 |
+
if 202021.25 != magic:
|
| 30 |
+
print("Magic number incorrect. Invalid .flo file")
|
| 31 |
+
return None
|
| 32 |
+
else:
|
| 33 |
+
w = np.fromfile(f, np.int32, count=1)
|
| 34 |
+
h = np.fromfile(f, np.int32, count=1)
|
| 35 |
+
# print 'Reading %d x %d flo file\n' % (w, h)
|
| 36 |
+
data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
|
| 37 |
+
# Reshape data into 3D array (columns, rows, bands)
|
| 38 |
+
# The reshape here is for visualization, the original code is (w,h,2)
|
| 39 |
+
return np.resize(data, (int(h), int(w), 2))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def readPFM(file):
|
| 43 |
+
file = open(file, "rb")
|
| 44 |
+
|
| 45 |
+
color = None
|
| 46 |
+
width = None
|
| 47 |
+
height = None
|
| 48 |
+
scale = None
|
| 49 |
+
endian = None
|
| 50 |
+
|
| 51 |
+
header = file.readline().rstrip()
|
| 52 |
+
if header == b"PF":
|
| 53 |
+
color = True
|
| 54 |
+
elif header == b"Pf":
|
| 55 |
+
color = False
|
| 56 |
+
else:
|
| 57 |
+
raise Exception("Not a PFM file.")
|
| 58 |
+
|
| 59 |
+
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
|
| 60 |
+
if dim_match:
|
| 61 |
+
width, height = map(int, dim_match.groups())
|
| 62 |
+
else:
|
| 63 |
+
raise Exception("Malformed PFM header.")
|
| 64 |
+
|
| 65 |
+
scale = float(file.readline().rstrip())
|
| 66 |
+
if scale < 0: # little-endian
|
| 67 |
+
endian = "<"
|
| 68 |
+
scale = -scale
|
| 69 |
+
else:
|
| 70 |
+
endian = ">" # big-endian
|
| 71 |
+
|
| 72 |
+
data = np.fromfile(file, endian + "f")
|
| 73 |
+
shape = (height, width, 3) if color else (height, width)
|
| 74 |
+
|
| 75 |
+
data = np.reshape(data, shape)
|
| 76 |
+
data = np.flipud(data)
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def readDispSintelStereo(file_name):
|
| 81 |
+
"""Return disparity read from filename."""
|
| 82 |
+
f_in = np.array(Image.open(file_name))
|
| 83 |
+
d_r = f_in[:, :, 0].astype("float64")
|
| 84 |
+
d_g = f_in[:, :, 1].astype("float64")
|
| 85 |
+
d_b = f_in[:, :, 2].astype("float64")
|
| 86 |
+
|
| 87 |
+
disp = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14)
|
| 88 |
+
mask = np.array(Image.open(file_name.replace("disparities", "occlusions")))
|
| 89 |
+
valid = (mask == 0) & (disp > 0)
|
| 90 |
+
return disp, valid
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def readDispMiddlebury(file_name):
|
| 94 |
+
assert basename(file_name) == "disp0GT.pfm"
|
| 95 |
+
disp = readPFM(file_name).astype(np.float32)
|
| 96 |
+
assert len(disp.shape) == 2
|
| 97 |
+
nocc_pix = file_name.replace("disp0GT.pfm", "mask0nocc.png")
|
| 98 |
+
assert exists(nocc_pix)
|
| 99 |
+
nocc_pix = imageio.imread(nocc_pix) == 255
|
| 100 |
+
assert np.any(nocc_pix)
|
| 101 |
+
return disp, nocc_pix
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def read_gen(file_name, pil=False):
|
| 105 |
+
ext = splitext(file_name)[-1]
|
| 106 |
+
if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
|
| 107 |
+
return Image.open(file_name)
|
| 108 |
+
elif ext == ".bin" or ext == ".raw":
|
| 109 |
+
return np.load(file_name)
|
| 110 |
+
elif ext == ".flo":
|
| 111 |
+
return readFlow(file_name).astype(np.float32)
|
| 112 |
+
elif ext == ".pfm":
|
| 113 |
+
flow = readPFM(file_name).astype(np.float32)
|
| 114 |
+
if len(flow.shape) == 2:
|
| 115 |
+
return flow
|
| 116 |
+
else:
|
| 117 |
+
return flow[:, :, :-1]
|
| 118 |
+
return []
|
evaluation/configs/eval_dynamic_replica_150_frames.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default_config_eval
|
| 3 |
+
visualize_interval: 0
|
| 4 |
+
exp_dir: ./outputs/dynamic_stereo_DR
|
| 5 |
+
sample_len: 150
|
| 6 |
+
MODEL:
|
| 7 |
+
model_name: DynamicStereoModel
|
| 8 |
+
|
evaluation/configs/eval_dynamic_replica_40_frames.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default_config_eval
|
| 3 |
+
visualize_interval: 0
|
| 4 |
+
exp_dir: ./outputs/dynamic_stereo_DR
|
| 5 |
+
sample_len: 40
|
| 6 |
+
MODEL:
|
| 7 |
+
model_name: DynamicStereoModel
|
| 8 |
+
|
evaluation/configs/eval_real_data.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default_config_eval
|
| 3 |
+
visualize_interval: 1
|
| 4 |
+
exp_dir: ./outputs/dynamic_stereo_real
|
| 5 |
+
dataset_name: real
|
| 6 |
+
sample_len: 40
|
| 7 |
+
MODEL:
|
| 8 |
+
model_name: DynamicStereoModel
|
| 9 |
+
|
evaluation/configs/eval_sintel_clean.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default_config_eval
|
| 3 |
+
visualize_interval: -1
|
| 4 |
+
exp_dir: ./outputs/dynamic_stereo_sintel_clean
|
| 5 |
+
sample_len: 30
|
| 6 |
+
dataset_name: sintel
|
| 7 |
+
dstype: clean
|
| 8 |
+
MODEL:
|
| 9 |
+
model_name: DynamicStereoModel
|
evaluation/configs/eval_sintel_final.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default_config_eval
|
| 3 |
+
visualize_interval: -1
|
| 4 |
+
exp_dir: ./outputs/dynamic_stereo_sintel_final
|
| 5 |
+
sample_len: 30
|
| 6 |
+
dataset_name: sintel
|
| 7 |
+
dstype: final
|
| 8 |
+
MODEL:
|
| 9 |
+
model_name: DynamicStereoModel
|
evaluation/core/evaluator.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from omegaconf import DictConfig
|
| 13 |
+
from pytorch3d.implicitron.tools.config import Configurable
|
| 14 |
+
|
| 15 |
+
from evaluation.utils.eval_utils import depth2disparity_scale, eval_batch
|
| 16 |
+
from evaluation.utils.utils import (
|
| 17 |
+
PerceptionPrediction,
|
| 18 |
+
pretty_print_perception_metrics,
|
| 19 |
+
visualize_batch,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Evaluator(Configurable):
|
| 24 |
+
"""
|
| 25 |
+
A class defining the DynamicStereo evaluator.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
eps: Threshold for converting disparity to depth.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
eps = 1e-5
|
| 32 |
+
|
| 33 |
+
def setup_visualization(self, cfg: DictConfig) -> None:
|
| 34 |
+
# Visualization
|
| 35 |
+
self.visualize_interval = cfg.visualize_interval
|
| 36 |
+
self.exp_dir = cfg.exp_dir
|
| 37 |
+
if self.visualize_interval > 0:
|
| 38 |
+
self.visualize_dir = os.path.join(cfg.exp_dir, "visualisations")
|
| 39 |
+
|
| 40 |
+
@torch.no_grad()
|
| 41 |
+
def evaluate_sequence(
|
| 42 |
+
self,
|
| 43 |
+
sci_enc_L,
|
| 44 |
+
sci_enc_R,
|
| 45 |
+
model,
|
| 46 |
+
test_dataloader: torch.utils.data.DataLoader,
|
| 47 |
+
is_real_data: bool = False,
|
| 48 |
+
step=None,
|
| 49 |
+
writer=None,
|
| 50 |
+
train_mode=False,
|
| 51 |
+
interp_shape=None,
|
| 52 |
+
resolution=[480, 640]
|
| 53 |
+
):
|
| 54 |
+
# -- Modified by Chu King on 20th November 2025 for SCI Stereo.
|
| 55 |
+
# -- model.eval()
|
| 56 |
+
|
| 57 |
+
per_batch_eval_results = []
|
| 58 |
+
|
| 59 |
+
if self.visualize_interval > 0:
|
| 60 |
+
os.makedirs(self.visualize_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
for batch_idx, sequence in enumerate(tqdm(test_dataloader)):
|
| 63 |
+
batch_dict = defaultdict(list)
|
| 64 |
+
batch_dict["stereo_video"] = sequence["img"]
|
| 65 |
+
if not is_real_data:
|
| 66 |
+
batch_dict["disparity"] = sequence["disp"][:, 0].abs()
|
| 67 |
+
batch_dict["disparity_mask"] = sequence["valid_disp"][:, :1] # ~ (T, 1, 720, 1280)
|
| 68 |
+
|
| 69 |
+
if "mask" in sequence:
|
| 70 |
+
batch_dict["fg_mask"] = sequence["mask"][:, :1]
|
| 71 |
+
else:
|
| 72 |
+
batch_dict["fg_mask"] = torch.ones_like(
|
| 73 |
+
batch_dict["disparity_mask"]
|
| 74 |
+
)
|
| 75 |
+
elif interp_shape is not None:
|
| 76 |
+
left_video = batch_dict["stereo_video"][:, 0]
|
| 77 |
+
left_video = F.interpolate(
|
| 78 |
+
left_video, tuple(interp_shape), mode="bilinear"
|
| 79 |
+
)
|
| 80 |
+
right_video = batch_dict["stereo_video"][:, 1]
|
| 81 |
+
right_video = F.interpolate(
|
| 82 |
+
right_video, tuple(interp_shape), mode="bilinear"
|
| 83 |
+
)
|
| 84 |
+
batch_dict["stereo_video"] = torch.stack([left_video, right_video], 1)
|
| 85 |
+
|
| 86 |
+
# -- This method is always invoked with train_mode=True.
|
| 87 |
+
if train_mode:
|
| 88 |
+
# -- Modified by Chu King on 20th November 2025.
|
| 89 |
+
# -- predictions = model.forward_batch_test(batch_dict)
|
| 90 |
+
predictions = model.forward_batch_test(batch_dict, sci_enc_L, sci_enc_R)
|
| 91 |
+
else:
|
| 92 |
+
predictions = model(batch_dict)
|
| 93 |
+
|
| 94 |
+
assert "disparity" in predictions
|
| 95 |
+
predictions["disparity"] = predictions["disparity"][:, :1].clone().cpu()
|
| 96 |
+
|
| 97 |
+
# -- print ("[INFO] predictions[\"disparity\"].shape", predictions["disparity"].shape)
|
| 98 |
+
# -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].shape)
|
| 99 |
+
# -- print ("[INFO] batch_dict[\"disparity_mask\"][..., :resolution[0], :resolution[1]].round().shape", batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round().shape)
|
| 100 |
+
|
| 101 |
+
if not is_real_data:
|
| 102 |
+
predictions["disparity"] = predictions["disparity"] * (
|
| 103 |
+
# -- Modified by Chu King on 22nd November 2025
|
| 104 |
+
# -- batch_dict["disparity_mask"].round()
|
| 105 |
+
batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]].round()
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
batch_eval_result, seq_length = eval_batch(batch_dict, predictions)
|
| 109 |
+
|
| 110 |
+
per_batch_eval_results.append((batch_eval_result, seq_length))
|
| 111 |
+
pretty_print_perception_metrics(batch_eval_result)
|
| 112 |
+
|
| 113 |
+
if (self.visualize_interval > 0) and (
|
| 114 |
+
batch_idx % self.visualize_interval == 0
|
| 115 |
+
):
|
| 116 |
+
perception_prediction = PerceptionPrediction()
|
| 117 |
+
|
| 118 |
+
pred_disp = predictions["disparity"]
|
| 119 |
+
pred_disp[pred_disp < self.eps] = self.eps
|
| 120 |
+
|
| 121 |
+
scale = depth2disparity_scale(
|
| 122 |
+
sequence["viewpoint"][0][0],
|
| 123 |
+
sequence["viewpoint"][0][1],
|
| 124 |
+
torch.tensor([pred_disp.shape[2], pred_disp.shape[3]])[None],
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
perception_prediction.depth_map = (scale / pred_disp).cuda()
|
| 128 |
+
perspective_cameras = []
|
| 129 |
+
for cam in sequence["viewpoint"]:
|
| 130 |
+
perspective_cameras.append(cam[0])
|
| 131 |
+
|
| 132 |
+
perception_prediction.perspective_cameras = perspective_cameras
|
| 133 |
+
|
| 134 |
+
# -- Modified by Chu King on 22nd November 2025 to fix image resolution during training.
|
| 135 |
+
if "stereo_original_video" in batch_dict:
|
| 136 |
+
batch_dict["stereo_video"] = batch_dict["stereo_original_video"][..., :resolution[0], :resolution[1]].clone()
|
| 137 |
+
|
| 138 |
+
for k, v in batch_dict.items():
|
| 139 |
+
if isinstance(v, torch.Tensor):
|
| 140 |
+
batch_dict[k] = v.cuda()
|
| 141 |
+
|
| 142 |
+
visualize_batch(
|
| 143 |
+
batch_dict,
|
| 144 |
+
perception_prediction,
|
| 145 |
+
output_dir=self.visualize_dir,
|
| 146 |
+
sequence_name=sequence["metadata"][0][0][0],
|
| 147 |
+
step=step,
|
| 148 |
+
writer=writer,
|
| 149 |
+
# -- Added by Chu King on 22nd November 2025 to fix image resolution during evaluation.
|
| 150 |
+
resolution=resolution
|
| 151 |
+
)
|
| 152 |
+
return per_batch_eval_results
|
evaluation/evaluate.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Any, Dict, Optional
|
| 11 |
+
|
| 12 |
+
import hydra
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
|
| 18 |
+
from dynamic_stereo.evaluation.utils.utils import aggregate_and_print_results
|
| 19 |
+
|
| 20 |
+
import dynamic_stereo.datasets.dynamic_stereo_datasets as datasets
|
| 21 |
+
|
| 22 |
+
from dynamic_stereo.models.core.model_zoo import (
|
| 23 |
+
get_all_model_default_configs,
|
| 24 |
+
model_zoo,
|
| 25 |
+
)
|
| 26 |
+
from pytorch3d.implicitron.tools.config import get_default_args_field
|
| 27 |
+
from dynamic_stereo.evaluation.core.evaluator import Evaluator
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(eq=False)
|
| 31 |
+
class DefaultConfig:
|
| 32 |
+
exp_dir: str = "./outputs"
|
| 33 |
+
|
| 34 |
+
# one of [sintel, dynamicreplica, things]
|
| 35 |
+
dataset_name: str = "dynamicreplica"
|
| 36 |
+
|
| 37 |
+
sample_len: int = -1
|
| 38 |
+
dstype: Optional[str] = None
|
| 39 |
+
# clean, final
|
| 40 |
+
MODEL: Dict[str, Any] = field(
|
| 41 |
+
default_factory=lambda: get_all_model_default_configs()
|
| 42 |
+
)
|
| 43 |
+
EVALUATOR: Dict[str, Any] = get_default_args_field(Evaluator)
|
| 44 |
+
|
| 45 |
+
seed: int = 42
|
| 46 |
+
gpu_idx: int = 0
|
| 47 |
+
|
| 48 |
+
visualize_interval: int = 0 # Use 0 for no visualization
|
| 49 |
+
|
| 50 |
+
# Override hydra's working directory to current working dir,
|
| 51 |
+
# also disable storing the .hydra logs:
|
| 52 |
+
hydra: dict = field(
|
| 53 |
+
default_factory=lambda: {
|
| 54 |
+
"run": {"dir": "."},
|
| 55 |
+
"output_subdir": None,
|
| 56 |
+
}
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def run_eval(cfg: DefaultConfig):
|
| 61 |
+
"""
|
| 62 |
+
Evaluates new view synthesis metrics of a specified model
|
| 63 |
+
on a benchmark dataset.
|
| 64 |
+
"""
|
| 65 |
+
# make the experiment directory
|
| 66 |
+
os.makedirs(cfg.exp_dir, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
# dump the exp cofig to the exp_dir
|
| 69 |
+
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
| 70 |
+
with open(cfg_file, "w") as f:
|
| 71 |
+
OmegaConf.save(config=cfg, f=f)
|
| 72 |
+
|
| 73 |
+
torch.manual_seed(cfg.seed)
|
| 74 |
+
np.random.seed(cfg.seed)
|
| 75 |
+
evaluator = Evaluator(**cfg.EVALUATOR)
|
| 76 |
+
|
| 77 |
+
model = model_zoo(**cfg.MODEL)
|
| 78 |
+
model.cuda(0)
|
| 79 |
+
evaluator.setup_visualization(cfg)
|
| 80 |
+
|
| 81 |
+
if cfg.dataset_name == "dynamicreplica":
|
| 82 |
+
test_dataloader = datasets.DynamicReplicaDataset(
|
| 83 |
+
split="valid", sample_len=cfg.sample_len, only_first_n_samples=1
|
| 84 |
+
)
|
| 85 |
+
elif cfg.dataset_name == "sintel":
|
| 86 |
+
test_dataloader = datasets.SequenceSintelStereo(dstype=cfg.dstype)
|
| 87 |
+
elif cfg.dataset_name == "things":
|
| 88 |
+
test_dataloader = datasets.SequenceSceneFlowDatasets(
|
| 89 |
+
{},
|
| 90 |
+
dstype=cfg.dstype,
|
| 91 |
+
sample_len=cfg.sample_len,
|
| 92 |
+
add_monkaa=False,
|
| 93 |
+
add_driving=False,
|
| 94 |
+
things_test=True,
|
| 95 |
+
)
|
| 96 |
+
elif cfg.dataset_name == "real":
|
| 97 |
+
for real_sequence_name in ["teddy_static", "ignacio_waving", "nikita_reading"]:
|
| 98 |
+
ds_path = f"./dynamic_replica_data/real/{real_sequence_name}"
|
| 99 |
+
# seq_len_real = 20
|
| 100 |
+
real_dataset = datasets.DynamicReplicaDataset(
|
| 101 |
+
split="test",
|
| 102 |
+
sample_len=cfg.sample_len,
|
| 103 |
+
root=ds_path,
|
| 104 |
+
only_first_n_samples=1,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
evaluator.evaluate_sequence(
|
| 108 |
+
model=model,
|
| 109 |
+
test_dataloader=real_dataset,
|
| 110 |
+
is_real_data=True,
|
| 111 |
+
train_mode=False,
|
| 112 |
+
)
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
print()
|
| 116 |
+
|
| 117 |
+
evaluate_result = evaluator.evaluate_sequence(
|
| 118 |
+
model,
|
| 119 |
+
test_dataloader,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
aggreegate_result = aggregate_and_print_results(evaluate_result)
|
| 123 |
+
|
| 124 |
+
result_file = os.path.join(cfg.exp_dir, f"result_eval.json")
|
| 125 |
+
|
| 126 |
+
print(f"Dumping eval results to {result_file}.")
|
| 127 |
+
with open(result_file, "w") as f:
|
| 128 |
+
json.dump(aggreegate_result, f)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
cs = hydra.core.config_store.ConfigStore.instance()
|
| 132 |
+
cs.store(name="default_config_eval", node=DefaultConfig)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
| 136 |
+
def evaluate(cfg: DefaultConfig) -> None:
|
| 137 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 138 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
| 139 |
+
run_eval(cfg)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
evaluate()
|
evaluation/utils/eval_utils.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Dict, Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from pytorch3d.utils import opencv_from_cameras_projection
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(eq=True, frozen=True)
|
| 15 |
+
class PerceptionMetric:
|
| 16 |
+
metric: str
|
| 17 |
+
depth_scaling_norm: Optional[str] = None
|
| 18 |
+
suffix: str = ""
|
| 19 |
+
index: str = ""
|
| 20 |
+
|
| 21 |
+
def __str__(self):
|
| 22 |
+
return (
|
| 23 |
+
self.metric
|
| 24 |
+
+ self.index
|
| 25 |
+
+ (
|
| 26 |
+
("_norm_" + self.depth_scaling_norm)
|
| 27 |
+
if self.depth_scaling_norm is not None
|
| 28 |
+
else ""
|
| 29 |
+
)
|
| 30 |
+
+ self.suffix
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def eval_endpoint_error_sequence(
|
| 35 |
+
x: torch.Tensor,
|
| 36 |
+
y: torch.Tensor,
|
| 37 |
+
mask: torch.Tensor,
|
| 38 |
+
crop: int = 0,
|
| 39 |
+
mask_thr: float = 0.5,
|
| 40 |
+
clamp_thr: float = 1e-5,
|
| 41 |
+
) -> Dict[str, torch.Tensor]:
|
| 42 |
+
|
| 43 |
+
assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
|
| 44 |
+
x.shape,
|
| 45 |
+
y.shape,
|
| 46 |
+
mask.shape,
|
| 47 |
+
)
|
| 48 |
+
assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)
|
| 49 |
+
|
| 50 |
+
# chuck out the border
|
| 51 |
+
if crop > 0:
|
| 52 |
+
if crop > min(y.shape[2:]) - crop:
|
| 53 |
+
raise ValueError("Incorrect crop size.")
|
| 54 |
+
y = y[:, :, crop:-crop, crop:-crop]
|
| 55 |
+
x = x[:, :, crop:-crop, crop:-crop]
|
| 56 |
+
mask = mask[:, :, crop:-crop, crop:-crop]
|
| 57 |
+
|
| 58 |
+
y = y * (mask > mask_thr).float()
|
| 59 |
+
x = x * (mask > mask_thr).float()
|
| 60 |
+
y[torch.isnan(y)] = 0
|
| 61 |
+
|
| 62 |
+
results = {}
|
| 63 |
+
for epe_name in ("epe", "temp_epe"):
|
| 64 |
+
if epe_name == "epe":
|
| 65 |
+
endpoint_error = (mask * (x - y) ** 2).sum(dim=1).sqrt()
|
| 66 |
+
elif epe_name == "temp_epe":
|
| 67 |
+
delta_mask = mask[:-1] * mask[1:]
|
| 68 |
+
endpoint_error = (
|
| 69 |
+
(delta_mask * ((x[:-1] - x[1:]) - (y[:-1] - y[1:])) ** 2)
|
| 70 |
+
.sum(dim=1)
|
| 71 |
+
.sqrt()
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# epe_nonzero = endpoint_error != 0
|
| 75 |
+
nonzero = torch.count_nonzero(endpoint_error)
|
| 76 |
+
|
| 77 |
+
epe_mean = endpoint_error.sum() / torch.clamp(
|
| 78 |
+
nonzero, clamp_thr
|
| 79 |
+
) # average error for all the sequence pixels
|
| 80 |
+
epe_inv_accuracy_05px = (endpoint_error > 0.5).sum() / torch.clamp(
|
| 81 |
+
nonzero, clamp_thr
|
| 82 |
+
)
|
| 83 |
+
epe_inv_accuracy_1px = (endpoint_error > 1).sum() / torch.clamp(
|
| 84 |
+
nonzero, clamp_thr
|
| 85 |
+
)
|
| 86 |
+
epe_inv_accuracy_2px = (endpoint_error > 2).sum() / torch.clamp(
|
| 87 |
+
nonzero, clamp_thr
|
| 88 |
+
)
|
| 89 |
+
epe_inv_accuracy_3px = (endpoint_error > 3).sum() / torch.clamp(
|
| 90 |
+
nonzero, clamp_thr
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
results[f"{epe_name}_mean"] = epe_mean[None]
|
| 94 |
+
results[f"{epe_name}_bad_0.5px"] = epe_inv_accuracy_05px[None] * 100
|
| 95 |
+
results[f"{epe_name}_bad_1px"] = epe_inv_accuracy_1px[None] * 100
|
| 96 |
+
results[f"{epe_name}_bad_2px"] = epe_inv_accuracy_2px[None] * 100
|
| 97 |
+
results[f"{epe_name}_bad_3px"] = epe_inv_accuracy_3px[None] * 100
|
| 98 |
+
return results
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def depth2disparity_scale(left_camera, right_camera, image_size_tensor):
|
| 102 |
+
# # opencv camera matrices
|
| 103 |
+
(_, T1, K1), (_, T2, _) = [
|
| 104 |
+
opencv_from_cameras_projection(
|
| 105 |
+
f,
|
| 106 |
+
image_size_tensor,
|
| 107 |
+
)
|
| 108 |
+
for f in (left_camera, right_camera)
|
| 109 |
+
]
|
| 110 |
+
fix_baseline = T1[0][0] - T2[0][0]
|
| 111 |
+
focal_length_px = K1[0][0][0]
|
| 112 |
+
# following this https://github.com/princeton-vl/RAFT-Stereo#converting-disparity-to-depth
|
| 113 |
+
return focal_length_px * fix_baseline
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def depth_to_pcd(
|
| 117 |
+
depth_map,
|
| 118 |
+
img,
|
| 119 |
+
focal_length,
|
| 120 |
+
cx,
|
| 121 |
+
cy,
|
| 122 |
+
step: int = None,
|
| 123 |
+
inv_extrinsic=None,
|
| 124 |
+
mask=None,
|
| 125 |
+
filter=False,
|
| 126 |
+
):
|
| 127 |
+
__, w, __ = img.shape
|
| 128 |
+
if step is None:
|
| 129 |
+
step = int(w / 100)
|
| 130 |
+
Z = depth_map[::step, ::step]
|
| 131 |
+
colors = img[::step, ::step, :]
|
| 132 |
+
|
| 133 |
+
Pixels_Y = torch.arange(Z.shape[0]).to(Z.device) * step
|
| 134 |
+
Pixels_X = torch.arange(Z.shape[1]).to(Z.device) * step
|
| 135 |
+
|
| 136 |
+
X = (Pixels_X[None, :] - cx) * Z / focal_length
|
| 137 |
+
Y = (Pixels_Y[:, None] - cy) * Z / focal_length
|
| 138 |
+
|
| 139 |
+
inds = Z > 0
|
| 140 |
+
|
| 141 |
+
if mask is not None:
|
| 142 |
+
inds = inds * (mask[::step, ::step] > 0)
|
| 143 |
+
|
| 144 |
+
X = X[inds].reshape(-1)
|
| 145 |
+
Y = Y[inds].reshape(-1)
|
| 146 |
+
Z = Z[inds].reshape(-1)
|
| 147 |
+
colors = colors[inds]
|
| 148 |
+
pcd = torch.stack([X, Y, Z]).T
|
| 149 |
+
|
| 150 |
+
if inv_extrinsic is not None:
|
| 151 |
+
pcd_ext = torch.vstack([pcd.T, torch.ones((1, pcd.shape[0])).to(Z.device)])
|
| 152 |
+
pcd = (inv_extrinsic @ pcd_ext)[:3, :].T
|
| 153 |
+
|
| 154 |
+
if filter:
|
| 155 |
+
pcd, filt_inds = filter_outliers(pcd)
|
| 156 |
+
colors = colors[filt_inds]
|
| 157 |
+
return pcd, colors
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def filter_outliers(pcd, sigma=3):
|
| 161 |
+
mean = pcd.mean(0)
|
| 162 |
+
std = pcd.std(0)
|
| 163 |
+
inds = ((pcd - mean).abs() < sigma * std)[:, 2]
|
| 164 |
+
pcd = pcd[inds]
|
| 165 |
+
return pcd, inds
|
| 166 |
+
|
| 167 |
+
# -- Modified by Chu King on 22nd November 2025 to fix the resolution during evaluation.
|
| 168 |
+
def eval_batch(batch_dict, predictions, resolution=[480, 640]) -> Dict[str, Union[float, torch.Tensor]]:
|
| 169 |
+
"""
|
| 170 |
+
Produce performance metrics for a single batch of perception
|
| 171 |
+
predictions.
|
| 172 |
+
Args:
|
| 173 |
+
frame_data: A PixarFrameData object containing the input to the new view
|
| 174 |
+
synthesis method.
|
| 175 |
+
preds: A PerceptionPrediction object with the predicted data.
|
| 176 |
+
Returns:
|
| 177 |
+
results: A dictionary holding evaluation metrics.
|
| 178 |
+
"""
|
| 179 |
+
results = {}
|
| 180 |
+
|
| 181 |
+
assert "disparity" in predictions
|
| 182 |
+
mask_now = torch.ones_like(batch_dict["fg_mask"][..., :resolution[0], :resolution[1]])
|
| 183 |
+
|
| 184 |
+
mask_now = mask_now * batch_dict["disparity_mask"][..., :resolution[0], :resolution[1]]
|
| 185 |
+
|
| 186 |
+
eval_flow_traj_output = eval_endpoint_error_sequence(
|
| 187 |
+
predictions["disparity"], batch_dict["disparity"][..., :resolution[0], :resolution[1]], mask_now
|
| 188 |
+
)
|
| 189 |
+
for epe_name in ("epe", "temp_epe"):
|
| 190 |
+
results[PerceptionMetric(f"disp_{epe_name}_mean")] = eval_flow_traj_output[
|
| 191 |
+
f"{epe_name}_mean"
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
results[PerceptionMetric(f"disp_{epe_name}_bad_3px")] = eval_flow_traj_output[
|
| 195 |
+
f"{epe_name}_bad_3px"
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
results[PerceptionMetric(f"disp_{epe_name}_bad_2px")] = eval_flow_traj_output[
|
| 199 |
+
f"{epe_name}_bad_2px"
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
results[PerceptionMetric(f"disp_{epe_name}_bad_1px")] = eval_flow_traj_output[
|
| 203 |
+
f"{epe_name}_bad_1px"
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
results[PerceptionMetric(f"disp_{epe_name}_bad_0.5px")] = eval_flow_traj_output[
|
| 207 |
+
f"{epe_name}_bad_0.5px"
|
| 208 |
+
]
|
| 209 |
+
if "endpoint_error_per_pixel" in eval_flow_traj_output:
|
| 210 |
+
results["disp_endpoint_error_per_pixel"] = eval_flow_traj_output[
|
| 211 |
+
"endpoint_error_per_pixel"
|
| 212 |
+
]
|
| 213 |
+
return (results, len(predictions["disparity"]))
|
evaluation/utils/utils.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import configparser
|
| 9 |
+
import os
|
| 10 |
+
import math
|
| 11 |
+
from typing import Optional, List
|
| 12 |
+
import torch
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from tabulate import tabulate
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from pytorch3d.structures import Pointclouds
|
| 20 |
+
from pytorch3d.transforms import RotateAxisAngle
|
| 21 |
+
from pytorch3d.utils import (
|
| 22 |
+
opencv_from_cameras_projection,
|
| 23 |
+
)
|
| 24 |
+
from pytorch3d.renderer import (
|
| 25 |
+
AlphaCompositor,
|
| 26 |
+
PointsRasterizationSettings,
|
| 27 |
+
PointsRasterizer,
|
| 28 |
+
PointsRenderer,
|
| 29 |
+
)
|
| 30 |
+
from evaluation.utils.eval_utils import depth_to_pcd
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class PerceptionPrediction:
|
| 35 |
+
"""
|
| 36 |
+
Holds the tensors that describe a result of any perception module.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
depth_map: Optional[torch.Tensor] = None
|
| 40 |
+
disparity: Optional[torch.Tensor] = None
|
| 41 |
+
image_rgb: Optional[torch.Tensor] = None
|
| 42 |
+
fg_probability: Optional[torch.Tensor] = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def aggregate_eval_results(per_batch_eval_results, reduction="mean"):
|
| 46 |
+
|
| 47 |
+
total_length = 0
|
| 48 |
+
aggregate_results = defaultdict(list)
|
| 49 |
+
for result in per_batch_eval_results:
|
| 50 |
+
if isinstance(result, tuple):
|
| 51 |
+
reduction = "sum"
|
| 52 |
+
length = result[1]
|
| 53 |
+
total_length += length
|
| 54 |
+
result = result[0]
|
| 55 |
+
for metric, val in result.items():
|
| 56 |
+
if reduction == "sum":
|
| 57 |
+
aggregate_results[metric].append(val * length)
|
| 58 |
+
|
| 59 |
+
if reduction == "mean":
|
| 60 |
+
return {k: torch.cat(v).mean().item() for k, v in aggregate_results.items()}
|
| 61 |
+
elif reduction == "sum":
|
| 62 |
+
return {
|
| 63 |
+
k: torch.cat(v).sum().item() / float(total_length)
|
| 64 |
+
for k, v in aggregate_results.items()
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def aggregate_and_print_results(
|
| 69 |
+
per_batch_eval_results: List[dict],
|
| 70 |
+
):
|
| 71 |
+
print("")
|
| 72 |
+
result = aggregate_eval_results(
|
| 73 |
+
per_batch_eval_results,
|
| 74 |
+
)
|
| 75 |
+
pretty_print_perception_metrics(result)
|
| 76 |
+
result = {str(k): v for k, v in result.items()}
|
| 77 |
+
|
| 78 |
+
print("")
|
| 79 |
+
return result
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def pretty_print_perception_metrics(results):
|
| 83 |
+
|
| 84 |
+
metrics = sorted(list(results.keys()), key=lambda x: x.metric)
|
| 85 |
+
|
| 86 |
+
print("===== Perception results =====")
|
| 87 |
+
print(
|
| 88 |
+
tabulate(
|
| 89 |
+
[[metric, results[metric]] for metric in metrics],
|
| 90 |
+
)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def read_calibration(calibration_file, resolution_string):
|
| 95 |
+
# ported from https://github.com/stereolabs/zed-open-capture/
|
| 96 |
+
# blob/dfa0aee51ccd2297782230a05ca59e697df496b2/examples/include/calibration.hpp#L4172
|
| 97 |
+
|
| 98 |
+
zed_resolutions = {
|
| 99 |
+
"2K": (1242, 2208),
|
| 100 |
+
"FHD": (1080, 1920),
|
| 101 |
+
"HD": (720, 1280),
|
| 102 |
+
# "qHD": (540, 960),
|
| 103 |
+
"VGA": (376, 672),
|
| 104 |
+
}
|
| 105 |
+
assert resolution_string in zed_resolutions.keys()
|
| 106 |
+
image_height, image_width = zed_resolutions[resolution_string]
|
| 107 |
+
|
| 108 |
+
# Open camera configuration file
|
| 109 |
+
assert os.path.isfile(calibration_file)
|
| 110 |
+
calib = configparser.ConfigParser()
|
| 111 |
+
calib.read(calibration_file)
|
| 112 |
+
|
| 113 |
+
# Get translations
|
| 114 |
+
T = np.zeros((3, 1))
|
| 115 |
+
T[0] = float(calib["STEREO"]["baseline"])
|
| 116 |
+
T[1] = float(calib["STEREO"]["ty"])
|
| 117 |
+
T[2] = float(calib["STEREO"]["tz"])
|
| 118 |
+
|
| 119 |
+
baseline = T[0]
|
| 120 |
+
|
| 121 |
+
# Get left parameters
|
| 122 |
+
left_cam_cx = float(calib[f"LEFT_CAM_{resolution_string}"]["cx"])
|
| 123 |
+
left_cam_cy = float(calib[f"LEFT_CAM_{resolution_string}"]["cy"])
|
| 124 |
+
left_cam_fx = float(calib[f"LEFT_CAM_{resolution_string}"]["fx"])
|
| 125 |
+
left_cam_fy = float(calib[f"LEFT_CAM_{resolution_string}"]["fy"])
|
| 126 |
+
left_cam_k1 = float(calib[f"LEFT_CAM_{resolution_string}"]["k1"])
|
| 127 |
+
left_cam_k2 = float(calib[f"LEFT_CAM_{resolution_string}"]["k2"])
|
| 128 |
+
left_cam_p1 = float(calib[f"LEFT_CAM_{resolution_string}"]["p1"])
|
| 129 |
+
left_cam_p2 = float(calib[f"LEFT_CAM_{resolution_string}"]["p2"])
|
| 130 |
+
left_cam_k3 = float(calib[f"LEFT_CAM_{resolution_string}"]["k3"])
|
| 131 |
+
|
| 132 |
+
# Get right parameters
|
| 133 |
+
right_cam_cx = float(calib[f"RIGHT_CAM_{resolution_string}"]["cx"])
|
| 134 |
+
right_cam_cy = float(calib[f"RIGHT_CAM_{resolution_string}"]["cy"])
|
| 135 |
+
right_cam_fx = float(calib[f"RIGHT_CAM_{resolution_string}"]["fx"])
|
| 136 |
+
right_cam_fy = float(calib[f"RIGHT_CAM_{resolution_string}"]["fy"])
|
| 137 |
+
right_cam_k1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k1"])
|
| 138 |
+
right_cam_k2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k2"])
|
| 139 |
+
right_cam_p1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p1"])
|
| 140 |
+
right_cam_p2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p2"])
|
| 141 |
+
right_cam_k3 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k3"])
|
| 142 |
+
|
| 143 |
+
# Get rotations
|
| 144 |
+
R_zed = np.zeros(3)
|
| 145 |
+
R_zed[0] = float(calib["STEREO"][f"rx_{resolution_string.lower()}"])
|
| 146 |
+
R_zed[1] = float(calib["STEREO"][f"cv_{resolution_string.lower()}"])
|
| 147 |
+
R_zed[2] = float(calib["STEREO"][f"rz_{resolution_string.lower()}"])
|
| 148 |
+
|
| 149 |
+
R = cv2.Rodrigues(R_zed)[0]
|
| 150 |
+
|
| 151 |
+
# Left
|
| 152 |
+
cameraMatrix_left = np.array(
|
| 153 |
+
[[left_cam_fx, 0, left_cam_cx], [0, left_cam_fy, left_cam_cy], [0, 0, 1]]
|
| 154 |
+
)
|
| 155 |
+
distCoeffs_left = np.array(
|
| 156 |
+
[left_cam_k1, left_cam_k2, left_cam_p1, left_cam_p2, left_cam_k3]
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Right
|
| 160 |
+
cameraMatrix_right = np.array(
|
| 161 |
+
[
|
| 162 |
+
[right_cam_fx, 0, right_cam_cx],
|
| 163 |
+
[0, right_cam_fy, right_cam_cy],
|
| 164 |
+
[0, 0, 1],
|
| 165 |
+
]
|
| 166 |
+
)
|
| 167 |
+
distCoeffs_right = np.array(
|
| 168 |
+
[right_cam_k1, right_cam_k2, right_cam_p1, right_cam_p2, right_cam_k3]
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Stereo
|
| 172 |
+
R1, R2, P1, P2, Q = cv2.stereoRectify(
|
| 173 |
+
cameraMatrix1=cameraMatrix_left,
|
| 174 |
+
distCoeffs1=distCoeffs_left,
|
| 175 |
+
cameraMatrix2=cameraMatrix_right,
|
| 176 |
+
distCoeffs2=distCoeffs_right,
|
| 177 |
+
imageSize=(image_width, image_height),
|
| 178 |
+
R=R,
|
| 179 |
+
T=T,
|
| 180 |
+
flags=cv2.CALIB_ZERO_DISPARITY,
|
| 181 |
+
newImageSize=(image_width, image_height),
|
| 182 |
+
alpha=0,
|
| 183 |
+
)[:5]
|
| 184 |
+
|
| 185 |
+
# Precompute maps for cv::remap()
|
| 186 |
+
map_left_x, map_left_y = cv2.initUndistortRectifyMap(
|
| 187 |
+
cameraMatrix_left,
|
| 188 |
+
distCoeffs_left,
|
| 189 |
+
R1,
|
| 190 |
+
P1,
|
| 191 |
+
(image_width, image_height),
|
| 192 |
+
cv2.CV_32FC1,
|
| 193 |
+
)
|
| 194 |
+
map_right_x, map_right_y = cv2.initUndistortRectifyMap(
|
| 195 |
+
cameraMatrix_right,
|
| 196 |
+
distCoeffs_right,
|
| 197 |
+
R2,
|
| 198 |
+
P2,
|
| 199 |
+
(image_width, image_height),
|
| 200 |
+
cv2.CV_32FC1,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
zed_calib = {
|
| 204 |
+
"map_left_x": map_left_x,
|
| 205 |
+
"map_left_y": map_left_y,
|
| 206 |
+
"map_right_x": map_right_x,
|
| 207 |
+
"map_right_y": map_right_y,
|
| 208 |
+
"pose_left": P1,
|
| 209 |
+
"pose_right": P2,
|
| 210 |
+
"baseline": baseline,
|
| 211 |
+
"image_width": image_width,
|
| 212 |
+
"image_height": image_height,
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
return zed_calib
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def visualize_batch(
|
| 219 |
+
batch_dict: dict,
|
| 220 |
+
preds: PerceptionPrediction,
|
| 221 |
+
output_dir: str,
|
| 222 |
+
ref_frame: int = 0,
|
| 223 |
+
only_foreground=False,
|
| 224 |
+
step=0,
|
| 225 |
+
sequence_name=None,
|
| 226 |
+
writer=None,
|
| 227 |
+
# -- Added by Chu King on 22nd November 2025 to fix image resolution during evaluation.
|
| 228 |
+
resolution=[480, 640]
|
| 229 |
+
):
|
| 230 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 231 |
+
|
| 232 |
+
outputs = {}
|
| 233 |
+
|
| 234 |
+
if preds.depth_map is not None:
|
| 235 |
+
device = preds.depth_map.device
|
| 236 |
+
|
| 237 |
+
pcd_global_seq = []
|
| 238 |
+
# -- H, W = batch_dict["stereo_video"].shape[3:]
|
| 239 |
+
H, W = resolution
|
| 240 |
+
|
| 241 |
+
for i in range(len(batch_dict["stereo_video"])):
|
| 242 |
+
R, T, K = opencv_from_cameras_projection(
|
| 243 |
+
preds.perspective_cameras[i],
|
| 244 |
+
torch.tensor([H, W])[None].to(device),
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
extrinsic_3x4_0 = torch.cat([R[0], T[0, :, None]], dim=1)
|
| 248 |
+
|
| 249 |
+
extr_matrix = torch.cat(
|
| 250 |
+
[
|
| 251 |
+
extrinsic_3x4_0,
|
| 252 |
+
torch.Tensor([[0, 0, 0, 1]]).to(extrinsic_3x4_0.device),
|
| 253 |
+
],
|
| 254 |
+
dim=0,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
inv_extr_matrix = extr_matrix.inverse().to(device)
|
| 258 |
+
pcd, colors = depth_to_pcd(
|
| 259 |
+
preds.depth_map[i, 0],
|
| 260 |
+
batch_dict["stereo_video"][..., :resolution[0], : resolution[1]][i][0].permute(1, 2, 0),
|
| 261 |
+
K[0][0][0],
|
| 262 |
+
K[0][0][2],
|
| 263 |
+
K[0][1][2],
|
| 264 |
+
step=1,
|
| 265 |
+
inv_extrinsic=inv_extr_matrix,
|
| 266 |
+
mask=batch_dict["fg_mask"][..., :resolution[0], : resolution[1]][i, 0] if only_foreground else None,
|
| 267 |
+
filter=False,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
R, T = inv_extr_matrix[None, :3, :3], inv_extr_matrix[None, :3, 3]
|
| 271 |
+
pcd_global_seq.append((pcd, colors, (R, T, preds.perspective_cameras[i])))
|
| 272 |
+
|
| 273 |
+
raster_settings = PointsRasterizationSettings(
|
| 274 |
+
image_size=[H, W], radius=0.003, points_per_pixel=10
|
| 275 |
+
)
|
| 276 |
+
R, T, cam_ = pcd_global_seq[ref_frame][2]
|
| 277 |
+
|
| 278 |
+
median_depth = preds.depth_map.median()
|
| 279 |
+
cam_.cuda()
|
| 280 |
+
|
| 281 |
+
for mode in ["angle_15", "angle_-15", "changing_angle"]:
|
| 282 |
+
res = []
|
| 283 |
+
|
| 284 |
+
for t, (pcd, color, __) in enumerate(pcd_global_seq):
|
| 285 |
+
|
| 286 |
+
if mode == "changing_angle":
|
| 287 |
+
angle = math.cos((math.pi) * (t / 15)) * 15
|
| 288 |
+
elif mode == "angle_15":
|
| 289 |
+
angle = 15
|
| 290 |
+
elif mode == "angle_-15":
|
| 291 |
+
angle = -15
|
| 292 |
+
|
| 293 |
+
delta_x = median_depth * math.sin(math.radians(angle))
|
| 294 |
+
delta_z = median_depth * (1 - math.cos(math.radians(angle)))
|
| 295 |
+
|
| 296 |
+
cam = cam_.clone()
|
| 297 |
+
cam.R = torch.bmm(
|
| 298 |
+
cam.R,
|
| 299 |
+
RotateAxisAngle(angle=angle, axis="Y", device=device).get_matrix()[
|
| 300 |
+
:, :3, :3
|
| 301 |
+
],
|
| 302 |
+
)
|
| 303 |
+
cam.T[0, 0] = cam.T[0, 0] - delta_x
|
| 304 |
+
cam.T[0, 2] = cam.T[0, 2] - delta_z + median_depth / 2.0
|
| 305 |
+
|
| 306 |
+
rasterizer = PointsRasterizer(
|
| 307 |
+
cameras=cam, raster_settings=raster_settings
|
| 308 |
+
)
|
| 309 |
+
renderer = PointsRenderer(
|
| 310 |
+
rasterizer=rasterizer,
|
| 311 |
+
compositor=AlphaCompositor(background_color=(1, 1, 1)),
|
| 312 |
+
)
|
| 313 |
+
pcd_copy = pcd.clone()
|
| 314 |
+
|
| 315 |
+
point_cloud = Pointclouds(points=[pcd_copy], features=[color / 255.0])
|
| 316 |
+
images = renderer(point_cloud)
|
| 317 |
+
res.append(images[0, ..., :3].cpu())
|
| 318 |
+
res = torch.stack(res)
|
| 319 |
+
|
| 320 |
+
video = (res * 255).numpy().astype(np.uint8)
|
| 321 |
+
save_name = f"{sequence_name}_reconstruction_{step}_mode_{mode}_"
|
| 322 |
+
if writer is None:
|
| 323 |
+
outputs[mode] = video
|
| 324 |
+
if only_foreground:
|
| 325 |
+
save_name += "fg_only"
|
| 326 |
+
else:
|
| 327 |
+
save_name += "full_scene"
|
| 328 |
+
video_out = cv2.VideoWriter(
|
| 329 |
+
os.path.join(
|
| 330 |
+
output_dir,
|
| 331 |
+
f"{save_name}.mp4",
|
| 332 |
+
),
|
| 333 |
+
cv2.VideoWriter_fourcc(*"mp4v"),
|
| 334 |
+
fps=10,
|
| 335 |
+
frameSize=(res.shape[2], res.shape[1]),
|
| 336 |
+
isColor=True,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
for i in range(len(video)):
|
| 340 |
+
video_out.write(cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB))
|
| 341 |
+
video_out.release()
|
| 342 |
+
|
| 343 |
+
if writer is not None:
|
| 344 |
+
writer.add_video(
|
| 345 |
+
f"{sequence_name}_reconstruction_mode_{mode}",
|
| 346 |
+
(res * 255).permute(0, 3, 1, 2).to(torch.uint8)[None],
|
| 347 |
+
global_step=step,
|
| 348 |
+
fps=8,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
return outputs
|
links_lite.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"real": [
|
| 3 |
+
"https://dl.fbaipublicfiles.com/dynamic_replica_v2/real/real_000.zip"
|
| 4 |
+
],
|
| 5 |
+
"test": [
|
| 6 |
+
"https://dl.fbaipublicfiles.com/dynamic_replica_v2/test/test_000.zip"
|
| 7 |
+
],
|
| 8 |
+
"valid": [
|
| 9 |
+
"https://dl.fbaipublicfiles.com/dynamic_replica_v2/valid/valid_000.zip",
|
| 10 |
+
"https://dl.fbaipublicfiles.com/dynamic_replica_v2/valid/valid_001.zip"
|
| 11 |
+
],
|
| 12 |
+
"train": [
|
| 13 |
+
"https://dl.fbaipublicfiles.com/dynamic_replica_v2/train/train_000.zip"
|
| 14 |
+
]
|
| 15 |
+
}
|
models/core/attention.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import copy
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import Module, Dropout
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
|
| 15 |
+
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def elu_feature_map(x):
|
| 20 |
+
return torch.nn.functional.elu(x) + 1
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PositionEncodingSine(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
| 32 |
+
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
|
| 33 |
+
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
|
| 34 |
+
on the final performance. For now, we keep both impls for backward compatability.
|
| 35 |
+
We will remove the buggy impl after re-training all variants of our released models.
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
# -- d_model: embedding dimension
|
| 40 |
+
pe = torch.zeros((d_model, *max_shape))
|
| 41 |
+
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
| 42 |
+
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
| 43 |
+
if temp_bug_fix:
|
| 44 |
+
div_term = torch.exp(
|
| 45 |
+
torch.arange(0, d_model // 2, 2).float()
|
| 46 |
+
* (-math.log(10000.0) / (d_model // 2))
|
| 47 |
+
)
|
| 48 |
+
else: # a buggy implementation (for backward compatability only)
|
| 49 |
+
div_term = torch.exp(
|
| 50 |
+
torch.arange(0, d_model // 2, 2).float()
|
| 51 |
+
* (-math.log(10000.0) / d_model // 2)
|
| 52 |
+
)
|
| 53 |
+
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
| 54 |
+
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
| 55 |
+
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
| 56 |
+
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
| 57 |
+
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
| 58 |
+
|
| 59 |
+
self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
"""
|
| 63 |
+
Args:
|
| 64 |
+
x: [N, C, H, W]
|
| 65 |
+
"""
|
| 66 |
+
return x + self.pe[:, :, : x.size(2), : x.size(3)].to(x.device)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class LinearAttention(Module):
|
| 70 |
+
def __init__(self, eps=1e-6):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.feature_map = elu_feature_map
|
| 73 |
+
self.eps = eps
|
| 74 |
+
|
| 75 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
| 76 |
+
"""Multi-Head linear attention proposed in "Transformers are RNNs"
|
| 77 |
+
Args:
|
| 78 |
+
queries: [N, L, H, D]
|
| 79 |
+
keys: [N, S, H, D]
|
| 80 |
+
values: [N, S, H, D]
|
| 81 |
+
q_mask: [N, L]
|
| 82 |
+
kv_mask: [N, S]
|
| 83 |
+
Returns:
|
| 84 |
+
queried_values: (N, L, H, D)
|
| 85 |
+
"""
|
| 86 |
+
Q = self.feature_map(queries)
|
| 87 |
+
K = self.feature_map(keys)
|
| 88 |
+
|
| 89 |
+
# set padded position to zero
|
| 90 |
+
if q_mask is not None:
|
| 91 |
+
Q = Q * q_mask[:, :, None, None]
|
| 92 |
+
if kv_mask is not None:
|
| 93 |
+
K = K * kv_mask[:, :, None, None]
|
| 94 |
+
values = values * kv_mask[:, :, None, None]
|
| 95 |
+
|
| 96 |
+
v_length = values.size(1)
|
| 97 |
+
values = values / v_length # prevent fp16 overflow
|
| 98 |
+
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
|
| 99 |
+
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
| 100 |
+
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
| 101 |
+
|
| 102 |
+
return queried_values.contiguous()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class FullAttention(Module):
|
| 106 |
+
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.use_dropout = use_dropout
|
| 109 |
+
self.dropout = Dropout(attention_dropout)
|
| 110 |
+
|
| 111 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
| 112 |
+
"""Multi-head scaled dot-product attention, a.k.a full attention.
|
| 113 |
+
Args:
|
| 114 |
+
queries: [N, L, H, D]
|
| 115 |
+
keys: [N, S, H, D]
|
| 116 |
+
values: [N, S, H, D]
|
| 117 |
+
q_mask: [N, L]
|
| 118 |
+
kv_mask: [N, S]
|
| 119 |
+
Returns:
|
| 120 |
+
queried_values: (N, L, H, D)
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
# Compute the unnormalized attention and apply the masks
|
| 124 |
+
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
| 125 |
+
if kv_mask is not None:
|
| 126 |
+
QK.masked_fill_(
|
| 127 |
+
~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf")
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Compute the attention and the weighted average
|
| 131 |
+
softmax_temp = 1.0 / queries.size(3) ** 0.5 # sqrt(D)
|
| 132 |
+
A = torch.softmax(softmax_temp * QK, dim=2)
|
| 133 |
+
if self.use_dropout:
|
| 134 |
+
A = self.dropout(A)
|
| 135 |
+
|
| 136 |
+
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
|
| 137 |
+
|
| 138 |
+
return queried_values.contiguous()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
|
| 142 |
+
class LoFTREncoderLayer(nn.Module):
|
| 143 |
+
def __init__(self, d_model, nhead, attention="linear"):
|
| 144 |
+
super(LoFTREncoderLayer, self).__init__()
|
| 145 |
+
|
| 146 |
+
self.dim = d_model // nhead
|
| 147 |
+
self.nhead = nhead
|
| 148 |
+
|
| 149 |
+
# multi-head attention
|
| 150 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 151 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 152 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 153 |
+
|
| 154 |
+
# -- LoFTR optionally uses linear attention (faster, avoids quadratic cost), otherwise normal softmax attention.
|
| 155 |
+
self.attention = LinearAttention() if attention == "linear" else FullAttention()
|
| 156 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
| 157 |
+
|
| 158 |
+
# feed-forward network
|
| 159 |
+
self.mlp = nn.Sequential(
|
| 160 |
+
nn.Linear(d_model * 2, d_model * 2, bias=False),
|
| 161 |
+
nn.ReLU(),
|
| 162 |
+
nn.Linear(d_model * 2, d_model, bias=False),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# norm and dropout
|
| 166 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 167 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 168 |
+
|
| 169 |
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
| 170 |
+
"""
|
| 171 |
+
Args:
|
| 172 |
+
x (torch.Tensor): [N, L, C]
|
| 173 |
+
source (torch.Tensor): [N, S, C]
|
| 174 |
+
x_mask (torch.Tensor): [N, L] (optional)
|
| 175 |
+
source_mask (torch.Tensor): [N, S] (optional)
|
| 176 |
+
"""
|
| 177 |
+
bs = x.size(0)
|
| 178 |
+
query, key, value = x, source, source
|
| 179 |
+
|
| 180 |
+
# multi-head attention
|
| 181 |
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
| 182 |
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
| 183 |
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
| 184 |
+
message = self.attention(
|
| 185 |
+
query, key, value, q_mask=x_mask, kv_mask=source_mask
|
| 186 |
+
) # [N, L, (H, D)]
|
| 187 |
+
message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
|
| 188 |
+
message = self.norm1(message)
|
| 189 |
+
|
| 190 |
+
# feed-forward network
|
| 191 |
+
message = self.mlp(torch.cat([x, message], dim=2))
|
| 192 |
+
message = self.norm2(message)
|
| 193 |
+
|
| 194 |
+
return x + message
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class LocalFeatureTransformer(nn.Module):
|
| 198 |
+
"""A Local Feature Transformer (LoFTR) module."""
|
| 199 |
+
|
| 200 |
+
def __init__(self, d_model, nhead, layer_names, attention):
|
| 201 |
+
super(LocalFeatureTransformer, self).__init__()
|
| 202 |
+
|
| 203 |
+
self.d_model = d_model
|
| 204 |
+
self.nhead = nhead
|
| 205 |
+
self.layer_names = layer_names
|
| 206 |
+
encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
|
| 207 |
+
self.layers = nn.ModuleList(
|
| 208 |
+
[copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
|
| 209 |
+
)
|
| 210 |
+
self._reset_parameters()
|
| 211 |
+
|
| 212 |
+
def _reset_parameters(self):
|
| 213 |
+
for p in self.parameters():
|
| 214 |
+
if p.dim() > 1:
|
| 215 |
+
nn.init.xavier_uniform_(p)
|
| 216 |
+
|
| 217 |
+
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
| 218 |
+
"""
|
| 219 |
+
Args:
|
| 220 |
+
feat0 (torch.Tensor): [N, L, C]
|
| 221 |
+
feat1 (torch.Tensor): [N, S, C]
|
| 222 |
+
mask0 (torch.Tensor): [N, L] (optional)
|
| 223 |
+
mask1 (torch.Tensor): [N, S] (optional)
|
| 224 |
+
"""
|
| 225 |
+
assert self.d_model == feat0.size(
|
| 226 |
+
2
|
| 227 |
+
), "the feature number of src and transformer must be equal"
|
| 228 |
+
|
| 229 |
+
for layer, name in zip(self.layers, self.layer_names):
|
| 230 |
+
|
| 231 |
+
if name == "self":
|
| 232 |
+
feat0 = layer(feat0, feat0, mask0, mask0)
|
| 233 |
+
feat1 = layer(feat1, feat1, mask1, mask1)
|
| 234 |
+
elif name == "cross":
|
| 235 |
+
feat0 = layer(feat0, feat1, mask0, mask1)
|
| 236 |
+
feat1 = layer(feat1, feat0, mask1, mask0)
|
| 237 |
+
else:
|
| 238 |
+
raise KeyError
|
| 239 |
+
|
| 240 |
+
return feat0, feat1
|
models/core/corr.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def bilinear_sampler(img, coords, mode="bilinear", mask=False, stereo=True):
|
| 12 |
+
"""Wrapper for grid_sample, uses pixel coordinates"""
|
| 13 |
+
H, W = img.shape[-2:]
|
| 14 |
+
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
| 15 |
+
xgrid = 2 * xgrid / (W - 1) - 1
|
| 16 |
+
if not stereo:
|
| 17 |
+
ygrid = 2 * ygrid / (H - 1) - 1
|
| 18 |
+
else:
|
| 19 |
+
assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
|
| 20 |
+
img = img.contiguous()
|
| 21 |
+
grid = torch.cat([xgrid, ygrid], dim=-1).contiguous()
|
| 22 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
| 23 |
+
|
| 24 |
+
if mask:
|
| 25 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 26 |
+
return img, mask.float()
|
| 27 |
+
|
| 28 |
+
return img
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def coords_grid(batch, ht, wd, device):
|
| 32 |
+
coords = torch.meshgrid(
|
| 33 |
+
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij"
|
| 34 |
+
)
|
| 35 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
| 36 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CorrBlock1D:
|
| 40 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
| 41 |
+
self.num_levels = num_levels
|
| 42 |
+
self.radius = radius
|
| 43 |
+
self.corr_pyramid = []
|
| 44 |
+
self.coords = coords_grid(
|
| 45 |
+
fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device
|
| 46 |
+
)
|
| 47 |
+
# all pairs correlation
|
| 48 |
+
corr = CorrBlock1D.corr(fmap1, fmap2)
|
| 49 |
+
|
| 50 |
+
batch, h1, w1, dim, w2 = corr.shape
|
| 51 |
+
corr = corr.reshape(batch * h1 * w1, dim, 1, w2)
|
| 52 |
+
|
| 53 |
+
self.corr_pyramid.append(corr)
|
| 54 |
+
for i in range(self.num_levels):
|
| 55 |
+
corr = F.avg_pool2d(corr, [1, 2], stride=[1, 2])
|
| 56 |
+
self.corr_pyramid.append(corr)
|
| 57 |
+
|
| 58 |
+
def __call__(self, flow):
|
| 59 |
+
r = self.radius
|
| 60 |
+
coords = self.coords + flow
|
| 61 |
+
coords = coords[:, :1].permute(0, 2, 3, 1)
|
| 62 |
+
batch, h1, w1, _ = coords.shape
|
| 63 |
+
|
| 64 |
+
out_pyramid = []
|
| 65 |
+
for i in range(self.num_levels):
|
| 66 |
+
corr = self.corr_pyramid[i]
|
| 67 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 68 |
+
dx = dx.view(1, 1, 2 * r + 1, 1).to(coords.device)
|
| 69 |
+
x0 = dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2 ** i
|
| 70 |
+
y0 = torch.zeros_like(x0)
|
| 71 |
+
|
| 72 |
+
coords_lvl = torch.cat([x0, y0], dim=-1)
|
| 73 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
| 74 |
+
corr = corr.view(batch, h1, w1, -1)
|
| 75 |
+
out_pyramid.append(corr)
|
| 76 |
+
|
| 77 |
+
out = torch.cat(out_pyramid, dim=-1)
|
| 78 |
+
return out.permute(0, 3, 1, 2).contiguous().float()
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def corr(fmap1, fmap2):
|
| 82 |
+
B, D, H, W1 = fmap1.shape
|
| 83 |
+
_, _, _, W2 = fmap2.shape
|
| 84 |
+
fmap1 = fmap1.view(B, D, H, W1)
|
| 85 |
+
fmap2 = fmap2.view(B, D, H, W2)
|
| 86 |
+
corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2)
|
| 87 |
+
corr = corr.reshape(B, H, W1, 1, W2).contiguous()
|
| 88 |
+
return corr / torch.sqrt(torch.tensor(D).float())
|
models/core/dynamic_stereo.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# -- Added by Chu King on 16th November 2025 for debugging purposes.
|
| 8 |
+
import os, signal
|
| 9 |
+
import logging
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
from typing import Dict, List
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from models.core.update import (
|
| 21 |
+
BasicUpdateBlock,
|
| 22 |
+
SequenceUpdateBlock3D,
|
| 23 |
+
TimeAttnBlock,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# -- Added by Chu King on 21st November 2025
|
| 27 |
+
from models.core.sci_codec import sci_decoder
|
| 28 |
+
from models.core.extractor import BasicEncoder
|
| 29 |
+
from models.core.corr import CorrBlock1D
|
| 30 |
+
|
| 31 |
+
from models.core.attention import (
|
| 32 |
+
PositionEncodingSine,
|
| 33 |
+
LocalFeatureTransformer,
|
| 34 |
+
)
|
| 35 |
+
from models.core.utils.utils import InputPadder, interp
|
| 36 |
+
|
| 37 |
+
autocast = torch.cuda.amp.autocast
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DynamicStereo(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
max_disp: int = 192,
|
| 44 |
+
mixed_precision: bool = False,
|
| 45 |
+
num_frames: int = 5,
|
| 46 |
+
attention_type: str = None,
|
| 47 |
+
use_3d_update_block: bool = False,
|
| 48 |
+
different_update_blocks: bool = False,
|
| 49 |
+
):
|
| 50 |
+
super(DynamicStereo, self).__init__()
|
| 51 |
+
|
| 52 |
+
self.max_flow = max_disp
|
| 53 |
+
self.mixed_precision = mixed_precision
|
| 54 |
+
|
| 55 |
+
self.hidden_dim = 128
|
| 56 |
+
self.context_dim = 128
|
| 57 |
+
dim = 256
|
| 58 |
+
self.dim = dim
|
| 59 |
+
self.dropout = 0 # -- dropout probability
|
| 60 |
+
|
| 61 |
+
# -- decide whether to use 3D update blocks (like RAFT3D) or simpler 2D blocks.
|
| 62 |
+
self.use_3d_update_block = use_3d_update_block
|
| 63 |
+
|
| 64 |
+
# -- Modified by Chu King on 21st November 2025
|
| 65 |
+
# -- CNN encoder that extracts features from images.
|
| 66 |
+
# * output_dim: output channels
|
| 67 |
+
# * norm_fn="instance": applies instance normalization.
|
| 68 |
+
# -- self.fnet = BasicEncoder(
|
| 69 |
+
# -- output_dim=dim, norm_fn="instance", dropout=self.dropout
|
| 70 |
+
# -- )
|
| 71 |
+
self.fnet = sci_decoder(
|
| 72 |
+
n_frame=num_frames,
|
| 73 |
+
n_taps=2,
|
| 74 |
+
output_dim=dim,
|
| 75 |
+
norm_fn="instance",
|
| 76 |
+
dropout=self.dropout
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# -- Boolean flag to decide whether different update blocks are used for different resolutions.
|
| 80 |
+
self.different_update_blocks = different_update_blocks
|
| 81 |
+
|
| 82 |
+
# -- Cost volumne planes (matching costs for disparity computation).
|
| 83 |
+
cor_planes = 4 * 9
|
| 84 |
+
self.depth = 4
|
| 85 |
+
self.attention_type = attention_type
|
| 86 |
+
# attention_type is a combination of the following attention types:
|
| 87 |
+
# self_stereo, temporal, update_time, update_space
|
| 88 |
+
# for example, self_stereo_temporal_update_time_update_space
|
| 89 |
+
|
| 90 |
+
if self.use_3d_update_block:
|
| 91 |
+
# -- Uses 3D convolutions for spatiotemporal processing.
|
| 92 |
+
if self.different_update_blocks:
|
| 93 |
+
# -- self.update_block08, self.update_block16, self.update_block04
|
| 94 |
+
# are update blocks for different resolution levels (i.e. 1/8, 1/16, 1/4)
|
| 95 |
+
self.update_block08 = SequenceUpdateBlock3D(
|
| 96 |
+
hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
|
| 97 |
+
)
|
| 98 |
+
self.update_block16 = SequenceUpdateBlock3D(
|
| 99 |
+
hidden_dim=self.hidden_dim,
|
| 100 |
+
cor_planes=cor_planes,
|
| 101 |
+
mask_size=4,
|
| 102 |
+
attention_type=attention_type,
|
| 103 |
+
)
|
| 104 |
+
self.update_block04 = SequenceUpdateBlock3D(
|
| 105 |
+
hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
self.update_block = SequenceUpdateBlock3D(
|
| 109 |
+
hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
# -- Uses standard 2D update blocks.
|
| 113 |
+
if self.different_update_blocks:
|
| 114 |
+
self.update_block08 = BasicUpdateBlock(
|
| 115 |
+
hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
|
| 116 |
+
)
|
| 117 |
+
self.update_block16 = BasicUpdateBlock(
|
| 118 |
+
hidden_dim=self.hidden_dim,
|
| 119 |
+
cor_planes=cor_planes,
|
| 120 |
+
mask_size=4,
|
| 121 |
+
attention_type=attention_type,
|
| 122 |
+
)
|
| 123 |
+
self.update_block04 = BasicUpdateBlock(
|
| 124 |
+
hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
self.update_block = BasicUpdateBlock(
|
| 128 |
+
hidden_dim=self.hidden_dim, cor_planes=cor_planes, mask_size=4
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if attention_type is not None:
|
| 132 |
+
# -- The model incorporates several attention types.
|
| 133 |
+
if ("update_time" in attention_type) or ("temporal" in attention_type):
|
| 134 |
+
# -- This variable learns positional embeddings for different time steps in the sequence.
|
| 135 |
+
self.time_embed = nn.Parameter(torch.zeros(1, num_frames, dim))
|
| 136 |
+
|
| 137 |
+
# -- Temporal attention: processes information across different time frames.
|
| 138 |
+
if "temporal" in attention_type:
|
| 139 |
+
self.time_attn_blocks = nn.ModuleList(
|
| 140 |
+
[TimeAttnBlock(dim=dim, num_heads=8) for _ in range(self.depth)]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# -- Stereo attention: includes self-attention and cross attention blocks for processing
|
| 144 |
+
# left-right stereo image pairs.
|
| 145 |
+
if "self_stereo" in attention_type:
|
| 146 |
+
self.self_attn_blocks = nn.ModuleList(
|
| 147 |
+
[
|
| 148 |
+
LocalFeatureTransformer(
|
| 149 |
+
d_model=dim,
|
| 150 |
+
nhead=8,
|
| 151 |
+
layer_names=["self"] * 1,
|
| 152 |
+
attention="linear",
|
| 153 |
+
)
|
| 154 |
+
for _ in range(self.depth)
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.cross_attn_blocks = nn.ModuleList(
|
| 159 |
+
[
|
| 160 |
+
LocalFeatureTransformer(
|
| 161 |
+
d_model=dim,
|
| 162 |
+
nhead=8,
|
| 163 |
+
layer_names=["cross"] * 1,
|
| 164 |
+
attention="linear",
|
| 165 |
+
)
|
| 166 |
+
for _ in range(self.depth)
|
| 167 |
+
]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.num_frames = num_frames
|
| 171 |
+
|
| 172 |
+
@torch.jit.ignore
|
| 173 |
+
def no_weight_decay(self):
|
| 174 |
+
return {"time_embed"}
|
| 175 |
+
|
| 176 |
+
def freeze_bn(self):
|
| 177 |
+
for m in self.modules():
|
| 178 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 179 |
+
m.eval()
|
| 180 |
+
|
| 181 |
+
def convex_upsample(self, flow: torch.Tensor, mask: torch.Tensor, rate: int = 4):
|
| 182 |
+
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
|
| 183 |
+
N, _, H, W = flow.shape
|
| 184 |
+
mask = mask.view(N, 1, 9, rate, rate, H, W)
|
| 185 |
+
mask = torch.softmax(mask, dim=2)
|
| 186 |
+
|
| 187 |
+
up_flow = F.unfold(rate * flow, [3, 3], padding=1)
|
| 188 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
| 189 |
+
|
| 190 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
| 191 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
| 192 |
+
return up_flow.reshape(N, 2, rate * H, rate * W)
|
| 193 |
+
|
| 194 |
+
def zero_init(self, fmap: torch.Tensor):
|
| 195 |
+
N, _, H, W = fmap.shape
|
| 196 |
+
_x = torch.zeros([N, 1, H, W], dtype=torch.float32)
|
| 197 |
+
_y = torch.zeros([N, 1, H, W], dtype=torch.float32)
|
| 198 |
+
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
|
| 199 |
+
return zero_flow
|
| 200 |
+
|
| 201 |
+
def forward_batch_test(
|
| 202 |
+
self, batch_dict: Dict, sci_enc_L, sci_enc_R, kernel_size: int = 14, iters: int = 20
|
| 203 |
+
):
|
| 204 |
+
stride = kernel_size // 2
|
| 205 |
+
predictions = defaultdict(list)
|
| 206 |
+
|
| 207 |
+
disp_preds = []
|
| 208 |
+
video = batch_dict["stereo_video"]
|
| 209 |
+
num_ims = len(video)
|
| 210 |
+
print("video", video.shape)
|
| 211 |
+
|
| 212 |
+
# -- Divide a single long sequence to multiple long sequences.
|
| 213 |
+
# -- For SCI stereo, we only test the first sequence.
|
| 214 |
+
# -- for i in range(0, num_ims, stride):
|
| 215 |
+
for i in range(1):
|
| 216 |
+
left_ims = video[i : min(i + kernel_size, num_ims), 0]
|
| 217 |
+
# -- padder = InputPadder(left_ims.shape, divis_by=32)
|
| 218 |
+
|
| 219 |
+
right_ims = video[i : min(i + kernel_size, num_ims), 1]
|
| 220 |
+
# -- left_ims, right_ims = padder.pad(left_ims, right_ims)
|
| 221 |
+
|
| 222 |
+
# -- Modified by Chu King on 20th November 2025
|
| 223 |
+
# 0) Convert to Gray
|
| 224 |
+
def rgb_to_gray(x):
|
| 225 |
+
weights = torch.tensor([0.2989, 0.5870, 0.1140], dtype=x.dtype, device=x.device)
|
| 226 |
+
gray = (x * weights[None, None, :, None, None]).sum(dim=2)
|
| 227 |
+
return gray # -- shape: [B, T, H, W]
|
| 228 |
+
|
| 229 |
+
video_L = rgb_to_gray(left_ims.to(next(sci_enc_L.parameters()).device)) # ~ (b, t, h, w)
|
| 230 |
+
video_R = rgb_to_gray(right_ims.to(next(sci_enc_R.parameters()).device)) # ~ (b, t, h, w)
|
| 231 |
+
|
| 232 |
+
# 1) Extract and normalize input videos.
|
| 233 |
+
# -- min_max_norm = lambda x : 2. * (x / 255.) - 1.
|
| 234 |
+
min_max_norm = lambda x: x / 255.
|
| 235 |
+
video_L = min_max_norm(video_L)
|
| 236 |
+
video_R = min_max_norm(video_R)
|
| 237 |
+
|
| 238 |
+
# 2) If the tensor is non-contiguous and we try .view() later, PyTorch will raise an error:
|
| 239 |
+
video_L = video_L.contiguous()
|
| 240 |
+
video_R = video_R.contiguous()
|
| 241 |
+
|
| 242 |
+
# 3) Coded exposure modeling.
|
| 243 |
+
snapshot_L = sci_enc_L(video_L)
|
| 244 |
+
snapshot_R = sci_enc_L(video_R)
|
| 245 |
+
|
| 246 |
+
with autocast(enabled=self.mixed_precision):
|
| 247 |
+
disparities_forw = self.forward(
|
| 248 |
+
# -- Modified by Chu King on 20th November 2025
|
| 249 |
+
# -- left_ims[None].cuda(),
|
| 250 |
+
# -- right_ims[None].cuda(),
|
| 251 |
+
snapshot_L,
|
| 252 |
+
snapshot_R,
|
| 253 |
+
iters=iters,
|
| 254 |
+
test_mode=True,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# -- Padding disabled by Chu King on 20th November 2025
|
| 258 |
+
# -- disparities_forw = padder.unpad(disparities_forw[:, 0])[:, None].cpu()
|
| 259 |
+
disparities_forw = disparities_forw[:, 0][:, None].cpu()
|
| 260 |
+
|
| 261 |
+
# -- We are not doing overlapping chunks in SCI stereo.
|
| 262 |
+
disp_preds.append(disparities_forw)
|
| 263 |
+
# -- if len(disp_preds) > 0 and len(disparities_forw) >= stride:
|
| 264 |
+
# -- if len(disparities_forw) < kernel_size:
|
| 265 |
+
# -- disp_preds.append(disparities_forw[stride // 2 :])
|
| 266 |
+
# -- else:
|
| 267 |
+
# -- disp_preds.append(disparities_forw[stride // 2 : -stride // 2])
|
| 268 |
+
# -- elif len(disp_preds) == 0:
|
| 269 |
+
# -- disp_preds.append(disparities_forw[: -stride // 2])
|
| 270 |
+
|
| 271 |
+
predictions["disparity"] = (torch.cat(disp_preds).squeeze(1).abs())[:, :1]
|
| 272 |
+
|
| 273 |
+
return predictions
|
| 274 |
+
|
| 275 |
+
def forward_sst_block(
|
| 276 |
+
self, fmap1_dw16: torch.Tensor, fmap2_dw16: torch.Tensor, T: int
|
| 277 |
+
):
|
| 278 |
+
# -- fmap1_dw16 ~ (B*T, C, H, W) -- left-view features
|
| 279 |
+
# -- fmap2_dw16 ~ (B*T, C, H, W) -- right-view features
|
| 280 |
+
*_, h, w = fmap1_dw16.shape
|
| 281 |
+
|
| 282 |
+
# positional encoding and self-attention
|
| 283 |
+
pos_encoding_fn_small = PositionEncodingSine(d_model=self.dim, max_shape=(h, w))
|
| 284 |
+
fmap1_dw16 = pos_encoding_fn_small(fmap1_dw16)
|
| 285 |
+
fmap2_dw16 = pos_encoding_fn_small(fmap2_dw16)
|
| 286 |
+
|
| 287 |
+
if self.attention_type is not None:
|
| 288 |
+
# add time embeddings
|
| 289 |
+
if (
|
| 290 |
+
"temporal" in self.attention_type
|
| 291 |
+
or "update_time" in self.attention_type
|
| 292 |
+
):
|
| 293 |
+
fmap1_dw16 = rearrange(
|
| 294 |
+
fmap1_dw16, "(b t) m h w -> (b h w) t m", t=T, h=h, w=w
|
| 295 |
+
)
|
| 296 |
+
fmap2_dw16 = rearrange(
|
| 297 |
+
fmap2_dw16, "(b t) m h w -> (b h w) t m", t=T, h=h, w=w
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# interpolate if video length doesn't match
|
| 301 |
+
if T != self.num_frames:
|
| 302 |
+
time_embed = self.time_embed.transpose(1, 2)
|
| 303 |
+
new_time_embed = F.interpolate(time_embed, size=(T), mode="nearest")
|
| 304 |
+
new_time_embed = new_time_embed.transpose(1, 2).contiguous()
|
| 305 |
+
else:
|
| 306 |
+
new_time_embed = self.time_embed
|
| 307 |
+
|
| 308 |
+
fmap1_dw16 = fmap1_dw16 + new_time_embed
|
| 309 |
+
fmap2_dw16 = fmap2_dw16 + new_time_embed
|
| 310 |
+
|
| 311 |
+
fmap1_dw16 = rearrange(
|
| 312 |
+
fmap1_dw16, "(b h w) t m -> (b t) m h w", t=T, h=h, w=w
|
| 313 |
+
)
|
| 314 |
+
fmap2_dw16 = rearrange(
|
| 315 |
+
fmap2_dw16, "(b h w) t m -> (b t) m h w", t=T, h=h, w=w
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if ("self_stereo" in self.attention_type) or (
|
| 319 |
+
"temporal" in self.attention_type
|
| 320 |
+
):
|
| 321 |
+
for att_ind in range(self.depth):
|
| 322 |
+
if "self_stereo" in self.attention_type:
|
| 323 |
+
fmap1_dw16 = rearrange(
|
| 324 |
+
fmap1_dw16, "(b t) m h w -> (b t) (h w) m", t=T, h=h, w=w
|
| 325 |
+
)
|
| 326 |
+
fmap2_dw16 = rearrange(
|
| 327 |
+
fmap2_dw16, "(b t) m h w -> (b t) (h w) m", t=T, h=h, w=w
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
fmap1_dw16, fmap2_dw16 = self.self_attn_blocks[att_ind](
|
| 331 |
+
fmap1_dw16, fmap2_dw16
|
| 332 |
+
)
|
| 333 |
+
fmap1_dw16, fmap2_dw16 = self.cross_attn_blocks[att_ind](
|
| 334 |
+
fmap1_dw16, fmap2_dw16
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
fmap1_dw16 = rearrange(
|
| 338 |
+
fmap1_dw16, "(b t) (h w) m -> (b t) m h w ", t=T, h=h, w=w
|
| 339 |
+
)
|
| 340 |
+
fmap2_dw16 = rearrange(
|
| 341 |
+
fmap2_dw16, "(b t) (h w) m -> (b t) m h w ", t=T, h=h, w=w
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if "temporal" in self.attention_type:
|
| 345 |
+
fmap1_dw16 = self.time_attn_blocks[att_ind](fmap1_dw16, T=T)
|
| 346 |
+
fmap2_dw16 = self.time_attn_blocks[att_ind](fmap2_dw16, T=T)
|
| 347 |
+
return fmap1_dw16, fmap2_dw16
|
| 348 |
+
|
| 349 |
+
def forward_update_block(
|
| 350 |
+
self,
|
| 351 |
+
update_block: nn.Module,
|
| 352 |
+
corr_fn: CorrBlock1D,
|
| 353 |
+
flow: torch.Tensor,
|
| 354 |
+
net: torch.Tensor,
|
| 355 |
+
inp: torch.Tensor,
|
| 356 |
+
predictions: List,
|
| 357 |
+
iters: int,
|
| 358 |
+
interp_scale: float,
|
| 359 |
+
t: int,
|
| 360 |
+
):
|
| 361 |
+
for _ in range(iters):
|
| 362 |
+
flow = flow.detach()
|
| 363 |
+
out_corrs = corr_fn(flow)
|
| 364 |
+
with autocast(enabled=self.mixed_precision):
|
| 365 |
+
net, up_mask, delta_flow = update_block(net, inp, out_corrs, flow, t=t)
|
| 366 |
+
|
| 367 |
+
flow = flow + delta_flow
|
| 368 |
+
flow_up = flow_out = self.convex_upsample(flow, up_mask, rate=4)
|
| 369 |
+
if interp_scale > 1:
|
| 370 |
+
flow_up = interp_scale * interp(
|
| 371 |
+
flow_out,
|
| 372 |
+
(
|
| 373 |
+
interp_scale * flow_out.shape[2],
|
| 374 |
+
interp_scale * flow_out.shape[3],
|
| 375 |
+
),
|
| 376 |
+
)
|
| 377 |
+
flow_up = flow_up[:, :1]
|
| 378 |
+
predictions.append(flow_up)
|
| 379 |
+
return flow_out, net
|
| 380 |
+
|
| 381 |
+
def forward(self, image1, image2, flow_init=None, iters=10, test_mode=False):
|
| 382 |
+
"""Estimate optical flow between pair of frames"""
|
| 383 |
+
b, *_ = image1.shape
|
| 384 |
+
|
| 385 |
+
hdim = self.hidden_dim
|
| 386 |
+
|
| 387 |
+
with autocast(enabled=self.mixed_precision):
|
| 388 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
| 389 |
+
|
| 390 |
+
net, inp = torch.split(fmap1, [hdim, hdim], dim=1)
|
| 391 |
+
net = torch.tanh(net)
|
| 392 |
+
inp = F.relu(inp)
|
| 393 |
+
*_, h, w = fmap1.shape
|
| 394 |
+
# 1/4 -> 1/16
|
| 395 |
+
# feature
|
| 396 |
+
fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
|
| 397 |
+
fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
|
| 398 |
+
|
| 399 |
+
fmap1_dw16, fmap2_dw16 = self.forward_sst_block(fmap1_dw16, fmap2_dw16, T=self.num_frames)
|
| 400 |
+
|
| 401 |
+
net_dw16, inp_dw16 = torch.split(fmap1_dw16, [hdim, hdim], dim=1)
|
| 402 |
+
net_dw16 = torch.tanh(net_dw16)
|
| 403 |
+
inp_dw16 = F.relu(inp_dw16)
|
| 404 |
+
|
| 405 |
+
fmap1_dw8 = (
|
| 406 |
+
F.avg_pool2d(fmap1, 2, stride=2) + interp(fmap1_dw16, (h // 2, w // 2))
|
| 407 |
+
) / 2.0
|
| 408 |
+
fmap2_dw8 = (
|
| 409 |
+
F.avg_pool2d(fmap2, 2, stride=2) + interp(fmap2_dw16, (h // 2, w // 2))
|
| 410 |
+
) / 2.0
|
| 411 |
+
|
| 412 |
+
net_dw8, inp_dw8 = torch.split(fmap1_dw8, [hdim, hdim], dim=1)
|
| 413 |
+
net_dw8 = torch.tanh(net_dw8)
|
| 414 |
+
inp_dw8 = F.relu(inp_dw8)
|
| 415 |
+
# Cascaded refinement (1/16 + 1/8 + 1/4)
|
| 416 |
+
predictions = []
|
| 417 |
+
flow = None
|
| 418 |
+
flow_up = None
|
| 419 |
+
if flow_init is not None:
|
| 420 |
+
scale = h / flow_init.shape[2]
|
| 421 |
+
flow = -scale * interp(flow_init, (h, w))
|
| 422 |
+
else:
|
| 423 |
+
# zero initialization
|
| 424 |
+
flow_dw16 = self.zero_init(fmap1_dw16) # -- (N, 2, H, W)
|
| 425 |
+
|
| 426 |
+
# Recurrent Update Module
|
| 427 |
+
# Update 1/16
|
| 428 |
+
update_block = (
|
| 429 |
+
self.update_block16
|
| 430 |
+
if self.different_update_blocks
|
| 431 |
+
else self.update_block
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
corr_fn_att_dw16 = CorrBlock1D(fmap1_dw16, fmap2_dw16)
|
| 435 |
+
flow, net_dw16 = self.forward_update_block(
|
| 436 |
+
update_block=update_block,
|
| 437 |
+
corr_fn=corr_fn_att_dw16,
|
| 438 |
+
flow=flow_dw16,
|
| 439 |
+
net=net_dw16,
|
| 440 |
+
inp=inp_dw16,
|
| 441 |
+
predictions=predictions,
|
| 442 |
+
iters=iters // 2,
|
| 443 |
+
interp_scale=4,
|
| 444 |
+
t=self.num_frames,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
scale = fmap1_dw8.shape[2] / flow.shape[2]
|
| 448 |
+
flow_dw8 = -scale * interp(flow, (fmap1_dw8.shape[2], fmap1_dw8.shape[3]))
|
| 449 |
+
|
| 450 |
+
net_dw8 = (
|
| 451 |
+
net_dw8
|
| 452 |
+
+ interp(net_dw16, (2 * net_dw16.shape[2], 2 * net_dw16.shape[3]))
|
| 453 |
+
) / 2.0
|
| 454 |
+
# Update 1/8
|
| 455 |
+
|
| 456 |
+
update_block = (
|
| 457 |
+
self.update_block08
|
| 458 |
+
if self.different_update_blocks
|
| 459 |
+
else self.update_block
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
corr_fn_dw8 = CorrBlock1D(fmap1_dw8, fmap2_dw8)
|
| 463 |
+
flow, net_dw8 = self.forward_update_block(
|
| 464 |
+
update_block=update_block,
|
| 465 |
+
corr_fn=corr_fn_dw8,
|
| 466 |
+
flow=flow_dw8,
|
| 467 |
+
net=net_dw8,
|
| 468 |
+
inp=inp_dw8,
|
| 469 |
+
predictions=predictions,
|
| 470 |
+
iters=iters // 2,
|
| 471 |
+
interp_scale=2,
|
| 472 |
+
t=self.num_frames,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
scale = h / flow.shape[2]
|
| 476 |
+
flow = -scale * interp(flow, (h, w))
|
| 477 |
+
|
| 478 |
+
net = (
|
| 479 |
+
net + interp(net_dw8, (2 * net_dw8.shape[2], 2 * net_dw8.shape[3]))
|
| 480 |
+
) / 2.0
|
| 481 |
+
# Update 1/4
|
| 482 |
+
update_block = (
|
| 483 |
+
self.update_block04 if self.different_update_blocks else self.update_block
|
| 484 |
+
)
|
| 485 |
+
corr_fn = CorrBlock1D(fmap1, fmap2)
|
| 486 |
+
flow, __ = self.forward_update_block(
|
| 487 |
+
update_block=update_block,
|
| 488 |
+
corr_fn=corr_fn,
|
| 489 |
+
flow=flow,
|
| 490 |
+
net=net,
|
| 491 |
+
inp=inp,
|
| 492 |
+
predictions=predictions,
|
| 493 |
+
iters=iters,
|
| 494 |
+
interp_scale=1,
|
| 495 |
+
t=self.num_frames,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
predictions = torch.stack(predictions)
|
| 499 |
+
|
| 500 |
+
predictions = rearrange(predictions, "d (b t) c h w -> d t b c h w", b=b, t=self.num_frames)
|
| 501 |
+
flow_up = predictions[-1]
|
| 502 |
+
|
| 503 |
+
if test_mode:
|
| 504 |
+
return flow_up
|
| 505 |
+
|
| 506 |
+
return predictions
|
models/core/extractor.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
# -- Added by Chu King on 16th November 2025 for debugging purposes.
|
| 11 |
+
import os, signal
|
| 12 |
+
import logging
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
|
| 15 |
+
class ResidualBlock(nn.Module):
|
| 16 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
| 17 |
+
super(ResidualBlock, self).__init__()
|
| 18 |
+
|
| 19 |
+
self.conv1 = nn.Conv2d(
|
| 20 |
+
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
| 21 |
+
)
|
| 22 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
| 23 |
+
self.relu = nn.ReLU(inplace=True)
|
| 24 |
+
|
| 25 |
+
num_groups = planes // 8
|
| 26 |
+
|
| 27 |
+
if norm_fn == "group":
|
| 28 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 29 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 30 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 31 |
+
|
| 32 |
+
elif norm_fn == "batch":
|
| 33 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 34 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 35 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 36 |
+
|
| 37 |
+
elif norm_fn == "instance":
|
| 38 |
+
self.norm1 = nn.InstanceNorm2d(planes, affine=False)
|
| 39 |
+
self.norm2 = nn.InstanceNorm2d(planes, affine=False)
|
| 40 |
+
self.norm3 = nn.InstanceNorm2d(planes, affine=False)
|
| 41 |
+
|
| 42 |
+
elif norm_fn == "none":
|
| 43 |
+
self.norm1 = nn.Sequential()
|
| 44 |
+
self.norm2 = nn.Sequential()
|
| 45 |
+
self.norm3 = nn.Sequential()
|
| 46 |
+
|
| 47 |
+
self.downsample = nn.Sequential(
|
| 48 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
y = x
|
| 53 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 54 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 55 |
+
|
| 56 |
+
# -- ensures that x is transformed to the correct shape so it can be added to y.
|
| 57 |
+
x = self.downsample(x)
|
| 58 |
+
|
| 59 |
+
return self.relu(x + y)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class BasicEncoder(nn.Module):
|
| 63 |
+
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
| 64 |
+
super(BasicEncoder, self).__init__()
|
| 65 |
+
self.norm_fn = norm_fn
|
| 66 |
+
|
| 67 |
+
if self.norm_fn == "group":
|
| 68 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 69 |
+
|
| 70 |
+
elif self.norm_fn == "batch":
|
| 71 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 72 |
+
|
| 73 |
+
elif self.norm_fn == "instance":
|
| 74 |
+
self.norm1 = nn.InstanceNorm2d(64, affine=False)
|
| 75 |
+
|
| 76 |
+
elif self.norm_fn == "none":
|
| 77 |
+
self.norm1 = nn.Sequential()
|
| 78 |
+
|
| 79 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 80 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 81 |
+
|
| 82 |
+
self.in_planes = 64
|
| 83 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 84 |
+
self.layer2 = self._make_layer(96, stride=2)
|
| 85 |
+
self.layer3 = self._make_layer(128, stride=1)
|
| 86 |
+
|
| 87 |
+
# output convolution
|
| 88 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
| 89 |
+
|
| 90 |
+
self.dropout = None
|
| 91 |
+
if dropout > 0:
|
| 92 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 93 |
+
|
| 94 |
+
# -- self.modules() is a PyTorch utility function that returns all submodules of this nn.Module recursively.
|
| 95 |
+
# -- This means it will looop through every layer: conv1, layer1, layer2, layer3, conv2 and so on.
|
| 96 |
+
for m in self.modules():
|
| 97 |
+
if isinstance(m, nn.Conv2d):
|
| 98 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 99 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 100 |
+
if m.weight is not None:
|
| 101 |
+
nn.init.constant_(m.weight, 1)
|
| 102 |
+
if m.bias is not None:
|
| 103 |
+
nn.init.constant_(m.bias, 0)
|
| 104 |
+
|
| 105 |
+
def _make_layer(self, dim, stride=1):
|
| 106 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 107 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 108 |
+
layers = (layer1, layer2)
|
| 109 |
+
|
| 110 |
+
self.in_planes = dim
|
| 111 |
+
return nn.Sequential(*layers)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
# -- x = [L, R]
|
| 115 |
+
# -- L, R ~ (b*t, c, h, w)
|
| 116 |
+
|
| 117 |
+
# if input is list, combine batch dimension
|
| 118 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 119 |
+
if is_list:
|
| 120 |
+
batch_dim = x[0].shape[0]
|
| 121 |
+
x = torch.cat(x, dim=0)
|
| 122 |
+
|
| 123 |
+
x = self.conv1(x)
|
| 124 |
+
x = self.norm1(x)
|
| 125 |
+
x = self.relu1(x)
|
| 126 |
+
|
| 127 |
+
x = self.layer1(x)
|
| 128 |
+
x = self.layer2(x)
|
| 129 |
+
x = self.layer3(x)
|
| 130 |
+
|
| 131 |
+
x = self.conv2(x)
|
| 132 |
+
|
| 133 |
+
if self.dropout is not None:
|
| 134 |
+
x = self.dropout(x)
|
| 135 |
+
|
| 136 |
+
if is_list:
|
| 137 |
+
x = torch.split(x, x.shape[0] // 2, dim=0)
|
| 138 |
+
|
| 139 |
+
return x
|
models/core/model_zoo.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
from dynamic_stereo.models.dynamic_stereo_model import DynamicStereoModel
|
| 9 |
+
|
| 10 |
+
from pytorch3d.implicitron.tools.config import get_default_args
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from dynamic_stereo.models.raft_stereo_model import RAFTStereoModel
|
| 14 |
+
|
| 15 |
+
MODELS = [RAFTStereoModel, DynamicStereoModel]
|
| 16 |
+
except:
|
| 17 |
+
MODELS = [DynamicStereoModel]
|
| 18 |
+
|
| 19 |
+
_MODEL_NAME_TO_MODEL = {model_cls.__name__: model_cls for model_cls in MODELS}
|
| 20 |
+
_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG = {}
|
| 21 |
+
for model_cls in MODELS:
|
| 22 |
+
_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG[
|
| 23 |
+
model_cls.MODEL_CONFIG_NAME
|
| 24 |
+
] = get_default_args(model_cls)
|
| 25 |
+
MODEL_NAME_NONE = "NONE"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def model_zoo(model_name: str, **kwargs):
|
| 29 |
+
if model_name.upper() == MODEL_NAME_NONE:
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
model_cls = _MODEL_NAME_TO_MODEL.get(model_name)
|
| 33 |
+
|
| 34 |
+
if model_cls is None:
|
| 35 |
+
raise ValueError(f"No such model name: {model_name}")
|
| 36 |
+
|
| 37 |
+
model_cls_params = {}
|
| 38 |
+
if "model_zoo" in getattr(model_cls, "__dataclass_fields__", []):
|
| 39 |
+
model_cls_params["model_zoo"] = model_zoo
|
| 40 |
+
print(
|
| 41 |
+
f"{model_cls.MODEL_CONFIG_NAME} model configs:",
|
| 42 |
+
kwargs.get(model_cls.MODEL_CONFIG_NAME),
|
| 43 |
+
)
|
| 44 |
+
return model_cls(**model_cls_params, **kwargs.get(model_cls.MODEL_CONFIG_NAME, {}))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_all_model_default_configs():
|
| 48 |
+
return copy.deepcopy(_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG)
|
models/core/sci_codec.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from models.core.extractor import ResidualBlock
|
| 7 |
+
|
| 8 |
+
autocast = torch.cuda.amp.autocast
|
| 9 |
+
|
| 10 |
+
class ste_fn(torch.autograd.Function):
|
| 11 |
+
@staticmethod
|
| 12 |
+
def forward(ctx, x):
|
| 13 |
+
return (x > 0).float()
|
| 14 |
+
@staticmethod
|
| 15 |
+
def backward(ctx, grad):
|
| 16 |
+
return F.hardtanh(grad)
|
| 17 |
+
|
| 18 |
+
class STE(nn.Module):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super(STE, self).__init__()
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return ste_fn.apply(x)
|
| 23 |
+
|
| 24 |
+
class sci_encoder(nn.Module):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
sigma_range=[0, 1e-9],
|
| 28 |
+
n_frame=8,
|
| 29 |
+
in_channels=1,
|
| 30 |
+
n_taps=2,
|
| 31 |
+
resolution=[480, 640]):
|
| 32 |
+
|
| 33 |
+
super(sci_encoder, self).__init__()
|
| 34 |
+
|
| 35 |
+
assert n_taps in [1, 2], "[ERROR] n_taps should be either 1 or 2."
|
| 36 |
+
|
| 37 |
+
self.sigma_range = sigma_range
|
| 38 |
+
self.n_frame = n_frame
|
| 39 |
+
self.in_channels = in_channels
|
| 40 |
+
self.n_taps = n_taps
|
| 41 |
+
self.resolution = resolution
|
| 42 |
+
|
| 43 |
+
# -- Shutter code; Learnable parameters
|
| 44 |
+
self.ce_weight = nn.Parameter(torch.Tensor(n_frame, in_channels, *resolution))
|
| 45 |
+
|
| 46 |
+
# -- initialize
|
| 47 |
+
nn.init.uniform_(self.ce_weight, a=-1, b=1)
|
| 48 |
+
|
| 49 |
+
self.ste = STE()
|
| 50 |
+
|
| 51 |
+
def forward(self, frames):
|
| 52 |
+
|
| 53 |
+
# -- print ("[INFO] self.ce_weight.device: ", self.ce_weight.device)
|
| 54 |
+
ce_code = self.ste(self.ce_weight)
|
| 55 |
+
# -- print ("[INFO] ce_code.device: ", ce_code.device)
|
| 56 |
+
|
| 57 |
+
frames = frames[..., :self.resolution[0], :self.resolution[1]]
|
| 58 |
+
frames = frames.contiguous()
|
| 59 |
+
frames = torch.unsqueeze(frames, 2)
|
| 60 |
+
|
| 61 |
+
# -- print ("[INFO] ce_code.shape: ", ce_code.shape)
|
| 62 |
+
# -- print ("[INFO] frames.shape: ", frames.shape)
|
| 63 |
+
|
| 64 |
+
# -- repeat by the batch size
|
| 65 |
+
ce_code = ce_code.repeat(frames.shape[0], 1, 1, 1, 1)
|
| 66 |
+
# -- print ("[INFO] ce_code.shape: ", ce_code.shape)
|
| 67 |
+
# -- print ("[INFO] ce_code.squeeze(2).shape: ", ce_code.squeeze(2).shape)
|
| 68 |
+
|
| 69 |
+
ce_blur_img = torch.zeros(frames.shape[0], self.in_channels * self.n_taps, *self.resolution).to(frames.device) # -- (b, c, h, w)
|
| 70 |
+
|
| 71 |
+
# -- print ("[INFO] ce_blur_img.shape: ", ce_blur_img.shape)
|
| 72 |
+
ce_blur_img[:, 0, ...] = torch.sum( ce_code * frames, axis=1) / self.n_frame
|
| 73 |
+
ce_blur_img[:, 1, ...] = torch.sum((1. - ce_code) * frames, axis=1) / self.n_frame
|
| 74 |
+
|
| 75 |
+
# -- add noise
|
| 76 |
+
noise_level = np.random.uniform(*self.sigma_range)
|
| 77 |
+
ce_blur_img_noisy = ce_blur_img + torch.tensor(noise_level).to(frames.device) * torch.randn(ce_blur_img.shape).to(frames.device)
|
| 78 |
+
|
| 79 |
+
# -- concat snapshots and mask patterns
|
| 80 |
+
out = torch.zeros(frames.shape[0], self.n_taps + self.n_frame, *self.resolution).to(frames.device)
|
| 81 |
+
|
| 82 |
+
# -- print ("[INFO] out.shape: ", out.shape)
|
| 83 |
+
out[:, :self.n_taps, :, :] = ce_blur_img_noisy
|
| 84 |
+
out[:, self.n_taps:, :, :] = ce_code.squeeze(2)
|
| 85 |
+
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
class sci_decoder(nn.Module):
|
| 89 |
+
def __init__(self,
|
| 90 |
+
n_frame=8,
|
| 91 |
+
n_taps=2,
|
| 92 |
+
output_dim=128,
|
| 93 |
+
norm_fn="batch",
|
| 94 |
+
dropout=.0):
|
| 95 |
+
|
| 96 |
+
super(sci_decoder, self).__init__()
|
| 97 |
+
|
| 98 |
+
self.norm_fn = norm_fn
|
| 99 |
+
if norm_fn == "group":
|
| 100 |
+
self.norm1 = nn.GroupNorm(num_groups=4, num_channels=4*n_frame)
|
| 101 |
+
elif norm_fn == "batch":
|
| 102 |
+
self.norm1 = nn.BatchNorm2d(4*n_frame)
|
| 103 |
+
elif norm_fn == "instance":
|
| 104 |
+
self.norm1 = nn.InstanceNorm2d(4*n_frame, affine=True)
|
| 105 |
+
elif norm_fn == "none":
|
| 106 |
+
self.norm1 = nn.Sequential()
|
| 107 |
+
|
| 108 |
+
# -- Input Convoultion
|
| 109 |
+
# -- Assuming n_frame=8; n_ich=10; n_och=32
|
| 110 |
+
self.conv1 = nn.Conv2d(n_taps+n_frame, 4*n_frame, kernel_size=7, stride=2, padding=3)
|
| 111 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 112 |
+
|
| 113 |
+
# -- Residual Blocks
|
| 114 |
+
self.layer1 = self._make_layer( 4*n_frame, 4*n_frame, stride=1)
|
| 115 |
+
self.layer2 = self._make_layer( 4*n_frame, 16*n_frame, stride=2)
|
| 116 |
+
self.layer3 = self._make_layer(16*n_frame, 64*n_frame, stride=1)
|
| 117 |
+
|
| 118 |
+
# -- Output Convolution
|
| 119 |
+
self.conv2 = nn.Conv2d(64*n_frame, output_dim*n_frame, kernel_size=1)
|
| 120 |
+
|
| 121 |
+
if dropout > 0.:
|
| 122 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 123 |
+
else:
|
| 124 |
+
self.dropout = None
|
| 125 |
+
|
| 126 |
+
# -- self.modules() is a PyTorch utility function that returns all submodules of this nn.Module recursively.
|
| 127 |
+
# -- This means it will looop through every layer: conv1, layer1, layer2, layer3, conv2 and so on.
|
| 128 |
+
for m in self.modules():
|
| 129 |
+
if isinstance(m, nn.Conv2d):
|
| 130 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 131 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 132 |
+
if m.weight is not None:
|
| 133 |
+
nn.init.constant_(m.weight, 1)
|
| 134 |
+
if m.bias is not None:
|
| 135 |
+
nn.init.constant_(m.bias, 0)
|
| 136 |
+
|
| 137 |
+
# -- Private function to make residual blocks
|
| 138 |
+
def _make_layer(self, n_ich, n_och, stride=1):
|
| 139 |
+
layer1 = ResidualBlock(n_ich, n_och, self.norm_fn, stride=stride)
|
| 140 |
+
layer2 = ResidualBlock(n_och, n_och, self.norm_fn, stride=1)
|
| 141 |
+
layers = (layer1, layer2)
|
| 142 |
+
|
| 143 |
+
return nn.Sequential(*layers)
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
# -- x = [L, R]
|
| 147 |
+
# -- L, R ~ (b, c, h, w); c=n_taps+n_frame
|
| 148 |
+
|
| 149 |
+
# -- if input is list, combine batch dimension
|
| 150 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 151 |
+
if is_list:
|
| 152 |
+
batch_dim = x[0].shape[0]
|
| 153 |
+
x = torch.cat(x, dim=0)
|
| 154 |
+
|
| 155 |
+
# -- print ("[INFO] x.shape: ", x.shape)
|
| 156 |
+
|
| 157 |
+
x = self.conv1(x)
|
| 158 |
+
x = self.norm1(x)
|
| 159 |
+
x = self.relu1(x)
|
| 160 |
+
|
| 161 |
+
x = self.layer1(x)
|
| 162 |
+
x = self.layer2(x)
|
| 163 |
+
x = self.layer3(x)
|
| 164 |
+
|
| 165 |
+
x = self.conv2(x)
|
| 166 |
+
|
| 167 |
+
# -- expand the temporal dimension
|
| 168 |
+
# -- (b, c, h, w) -> (b*t, c//t, h, w)
|
| 169 |
+
x = x.contiguous()
|
| 170 |
+
x = x.view(x.shape[0]*8, x.shape[1]//8, x.shape[-2], x.shape[-1])
|
| 171 |
+
|
| 172 |
+
if self.dropout is not None:
|
| 173 |
+
x = self.dropout(x)
|
| 174 |
+
|
| 175 |
+
# -- if input is list, split the first dimension
|
| 176 |
+
if is_list:
|
| 177 |
+
x = torch.split(x, x.shape[0] // 2, dim=0)
|
| 178 |
+
|
| 179 |
+
return x
|
| 180 |
+
|
models/core/update.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from models.core.attention import LoFTREncoderLayer
|
| 13 |
+
|
| 14 |
+
# -- Added by Chu King on 16th November 2025 for debugging purposes.
|
| 15 |
+
import os, signal
|
| 16 |
+
import logging
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
|
| 19 |
+
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py
|
| 20 |
+
class FlowHead(nn.Module):
|
| 21 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
| 22 |
+
super(FlowHead, self).__init__()
|
| 23 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
| 24 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
| 25 |
+
self.relu = nn.ReLU(inplace=True)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return self.conv2(self.relu(self.conv1(x)))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SepConvGRU(nn.Module):
|
| 32 |
+
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
| 33 |
+
super(SepConvGRU, self).__init__()
|
| 34 |
+
self.convz1 = nn.Conv2d(
|
| 35 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
| 36 |
+
)
|
| 37 |
+
self.convr1 = nn.Conv2d(
|
| 38 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
| 39 |
+
)
|
| 40 |
+
self.convq1 = nn.Conv2d(
|
| 41 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.convz2 = nn.Conv2d(
|
| 45 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
| 46 |
+
)
|
| 47 |
+
self.convr2 = nn.Conv2d(
|
| 48 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
| 49 |
+
)
|
| 50 |
+
self.convq2 = nn.Conv2d(
|
| 51 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, h, x):
|
| 55 |
+
# horizontal
|
| 56 |
+
hx = torch.cat([h, x], dim=1)
|
| 57 |
+
z = torch.sigmoid(self.convz1(hx))
|
| 58 |
+
r = torch.sigmoid(self.convr1(hx))
|
| 59 |
+
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
| 60 |
+
h = (1 - z) * h + z * q
|
| 61 |
+
|
| 62 |
+
# vertical
|
| 63 |
+
hx = torch.cat([h, x], dim=1)
|
| 64 |
+
z = torch.sigmoid(self.convz2(hx))
|
| 65 |
+
r = torch.sigmoid(self.convr2(hx))
|
| 66 |
+
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
| 67 |
+
h = (1 - z) * h + z * q
|
| 68 |
+
|
| 69 |
+
return h
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ConvGRU(nn.Module):
|
| 73 |
+
def __init__(self, hidden_dim, input_dim, kernel_size=3):
|
| 74 |
+
super(ConvGRU, self).__init__()
|
| 75 |
+
self.convz = nn.Conv2d(
|
| 76 |
+
hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2
|
| 77 |
+
)
|
| 78 |
+
self.convr = nn.Conv2d(
|
| 79 |
+
hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2
|
| 80 |
+
)
|
| 81 |
+
self.convq = nn.Conv2d(
|
| 82 |
+
hidden_dim + input_dim, hidden_dim, kernel_size, padding=kernel_size // 2
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def forward(self, h, x):
|
| 86 |
+
hx = torch.cat([h, x], dim=1)
|
| 87 |
+
|
| 88 |
+
z = torch.sigmoid(self.convz(hx))
|
| 89 |
+
r = torch.sigmoid(self.convr(hx))
|
| 90 |
+
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
| 91 |
+
|
| 92 |
+
h = (1 - z) * h + z * q
|
| 93 |
+
return h
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SepConvGRU3D(nn.Module):
|
| 97 |
+
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
| 98 |
+
super(SepConvGRU3D, self).__init__()
|
| 99 |
+
self.convz1 = nn.Conv3d(
|
| 100 |
+
hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
|
| 101 |
+
)
|
| 102 |
+
self.convr1 = nn.Conv3d(
|
| 103 |
+
hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
|
| 104 |
+
)
|
| 105 |
+
self.convq1 = nn.Conv3d(
|
| 106 |
+
hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.convz2 = nn.Conv3d(
|
| 110 |
+
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
|
| 111 |
+
)
|
| 112 |
+
self.convr2 = nn.Conv3d(
|
| 113 |
+
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
|
| 114 |
+
)
|
| 115 |
+
self.convq2 = nn.Conv3d(
|
| 116 |
+
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.convz3 = nn.Conv3d(
|
| 120 |
+
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
|
| 121 |
+
)
|
| 122 |
+
self.convr3 = nn.Conv3d(
|
| 123 |
+
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
|
| 124 |
+
)
|
| 125 |
+
self.convq3 = nn.Conv3d(
|
| 126 |
+
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(self, h, x):
|
| 130 |
+
hx = torch.cat([h, x], dim=1)
|
| 131 |
+
z = torch.sigmoid(self.convz1(hx))
|
| 132 |
+
r = torch.sigmoid(self.convr1(hx))
|
| 133 |
+
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
| 134 |
+
h = (1 - z) * h + z * q
|
| 135 |
+
|
| 136 |
+
# vertical
|
| 137 |
+
hx = torch.cat([h, x], dim=1)
|
| 138 |
+
z = torch.sigmoid(self.convz2(hx))
|
| 139 |
+
r = torch.sigmoid(self.convr2(hx))
|
| 140 |
+
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
| 141 |
+
h = (1 - z) * h + z * q
|
| 142 |
+
|
| 143 |
+
# time
|
| 144 |
+
hx = torch.cat([h, x], dim=1)
|
| 145 |
+
z = torch.sigmoid(self.convz3(hx))
|
| 146 |
+
r = torch.sigmoid(self.convr3(hx))
|
| 147 |
+
q = torch.tanh(self.convq3(torch.cat([r * h, x], dim=1)))
|
| 148 |
+
h = (1 - z) * h + z * q
|
| 149 |
+
|
| 150 |
+
return h
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class BasicMotionEncoder(nn.Module):
|
| 154 |
+
def __init__(self, cor_planes):
|
| 155 |
+
super(BasicMotionEncoder, self).__init__()
|
| 156 |
+
|
| 157 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
| 158 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
| 159 |
+
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
| 160 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
| 161 |
+
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
| 162 |
+
|
| 163 |
+
def forward(self, flow, corr):
|
| 164 |
+
cor = F.relu(self.convc1(corr))
|
| 165 |
+
cor = F.relu(self.convc2(cor))
|
| 166 |
+
flo = F.relu(self.convf1(flow))
|
| 167 |
+
flo = F.relu(self.convf2(flo))
|
| 168 |
+
|
| 169 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 170 |
+
out = F.relu(self.conv(cor_flo))
|
| 171 |
+
return torch.cat([out, flow], dim=1)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class Attention(nn.Module):
|
| 175 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.num_heads = num_heads
|
| 178 |
+
head_dim = dim // num_heads
|
| 179 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 180 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 181 |
+
self.proj = nn.Linear(dim, dim)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
B, N, C = x.shape
|
| 185 |
+
# -- Bug fixed by Chu King on 22nd November 2025
|
| 186 |
+
qkv = self.qkv(x)
|
| 187 |
+
# -- qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 188 |
+
qkv = qkv.view(B, N, 3, self.num_heads, C // self.num_heads)
|
| 189 |
+
qkv = qkv.permute(0, 3, 1, 2, 4) # -- (B, H, N, 3, -1)
|
| 190 |
+
# -- q, k, v = qkv, qkv, qkv
|
| 191 |
+
q, k, v = qkv.unbind(dim=3)
|
| 192 |
+
|
| 193 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 194 |
+
|
| 195 |
+
attn = attn.softmax(dim=-1)
|
| 196 |
+
|
| 197 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C).contiguous()
|
| 198 |
+
x = self.proj(x)
|
| 199 |
+
return x
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class Mlp(nn.Module):
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
in_features,
|
| 206 |
+
hidden_features=None,
|
| 207 |
+
out_features=None,
|
| 208 |
+
act_layer=nn.GELU,
|
| 209 |
+
drop=0.0,
|
| 210 |
+
):
|
| 211 |
+
super().__init__()
|
| 212 |
+
out_features = out_features or in_features
|
| 213 |
+
hidden_features = hidden_features or in_features
|
| 214 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 215 |
+
self.act = act_layer()
|
| 216 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 217 |
+
self.drop = nn.Dropout(drop)
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
x = self.fc1(x)
|
| 221 |
+
x = self.act(x)
|
| 222 |
+
x = self.drop(x)
|
| 223 |
+
x = self.fc2(x)
|
| 224 |
+
x = self.drop(x)
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class TimeAttnBlock(nn.Module):
|
| 229 |
+
def __init__(self, dim=256, num_heads=8):
|
| 230 |
+
super(TimeAttnBlock, self).__init__()
|
| 231 |
+
self.temporal_attn = Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None)
|
| 232 |
+
self.temporal_fc = nn.Linear(dim, dim)
|
| 233 |
+
self.temporal_norm1 = nn.LayerNorm(dim)
|
| 234 |
+
|
| 235 |
+
nn.init.constant_(self.temporal_fc.weight, 0)
|
| 236 |
+
nn.init.constant_(self.temporal_fc.bias, 0)
|
| 237 |
+
|
| 238 |
+
def forward(self, x, T=1):
|
| 239 |
+
_, _, h, w = x.shape
|
| 240 |
+
|
| 241 |
+
x = rearrange(x, "(b t) m h w -> (b h w) t m", h=h, w=w, t=T)
|
| 242 |
+
res_temporal1 = self.temporal_attn(self.temporal_norm1(x))
|
| 243 |
+
res_temporal1 = rearrange(
|
| 244 |
+
res_temporal1, "(b h w) t m -> b (h w t) m", h=h, w=w, t=T
|
| 245 |
+
)
|
| 246 |
+
res_temporal1 = self.temporal_fc(res_temporal1)
|
| 247 |
+
res_temporal1 = rearrange(
|
| 248 |
+
res_temporal1, " b (h w t) m -> b t m h w", h=h, w=w, t=T
|
| 249 |
+
)
|
| 250 |
+
x = rearrange(x, "(b h w) t m -> b t m h w", h=h, w=w, t=T)
|
| 251 |
+
x = x + res_temporal1
|
| 252 |
+
x = rearrange(x, "b t m h w -> (b t) m h w", h=h, w=w, t=T)
|
| 253 |
+
return x
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class SpaceAttnBlock(nn.Module):
|
| 257 |
+
def __init__(self, dim=256, num_heads=8):
|
| 258 |
+
super(SpaceAttnBlock, self).__init__()
|
| 259 |
+
self.encoder_layer = LoFTREncoderLayer(dim, nhead=num_heads, attention="linear")
|
| 260 |
+
|
| 261 |
+
def forward(self, x, T=1):
|
| 262 |
+
_, _, h, w = x.shape
|
| 263 |
+
x = rearrange(x, "(b t) m h w -> (b t) (h w) m", h=h, w=w, t=T)
|
| 264 |
+
x = self.encoder_layer(x, x)
|
| 265 |
+
x = rearrange(x, "(b t) (h w) m -> (b t) m h w", h=h, w=w, t=T)
|
| 266 |
+
return x
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class BasicUpdateBlock(nn.Module):
|
| 270 |
+
def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type=None):
|
| 271 |
+
super(BasicUpdateBlock, self).__init__()
|
| 272 |
+
self.attention_type = attention_type
|
| 273 |
+
if attention_type is not None:
|
| 274 |
+
if "update_time" in attention_type:
|
| 275 |
+
self.time_attn = TimeAttnBlock(dim=256, num_heads=8)
|
| 276 |
+
|
| 277 |
+
if "update_space" in attention_type:
|
| 278 |
+
self.space_attn = SpaceAttnBlock(dim=256, num_heads=8)
|
| 279 |
+
|
| 280 |
+
self.encoder = BasicMotionEncoder(cor_planes)
|
| 281 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
| 282 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
| 283 |
+
|
| 284 |
+
self.mask = nn.Sequential(
|
| 285 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
| 286 |
+
nn.ReLU(inplace=True),
|
| 287 |
+
nn.Conv2d(256, mask_size ** 2 * 9, 1, padding=0),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
def forward(self, net, inp, corr, flow, upsample=True, t=1):
|
| 291 |
+
motion_features = self.encoder(flow, corr)
|
| 292 |
+
inp = torch.cat((inp, motion_features), dim=1)
|
| 293 |
+
|
| 294 |
+
if self.attention_type is not None:
|
| 295 |
+
if "update_time" in self.attention_type:
|
| 296 |
+
inp = self.time_attn(inp, T=t)
|
| 297 |
+
|
| 298 |
+
if "update_space" in self.attention_type:
|
| 299 |
+
inp = self.space_attn(inp, T=t)
|
| 300 |
+
|
| 301 |
+
net = self.gru(net, inp)
|
| 302 |
+
delta_flow = self.flow_head(net)
|
| 303 |
+
|
| 304 |
+
# scale mask to balence gradients
|
| 305 |
+
mask = 0.25 * self.mask(net)
|
| 306 |
+
return net, mask, delta_flow
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class FlowHead3D(nn.Module):
|
| 310 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
| 311 |
+
super(FlowHead3D, self).__init__()
|
| 312 |
+
self.conv1 = nn.Conv3d(input_dim, hidden_dim, 3, padding=1)
|
| 313 |
+
self.conv2 = nn.Conv3d(hidden_dim, 2, 3, padding=1)
|
| 314 |
+
self.relu = nn.ReLU(inplace=True)
|
| 315 |
+
|
| 316 |
+
def forward(self, x):
|
| 317 |
+
return self.conv2(self.relu(self.conv1(x)))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class SequenceUpdateBlock3D(nn.Module):
|
| 321 |
+
def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type=None):
|
| 322 |
+
super(SequenceUpdateBlock3D, self).__init__()
|
| 323 |
+
|
| 324 |
+
# -- Extracts motion-related features from:
|
| 325 |
+
# * current flow estimate
|
| 326 |
+
# * correlation volume
|
| 327 |
+
self.encoder = BasicMotionEncoder(cor_planes)
|
| 328 |
+
|
| 329 |
+
# -- 3D separable convolution GRU enables temporal reasoning with 3D convolutions.
|
| 330 |
+
self.gru = SepConvGRU3D(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
| 331 |
+
|
| 332 |
+
self.flow_head = FlowHead3D(hidden_dim, hidden_dim=256)
|
| 333 |
+
self.mask = nn.Sequential(
|
| 334 |
+
nn.Conv2d(hidden_dim, hidden_dim + 128, 3, padding=1),
|
| 335 |
+
nn.ReLU(inplace=True),
|
| 336 |
+
nn.Conv2d(hidden_dim + 128, (mask_size ** 2) * 9, 1, padding=0),
|
| 337 |
+
)
|
| 338 |
+
self.attention_type = attention_type
|
| 339 |
+
if attention_type is not None:
|
| 340 |
+
if "update_time" in attention_type:
|
| 341 |
+
self.time_attn = TimeAttnBlock(dim=256, num_heads=8)
|
| 342 |
+
if "update_space" in attention_type:
|
| 343 |
+
self.space_attn = SpaceAttnBlock(dim=256, num_heads=8)
|
| 344 |
+
|
| 345 |
+
def forward(self, net, inp, corrs, flows, t, upsample=True):
|
| 346 |
+
inp_tensor = []
|
| 347 |
+
|
| 348 |
+
motion_features = self.encoder(flows, corrs)
|
| 349 |
+
inp_tensor = torch.cat([inp, motion_features], dim=1)
|
| 350 |
+
|
| 351 |
+
if self.attention_type is not None:
|
| 352 |
+
if "update_time" in self.attention_type:
|
| 353 |
+
inp_tensor = self.time_attn(inp_tensor, T=t)
|
| 354 |
+
|
| 355 |
+
if "update_space" in self.attention_type:
|
| 356 |
+
inp_tensor = self.space_attn(inp_tensor, T=t)
|
| 357 |
+
|
| 358 |
+
net = rearrange(net, "(b t) c h w -> b c t h w", t=t)
|
| 359 |
+
inp_tensor = rearrange(inp_tensor, "(b t) c h w -> b c t h w", t=t)
|
| 360 |
+
|
| 361 |
+
net = self.gru(net, inp_tensor)
|
| 362 |
+
|
| 363 |
+
delta_flow = self.flow_head(net)
|
| 364 |
+
|
| 365 |
+
# scale mask to balance gradients
|
| 366 |
+
net = rearrange(net, " b c t h w -> (b t) c h w")
|
| 367 |
+
mask = 0.25 * self.mask(net)
|
| 368 |
+
|
| 369 |
+
delta_flow = rearrange(delta_flow, " b c t h w -> (b t) c h w")
|
| 370 |
+
return net, mask, delta_flow
|
models/core/utils/config.py
ADDED
|
@@ -0,0 +1,961 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import dataclasses
|
| 8 |
+
import inspect
|
| 9 |
+
import itertools
|
| 10 |
+
import sys
|
| 11 |
+
import warnings
|
| 12 |
+
from collections import Counter, defaultdict
|
| 13 |
+
from enum import Enum
|
| 14 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
| 15 |
+
|
| 16 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
| 17 |
+
from pytorch3d.common.datatypes import get_args, get_origin
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
This functionality allows a configurable system to be determined in a dataclass-type
|
| 22 |
+
way. It is a generalization of omegaconf's "structured", in the dataclass case.
|
| 23 |
+
Core functionality:
|
| 24 |
+
|
| 25 |
+
- Configurable -- A base class used to label a class as being one which uses this
|
| 26 |
+
system. Uses class members and __post_init__ like a dataclass.
|
| 27 |
+
|
| 28 |
+
- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.
|
| 29 |
+
|
| 30 |
+
- get_default_args -- gets an omegaconf.DictConfig for initializing a given class.
|
| 31 |
+
|
| 32 |
+
- run_auto_creation -- Initialises nested members. To be called in __post_init__.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
In addition, a Configurable may contain members whose type is decided at runtime.
|
| 36 |
+
|
| 37 |
+
- ReplaceableBase -- As a base instead of Configurable, labels a class to say that
|
| 38 |
+
any child class can be used instead.
|
| 39 |
+
|
| 40 |
+
- registry -- A global store of named child classes of ReplaceableBase classes.
|
| 41 |
+
Used as `@registry.register` decorator on class definition.
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
Additional utility functions:
|
| 45 |
+
|
| 46 |
+
- remove_unused_components -- used for simplifying a DictConfig instance.
|
| 47 |
+
- get_default_args_field -- default for DictConfig member of another configurable.
|
| 48 |
+
- enable_get_default_args -- Allows get_default_args on a function or plain class.
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
1. The simplest usage of this functionality is as follows. First a schema is defined
|
| 52 |
+
in dataclass style.
|
| 53 |
+
|
| 54 |
+
class A(Configurable):
|
| 55 |
+
n: int = 9
|
| 56 |
+
|
| 57 |
+
class B(Configurable):
|
| 58 |
+
a: A
|
| 59 |
+
|
| 60 |
+
def __post_init__(self):
|
| 61 |
+
run_auto_creation(self)
|
| 62 |
+
|
| 63 |
+
Then it can be used like
|
| 64 |
+
|
| 65 |
+
b_args = get_default_args(B)
|
| 66 |
+
b = B(**b_args)
|
| 67 |
+
|
| 68 |
+
In this case, get_default_args(B) returns an omegaconf.DictConfig with the right
|
| 69 |
+
members {"a_args": {"n": 9}}. It also modifies the definitions of the classes to
|
| 70 |
+
something like the following. (The modification itself is done by the function
|
| 71 |
+
`expand_args_fields`, which is called inside `get_default_args`.)
|
| 72 |
+
|
| 73 |
+
@dataclasses.dataclass
|
| 74 |
+
class A:
|
| 75 |
+
n: int = 9
|
| 76 |
+
|
| 77 |
+
@dataclasses.dataclass
|
| 78 |
+
class B:
|
| 79 |
+
a_args: DictConfig = dataclasses.field(default_factory=lambda: DictConfig({"n": 9}))
|
| 80 |
+
|
| 81 |
+
def __post_init__(self):
|
| 82 |
+
self.a = A(**self.a_args)
|
| 83 |
+
|
| 84 |
+
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
|
| 85 |
+
it can be given a base class and the implementation will be looked up by name in the
|
| 86 |
+
global `registry` in this module. E.g.
|
| 87 |
+
|
| 88 |
+
class A(ReplaceableBase):
|
| 89 |
+
k: int = 1
|
| 90 |
+
|
| 91 |
+
@registry.register
|
| 92 |
+
class A1(A):
|
| 93 |
+
m: int = 3
|
| 94 |
+
|
| 95 |
+
@registry.register
|
| 96 |
+
class A2(A):
|
| 97 |
+
n: str = "2"
|
| 98 |
+
|
| 99 |
+
class B(Configurable):
|
| 100 |
+
a: A
|
| 101 |
+
a_class_type: str = "A2"
|
| 102 |
+
b: Optional[A]
|
| 103 |
+
b_class_type: Optional[str] = "A2"
|
| 104 |
+
|
| 105 |
+
def __post_init__(self):
|
| 106 |
+
run_auto_creation(self)
|
| 107 |
+
|
| 108 |
+
will expand to
|
| 109 |
+
|
| 110 |
+
@dataclasses.dataclass
|
| 111 |
+
class A:
|
| 112 |
+
k: int = 1
|
| 113 |
+
|
| 114 |
+
@dataclasses.dataclass
|
| 115 |
+
class A1(A):
|
| 116 |
+
m: int = 3
|
| 117 |
+
|
| 118 |
+
@dataclasses.dataclass
|
| 119 |
+
class A2(A):
|
| 120 |
+
n: str = "2"
|
| 121 |
+
|
| 122 |
+
@dataclasses.dataclass
|
| 123 |
+
class B:
|
| 124 |
+
a_class_type: str = "A2"
|
| 125 |
+
a_A1_args: DictConfig = dataclasses.field(
|
| 126 |
+
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
| 127 |
+
)
|
| 128 |
+
a_A2_args: DictConfig = dataclasses.field(
|
| 129 |
+
default_factory=lambda: DictConfig({"k": 1, "n": 2}
|
| 130 |
+
)
|
| 131 |
+
b_class_type: Optional[str] = "A2"
|
| 132 |
+
b_A1_args: DictConfig = dataclasses.field(
|
| 133 |
+
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
| 134 |
+
)
|
| 135 |
+
b_A2_args: DictConfig = dataclasses.field(
|
| 136 |
+
default_factory=lambda: DictConfig({"k": 1, "n": 2}
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def __post_init__(self):
|
| 140 |
+
if self.a_class_type == "A1":
|
| 141 |
+
self.a = A1(**self.a_A1_args)
|
| 142 |
+
elif self.a_class_type == "A2":
|
| 143 |
+
self.a = A2(**self.a_A2_args)
|
| 144 |
+
else:
|
| 145 |
+
raise ValueError(...)
|
| 146 |
+
|
| 147 |
+
if self.b_class_type is None:
|
| 148 |
+
self.b = None
|
| 149 |
+
elif self.b_class_type == "A1":
|
| 150 |
+
self.b = A1(**self.b_A1_args)
|
| 151 |
+
elif self.b_class_type == "A2":
|
| 152 |
+
self.b = A2(**self.b_A2_args)
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError(...)
|
| 155 |
+
|
| 156 |
+
3. Aside from these classes, the members of these classes should be things
|
| 157 |
+
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
|
| 158 |
+
can be built from them with `DictConfig`s and lists of them.
|
| 159 |
+
|
| 160 |
+
In addition, you can call `get_default_args` on a function or class to get
|
| 161 |
+
the `DictConfig` of its defaulted arguments, assuming those are all things
|
| 162 |
+
which `DictConfig` is happy with, so long as you add a call to
|
| 163 |
+
`enable_get_default_args` after its definition. If you want to use such a
|
| 164 |
+
thing as the default for a member of another configured class,
|
| 165 |
+
`get_default_args_field` is a helper.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
_unprocessed_warning: str = (
|
| 170 |
+
" must be processed before it can be used."
|
| 171 |
+
+ " This is done by calling expand_args_fields "
|
| 172 |
+
+ "or get_default_args on it."
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
TYPE_SUFFIX: str = "_class_type"
|
| 176 |
+
ARGS_SUFFIX: str = "_args"
|
| 177 |
+
ENABLED_SUFFIX: str = "_enabled"
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class ReplaceableBase:
|
| 181 |
+
"""
|
| 182 |
+
Base class for dataclass-style classes which
|
| 183 |
+
can be stored in the registry.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __new__(cls, *args, **kwargs):
|
| 187 |
+
"""
|
| 188 |
+
This function only exists to raise a
|
| 189 |
+
warning if class construction is attempted
|
| 190 |
+
without processing.
|
| 191 |
+
"""
|
| 192 |
+
obj = super().__new__(cls)
|
| 193 |
+
if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
|
| 194 |
+
warnings.warn(cls.__name__ + _unprocessed_warning)
|
| 195 |
+
return obj
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class Configurable:
|
| 199 |
+
"""
|
| 200 |
+
This indicates a class which is not ReplaceableBase
|
| 201 |
+
but still needs to be
|
| 202 |
+
expanded into a dataclass with expand_args_fields.
|
| 203 |
+
This expansion is delayed.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __new__(cls, *args, **kwargs):
|
| 207 |
+
"""
|
| 208 |
+
This function only exists to raise a
|
| 209 |
+
warning if class construction is attempted
|
| 210 |
+
without processing.
|
| 211 |
+
"""
|
| 212 |
+
obj = super().__new__(cls)
|
| 213 |
+
if cls is not Configurable and not _is_actually_dataclass(cls):
|
| 214 |
+
warnings.warn(cls.__name__ + _unprocessed_warning)
|
| 215 |
+
return obj
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
_X = TypeVar("X", bound=ReplaceableBase)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class _Registry:
|
| 222 |
+
"""
|
| 223 |
+
Register from names to classes. In particular, we say that direct subclasses of
|
| 224 |
+
ReplaceableBase are "base classes" and we register subclasses of each base class
|
| 225 |
+
in a separate namespace.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(self) -> None:
|
| 229 |
+
self._mapping: Dict[
|
| 230 |
+
Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]
|
| 231 |
+
] = defaultdict(dict)
|
| 232 |
+
|
| 233 |
+
def register(self, some_class: Type[_X]) -> Type[_X]:
|
| 234 |
+
"""
|
| 235 |
+
A class decorator, to register a class in self.
|
| 236 |
+
"""
|
| 237 |
+
name = some_class.__name__
|
| 238 |
+
self._register(some_class, name=name)
|
| 239 |
+
return some_class
|
| 240 |
+
|
| 241 |
+
def _register(
|
| 242 |
+
self,
|
| 243 |
+
some_class: Type[ReplaceableBase],
|
| 244 |
+
*,
|
| 245 |
+
base_class: Optional[Type[ReplaceableBase]] = None,
|
| 246 |
+
name: str,
|
| 247 |
+
) -> None:
|
| 248 |
+
"""
|
| 249 |
+
Register a new member.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
cls: the new member
|
| 253 |
+
base_class: (optional) what the new member is a type for
|
| 254 |
+
name: name for the new member
|
| 255 |
+
"""
|
| 256 |
+
if base_class is None:
|
| 257 |
+
base_class = self._base_class_from_class(some_class)
|
| 258 |
+
if base_class is None:
|
| 259 |
+
raise ValueError(
|
| 260 |
+
f"Cannot register {some_class}. Cannot tell what it is."
|
| 261 |
+
)
|
| 262 |
+
if some_class is base_class:
|
| 263 |
+
raise ValueError(f"Attempted to register the base class {some_class}")
|
| 264 |
+
self._mapping[base_class][name] = some_class
|
| 265 |
+
|
| 266 |
+
def get(
|
| 267 |
+
self, base_class_wanted: Type[ReplaceableBase], name: str
|
| 268 |
+
) -> Type[ReplaceableBase]:
|
| 269 |
+
"""
|
| 270 |
+
Retrieve a class from the registry by name
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
base_class_wanted: parent type of type we are looking for.
|
| 274 |
+
It determines the namespace.
|
| 275 |
+
This will typically be a direct subclass of ReplaceableBase.
|
| 276 |
+
name: what to look for
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
class type
|
| 280 |
+
"""
|
| 281 |
+
if self._is_base_class(base_class_wanted):
|
| 282 |
+
base_class = base_class_wanted
|
| 283 |
+
else:
|
| 284 |
+
base_class = self._base_class_from_class(base_class_wanted)
|
| 285 |
+
if base_class is None:
|
| 286 |
+
raise ValueError(
|
| 287 |
+
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
|
| 288 |
+
)
|
| 289 |
+
result = self._mapping[base_class].get(name)
|
| 290 |
+
if result is None:
|
| 291 |
+
raise ValueError(f"{name} has not been registered.")
|
| 292 |
+
if not issubclass(result, base_class_wanted):
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
|
| 295 |
+
)
|
| 296 |
+
return result
|
| 297 |
+
|
| 298 |
+
def get_all(
|
| 299 |
+
self, base_class_wanted: Type[ReplaceableBase]
|
| 300 |
+
) -> List[Type[ReplaceableBase]]:
|
| 301 |
+
"""
|
| 302 |
+
Retrieve all registered implementations from the registry
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
base_class_wanted: parent type of type we are looking for.
|
| 306 |
+
It determines the namespace.
|
| 307 |
+
This will typically be a direct subclass of ReplaceableBase.
|
| 308 |
+
Returns:
|
| 309 |
+
list of class types
|
| 310 |
+
"""
|
| 311 |
+
if self._is_base_class(base_class_wanted):
|
| 312 |
+
return list(self._mapping[base_class_wanted].values())
|
| 313 |
+
|
| 314 |
+
base_class = self._base_class_from_class(base_class_wanted)
|
| 315 |
+
if base_class is None:
|
| 316 |
+
raise ValueError(
|
| 317 |
+
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
|
| 318 |
+
)
|
| 319 |
+
return [
|
| 320 |
+
class_
|
| 321 |
+
for class_ in self._mapping[base_class].values()
|
| 322 |
+
if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
@staticmethod
|
| 326 |
+
def _is_base_class(some_class: Type[ReplaceableBase]) -> bool:
|
| 327 |
+
"""
|
| 328 |
+
Return whether the given type is a direct subclass of ReplaceableBase
|
| 329 |
+
and so gets used as a namespace.
|
| 330 |
+
"""
|
| 331 |
+
return ReplaceableBase in some_class.__bases__
|
| 332 |
+
|
| 333 |
+
@staticmethod
|
| 334 |
+
def _base_class_from_class(
|
| 335 |
+
some_class: Type[ReplaceableBase],
|
| 336 |
+
) -> Optional[Type[ReplaceableBase]]:
|
| 337 |
+
"""
|
| 338 |
+
Find the parent class of some_class which inherits ReplaceableBase, or None
|
| 339 |
+
"""
|
| 340 |
+
for base in some_class.mro()[-3::-1]:
|
| 341 |
+
if base is not ReplaceableBase and issubclass(base, ReplaceableBase):
|
| 342 |
+
return base
|
| 343 |
+
return None
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# Global instance of the registry
|
| 347 |
+
registry = _Registry()
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class _ProcessType(Enum):
|
| 351 |
+
"""
|
| 352 |
+
Type of member which gets rewritten by expand_args_fields.
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
CONFIGURABLE = 1
|
| 356 |
+
REPLACEABLE = 2
|
| 357 |
+
OPTIONAL_CONFIGURABLE = 3
|
| 358 |
+
OPTIONAL_REPLACEABLE = 4
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _default_create(
|
| 362 |
+
name: str, type_: Type, process_type: _ProcessType
|
| 363 |
+
) -> Callable[[Any], None]:
|
| 364 |
+
"""
|
| 365 |
+
Return the default creation function for a member. This is a function which
|
| 366 |
+
could be called in __post_init__ to initialise the member, and will be called
|
| 367 |
+
from run_auto_creation.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
name: name of the member
|
| 371 |
+
type_: type of the member (with any Optional removed)
|
| 372 |
+
process_type: Shows whether member's declared type inherits ReplaceableBase,
|
| 373 |
+
in which case the actual type to be created is decided at
|
| 374 |
+
runtime.
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
Function taking one argument, the object whose member should be
|
| 378 |
+
initialized.
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
def inner(self):
|
| 382 |
+
expand_args_fields(type_)
|
| 383 |
+
args = getattr(self, name + ARGS_SUFFIX)
|
| 384 |
+
setattr(self, name, type_(**args))
|
| 385 |
+
|
| 386 |
+
def inner_optional(self):
|
| 387 |
+
expand_args_fields(type_)
|
| 388 |
+
enabled = getattr(self, name + ENABLED_SUFFIX)
|
| 389 |
+
if enabled:
|
| 390 |
+
args = getattr(self, name + ARGS_SUFFIX)
|
| 391 |
+
setattr(self, name, type_(**args))
|
| 392 |
+
else:
|
| 393 |
+
setattr(self, name, None)
|
| 394 |
+
|
| 395 |
+
def inner_pluggable(self):
|
| 396 |
+
type_name = getattr(self, name + TYPE_SUFFIX)
|
| 397 |
+
if type_name is None:
|
| 398 |
+
setattr(self, name, None)
|
| 399 |
+
return
|
| 400 |
+
|
| 401 |
+
chosen_class = registry.get(type_, type_name)
|
| 402 |
+
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
|
| 403 |
+
# If this warning is raised, it means that a new definition of
|
| 404 |
+
# the chosen class has been registered since our class was processed
|
| 405 |
+
# (i.e. expanded). A DictConfig which comes from our get_default_args
|
| 406 |
+
# (which might have triggered the processing) will contain the old default
|
| 407 |
+
# values for the members of the chosen class. Changes to those defaults which
|
| 408 |
+
# were made in the redefinition will not be reflected here.
|
| 409 |
+
warnings.warn(f"New implementation of {type_name} is being chosen.")
|
| 410 |
+
expand_args_fields(chosen_class)
|
| 411 |
+
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
| 412 |
+
setattr(self, name, chosen_class(**args))
|
| 413 |
+
|
| 414 |
+
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
| 415 |
+
return inner_optional
|
| 416 |
+
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def run_auto_creation(self: Any) -> None:
|
| 420 |
+
"""
|
| 421 |
+
Run all the functions named in self._creation_functions.
|
| 422 |
+
"""
|
| 423 |
+
for create_function in self._creation_functions:
|
| 424 |
+
getattr(self, create_function)()
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def _is_configurable_class(C) -> bool:
|
| 428 |
+
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
|
| 432 |
+
"""
|
| 433 |
+
Get the DictConfig corresponding to the defaults in a dataclass or
|
| 434 |
+
configurable. Normal use is to provide a dataclass can be provided as C.
|
| 435 |
+
If enable_get_default_args has been called on a function or plain class,
|
| 436 |
+
then that function or class can be provided as C.
|
| 437 |
+
|
| 438 |
+
If C is a subclass of Configurable or ReplaceableBase, we make sure
|
| 439 |
+
it has been processed with expand_args_fields.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
C: the class or function to be processed
|
| 443 |
+
_do_not_process: (internal use) When this function is called from
|
| 444 |
+
expand_args_fields, we specify any class currently being
|
| 445 |
+
processed, to make sure we don't try to process a class
|
| 446 |
+
while it is already being processed.
|
| 447 |
+
|
| 448 |
+
Returns:
|
| 449 |
+
new DictConfig object, which is typed.
|
| 450 |
+
"""
|
| 451 |
+
if C is None:
|
| 452 |
+
return DictConfig({})
|
| 453 |
+
|
| 454 |
+
if _is_configurable_class(C):
|
| 455 |
+
if C in _do_not_process:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
f"Internal recursion error. Need processed {C},"
|
| 458 |
+
f" but cannot get it. _do_not_process={_do_not_process}"
|
| 459 |
+
)
|
| 460 |
+
# This is safe to run multiple times. It will return
|
| 461 |
+
# straight away if C has already been processed.
|
| 462 |
+
expand_args_fields(C, _do_not_process=_do_not_process)
|
| 463 |
+
|
| 464 |
+
if dataclasses.is_dataclass(C):
|
| 465 |
+
# Note that if get_default_args_field is used somewhere in C,
|
| 466 |
+
# this call is recursive. No special care is needed,
|
| 467 |
+
# because in practice get_default_args_field is used for
|
| 468 |
+
# separate types than the outer type.
|
| 469 |
+
|
| 470 |
+
out: DictConfig = OmegaConf.structured(C)
|
| 471 |
+
exclude = getattr(C, "_processed_members", ())
|
| 472 |
+
with open_dict(out):
|
| 473 |
+
for field in exclude:
|
| 474 |
+
out.pop(field, None)
|
| 475 |
+
return out
|
| 476 |
+
|
| 477 |
+
if _is_configurable_class(C):
|
| 478 |
+
raise ValueError(f"Failed to process {C}")
|
| 479 |
+
|
| 480 |
+
if not inspect.isfunction(C) and not inspect.isclass(C):
|
| 481 |
+
raise ValueError(f"Unexpected {C}")
|
| 482 |
+
|
| 483 |
+
dataclass_name = _dataclass_name_for_function(C)
|
| 484 |
+
dataclass = getattr(sys.modules[C.__module__], dataclass_name, None)
|
| 485 |
+
if dataclass is None:
|
| 486 |
+
raise ValueError(
|
| 487 |
+
f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
return OmegaConf.structured(dataclass)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def _dataclass_name_for_function(C: Any) -> str:
|
| 494 |
+
"""
|
| 495 |
+
Returns the name of the dataclass which enable_get_default_args(C)
|
| 496 |
+
creates.
|
| 497 |
+
"""
|
| 498 |
+
name = f"_{C.__name__}_default_args_"
|
| 499 |
+
return name
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
|
| 503 |
+
"""
|
| 504 |
+
If C is a function or a plain class with an __init__ function,
|
| 505 |
+
and you want get_default_args(C) to work, then add
|
| 506 |
+
`enable_get_default_args(C)` straight after the definition of C.
|
| 507 |
+
This makes a dataclass corresponding to the default arguments of C
|
| 508 |
+
and stores it in the same module as C.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
C: a function, or a class with an __init__ function. Must
|
| 512 |
+
have types for all its defaulted args.
|
| 513 |
+
overwrite: whether to allow calling this a second time on
|
| 514 |
+
the same function.
|
| 515 |
+
"""
|
| 516 |
+
if not inspect.isfunction(C) and not inspect.isclass(C):
|
| 517 |
+
raise ValueError(f"Unexpected {C}")
|
| 518 |
+
|
| 519 |
+
field_annotations = []
|
| 520 |
+
for pname, defval in _params_iter(C):
|
| 521 |
+
default = defval.default
|
| 522 |
+
if default == inspect.Parameter.empty:
|
| 523 |
+
# we do not have a default value for the parameter
|
| 524 |
+
continue
|
| 525 |
+
|
| 526 |
+
if defval.annotation == inspect._empty:
|
| 527 |
+
raise ValueError(
|
| 528 |
+
"All arguments of the input callable have to be typed."
|
| 529 |
+
+ f" Argument '{pname}' does not have a type annotation."
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
_, annotation = _resolve_optional(defval.annotation)
|
| 533 |
+
|
| 534 |
+
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
| 535 |
+
default = tuple(default)
|
| 536 |
+
|
| 537 |
+
if isinstance(default, (list, dict)):
|
| 538 |
+
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
|
| 539 |
+
field_ = dataclasses.field(default_factory=lambda default=default: default)
|
| 540 |
+
elif not _is_immutable_type(annotation, default):
|
| 541 |
+
continue
|
| 542 |
+
else:
|
| 543 |
+
# we can use a simple default argument for dataclass.field
|
| 544 |
+
field_ = dataclasses.field(default=default)
|
| 545 |
+
field_annotations.append((pname, defval.annotation, field_))
|
| 546 |
+
|
| 547 |
+
name = _dataclass_name_for_function(C)
|
| 548 |
+
module = sys.modules[C.__module__]
|
| 549 |
+
if hasattr(module, name):
|
| 550 |
+
if overwrite:
|
| 551 |
+
warnings.warn(f"Overwriting {name} in {C.__module__}.")
|
| 552 |
+
else:
|
| 553 |
+
raise ValueError(f"Cannot overwrite {name} in {C.__module__}.")
|
| 554 |
+
dc = dataclasses.make_dataclass(name, field_annotations)
|
| 555 |
+
dc.__module__ = C.__module__
|
| 556 |
+
setattr(module, name, dc)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def _params_iter(C):
|
| 560 |
+
"""Returns dict of keyword args of a class or function C."""
|
| 561 |
+
if inspect.isclass(C):
|
| 562 |
+
return itertools.islice( # exclude `self`
|
| 563 |
+
inspect.signature(C.__init__).parameters.items(), 1, None
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
return inspect.signature(C).parameters.items()
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def _is_immutable_type(type_: Type, val: Any) -> bool:
|
| 570 |
+
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
|
| 571 |
+
# sometimes type can be too relaxed (e.g. Any), so we also check values
|
| 572 |
+
if isinstance(val, PRIMITIVE_TYPES):
|
| 573 |
+
return True
|
| 574 |
+
|
| 575 |
+
return type_ in PRIMITIVE_TYPES or (
|
| 576 |
+
inspect.isclass(type_) and issubclass(type_, Enum)
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# copied from OmegaConf
|
| 581 |
+
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
| 582 |
+
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
| 583 |
+
if get_origin(type_) is Union:
|
| 584 |
+
args = get_args(type_)
|
| 585 |
+
if len(args) == 2 and args[1] == type(None): # noqa E721
|
| 586 |
+
return True, args[0]
|
| 587 |
+
if type_ is Any:
|
| 588 |
+
return True, Any
|
| 589 |
+
|
| 590 |
+
return False, type_
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def _is_actually_dataclass(some_class) -> bool:
|
| 594 |
+
# Return whether the class some_class has been processed with
|
| 595 |
+
# the dataclass annotation. This is more specific than
|
| 596 |
+
# dataclasses.is_dataclass which returns True on anything
|
| 597 |
+
# deriving from a dataclass.
|
| 598 |
+
|
| 599 |
+
# Checking for __init__ would also work for our purpose.
|
| 600 |
+
return "__dataclass_fields__" in some_class.__dict__
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def expand_args_fields(
|
| 604 |
+
some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
|
| 605 |
+
) -> Type[_X]:
|
| 606 |
+
"""
|
| 607 |
+
This expands a class which inherits Configurable or ReplaceableBase classes,
|
| 608 |
+
including dataclass processing. some_class is modified in place by this function.
|
| 609 |
+
For classes of type ReplaceableBase, you can add some_class to the registry before
|
| 610 |
+
or after calling this function. But potential inner classes need to be registered
|
| 611 |
+
before this function is run on the outer class.
|
| 612 |
+
|
| 613 |
+
The transformations this function makes, before the concluding
|
| 614 |
+
dataclasses.dataclass, are as follows. if X is a base class with registered
|
| 615 |
+
subclasses Y and Z, replace a class member
|
| 616 |
+
|
| 617 |
+
x: X
|
| 618 |
+
|
| 619 |
+
and optionally
|
| 620 |
+
|
| 621 |
+
x_class_type: str = "Y"
|
| 622 |
+
def create_x(self):...
|
| 623 |
+
|
| 624 |
+
with
|
| 625 |
+
|
| 626 |
+
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
| 627 |
+
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
| 628 |
+
def create_x(self):
|
| 629 |
+
self.x = registry.get(X, self.x_class_type)(
|
| 630 |
+
**self.getattr(f"x_{self.x_class_type}_args)
|
| 631 |
+
)
|
| 632 |
+
x_class_type: str = "UNDEFAULTED"
|
| 633 |
+
|
| 634 |
+
without adding the optional attributes if they are already there.
|
| 635 |
+
|
| 636 |
+
Similarly, replace
|
| 637 |
+
|
| 638 |
+
x: Optional[X]
|
| 639 |
+
|
| 640 |
+
and optionally
|
| 641 |
+
|
| 642 |
+
x_class_type: Optional[str] = "Y"
|
| 643 |
+
def create_x(self):...
|
| 644 |
+
|
| 645 |
+
with
|
| 646 |
+
|
| 647 |
+
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
| 648 |
+
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
| 649 |
+
def create_x(self):
|
| 650 |
+
if self.x_class_type is None:
|
| 651 |
+
self.x = None
|
| 652 |
+
return
|
| 653 |
+
|
| 654 |
+
self.x = registry.get(X, self.x_class_type)(
|
| 655 |
+
**self.getattr(f"x_{self.x_class_type}_args)
|
| 656 |
+
)
|
| 657 |
+
x_class_type: Optional[str] = "UNDEFAULTED"
|
| 658 |
+
|
| 659 |
+
without adding the optional attributes if they are already there.
|
| 660 |
+
|
| 661 |
+
Similarly, if X is a subclass of Configurable,
|
| 662 |
+
|
| 663 |
+
x: X
|
| 664 |
+
|
| 665 |
+
and optionally
|
| 666 |
+
|
| 667 |
+
def create_x(self):...
|
| 668 |
+
|
| 669 |
+
will be replaced with
|
| 670 |
+
|
| 671 |
+
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
| 672 |
+
def create_x(self):
|
| 673 |
+
self.x = X(self.x_args)
|
| 674 |
+
|
| 675 |
+
Similarly, replace,
|
| 676 |
+
|
| 677 |
+
x: Optional[X]
|
| 678 |
+
|
| 679 |
+
and optionally
|
| 680 |
+
|
| 681 |
+
def create_x(self):...
|
| 682 |
+
x_enabled: bool = ...
|
| 683 |
+
|
| 684 |
+
with
|
| 685 |
+
|
| 686 |
+
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
| 687 |
+
x_enabled: bool = False
|
| 688 |
+
def create_x(self):
|
| 689 |
+
if self.x_enabled:
|
| 690 |
+
self.x = X(self.x_args)
|
| 691 |
+
else:
|
| 692 |
+
self.x = None
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
Also adds the following class members, unannotated so that dataclass
|
| 696 |
+
ignores them.
|
| 697 |
+
- _creation_functions: Tuple[str] of all the create_ functions,
|
| 698 |
+
including those from base classes.
|
| 699 |
+
- _known_implementations: Dict[str, Type] containing the classes which
|
| 700 |
+
have been found from the registry.
|
| 701 |
+
(used only to raise a warning if it one has been overwritten)
|
| 702 |
+
- _processed_members: a Dict[str, Any] of all the members which have been
|
| 703 |
+
transformed, with values giving the types they were declared to have.
|
| 704 |
+
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
some_class: the class to be processed
|
| 708 |
+
_do_not_process: Internal use for get_default_args: Because get_default_args calls
|
| 709 |
+
and is called by this function, we let it specify any class currently
|
| 710 |
+
being processed, to make sure we don't try to process a class while
|
| 711 |
+
it is already being processed.
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
Returns:
|
| 715 |
+
some_class itself, which has been modified in place. This
|
| 716 |
+
allows this function to be used as a class decorator.
|
| 717 |
+
"""
|
| 718 |
+
if _is_actually_dataclass(some_class):
|
| 719 |
+
return some_class
|
| 720 |
+
|
| 721 |
+
# The functions this class's run_auto_creation will run.
|
| 722 |
+
creation_functions: List[str] = []
|
| 723 |
+
# The classes which this type knows about from the registry
|
| 724 |
+
# We could use a weakref.WeakValueDictionary here which would mean
|
| 725 |
+
# that we don't warn if the class we should have expected is elsewhere
|
| 726 |
+
# unused.
|
| 727 |
+
known_implementations: Dict[str, Type] = {}
|
| 728 |
+
# Names of members which have been processed.
|
| 729 |
+
processed_members: Dict[str, Any] = {}
|
| 730 |
+
|
| 731 |
+
# For all bases except ReplaceableBase and Configurable and object,
|
| 732 |
+
# we need to process them before our own processing. This is
|
| 733 |
+
# because dataclasses expect to inherit dataclasses and not unprocessed
|
| 734 |
+
# dataclasses.
|
| 735 |
+
for base in some_class.mro()[-3:0:-1]:
|
| 736 |
+
if base is ReplaceableBase:
|
| 737 |
+
continue
|
| 738 |
+
if base is Configurable:
|
| 739 |
+
continue
|
| 740 |
+
if not issubclass(base, (Configurable, ReplaceableBase)):
|
| 741 |
+
continue
|
| 742 |
+
expand_args_fields(base, _do_not_process=_do_not_process)
|
| 743 |
+
if "_creation_functions" in base.__dict__:
|
| 744 |
+
creation_functions.extend(base._creation_functions)
|
| 745 |
+
if "_known_implementations" in base.__dict__:
|
| 746 |
+
known_implementations.update(base._known_implementations)
|
| 747 |
+
if "_processed_members" in base.__dict__:
|
| 748 |
+
processed_members.update(base._processed_members)
|
| 749 |
+
|
| 750 |
+
to_process: List[Tuple[str, Type, _ProcessType]] = []
|
| 751 |
+
if "__annotations__" in some_class.__dict__:
|
| 752 |
+
for name, type_ in some_class.__annotations__.items():
|
| 753 |
+
underlying_and_process_type = _get_type_to_process(type_)
|
| 754 |
+
if underlying_and_process_type is None:
|
| 755 |
+
continue
|
| 756 |
+
underlying_type, process_type = underlying_and_process_type
|
| 757 |
+
to_process.append((name, underlying_type, process_type))
|
| 758 |
+
|
| 759 |
+
for name, underlying_type, process_type in to_process:
|
| 760 |
+
processed_members[name] = some_class.__annotations__[name]
|
| 761 |
+
_process_member(
|
| 762 |
+
name=name,
|
| 763 |
+
type_=underlying_type,
|
| 764 |
+
process_type=process_type,
|
| 765 |
+
some_class=some_class,
|
| 766 |
+
creation_functions=creation_functions,
|
| 767 |
+
_do_not_process=_do_not_process,
|
| 768 |
+
known_implementations=known_implementations,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
for key, count in Counter(creation_functions).items():
|
| 772 |
+
if count > 1:
|
| 773 |
+
warnings.warn(f"Clash with {key} in a base class.")
|
| 774 |
+
some_class._creation_functions = tuple(creation_functions)
|
| 775 |
+
some_class._processed_members = processed_members
|
| 776 |
+
some_class._known_implementations = known_implementations
|
| 777 |
+
|
| 778 |
+
dataclasses.dataclass(eq=False)(some_class)
|
| 779 |
+
return some_class
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
| 783 |
+
"""
|
| 784 |
+
Get a dataclass field which defaults to get_default_args(...)
|
| 785 |
+
|
| 786 |
+
Args:
|
| 787 |
+
As for get_default_args.
|
| 788 |
+
|
| 789 |
+
Returns:
|
| 790 |
+
function to return new DictConfig object
|
| 791 |
+
"""
|
| 792 |
+
|
| 793 |
+
def create():
|
| 794 |
+
return get_default_args(C, _do_not_process=_do_not_process)
|
| 795 |
+
|
| 796 |
+
return dataclasses.field(default_factory=create)
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
|
| 800 |
+
"""
|
| 801 |
+
If a member is annotated as `type_`, and that should expanded in
|
| 802 |
+
expand_args_fields, return how it should be expanded.
|
| 803 |
+
"""
|
| 804 |
+
if get_origin(type_) == Union:
|
| 805 |
+
# We look for Optional[X] which is a Union of X with None.
|
| 806 |
+
args = get_args(type_)
|
| 807 |
+
if len(args) != 2 or all(a is not type(None) for a in args): # noqa: E721
|
| 808 |
+
return
|
| 809 |
+
underlying = args[0] if args[1] is type(None) else args[1] # noqa: E721
|
| 810 |
+
if (
|
| 811 |
+
isinstance(underlying, type)
|
| 812 |
+
and issubclass(underlying, ReplaceableBase)
|
| 813 |
+
and ReplaceableBase in underlying.__bases__
|
| 814 |
+
):
|
| 815 |
+
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
|
| 816 |
+
|
| 817 |
+
if isinstance(underlying, type) and issubclass(underlying, Configurable):
|
| 818 |
+
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
|
| 819 |
+
|
| 820 |
+
if not isinstance(type_, type):
|
| 821 |
+
# e.g. any other Union or Tuple
|
| 822 |
+
return
|
| 823 |
+
|
| 824 |
+
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
|
| 825 |
+
return type_, _ProcessType.REPLACEABLE
|
| 826 |
+
|
| 827 |
+
if issubclass(type_, Configurable):
|
| 828 |
+
return type_, _ProcessType.CONFIGURABLE
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def _process_member(
|
| 832 |
+
*,
|
| 833 |
+
name: str,
|
| 834 |
+
type_: Type,
|
| 835 |
+
process_type: _ProcessType,
|
| 836 |
+
some_class: Type,
|
| 837 |
+
creation_functions: List[str],
|
| 838 |
+
_do_not_process: Tuple[type, ...],
|
| 839 |
+
known_implementations: Dict[str, Type],
|
| 840 |
+
) -> None:
|
| 841 |
+
"""
|
| 842 |
+
Make the modification (of expand_args_fields) to some_class for a single member.
|
| 843 |
+
|
| 844 |
+
Args:
|
| 845 |
+
name: member name
|
| 846 |
+
type_: member type (with Optional removed if needed)
|
| 847 |
+
process_type: whether member has dynamic type
|
| 848 |
+
some_class: (MODIFIED IN PLACE) the class being processed
|
| 849 |
+
creation_functions: (MODIFIED IN PLACE) the names of the create functions
|
| 850 |
+
_do_not_process: as for expand_args_fields.
|
| 851 |
+
known_implementations: (MODIFIED IN PLACE) known types from the registry
|
| 852 |
+
"""
|
| 853 |
+
# Because we are adding defaultable members, make
|
| 854 |
+
# sure they go at the end of __annotations__ in case
|
| 855 |
+
# there are non-defaulted standard class members.
|
| 856 |
+
del some_class.__annotations__[name]
|
| 857 |
+
|
| 858 |
+
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
| 859 |
+
type_name = name + TYPE_SUFFIX
|
| 860 |
+
if type_name not in some_class.__annotations__:
|
| 861 |
+
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
|
| 862 |
+
some_class.__annotations__[type_name] = Optional[str]
|
| 863 |
+
else:
|
| 864 |
+
some_class.__annotations__[type_name] = str
|
| 865 |
+
setattr(some_class, type_name, "UNDEFAULTED")
|
| 866 |
+
|
| 867 |
+
for derived_type in registry.get_all(type_):
|
| 868 |
+
if derived_type in _do_not_process:
|
| 869 |
+
continue
|
| 870 |
+
if issubclass(derived_type, some_class):
|
| 871 |
+
# When derived_type is some_class we have a simple
|
| 872 |
+
# recursion to avoid. When it's a strict subclass the
|
| 873 |
+
# situation is even worse.
|
| 874 |
+
continue
|
| 875 |
+
known_implementations[derived_type.__name__] = derived_type
|
| 876 |
+
args_name = f"{name}_{derived_type.__name__}{ARGS_SUFFIX}"
|
| 877 |
+
if args_name in some_class.__annotations__:
|
| 878 |
+
raise ValueError(
|
| 879 |
+
f"Cannot generate {args_name} because it is already present."
|
| 880 |
+
)
|
| 881 |
+
some_class.__annotations__[args_name] = DictConfig
|
| 882 |
+
setattr(
|
| 883 |
+
some_class,
|
| 884 |
+
args_name,
|
| 885 |
+
get_default_args_field(
|
| 886 |
+
derived_type, _do_not_process=_do_not_process + (some_class,)
|
| 887 |
+
),
|
| 888 |
+
)
|
| 889 |
+
else:
|
| 890 |
+
args_name = name + ARGS_SUFFIX
|
| 891 |
+
if args_name in some_class.__annotations__:
|
| 892 |
+
raise ValueError(
|
| 893 |
+
f"Cannot generate {args_name} because it is already present."
|
| 894 |
+
)
|
| 895 |
+
if issubclass(type_, some_class) or type_ in _do_not_process:
|
| 896 |
+
raise ValueError(f"Cannot process {type_} inside {some_class}")
|
| 897 |
+
|
| 898 |
+
some_class.__annotations__[args_name] = DictConfig
|
| 899 |
+
setattr(
|
| 900 |
+
some_class,
|
| 901 |
+
args_name,
|
| 902 |
+
get_default_args_field(
|
| 903 |
+
type_,
|
| 904 |
+
_do_not_process=_do_not_process + (some_class,),
|
| 905 |
+
),
|
| 906 |
+
)
|
| 907 |
+
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
| 908 |
+
enabled_name = name + ENABLED_SUFFIX
|
| 909 |
+
if enabled_name not in some_class.__annotations__:
|
| 910 |
+
some_class.__annotations__[enabled_name] = bool
|
| 911 |
+
setattr(some_class, enabled_name, False)
|
| 912 |
+
|
| 913 |
+
creation_function_name = f"create_{name}"
|
| 914 |
+
if not hasattr(some_class, creation_function_name):
|
| 915 |
+
setattr(
|
| 916 |
+
some_class,
|
| 917 |
+
creation_function_name,
|
| 918 |
+
_default_create(name, type_, process_type),
|
| 919 |
+
)
|
| 920 |
+
creation_functions.append(creation_function_name)
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
def remove_unused_components(dict_: DictConfig) -> None:
|
| 924 |
+
"""
|
| 925 |
+
Assuming dict_ represents the state of a configurable,
|
| 926 |
+
modify it to remove all the portions corresponding to
|
| 927 |
+
pluggable parts which are not in use.
|
| 928 |
+
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
|
| 929 |
+
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
|
| 930 |
+
removed. Also, if chocolate_enabled is False, then chocolate_args will
|
| 931 |
+
be removed.
|
| 932 |
+
|
| 933 |
+
Args:
|
| 934 |
+
dict_: (MODIFIED IN PLACE) a DictConfig instance
|
| 935 |
+
"""
|
| 936 |
+
keys = [key for key in dict_ if isinstance(key, str)]
|
| 937 |
+
suffix_length = len(TYPE_SUFFIX)
|
| 938 |
+
replaceables = [key[:-suffix_length] for key in keys if key.endswith(TYPE_SUFFIX)]
|
| 939 |
+
args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
|
| 940 |
+
for replaceable in replaceables:
|
| 941 |
+
selected_type = dict_[replaceable + TYPE_SUFFIX]
|
| 942 |
+
if selected_type is None:
|
| 943 |
+
expect = ""
|
| 944 |
+
else:
|
| 945 |
+
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
|
| 946 |
+
with open_dict(dict_):
|
| 947 |
+
for key in args_keys:
|
| 948 |
+
if key.startswith(replaceable + "_") and key != expect:
|
| 949 |
+
del dict_[key]
|
| 950 |
+
|
| 951 |
+
suffix_length = len(ENABLED_SUFFIX)
|
| 952 |
+
enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)]
|
| 953 |
+
for enableable in enableables:
|
| 954 |
+
enabled = dict_[enableable + ENABLED_SUFFIX]
|
| 955 |
+
if not enabled:
|
| 956 |
+
with open_dict(dict_):
|
| 957 |
+
dict_.pop(enableable + ARGS_SUFFIX, None)
|
| 958 |
+
|
| 959 |
+
for key in dict_:
|
| 960 |
+
if isinstance(dict_.get(key), DictConfig):
|
| 961 |
+
remove_unused_components(dict_[key])
|
models/core/utils/utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def interp(tensor, size):
|
| 11 |
+
return F.interpolate(
|
| 12 |
+
tensor,
|
| 13 |
+
size=size,
|
| 14 |
+
mode="bilinear",
|
| 15 |
+
align_corners=True,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class InputPadder:
|
| 20 |
+
"""Pads images such that dimensions are divisible by 8"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, dims, mode="sintel", divis_by=8):
|
| 23 |
+
self.ht, self.wd = dims[-2:]
|
| 24 |
+
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
| 25 |
+
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
| 26 |
+
if mode == "sintel":
|
| 27 |
+
self._pad = [
|
| 28 |
+
pad_wd // 2,
|
| 29 |
+
pad_wd - pad_wd // 2,
|
| 30 |
+
pad_ht // 2,
|
| 31 |
+
pad_ht - pad_ht // 2,
|
| 32 |
+
]
|
| 33 |
+
else:
|
| 34 |
+
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
| 35 |
+
|
| 36 |
+
def pad(self, *inputs):
|
| 37 |
+
assert all((x.ndim == 4) for x in inputs)
|
| 38 |
+
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
| 39 |
+
|
| 40 |
+
def unpad(self, x):
|
| 41 |
+
assert x.ndim == 4
|
| 42 |
+
ht, wd = x.shape[-2:]
|
| 43 |
+
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
| 44 |
+
return x[..., c[0] : c[1], c[2] : c[3]]
|
models/dynamic_stereo_model.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import ClassVar
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from pytorch3d.implicitron.tools.config import Configurable
|
| 11 |
+
|
| 12 |
+
from dynamic_stereo.models.core.dynamic_stereo import DynamicStereo
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DynamicStereoModel(Configurable, torch.nn.Module):
|
| 16 |
+
|
| 17 |
+
MODEL_CONFIG_NAME: ClassVar[str] = "DynamicStereoModel"
|
| 18 |
+
|
| 19 |
+
# model_weights: str = "./checkpoints/dynamic_stereo_sf.pth"
|
| 20 |
+
model_weights: str = "./checkpoints/dynamic_stereo_dr_sf.pth"
|
| 21 |
+
kernel_size: int = 20
|
| 22 |
+
|
| 23 |
+
def __post_init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.mixed_precision = False
|
| 27 |
+
model = DynamicStereo(
|
| 28 |
+
mixed_precision=self.mixed_precision,
|
| 29 |
+
num_frames=5,
|
| 30 |
+
attention_type="self_stereo_temporal_update_time_update_space",
|
| 31 |
+
use_3d_update_block=True,
|
| 32 |
+
different_update_blocks=True,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
state_dict = torch.load(self.model_weights, map_location="cpu")
|
| 36 |
+
if "model" in state_dict:
|
| 37 |
+
state_dict = state_dict["model"]
|
| 38 |
+
if "state_dict" in state_dict:
|
| 39 |
+
state_dict = state_dict["state_dict"]
|
| 40 |
+
state_dict = {"module." + k: v for k, v in state_dict.items()}
|
| 41 |
+
model.load_state_dict(state_dict, strict=False)
|
| 42 |
+
|
| 43 |
+
self.model = model
|
| 44 |
+
self.model.to("cuda")
|
| 45 |
+
self.model.eval()
|
| 46 |
+
|
| 47 |
+
def forward(self, batch_dict, iters=20):
|
| 48 |
+
return self.model.forward_batch_test(
|
| 49 |
+
batch_dict, kernel_size=self.kernel_size, iters=iters
|
| 50 |
+
)
|
models/raft_stereo_model.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from types import SimpleNamespace
|
| 9 |
+
from typing import ClassVar
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from pytorch3d.implicitron.tools.config import Configurable
|
| 13 |
+
|
| 14 |
+
import importlib
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
sys.path.append("third_party/RAFT-Stereo")
|
| 18 |
+
raft_stereo = importlib.import_module(
|
| 19 |
+
"dynamic_stereo.third_party.RAFT-Stereo.core.raft_stereo"
|
| 20 |
+
)
|
| 21 |
+
raft_stereo_utils = importlib.import_module(
|
| 22 |
+
"dynamic_stereo.third_party.RAFT-Stereo.core.utils.utils"
|
| 23 |
+
)
|
| 24 |
+
autocast = torch.cuda.amp.autocast
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RAFTStereoModel(Configurable, torch.nn.Module):
|
| 28 |
+
MODEL_CONFIG_NAME: ClassVar[str] = "RAFTStereoModel"
|
| 29 |
+
model_weights: str = "./third_party/RAFT-Stereo/models/raftstereo-middlebury.pth"
|
| 30 |
+
|
| 31 |
+
def __post_init__(self):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
model_args = SimpleNamespace(
|
| 35 |
+
hidden_dims=[128] * 3,
|
| 36 |
+
corr_implementation="reg",
|
| 37 |
+
shared_backbone=False,
|
| 38 |
+
corr_levels=4,
|
| 39 |
+
corr_radius=4,
|
| 40 |
+
n_downsample=2,
|
| 41 |
+
slow_fast_gru=False,
|
| 42 |
+
n_gru_layers=3,
|
| 43 |
+
mixed_precision=False,
|
| 44 |
+
context_norm="batch",
|
| 45 |
+
)
|
| 46 |
+
self.args = model_args
|
| 47 |
+
model = torch.nn.DataParallel(
|
| 48 |
+
raft_stereo.RAFTStereo(model_args), device_ids=[0]
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
state_dict = torch.load(self.model_weights, map_location="cpu")
|
| 52 |
+
if "state_dict" in state_dict:
|
| 53 |
+
state_dict = state_dict["state_dict"]
|
| 54 |
+
state_dict = {"module." + k: v for k, v in state_dict.items()}
|
| 55 |
+
model.load_state_dict(state_dict)
|
| 56 |
+
|
| 57 |
+
self.model = model.module
|
| 58 |
+
self.model.to("cuda")
|
| 59 |
+
self.model.eval()
|
| 60 |
+
|
| 61 |
+
def forward(self, batch_dict, iters=32):
|
| 62 |
+
predictions = defaultdict(list)
|
| 63 |
+
for stereo_pair in batch_dict["stereo_video"]:
|
| 64 |
+
left_image_rgb = stereo_pair[None, 0].cuda()
|
| 65 |
+
right_image_rgb = stereo_pair[None, 1].cuda()
|
| 66 |
+
|
| 67 |
+
padder = raft_stereo_utils.InputPadder(left_image_rgb.shape, divis_by=32)
|
| 68 |
+
left_image_rgb, right_image_rgb = padder.pad(
|
| 69 |
+
left_image_rgb, right_image_rgb
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
with autocast(enabled=self.args.mixed_precision):
|
| 73 |
+
_, flow_up = self.model.forward(
|
| 74 |
+
left_image_rgb,
|
| 75 |
+
right_image_rgb,
|
| 76 |
+
iters=iters,
|
| 77 |
+
test_mode=True,
|
| 78 |
+
)
|
| 79 |
+
flow_up = padder.unpad(flow_up)
|
| 80 |
+
predictions["disparity"].append(flow_up)
|
| 81 |
+
predictions["disparity"] = (
|
| 82 |
+
torch.stack(predictions["disparity"]).squeeze(1).abs()
|
| 83 |
+
)
|
| 84 |
+
return predictions
|
notebooks/Dynamic_Replica_demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/evaluate.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra-core==1.1
|
| 2 |
+
einops==0.4.1
|
| 3 |
+
flow_vis==0.1
|
| 4 |
+
imageio==2.21.1
|
| 5 |
+
matplotlib==3.5.3
|
| 6 |
+
munch==2.5.0
|
| 7 |
+
numpy==1.23.5
|
| 8 |
+
omegaconf==2.1.0
|
| 9 |
+
opencv_python==4.6.0.66
|
| 10 |
+
opt_einsum==3.3.0
|
| 11 |
+
Pillow==9.5.0
|
| 12 |
+
pytorch_lightning==1.6.0
|
| 13 |
+
requests
|
| 14 |
+
scikit_image==0.19.2
|
| 15 |
+
scipy==1.10.0
|
| 16 |
+
setuptools==65.6.3
|
| 17 |
+
tabulate==0.8.10
|
| 18 |
+
tqdm==4.64.1
|
| 19 |
+
moviepy
|
| 20 |
+
jupyter
|
scripts/checksum_check.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import glob
|
| 9 |
+
import argparse
|
| 10 |
+
import hashlib
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
from typing import Optional
|
| 14 |
+
from multiprocessing import Pool
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DEFAULT_SHA256S_FILE = os.path.join(__file__.rsplit(os.sep, 2)[0], "dr_sha256.json")
|
| 19 |
+
BLOCKSIZE = 65536
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main(
|
| 23 |
+
download_folder: str,
|
| 24 |
+
sha256s_file: str,
|
| 25 |
+
dump: bool = False,
|
| 26 |
+
n_sha256_workers: int = 4
|
| 27 |
+
):
|
| 28 |
+
if not os.path.isfile(sha256s_file):
|
| 29 |
+
raise ValueError(f"The SHA256 file does not exist ({sha256s_file}).")
|
| 30 |
+
|
| 31 |
+
expected_sha256s = get_expected_sha256s(
|
| 32 |
+
sha256s_file=sha256s_file
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
zipfiles = sorted(glob.glob(os.path.join(download_folder, "*.zip")))
|
| 36 |
+
print(f"Extracting SHA256 hashes for {len(zipfiles)} files in {download_folder}.")
|
| 37 |
+
extracted_sha256s_list = []
|
| 38 |
+
with Pool(processes=n_sha256_workers) as sha_pool:
|
| 39 |
+
for extracted_hash in tqdm(
|
| 40 |
+
sha_pool.imap(_sha256_file_and_print, zipfiles),
|
| 41 |
+
total=len(zipfiles),
|
| 42 |
+
):
|
| 43 |
+
extracted_sha256s_list.append(extracted_hash)
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
extracted_sha256s = dict(
|
| 47 |
+
zip([os.path.split(z)[-1] for z in zipfiles], extracted_sha256s_list)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if dump:
|
| 51 |
+
print(extracted_sha256s)
|
| 52 |
+
with open(sha256s_file, "w") as f:
|
| 53 |
+
json.dump(extracted_sha256s, f, indent=2)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
missing_keys, invalid_keys = [], []
|
| 57 |
+
for k in expected_sha256s.keys():
|
| 58 |
+
if k not in extracted_sha256s:
|
| 59 |
+
print(f"{k} missing!")
|
| 60 |
+
missing_keys.append(k)
|
| 61 |
+
elif expected_sha256s[k] != extracted_sha256s[k]:
|
| 62 |
+
print(
|
| 63 |
+
f"'{k}' does not match!"
|
| 64 |
+
+ f" ({expected_sha256s[k]} != {extracted_sha256s[k]})"
|
| 65 |
+
)
|
| 66 |
+
invalid_keys.append(k)
|
| 67 |
+
if len(invalid_keys) + len(missing_keys) > 0:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
f"Checksum checker failed!"
|
| 70 |
+
+ f" Non-matching checksums: {str(invalid_keys)};"
|
| 71 |
+
+ f" missing files: {str(missing_keys)}."
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_expected_sha256s(
|
| 76 |
+
sha256s_file: str
|
| 77 |
+
):
|
| 78 |
+
with open(sha256s_file, "r") as f:
|
| 79 |
+
expected_sha256s = json.load(f)
|
| 80 |
+
return expected_sha256s
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def check_dr_sha256(
|
| 84 |
+
path: str,
|
| 85 |
+
sha256s_file: str,
|
| 86 |
+
expected_sha256s: Optional[dict] = None,
|
| 87 |
+
do_assertion: bool = True,
|
| 88 |
+
):
|
| 89 |
+
zipname = os.path.split(path)[-1]
|
| 90 |
+
if expected_sha256s is None:
|
| 91 |
+
expected_sha256s = get_expected_sha256s(
|
| 92 |
+
sha256s_file=sha256s_file,
|
| 93 |
+
)
|
| 94 |
+
extracted_hash = sha256_file(path)
|
| 95 |
+
if do_assertion:
|
| 96 |
+
assert (
|
| 97 |
+
extracted_hash == expected_sha256s[zipname]
|
| 98 |
+
), f"{zipname}: ({extracted_hash} != {expected_sha256s[zipname]})"
|
| 99 |
+
else:
|
| 100 |
+
return extracted_hash == expected_sha256s[zipname]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def sha256_file(path: str):
|
| 104 |
+
sha256_hash = hashlib.sha256()
|
| 105 |
+
with open(path, "rb") as f:
|
| 106 |
+
file_buffer = f.read(BLOCKSIZE)
|
| 107 |
+
while len(file_buffer) > 0:
|
| 108 |
+
sha256_hash.update(file_buffer)
|
| 109 |
+
file_buffer = f.read(BLOCKSIZE)
|
| 110 |
+
digest_ = sha256_hash.hexdigest()
|
| 111 |
+
return digest_
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _sha256_file_and_print(path: str):
|
| 115 |
+
digest_ = sha256_file(path)
|
| 116 |
+
print(f"{path}: {digest_}")
|
| 117 |
+
return digest_
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
parser = argparse.ArgumentParser(
|
| 123 |
+
description="Check SHA256 hashes of the Dynamic Replica dataset."
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--download_folder",
|
| 127 |
+
type=str,
|
| 128 |
+
help="A local target folder for downloading the the dataset files.",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--sha256s_file",
|
| 132 |
+
type=str,
|
| 133 |
+
help="A local target folder for downloading the the dataset files.",
|
| 134 |
+
default=DEFAULT_SHA256S_FILE,
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--num_workers",
|
| 138 |
+
type=int,
|
| 139 |
+
default=4,
|
| 140 |
+
help="The number of sha256 extraction workers.",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--dump_sha256s",
|
| 144 |
+
action="store_true",
|
| 145 |
+
help="Store sha256s hashes.",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
main(
|
| 150 |
+
str(args.download_folder),
|
| 151 |
+
dump=bool(args.dump_sha256s),
|
| 152 |
+
n_sha256_workers=int(args.num_workers),
|
| 153 |
+
sha256s_file=str(args.sha256s_file),
|
| 154 |
+
)
|
scripts/download_dynamic_replica.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
sys.path.append("./scripts/")
|
| 11 |
+
from download_utils import build_arg_parser, download_dataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DEFAULT_LINK_LIST_FILE = os.path.join(os.path.dirname(__file__), "links.json")
|
| 15 |
+
DEFAULT_SHA256S_FILE = os.path.join(os.path.dirname(__file__), "dr_sha256.json")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
parser = build_arg_parser(
|
| 20 |
+
"dynamic_replica", DEFAULT_LINK_LIST_FILE, DEFAULT_SHA256S_FILE
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
os.makedirs(args.download_folder, exist_ok=True)
|
| 25 |
+
download_dataset(
|
| 26 |
+
str(args.link_list_file),
|
| 27 |
+
str(args.download_folder),
|
| 28 |
+
n_download_workers=int(args.n_download_workers),
|
| 29 |
+
n_extract_workers=int(args.n_extract_workers),
|
| 30 |
+
download_splits=args.download_splits,
|
| 31 |
+
checksum_check=bool(args.checksum_check),
|
| 32 |
+
clear_archives_after_unpacking=bool(args.clear_archives_after_unpacking),
|
| 33 |
+
sha256s_file=str(args.sha256_file),
|
| 34 |
+
skip_downloaded_archives=not bool(args.redownload_existing_archives),
|
| 35 |
+
)
|
scripts/download_utils.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
import requests
|
| 10 |
+
import functools
|
| 11 |
+
import json
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from argparse import ArgumentParser
|
| 15 |
+
from typing import List, Optional
|
| 16 |
+
from multiprocessing import Pool
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
sys.path.append("./scripts/")
|
| 22 |
+
|
| 23 |
+
from checksum_check import check_dr_sha256
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def download_dataset(
|
| 27 |
+
link_list_file: str,
|
| 28 |
+
download_folder: str,
|
| 29 |
+
n_download_workers: int = 4,
|
| 30 |
+
n_extract_workers: int = 4,
|
| 31 |
+
download_splits: List[str] = ['real', 'valid', 'test', 'train'],
|
| 32 |
+
checksum_check: bool = False,
|
| 33 |
+
clear_archives_after_unpacking: bool = False,
|
| 34 |
+
skip_downloaded_archives: bool = True,
|
| 35 |
+
sha256s_file: Optional[str] = None,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Downloads and unpacks the dataset in CO3D format.
|
| 39 |
+
Note: The script will make a folder `<download_folder>/_in_progress`, which
|
| 40 |
+
stores files whose download is in progress. The folder can be safely deleted
|
| 41 |
+
the download is finished.
|
| 42 |
+
Args:
|
| 43 |
+
link_list_file: A text file with the list of zip file download links.
|
| 44 |
+
download_folder: A local target folder for downloading the
|
| 45 |
+
the dataset files.
|
| 46 |
+
n_download_workers: The number of parallel workers
|
| 47 |
+
for downloading the dataset files.
|
| 48 |
+
n_extract_workers: The number of parallel workers
|
| 49 |
+
for extracting the dataset files.
|
| 50 |
+
download_splits: A list of data splits to download.
|
| 51 |
+
Must be in ['real', 'valid', 'test', 'train'].
|
| 52 |
+
checksum_check: Enable validation of the downloaded file's checksum before
|
| 53 |
+
extraction.
|
| 54 |
+
clear_archives_after_unpacking: Delete the unnecessary downloaded archive files
|
| 55 |
+
after unpacking.
|
| 56 |
+
skip_downloaded_archives: Skip re-downloading already downloaded archives.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
if checksum_check and not sha256s_file:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
"checksum_check is requested but ground-truth SHA256 file not provided!"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if not os.path.isfile(link_list_file):
|
| 65 |
+
raise ValueError(
|
| 66 |
+
"Please specify `link_list_file` with a valid path to a json"
|
| 67 |
+
" with zip file download links."
|
| 68 |
+
# " The file is stored in the DynamicStereo github:"
|
| 69 |
+
# " https://github.com/facebookresearch/dynamic_stereo/blob/main/dynamic_stereo/links.json"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if not os.path.isdir(download_folder):
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Please specify `download_folder` with a valid path to a target folder"
|
| 75 |
+
+ " for downloading the dataset."
|
| 76 |
+
+ f" {download_folder} does not exist."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# read the link file
|
| 80 |
+
with open(link_list_file, "r") as f:
|
| 81 |
+
links = json.load(f)
|
| 82 |
+
|
| 83 |
+
for split in download_splits:
|
| 84 |
+
if split not in ['real', 'valid', 'test', 'train']:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"Download split {str(split)} is not valid"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
data_links = []
|
| 90 |
+
for split_name, urls in links.items():
|
| 91 |
+
if split_name in download_splits:
|
| 92 |
+
for url in urls:
|
| 93 |
+
link_name = os.path.split(url)[-1]
|
| 94 |
+
data_links.append((split_name, link_name, url))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
with Pool(processes=n_download_workers) as download_pool:
|
| 98 |
+
download_ok = {}
|
| 99 |
+
for link_name, ok in tqdm(
|
| 100 |
+
download_pool.imap(
|
| 101 |
+
functools.partial(
|
| 102 |
+
_download_split_file,
|
| 103 |
+
download_folder,
|
| 104 |
+
checksum_check,
|
| 105 |
+
sha256s_file,
|
| 106 |
+
skip_downloaded_archives,
|
| 107 |
+
),
|
| 108 |
+
data_links,
|
| 109 |
+
),
|
| 110 |
+
total=len(data_links),
|
| 111 |
+
):
|
| 112 |
+
download_ok[link_name] = ok
|
| 113 |
+
|
| 114 |
+
with Pool(processes=n_extract_workers) as extract_pool:
|
| 115 |
+
for _ in tqdm(
|
| 116 |
+
extract_pool.imap(
|
| 117 |
+
functools.partial(
|
| 118 |
+
_unpack_split_file,
|
| 119 |
+
download_folder,
|
| 120 |
+
clear_archives_after_unpacking,
|
| 121 |
+
),
|
| 122 |
+
data_links,
|
| 123 |
+
),
|
| 124 |
+
total=len(data_links),
|
| 125 |
+
):
|
| 126 |
+
pass
|
| 127 |
+
print("Done")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def build_arg_parser(
|
| 132 |
+
dataset_name: str,
|
| 133 |
+
default_link_list_file: str,
|
| 134 |
+
default_sha256_file: str,
|
| 135 |
+
) -> ArgumentParser:
|
| 136 |
+
parser = ArgumentParser(description=f"Download the {dataset_name} dataset.")
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--download_folder",
|
| 139 |
+
type=str,
|
| 140 |
+
required=True,
|
| 141 |
+
help="A local target folder for downloading the the dataset files.",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--n_download_workers",
|
| 145 |
+
type=int,
|
| 146 |
+
default=4,
|
| 147 |
+
help="The number of parallel workers for downloading the dataset files.",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--n_extract_workers",
|
| 151 |
+
type=int,
|
| 152 |
+
default=4,
|
| 153 |
+
help="The number of parallel workers for extracting the dataset files.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--download_splits",
|
| 157 |
+
default=['real', 'valid', 'test', 'train'],
|
| 158 |
+
nargs='+',
|
| 159 |
+
help=f"A comma-separated list of {dataset_name} splits to download.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--link_list_file",
|
| 163 |
+
type=str,
|
| 164 |
+
default=default_link_list_file,
|
| 165 |
+
help=(
|
| 166 |
+
f"The file with html links to the {dataset_name} dataset files."
|
| 167 |
+
+ " In most cases the default local file `links.json` should be used."
|
| 168 |
+
),
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--sha256_file",
|
| 172 |
+
type=str,
|
| 173 |
+
default=default_sha256_file,
|
| 174 |
+
help=(
|
| 175 |
+
f"The file with SHA256 hashes of {dataset_name} dataset files."
|
| 176 |
+
+ " In most cases the default local file `dr_sha256.json` should be used."
|
| 177 |
+
),
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--checksum_check",
|
| 181 |
+
action="store_true",
|
| 182 |
+
default=True,
|
| 183 |
+
help="Check the SHA256 checksum of each downloaded file before extraction.",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--no_checksum_check",
|
| 187 |
+
action="store_false",
|
| 188 |
+
dest="checksum_check",
|
| 189 |
+
default=False,
|
| 190 |
+
help="Does not check the SHA256 checksum of each downloaded file before extraction.",
|
| 191 |
+
)
|
| 192 |
+
parser.set_defaults(checksum_check=True)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--clear_archives_after_unpacking",
|
| 195 |
+
action="store_true",
|
| 196 |
+
default=False,
|
| 197 |
+
help="Delete the unnecessary downloaded archive files after unpacking.",
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--redownload_existing_archives",
|
| 201 |
+
action="store_true",
|
| 202 |
+
default=False,
|
| 203 |
+
help="Redownload the already-downloaded archives.",
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
return parser
|
| 207 |
+
|
| 208 |
+
def _unpack_split_file(
|
| 209 |
+
download_folder: str,
|
| 210 |
+
clear_archive: bool,
|
| 211 |
+
link: str,
|
| 212 |
+
):
|
| 213 |
+
split, link_name, url = link
|
| 214 |
+
local_fl = os.path.join(download_folder, link_name)
|
| 215 |
+
print(f"Unpacking dataset file {local_fl} ({link_name}) to {download_folder}.")
|
| 216 |
+
|
| 217 |
+
download_folder_split = os.path.join(download_folder, split)
|
| 218 |
+
# os.makedirs(download_folder_split, exist_ok=True)
|
| 219 |
+
shutil.unpack_archive(local_fl, download_folder_split)
|
| 220 |
+
if clear_archive:
|
| 221 |
+
os.remove(local_fl)
|
| 222 |
+
|
| 223 |
+
def _download_split_file(
|
| 224 |
+
download_folder: str,
|
| 225 |
+
checksum_check: bool,
|
| 226 |
+
sha256s_file: Optional[str],
|
| 227 |
+
skip_downloaded_files: bool,
|
| 228 |
+
link: str,
|
| 229 |
+
):
|
| 230 |
+
__, link_name, url = link
|
| 231 |
+
local_fl_final = os.path.join(download_folder, link_name)
|
| 232 |
+
|
| 233 |
+
if skip_downloaded_files and os.path.isfile(local_fl_final):
|
| 234 |
+
print(f"Skipping {local_fl_final}, already downloaded!")
|
| 235 |
+
return link_name, True
|
| 236 |
+
|
| 237 |
+
in_progress_folder = os.path.join(download_folder, "_in_progress")
|
| 238 |
+
os.makedirs(in_progress_folder, exist_ok=True)
|
| 239 |
+
local_fl = os.path.join(in_progress_folder, link_name)
|
| 240 |
+
|
| 241 |
+
print(f"Downloading dataset file {link_name} ({url}) to {local_fl}.")
|
| 242 |
+
_download_with_progress_bar(url, local_fl, link_name)
|
| 243 |
+
if checksum_check:
|
| 244 |
+
print(f"Checking SHA256 for {local_fl}.")
|
| 245 |
+
try:
|
| 246 |
+
check_dr_sha256(
|
| 247 |
+
local_fl,
|
| 248 |
+
sha256s_file=sha256s_file,
|
| 249 |
+
)
|
| 250 |
+
except AssertionError:
|
| 251 |
+
warnings.warn(
|
| 252 |
+
f"Checksums for {local_fl} did not match!"
|
| 253 |
+
+ " This is likely due to a network failure,"
|
| 254 |
+
+ " please restart the download script."
|
| 255 |
+
)
|
| 256 |
+
return link_name, False
|
| 257 |
+
|
| 258 |
+
os.rename(local_fl, local_fl_final)
|
| 259 |
+
return link_name, True
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _download_with_progress_bar(url: str, fname: str, filename: str):
|
| 263 |
+
|
| 264 |
+
# taken from https://stackoverflow.com/a/62113293/986477
|
| 265 |
+
resp = requests.get(url, stream=True)
|
| 266 |
+
print(url)
|
| 267 |
+
total = int(resp.headers.get("content-length", 0))
|
| 268 |
+
with open(fname, "wb") as file, tqdm(
|
| 269 |
+
desc=fname,
|
| 270 |
+
total=total,
|
| 271 |
+
unit="iB",
|
| 272 |
+
unit_scale=True,
|
| 273 |
+
unit_divisor=1024,
|
| 274 |
+
) as bar:
|
| 275 |
+
for datai, data in enumerate(resp.iter_content(chunk_size=1024)):
|
| 276 |
+
size = file.write(data)
|
| 277 |
+
bar.update(size)
|
| 278 |
+
if datai % max((max(total // 1024, 1) // 20), 1) == 0:
|
| 279 |
+
print(f"{filename}: Downloaded {100.0*(float(bar.n)/max(total, 1)):3.1f}%.")
|
| 280 |
+
print(bar)
|
scripts/dr_sha256.json
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"real_000.zip": "e5c2aac04146d783c64f76d0ef7a9e8d49d80ffac99d2a795563517f15943a6f",
|
| 3 |
+
"valid_000.zip": "0f35bee47030ae1a30289beb92ba69c5336491e0f07aab0a05cb5505173d1faf",
|
| 4 |
+
"valid_001.zip": "cb37d3b1f643118ae22840b4212b00c00a8fe137099d3730a07796a5fefab24a",
|
| 5 |
+
"valid_002.zip": "5535f2a98e06c68cf97e3259e962e3d44465a1820369e4425c4ef2a719b01ad0",
|
| 6 |
+
"valid_003.zip": "e19db94514d22829743aa363698f407ecfd98d8f08eab037289a420939ef5143",
|
| 7 |
+
"valid_004.zip": "953328f24ba0c3e8709df3829cce238305a8998bf7ae938c80069fab6f513862",
|
| 8 |
+
"valid_005.zip": "27ce4c7424292dcf3e8e0b370fbbc848bd6d73ae28ea5832fddfa8e9c17d6011",
|
| 9 |
+
"test_000.zip": "a56fa676a7a3dc52b33f1571d41fb0221e289735acccb7b9ad42dfb13fdac68c",
|
| 10 |
+
"test_001.zip": "43580e89331826182f41d2ce9f06f62da46617fea9e612a16b2610de8ffdc10b",
|
| 11 |
+
"test_002.zip": "33551fb68979d3d2f20e1976d9169a84ad58658c459aba4d7a2671c8d66904b9",
|
| 12 |
+
"test_003.zip": "45ad28d7555e3579225d26dfcb8244b65de0d1ee749560cc6dd84f121b4b40de",
|
| 13 |
+
"test_004.zip": "d736b56fe15410525deda1c16c0b8da4497383480a4328da92bc0ddb64a62d52",
|
| 14 |
+
"test_005.zip": "3ae331047019a39c6306a17407c72e40dc15b5113f6f9ef72aba2da0b859ea7d",
|
| 15 |
+
"test_006.zip": "94341c8ac8ed1d7f11816ad121e6c5821a751fdc3d3122a653c86f7b5845ca80",
|
| 16 |
+
"test_007.zip": "4e18facbd507e16fc41d90d5c2ce1b44c511d3e2986e1ccdf5d264748d5d7e15",
|
| 17 |
+
"test_008.zip": "e4d5aa0c25eb01863bbced477e17fddd9d8217d23d238bb06b0b688a0f6ed8e3",
|
| 18 |
+
"test_009.zip": "5a413411cfc376078ed0357708084f71949159c13119aabb5c9ae1ffde33b6b7",
|
| 19 |
+
"test_010.zip": "82ea42c7544385aa2d41271e63399534a398dbbef8a06cb990c8bb34296928c8",
|
| 20 |
+
"train_000.zip": "e9fd9af579b0d08d538551c0ab6f7231a1fd162139667803e672cc0dc8b98b03",
|
| 21 |
+
"train_001.zip": "65cb438c7a48567f85db8e54109db6c25d2a621fcbd3267c542a8a640e1dad56",
|
| 22 |
+
"train_002.zip": "c3d9a76a955dd9feb0275837a17133a1d7ee76c963f1c6fa7630deb0ca8209b2",
|
| 23 |
+
"train_003.zip": "13e108f78c7da1f1c1469dd87fab55a6e4ec79f1fcb6d7d16cc9006a933979f4",
|
| 24 |
+
"train_004.zip": "171b92a62b46a68f1d89c2326ba67b6433faf087bc1eecc7a947c19d0f90d3e6",
|
| 25 |
+
"train_005.zip": "75461ffe13cfbd87b4f0f9ffc83002b8381f5a0a212ece18b8961012f865a46e",
|
| 26 |
+
"train_006.zip": "7546f94817814031a738082e6b30858d0057710af052a88fa505a961b6699886",
|
| 27 |
+
"train_007.zip": "371dd100b215bcd41129def1c8fd07f974af11a9b3d3b9966ce5d9700b9929ad",
|
| 28 |
+
"train_008.zip": "313f5c2089c6afc1691edf054e8b2af9eb8b2d91f791153763758c8d91abee48",
|
| 29 |
+
"train_009.zip": "9cbb9f44bb6b7dcc74f00a51db4d2a8797c95a0d880d63ef1612d3883b16b995",
|
| 30 |
+
"train_010.zip": "eb158fccc23a4b41358ec94be203f49a677f86626af7a88f0e649454c409c706",
|
| 31 |
+
"train_011.zip": "f8b3f8c738cdcdbbdf346a4dd78b99883b5d4ab74c11b64ec7b4f8ccd3b68ffc",
|
| 32 |
+
"train_012.zip": "b364ba9d35d7e55019d3554cf65b295d2358859c222b3b847b0f2cced948cfce",
|
| 33 |
+
"train_013.zip": "c8a50efbd93e6e422eabf1846dac2d75e81dfcfcd4d785fe18b01526af9695f6",
|
| 34 |
+
"train_014.zip": "52a768ce76310861cf1fc990ebb8d16f0c27fceff02c12b11638d36ca1c3a927",
|
| 35 |
+
"train_015.zip": "67bf0ba775948997f5ab3cc810b6d0e8149758334210ace6f5cdfc529fe7d26e",
|
| 36 |
+
"train_016.zip": "d5b9a26736421d8f330fd5e531d26071531501a88609d29d580b9d56b6bc17a3",
|
| 37 |
+
"train_017.zip": "5f2d2c93e7944baf1e6d3dee671b12abb7476a75cbd6f572af86fe5c22472fa6",
|
| 38 |
+
"train_018.zip": "77aa801b6b0359b970466329e4a05b937df94b650228cf4797a2a029606b8e5b",
|
| 39 |
+
"train_019.zip": "30934c91cc0ae69acef6a89e4a5180686bd04080e2384a8bde5877cbaaadc575",
|
| 40 |
+
"train_020.zip": "901d5c08705a70053a3e865354a4e7149c35f026b6ed166fee029d829d88c124",
|
| 41 |
+
"train_021.zip": "f27019ff58e54a004ed2cf2106ed459a31b010ed82d32028b0e196dd365b8b0e",
|
| 42 |
+
"train_022.zip": "0600346a2ce162f7e9824e90c553b69a656d4731c86d903e300d932ec8ba7600",
|
| 43 |
+
"train_023.zip": "660d768e4b1bfe742a42ae6ee84f5e91c930789488a7c7f118e5d0edd1f1a010",
|
| 44 |
+
"train_024.zip": "1f8792002baceaba8f93f93be1bee7c83a48c677e4b2d025b6f0047a796e94cd",
|
| 45 |
+
"train_025.zip": "0b92b3f41c18fded8fcb7aba44e7d8738750b8155c907924200fdf4dc1718794",
|
| 46 |
+
"train_026.zip": "4dc401639317527231abfef07221b8d7db2d0950008828104cd1f72092325d05",
|
| 47 |
+
"train_027.zip": "e8313eaa21163f9dd2ff4558d16b1c9cf4962c2e4c0403d6a315955660a98b14",
|
| 48 |
+
"train_028.zip": "d73edf1c500b4311795aaae0a03b3bc04a2c266e2a20b27ba9b6e72fb27fd277",
|
| 49 |
+
"train_029.zip": "c5e4d302c62e693626445aba19638711108049235b0075558e7949b189050c56",
|
| 50 |
+
"train_030.zip": "506b9ba7a740b0bf84159546f797437a48a24e468cb949f2189e51cf404c6170",
|
| 51 |
+
"train_031.zip": "f36bb4b77fdb255dae2050884cf59cd3f8e46e77ea2984b4b219b799c4aac089",
|
| 52 |
+
"train_032.zip": "fddca4efc40ed8d05adf9d519e4fb5b486ac77e8fa08c98d5c4be15867fda8a0",
|
| 53 |
+
"train_033.zip": "c24d2b5c04f3e90b265fd0762e7ae19fb01a7c1948a4c09451383a9eec9f640f",
|
| 54 |
+
"train_034.zip": "5828fbd615c4476f6107fe844cbf81632eff2f9c75194cb84d749630d9359e14",
|
| 55 |
+
"train_035.zip": "7b60fe125fd1a9ba7991e2accd0f2b212968983b4631d43eccff9836a0c35ba8",
|
| 56 |
+
"train_036.zip": "0f4eaf464a2afc62447a802159b3844487b80e9d1c9b0a7d324b0d5914514d60",
|
| 57 |
+
"train_037.zip": "ba85a6692d86e48c4c787b334d4384c08b914e4cee7f3d2692dcae1bbac55878",
|
| 58 |
+
"train_038.zip": "c67b0f5305560d8089bdc2f6212c05256c044e50a715d59b864fbef705bc6b5c",
|
| 59 |
+
"train_039.zip": "f4b66c9e1360a8d6d8337c94eefb1132d865c2735c6b78ba726a590073174aad",
|
| 60 |
+
"train_040.zip": "2c64b76d028fcc153f267925b79a24cf3bb0e42cc7716773df2139f5cec5e319",
|
| 61 |
+
"train_041.zip": "22b1c0ab99a7f8bd0d36c2d2511d3d469cc390776c38132d1e8f1ad7aae5d4ff",
|
| 62 |
+
"train_042.zip": "8f2afaecb9f90947c9071111fde9c015acfceb432ae0bf94deff3ecd581b26c8",
|
| 63 |
+
"train_043.zip": "adf7ea7c356339b10b797c49163252704b4e6b0cebcc741d3374f8c9467f6b43",
|
| 64 |
+
"train_044.zip": "3d0fe4a85fd22ff9c8ed468ca8173d93406a72fadf800d9e6bbf209348cf8965",
|
| 65 |
+
"train_045.zip": "70874eca6bce66cb7681092755d066968e9c8fc32a266d7c0d2f29c01b2b2669",
|
| 66 |
+
"train_046.zip": "01adcdbba0a25383e2281ce02a946f6bc824e1b8e16cf88e85a4ad275203884c",
|
| 67 |
+
"train_047.zip": "50ed632ae330acf60c1b2e22b28fbfab5ccf0e8f23320b2911dcc2d43db048b6",
|
| 68 |
+
"train_048.zip": "f302984f486df60d7a281e2b0a9b6d32456fc6042eb596cb5ef54ee919ccd7bb",
|
| 69 |
+
"train_049.zip": "8e8e0a426796f76dfb2d29cb855894fd01cc954b017aa1d06ae1a121fb310088",
|
| 70 |
+
"train_050.zip": "051f0dd8e612e7073dd20585c42681daeff853a6ee0de6f2e8ff4581cdf4f83b",
|
| 71 |
+
"train_051.zip": "3f39b3732c32b960aef4bf3f152b1a72195dc4ab4bbc10116a05875ca8d40417",
|
| 72 |
+
"train_052.zip": "361b9bcd3364c63c8f2814dfacf91489b79c9cedf03ffcb03b3dacfb77cee3a1",
|
| 73 |
+
"train_053.zip": "f6afe23b3005b1889f76ea9c10ac42f7c4f07cefbe737781229640b834f8ede2",
|
| 74 |
+
"train_054.zip": "ef993bd657104770df8e07a9d7c8ac1d1c3ac57b91f66796bea97f03e5a01df2",
|
| 75 |
+
"train_055.zip": "ec0dea8199e1db7bd8e19f85b0d1a9ab9e8fc2be2c5da5b3455f96e074ad7f22",
|
| 76 |
+
"train_056.zip": "44259829f6832c3dc14b893d5f5b7b6f784a09570f26e9cc9749807a1b05b21e",
|
| 77 |
+
"train_057.zip": "263b712fe2ded353cb248324305f831d8b14aa0858f005067bb27e88decd7f32",
|
| 78 |
+
"train_058.zip": "c44fb44365bc4cd8c4c9bb13d70fa9bb290708b7d3fe44fd79c6eed42702ed70",
|
| 79 |
+
"train_059.zip": "43dd65609afb3992273f914b4d0108187f85eaf1f252f85556f10e40816d5e6c",
|
| 80 |
+
"train_060.zip": "97b2abe90259f4629d7c1c1cec2427f155252403f5dcfea563e2d1338ae63150",
|
| 81 |
+
"train_061.zip": "9d8c790d1806659617ddd6dd99ae56388b5eb9f311c47a079ac8fa5df8f44f57",
|
| 82 |
+
"train_062.zip": "5b4398d6a8709ddf1b050b03b19dfe8aacf3378a4879402f457f12bd97ab99df",
|
| 83 |
+
"train_063.zip": "05024f1b0671cb3026db0b9e801c9aab000b828784839f970a8ad0bc23125435",
|
| 84 |
+
"train_064.zip": "b9bba3999971745ea2cdce69c00c49b109ba02c9f3169614d1d229e468bebc68",
|
| 85 |
+
"train_065.zip": "ff4084dd7c017478b872fd7c9152df5271a7088489d3b86cc21968db272356ef",
|
| 86 |
+
"train_066.zip": "9d8158fd6691065c1cb76ac36c3be90b065e8848856a66b10475b11e1261dd4d",
|
| 87 |
+
"train_067.zip": "3e4b9ebef2bdecab5774a72037d9f1f7c40359e6a2d00851c0c40bdd686373c5",
|
| 88 |
+
"train_068.zip": "a89d53ce7c79af32a659a2a59138568ada1395c56c6063f4f49c1d4e052cf9cd",
|
| 89 |
+
"train_069.zip": "3f66206486af3f0bfa04ce8f664b6af6aa7fd2ad8ebadd5c75039de8c5ffea91",
|
| 90 |
+
"train_070.zip": "e8a95aad5f81e7185a7dacb9031a5c27010ec17302e2e35f7f1de3dc88e02a7b",
|
| 91 |
+
"train_071.zip": "677bf42f8d576c79189cd5af2abf420990368d9c7d768a21a10fc0939dde121f",
|
| 92 |
+
"train_072.zip": "f8d5ea223dc13663bbaae6c5bbd732db15f1c249e7fe2da44b5a6ba5b7dbf505",
|
| 93 |
+
"train_073.zip": "3057bda88ebd5bffb0da030d1126e1fb4fed4b5fbfc547dc0be669ece39979c1",
|
| 94 |
+
"train_074.zip": "f3a01d19e6fedd44679d76ee93051b91b616a55b6b22861db126b8d2bfdba7ce",
|
| 95 |
+
"train_075.zip": "0faa29f3f712f744e003da29b249896cc770fb9b357e8a4c447eeb6ad2798ce2",
|
| 96 |
+
"train_076.zip": "d9943f9b72be89dd8f1273bd02133ab24b81e3c3f794e13362a96b0826518696",
|
| 97 |
+
"train_077.zip": "cfab28d27c1532a91980b65baa4d40c8e13144788b9ae7a4c36ce8b909e51e55",
|
| 98 |
+
"train_078.zip": "b06277baadbe60b2019d0f7b6ed637b23957b6320797bf4b6b9099dc4df0cc7e",
|
| 99 |
+
"train_079.zip": "2163ef05752f7a8813fa9cd5661547bc280239fd3bd903b94a8aef37182e9645",
|
| 100 |
+
"train_080.zip": "13ae6b86afe4aa00ce19f4f7a8df24d11742340c5775fca02f6e1f70cd9a3be7",
|
| 101 |
+
"train_081.zip": "a2512084c16220e0acd207f5e330dd319a30c3445b5034f2c14f9a65111628a3",
|
| 102 |
+
"train_082.zip": "d9615ac989465bc85cf990167ce176af55b8affeebb58d5021c215c1f7235c8a",
|
| 103 |
+
"train_083.zip": "539710fcc33b043dd24499d3987852a35c8a1c5fb75f7530a9caebf57fd5f324",
|
| 104 |
+
"train_084.zip": "33232eb1d68e493a25126f22e31326b7c1195ea511c332a1413e83a0245bdae6",
|
| 105 |
+
"train_085.zip": "13e575f24a77278b7de25e3d186f6201692b3e45ed4701b071d5a770c0e1d590"
|
| 106 |
+
}
|
setup.csh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/csh
|
| 2 |
+
|
| 3 |
+
python -m virtualenv venv
|
| 4 |
+
|
| 5 |
+
# -- CUDA 12.6
|
| 6 |
+
pip install torch==2.1.0+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
| 7 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
| 8 |
+
python -m pip install pip==24.0
|
| 9 |
+
pip install -r requirements.txt
|
train.csh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/csh
|
| 2 |
+
|
| 3 |
+
set arg_count = $#argv
|
| 4 |
+
if ( $arg_count >= 1 ) then
|
| 5 |
+
if ( "$argv[1]" == "-clean" || "$argv[1]" == "-clean_only" ) then
|
| 6 |
+
echo "[INFO] Killing alll other GPU processes to free up resources."
|
| 7 |
+
|
| 8 |
+
sh -c 'ps | grep python | sed "s/ pts.\+$//g" > .tmp.csh'
|
| 9 |
+
chmod +x .tmp.csh
|
| 10 |
+
sed -i "s/^/kill -9 /g" .tmp.csh
|
| 11 |
+
source .tmp.csh
|
| 12 |
+
rm -rf .tmp.csh
|
| 13 |
+
rm -rf debug_rank_*
|
| 14 |
+
rm -rf dynamicstereo_sf_dr
|
| 15 |
+
endif
|
| 16 |
+
|
| 17 |
+
if ( "$argv[1]" == "-clean_only" ) then
|
| 18 |
+
exit 0
|
| 19 |
+
endif
|
| 20 |
+
endif
|
| 21 |
+
|
| 22 |
+
setenv PYTORCH_CUDA_ALLOC_CONF "max_split_size_mb:32,garbage_collection_threshold:0.5,expandable_segments:False"
|
| 23 |
+
setenv CUDA_LAUNCH_BLOCKING 1
|
| 24 |
+
setenv PYTORCH_NO_CUDA_MEMORY_CACHING 1
|
| 25 |
+
setenv CUBLAS_WORKSPACE_CONFIG ":16:8"
|
| 26 |
+
setenv CUDA_VISIBLE_DEVICES 3
|
| 27 |
+
|
| 28 |
+
# -- GPU OOM Error when trained with sample_len=8 on kilby.
|
| 29 |
+
python train.py --batch_size 1 \
|
| 30 |
+
--image_size 480 640 --saturation_range 0 1.4 --num_steps 200000 \
|
| 31 |
+
--ckpt_path dynamicstereo_sf_dr \
|
| 32 |
+
--sample_len 8 --lr 0.0003 --train_iters 8 --valid_iters 8 \
|
| 33 |
+
--num_workers 28 --save_freq 100 --update_block_3d --different_update_blocks \
|
| 34 |
+
--attention_type self_stereo_temporal_update_time_update_space --train_datasets dynamic_replica
|
train.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import os
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
|
| 16 |
+
from munch import DefaultMunch
|
| 17 |
+
import json
|
| 18 |
+
from pytorch_lightning.lite import LightningLite
|
| 19 |
+
from torch.cuda.amp import GradScaler
|
| 20 |
+
|
| 21 |
+
from train_utils.utils import (
|
| 22 |
+
run_test_eval,
|
| 23 |
+
save_ims_to_tb,
|
| 24 |
+
count_parameters,
|
| 25 |
+
)
|
| 26 |
+
from train_utils.logger import Logger
|
| 27 |
+
from models.core.dynamic_stereo import DynamicStereo
|
| 28 |
+
from models.core.sci_codec import sci_encoder
|
| 29 |
+
from evaluation.core.evaluator import Evaluator
|
| 30 |
+
from train_utils.losses import sequence_loss
|
| 31 |
+
import datasets.dynamic_stereo_datasets as datasets
|
| 32 |
+
|
| 33 |
+
class wrapper(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
sigma_range=[0, 1e-9],
|
| 37 |
+
num_frames=8,
|
| 38 |
+
in_channels=1,
|
| 39 |
+
n_taps=2,
|
| 40 |
+
resolution=[480, 640],
|
| 41 |
+
mixed_precision=True,
|
| 42 |
+
attention_type="self_stereo_temporal_update_time_update_space",
|
| 43 |
+
update_block_3d=True,
|
| 44 |
+
different_update_blocks=True,
|
| 45 |
+
train_iters=16):
|
| 46 |
+
|
| 47 |
+
super(wrapper, self).__init__()
|
| 48 |
+
|
| 49 |
+
self.train_iters = train_iters
|
| 50 |
+
|
| 51 |
+
self.sci_enc_L = sci_encoder(sigma_range=sigma_range,
|
| 52 |
+
n_frame=num_frames,
|
| 53 |
+
in_channels=in_channels,
|
| 54 |
+
n_taps=n_taps,
|
| 55 |
+
resolution=resolution)
|
| 56 |
+
self.sci_enc_R = sci_encoder(sigma_range=sigma_range,
|
| 57 |
+
n_frame=num_frames,
|
| 58 |
+
in_channels=in_channels,
|
| 59 |
+
n_taps=n_taps,
|
| 60 |
+
resolution=resolution)
|
| 61 |
+
|
| 62 |
+
self.stereo = DynamicStereo(max_disp=256,
|
| 63 |
+
mixed_precision=mixed_precision,
|
| 64 |
+
num_frames=num_frames,
|
| 65 |
+
attention_type=attention_type,
|
| 66 |
+
use_3d_update_block=update_block_3d,
|
| 67 |
+
different_update_blocks=different_update_blocks)
|
| 68 |
+
|
| 69 |
+
def forward(self, batch):
|
| 70 |
+
# ---- ---- FORWARD PASS ---- ----
|
| 71 |
+
# -- Modified by Chu King on 20th November 2025
|
| 72 |
+
|
| 73 |
+
# -- print ("[INFO] batch[\"img\"].device: ", batch["img"].device)
|
| 74 |
+
|
| 75 |
+
# 0) Convert to Gray
|
| 76 |
+
def rgb_to_gray(x):
|
| 77 |
+
weights = torch.tensor([0.2989, 0.5870, 0.1140], dtype=x.dtype, device=x.device)
|
| 78 |
+
gray = (x * weights[None, None, :, None, None]).sum(dim=2)
|
| 79 |
+
return gray # -- shape: [B, T, H, W]
|
| 80 |
+
|
| 81 |
+
video_L = rgb_to_gray(batch["img"][:, :, 0]).cuda() # ~ (b, t, h, w)
|
| 82 |
+
video_R = rgb_to_gray(batch["img"][:, :, 1]).cuda() # ~ (b, t, h, w)
|
| 83 |
+
|
| 84 |
+
# -- print ("[INFO] video_L.device: ", video_L.device)
|
| 85 |
+
|
| 86 |
+
# 1) Extract and normalize input videos.
|
| 87 |
+
# -- min_max_norm = lambda x : 2. * (x / 255.) - 1.
|
| 88 |
+
min_max_norm = lambda x: x / 255.
|
| 89 |
+
video_L = min_max_norm(video_L) # ~ (b, t, h, w)
|
| 90 |
+
video_R = min_max_norm(video_R) # ~ (b, t, h, w)
|
| 91 |
+
# -- print ("[INFO] video_L.device: ", video_L.device)
|
| 92 |
+
|
| 93 |
+
# 2) If the tensor is non-contiguous and we try .view() later, PyTorch will raise an error:
|
| 94 |
+
video_L = video_L.contiguous()
|
| 95 |
+
video_R = video_R.contiguous()
|
| 96 |
+
|
| 97 |
+
# -- print ("[INFO] video_L.device: ", video_L.device)
|
| 98 |
+
|
| 99 |
+
# 3) Coded exposure modeling.
|
| 100 |
+
snapshot_L = self.sci_enc_L(video_L) # ~ (b, c, h, w) -- c=2 for 2 taps
|
| 101 |
+
snapshot_R = self.sci_enc_R(video_R) # ~ (b, c, h, w) -- c=2 for 2 taps
|
| 102 |
+
|
| 103 |
+
# -- print ("[INFO] self.sci_enc_L.device: ", next(self.sci_enc_R.parameters()).device)
|
| 104 |
+
# -- print ("[INFO] snapshot_L.device: ", snapshot_L.device)
|
| 105 |
+
|
| 106 |
+
# 4) Dynamic Stereo
|
| 107 |
+
output = {}
|
| 108 |
+
|
| 109 |
+
disparities = self.stereo(
|
| 110 |
+
snapshot_L,
|
| 111 |
+
snapshot_R,
|
| 112 |
+
iters=self.train_iters,
|
| 113 |
+
test_mode=False
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
n_views = len(batch["disp"][0]) # -- sample_len
|
| 117 |
+
for i in range(n_views):
|
| 118 |
+
seq_loss, metrics = sequence_loss(
|
| 119 |
+
disparities[:, i], batch["disp"][:, i, 0], batch["valid_disp"][:, i, 0]
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
output[f"disp_{i}"] = {"loss": seq_loss / n_views, "metrics": metrics}
|
| 123 |
+
output["disparity"] = {
|
| 124 |
+
"predictions": torch.cat(
|
| 125 |
+
[disparities[-1, i, 0] for i in range(n_views)], dim=1
|
| 126 |
+
).detach(),
|
| 127 |
+
}
|
| 128 |
+
return output
|
| 129 |
+
|
| 130 |
+
def fetch_optimizer(args, model):
|
| 131 |
+
"""Create the optimizer and learning rate scheduler"""
|
| 132 |
+
optimizer = optim.AdamW(
|
| 133 |
+
model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
|
| 134 |
+
)
|
| 135 |
+
scheduler = optim.lr_scheduler.OneCycleLR(
|
| 136 |
+
optimizer,
|
| 137 |
+
args.lr,
|
| 138 |
+
args.num_steps + 100,
|
| 139 |
+
pct_start=0.01,
|
| 140 |
+
cycle_momentum=False,
|
| 141 |
+
anneal_strategy="linear",
|
| 142 |
+
)
|
| 143 |
+
return optimizer, scheduler
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# -- Modified by Chu King on 20th November 2025
|
| 147 |
+
# -- Take snapshots instead of videos as input.
|
| 148 |
+
# -- def forward_batch(batch, model, args):
|
| 149 |
+
def forward_batch(snapshot_L, snapshot_R, model, args):
|
| 150 |
+
output = {}
|
| 151 |
+
|
| 152 |
+
disparities = model(
|
| 153 |
+
# -- batch["img"][:, :, 0],
|
| 154 |
+
# -- batch["img"][:, :, 1],
|
| 155 |
+
snapshot_L,
|
| 156 |
+
snapshot_R,
|
| 157 |
+
iters=args.train_iters,
|
| 158 |
+
test_mode=False,
|
| 159 |
+
)
|
| 160 |
+
num_traj = len(batch["disp"][0])
|
| 161 |
+
for i in range(num_traj):
|
| 162 |
+
seq_loss, metrics = sequence_loss(
|
| 163 |
+
disparities[:, i], batch["disp"][:, i, 0], batch["valid_disp"][:, i, 0]
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
output[f"disp_{i}"] = {"loss": seq_loss / num_traj, "metrics": metrics}
|
| 167 |
+
output["disparity"] = {
|
| 168 |
+
"predictions": torch.cat(
|
| 169 |
+
[disparities[-1, i, 0] for i in range(num_traj)], dim=1
|
| 170 |
+
).detach(),
|
| 171 |
+
}
|
| 172 |
+
return output
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class Lite(LightningLite):
|
| 176 |
+
def run(self, args):
|
| 177 |
+
self.seed_everything(0)
|
| 178 |
+
|
| 179 |
+
# ----------------------------------------- Loading Dataset -----------------------------------------------
|
| 180 |
+
# -- Modified by Chu King on 15th November 2025 to allow quick testing with only 1 training video on the workstation.
|
| 181 |
+
# -- The number of subframes should be fixed for SCI stereo.
|
| 182 |
+
eval_dataloader_dr = datasets.DynamicReplicaDataset(
|
| 183 |
+
# -- split="valid", sample_len=40, only_first_n_samples=1, VERBOSE=False
|
| 184 |
+
split="valid", sample_len=args.sample_len, only_first_n_samples=1, VERBOSE=False
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
eval_dataloader_sintel_clean = datasets.SequenceSintelStereo(dstype="clean")
|
| 188 |
+
eval_dataloader_sintel_final = datasets.SequenceSintelStereo(dstype="final")
|
| 189 |
+
|
| 190 |
+
eval_dataloaders = [
|
| 191 |
+
("sintel_clean", eval_dataloader_sintel_clean),
|
| 192 |
+
("sintel_final", eval_dataloader_sintel_final),
|
| 193 |
+
("dynamic_replica", eval_dataloader_dr),
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
evaluator = Evaluator()
|
| 197 |
+
|
| 198 |
+
eval_vis_cfg = {
|
| 199 |
+
"visualize_interval": 1, # Use 0 for no visualization
|
| 200 |
+
"exp_dir": args.ckpt_path,
|
| 201 |
+
}
|
| 202 |
+
eval_vis_cfg = DefaultMunch.fromDict(eval_vis_cfg, object())
|
| 203 |
+
evaluator.setup_visualization(eval_vis_cfg)
|
| 204 |
+
|
| 205 |
+
# ----------------------------------------- Model Instantiation -----------------------------------------------
|
| 206 |
+
# -- Added by Chu King on 20th November 2025
|
| 207 |
+
# -- Instantiate the model
|
| 208 |
+
model = wrapper(sigma_range=[0, 1e-9],
|
| 209 |
+
num_frames=args.sample_len,
|
| 210 |
+
in_channels=1,
|
| 211 |
+
n_taps=2,
|
| 212 |
+
resolution=args.image_size,
|
| 213 |
+
mixed_precision=args.mixed_precision,
|
| 214 |
+
attention_type=args.attention_type,
|
| 215 |
+
update_block_3d=args.update_block_3d,
|
| 216 |
+
different_update_blocks=args.different_update_blocks,
|
| 217 |
+
train_iters=args.train_iters)
|
| 218 |
+
|
| 219 |
+
with open(args.ckpt_path + "/meta.json", "w") as file:
|
| 220 |
+
json.dump(vars(args), file, sort_keys=True, indent=4)
|
| 221 |
+
|
| 222 |
+
model.cuda()
|
| 223 |
+
|
| 224 |
+
logging.info("count_parameters(model): {}".format(count_parameters(model)))
|
| 225 |
+
|
| 226 |
+
train_loader = datasets.fetch_dataloader(args)
|
| 227 |
+
train_loader = self.setup_dataloaders(train_loader, move_to_device=False)
|
| 228 |
+
|
| 229 |
+
logging.info(f"Train loader size: {len(train_loader)}")
|
| 230 |
+
|
| 231 |
+
optimizer, scheduler = fetch_optimizer(args, model)
|
| 232 |
+
|
| 233 |
+
total_steps = 0
|
| 234 |
+
logger = Logger(model, scheduler, args.ckpt_path)
|
| 235 |
+
|
| 236 |
+
# ----------------------------------------- Loading Checkpoint -----------------------------------------------
|
| 237 |
+
folder_ckpts = [
|
| 238 |
+
f
|
| 239 |
+
for f in os.listdir(args.ckpt_path)
|
| 240 |
+
if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f
|
| 241 |
+
]
|
| 242 |
+
if len(folder_ckpts) > 0:
|
| 243 |
+
ckpt_path = sorted(folder_ckpts)[-1]
|
| 244 |
+
ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path))
|
| 245 |
+
logging.info(f"Loading checkpoint {ckpt_path}")
|
| 246 |
+
if "model" in ckpt:
|
| 247 |
+
model.load_state_dict(ckpt["model"])
|
| 248 |
+
else:
|
| 249 |
+
model.load_state_dict(ckpt)
|
| 250 |
+
if "optimizer" in ckpt:
|
| 251 |
+
logging.info("Load optimizer")
|
| 252 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 253 |
+
if "scheduler" in ckpt:
|
| 254 |
+
logging.info("Load scheduler")
|
| 255 |
+
scheduler.load_state_dict(ckpt["scheduler"])
|
| 256 |
+
if "total_steps" in ckpt:
|
| 257 |
+
total_steps = ckpt["total_steps"]
|
| 258 |
+
logging.info(f"Load total_steps {total_steps}")
|
| 259 |
+
|
| 260 |
+
elif args.restore_ckpt is not None:
|
| 261 |
+
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
|
| 262 |
+
".pt"
|
| 263 |
+
)
|
| 264 |
+
logging.info("Loading checkpoint...")
|
| 265 |
+
strict = True
|
| 266 |
+
|
| 267 |
+
state_dict = self.load(args.restore_ckpt)
|
| 268 |
+
if "model" in state_dict:
|
| 269 |
+
state_dict = state_dict["model"]
|
| 270 |
+
# -- Since we wrapped the model in torch.nn.DataParallel or torch.nn.parallel.DistributedDataParallel,
|
| 271 |
+
# PyTorch automatically prefixes all parameter names with "module.":
|
| 272 |
+
# state_dict = {
|
| 273 |
+
# 'module.conv1.weight': tensor(...),
|
| 274 |
+
# 'module.conv1.bias': tensor(...),
|
| 275 |
+
# 'module.fc.weight': tensor(...),
|
| 276 |
+
# 'module.fc.bias': tensor(...),
|
| 277 |
+
# }
|
| 278 |
+
# -- So we need to strip the "module." prefix:
|
| 279 |
+
if list(state_dict.keys())[0].startswith("module."):
|
| 280 |
+
state_dict = {
|
| 281 |
+
k.replace("module.", ""): v for k, v in state_dict.items()
|
| 282 |
+
}
|
| 283 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 284 |
+
|
| 285 |
+
logging.info(f"Done loading checkpoint")
|
| 286 |
+
# ----------------------------------------- Optimzer, Scheduler -----------------------------------------------
|
| 287 |
+
|
| 288 |
+
model, optimizer = self.setup(model, optimizer, move_to_device=False)
|
| 289 |
+
model.cuda()
|
| 290 |
+
model.train()
|
| 291 |
+
model.module.module.stereo.freeze_bn() # -- We keep BatchNorm frozen
|
| 292 |
+
|
| 293 |
+
save_freq = args.save_freq
|
| 294 |
+
scaler = GradScaler(enabled=args.mixed_precision)
|
| 295 |
+
|
| 296 |
+
# ----------------------------------------- Training Loop -----------------------------------------------
|
| 297 |
+
should_keep_training = True
|
| 298 |
+
global_batch_num = 0
|
| 299 |
+
epoch = -1
|
| 300 |
+
while should_keep_training:
|
| 301 |
+
epoch += 1
|
| 302 |
+
|
| 303 |
+
for i_batch, batch in enumerate(tqdm(train_loader)):
|
| 304 |
+
optimizer.zero_grad()
|
| 305 |
+
if batch is None:
|
| 306 |
+
print("batch is None")
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
for k, v in batch.items():
|
| 310 |
+
batch[k] = v.cuda()
|
| 311 |
+
|
| 312 |
+
assert model.training
|
| 313 |
+
|
| 314 |
+
# ---- ---- FORWARD PASS ---- ----
|
| 315 |
+
# -- Modified by Chu King on 20th November 2025
|
| 316 |
+
output = model(batch)
|
| 317 |
+
|
| 318 |
+
loss = 0
|
| 319 |
+
logger.update()
|
| 320 |
+
for k, v in output.items():
|
| 321 |
+
if "loss" in v:
|
| 322 |
+
loss += v["loss"]
|
| 323 |
+
logger.writer.add_scalar(
|
| 324 |
+
f"live_{k}_loss", v["loss"].item(), total_steps
|
| 325 |
+
)
|
| 326 |
+
if "metrics" in v:
|
| 327 |
+
logger.push(v["metrics"], k)
|
| 328 |
+
|
| 329 |
+
if self.global_rank == 0:
|
| 330 |
+
if total_steps % save_freq == save_freq - 1:
|
| 331 |
+
save_ims_to_tb(logger.writer, batch, output, total_steps)
|
| 332 |
+
if len(output) > 1:
|
| 333 |
+
logger.writer.add_scalar(
|
| 334 |
+
f"live_total_loss", loss.item(), total_steps
|
| 335 |
+
)
|
| 336 |
+
logger.writer.add_scalar(
|
| 337 |
+
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
|
| 338 |
+
)
|
| 339 |
+
global_batch_num += 1
|
| 340 |
+
self.barrier()
|
| 341 |
+
|
| 342 |
+
# ---- ---- BACKWARD PASS ---- ----
|
| 343 |
+
self.backward(scaler.scale(loss))
|
| 344 |
+
scaler.unscale_(optimizer)
|
| 345 |
+
|
| 346 |
+
# -- Prevent exploding gradients in RNNs or very deep networks
|
| 347 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 348 |
+
|
| 349 |
+
scaler.step(optimizer)
|
| 350 |
+
scheduler.step()
|
| 351 |
+
scaler.update()
|
| 352 |
+
total_steps += 1
|
| 353 |
+
|
| 354 |
+
if self.global_rank == 0:
|
| 355 |
+
|
| 356 |
+
if (i_batch >= len(train_loader) - 1) or (
|
| 357 |
+
total_steps == 1 and args.validate_at_start
|
| 358 |
+
):
|
| 359 |
+
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps)
|
| 360 |
+
save_path = Path(
|
| 361 |
+
f"{args.ckpt_path}/model_{args.name}_{ckpt_iter}.pth"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
save_dict = {
|
| 365 |
+
"model": model.module.module.state_dict(),
|
| 366 |
+
"optimizer": optimizer.state_dict(),
|
| 367 |
+
"scheduler": scheduler.state_dict(),
|
| 368 |
+
"total_steps": total_steps,
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
logging.info(f"Saving file {save_path}")
|
| 372 |
+
self.save(save_dict, save_path)
|
| 373 |
+
|
| 374 |
+
# ---- ---- EVALUATION ---- ----
|
| 375 |
+
if epoch % args.evaluate_every_n_epoch == 0:
|
| 376 |
+
# -- Added by Chu King on 21st November 2025
|
| 377 |
+
model.eval()
|
| 378 |
+
|
| 379 |
+
logging.info(f"Evaluation at epoch {epoch}")
|
| 380 |
+
run_test_eval(
|
| 381 |
+
args.ckpt_path,
|
| 382 |
+
"valid",
|
| 383 |
+
evaluator,
|
| 384 |
+
model.module.module.sci_enc_L,
|
| 385 |
+
model.module.module.sci_enc_R,
|
| 386 |
+
model.module.module.stereo,
|
| 387 |
+
eval_dataloaders,
|
| 388 |
+
logger.writer,
|
| 389 |
+
total_steps,
|
| 390 |
+
resolution=args.image_size
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# -- Added by Chu King on 20th November 2025 for SCI stereo
|
| 394 |
+
model.train()
|
| 395 |
+
|
| 396 |
+
model.module.module.stereo.freeze_bn()
|
| 397 |
+
|
| 398 |
+
self.barrier()
|
| 399 |
+
if total_steps > args.num_steps:
|
| 400 |
+
should_keep_training = False
|
| 401 |
+
break
|
| 402 |
+
|
| 403 |
+
logger.close()
|
| 404 |
+
# ----------------------------------------- Save models after training -----------------------------------------------
|
| 405 |
+
# -- Modified by Chu King on 20th November 2025 to save SCI encoders' models.
|
| 406 |
+
# -- PATH = f"{args.ckpt_path}/{args.name}_final.pth"
|
| 407 |
+
PATH = f"{args.ckpt_path}/{args.name}_model_final.pth"
|
| 408 |
+
torch.save(model.module.module.state_dict(), PATH)
|
| 409 |
+
|
| 410 |
+
# ----------------------------------------- Testing -----------------------------------------------
|
| 411 |
+
# -- Modified by Chu King on 20th November 2025
|
| 412 |
+
test_dataloader_dr = datasets.DynamicStereoDataset(
|
| 413 |
+
# -- The number of subframes should be fixed for SCI stereo
|
| 414 |
+
# -- split="test", sample_len=150, only_first_n_samples=1
|
| 415 |
+
split="test", sample_len=args.sample_len, only_first_n_samples=1
|
| 416 |
+
)
|
| 417 |
+
test_dataloaders = [
|
| 418 |
+
("sintel_clean", eval_dataloader_sintel_clean),
|
| 419 |
+
("sintel_final", eval_dataloader_sintel_final),
|
| 420 |
+
("dynamic_replica", test_dataloader_dr),
|
| 421 |
+
]
|
| 422 |
+
|
| 423 |
+
# -- Modifed by Chu King on 21st November 2025
|
| 424 |
+
model.eval()
|
| 425 |
+
run_test_eval(
|
| 426 |
+
args.ckpt_path,
|
| 427 |
+
"test",
|
| 428 |
+
evaluator,
|
| 429 |
+
model.module.module.sci_enc_L,
|
| 430 |
+
model.module.module.sci_enc_R,
|
| 431 |
+
model.module.module.stereo,
|
| 432 |
+
test_dataloaders,
|
| 433 |
+
logger.writer,
|
| 434 |
+
total_steps,
|
| 435 |
+
resolution=args.image_size
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if __name__ == "__main__":
|
| 440 |
+
parser = argparse.ArgumentParser()
|
| 441 |
+
parser.add_argument("--name", default="dynamic-stereo", help="name your experiment")
|
| 442 |
+
parser.add_argument("--restore_ckpt", help="restore checkpoint")
|
| 443 |
+
parser.add_argument("--ckpt_path", help="path to save checkpoints")
|
| 444 |
+
parser.add_argument(
|
| 445 |
+
"--mixed_precision", action="store_true", help="use mixed precision"
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# Training parameters
|
| 449 |
+
parser.add_argument(
|
| 450 |
+
"--batch_size", type=int, default=6, help="batch size used during training."
|
| 451 |
+
)
|
| 452 |
+
parser.add_argument(
|
| 453 |
+
"--train_datasets",
|
| 454 |
+
nargs="+",
|
| 455 |
+
default=["things", "monkaa", "driving"],
|
| 456 |
+
help="training datasets.",
|
| 457 |
+
)
|
| 458 |
+
parser.add_argument("--lr", type=float, default=0.0002, help="max learning rate.")
|
| 459 |
+
|
| 460 |
+
parser.add_argument(
|
| 461 |
+
"--num_steps", type=int, default=100000, help="length of training schedule."
|
| 462 |
+
)
|
| 463 |
+
parser.add_argument(
|
| 464 |
+
"--image_size",
|
| 465 |
+
type=int,
|
| 466 |
+
nargs="+",
|
| 467 |
+
default=[320, 720],
|
| 468 |
+
help="size of the random image crops used during training.",
|
| 469 |
+
)
|
| 470 |
+
parser.add_argument(
|
| 471 |
+
"--train_iters",
|
| 472 |
+
type=int,
|
| 473 |
+
default=16,
|
| 474 |
+
help="number of updates to the disparity field in each forward pass.",
|
| 475 |
+
)
|
| 476 |
+
parser.add_argument(
|
| 477 |
+
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
parser.add_argument(
|
| 481 |
+
"--sample_len", type=int, default=2, help="length of training video samples"
|
| 482 |
+
)
|
| 483 |
+
parser.add_argument(
|
| 484 |
+
"--validate_at_start", action="store_true", help="validate the model at start"
|
| 485 |
+
)
|
| 486 |
+
parser.add_argument("--save_freq", type=int, default=100, help="save frequency")
|
| 487 |
+
|
| 488 |
+
parser.add_argument(
|
| 489 |
+
"--evaluate_every_n_epoch",
|
| 490 |
+
type=int,
|
| 491 |
+
default=1,
|
| 492 |
+
help="evaluate every n epoch",
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
parser.add_argument(
|
| 496 |
+
"--num_workers", type=int, default=6, help="number of dataloader workers."
|
| 497 |
+
)
|
| 498 |
+
# Validation parameters
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
"--valid_iters",
|
| 501 |
+
type=int,
|
| 502 |
+
default=32,
|
| 503 |
+
help="number of updates to the disparity field in each forward pass during validation.",
|
| 504 |
+
)
|
| 505 |
+
# Architecure choices
|
| 506 |
+
parser.add_argument(
|
| 507 |
+
"--different_update_blocks",
|
| 508 |
+
action="store_true",
|
| 509 |
+
help="use different update blocks for each resolution",
|
| 510 |
+
)
|
| 511 |
+
parser.add_argument(
|
| 512 |
+
"--attention_type",
|
| 513 |
+
type=str,
|
| 514 |
+
help="attention type of the SST and update blocks. \
|
| 515 |
+
Any combination of 'self_stereo', 'temporal', 'update_time', 'update_space' connected by an underscore.",
|
| 516 |
+
)
|
| 517 |
+
parser.add_argument(
|
| 518 |
+
"--update_block_3d", action="store_true", help="use Conv3D update block"
|
| 519 |
+
)
|
| 520 |
+
# Data augmentation
|
| 521 |
+
parser.add_argument(
|
| 522 |
+
"--img_gamma", type=float, nargs="+", default=None, help="gamma range"
|
| 523 |
+
)
|
| 524 |
+
parser.add_argument(
|
| 525 |
+
"--saturation_range",
|
| 526 |
+
type=float,
|
| 527 |
+
nargs="+",
|
| 528 |
+
default=None,
|
| 529 |
+
help="color saturation",
|
| 530 |
+
)
|
| 531 |
+
parser.add_argument(
|
| 532 |
+
"--do_flip",
|
| 533 |
+
default=False,
|
| 534 |
+
choices=["h", "v"],
|
| 535 |
+
help="flip the images horizontally or vertically",
|
| 536 |
+
)
|
| 537 |
+
parser.add_argument(
|
| 538 |
+
"--spatial_scale",
|
| 539 |
+
type=float,
|
| 540 |
+
nargs="+",
|
| 541 |
+
default=[0, 0],
|
| 542 |
+
help="re-scale the images randomly",
|
| 543 |
+
)
|
| 544 |
+
parser.add_argument(
|
| 545 |
+
"--noyjitter",
|
| 546 |
+
action="store_true",
|
| 547 |
+
help="don't simulate imperfect rectification",
|
| 548 |
+
)
|
| 549 |
+
args = parser.parse_args()
|
| 550 |
+
|
| 551 |
+
logging.basicConfig(
|
| 552 |
+
level=logging.INFO,
|
| 553 |
+
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
Path(args.ckpt_path).mkdir(exist_ok=True, parents=True)
|
| 557 |
+
from pytorch_lightning.strategies import DDPStrategy
|
| 558 |
+
|
| 559 |
+
Lite(
|
| 560 |
+
# -- strategy=DDPStrategy(find_unused_parameters=True),
|
| 561 |
+
strategy=DDPStrategy(find_unused_parameters=False),
|
| 562 |
+
devices="auto",
|
| 563 |
+
accelerator="gpu",
|
| 564 |
+
precision=32,
|
| 565 |
+
).run(args)
|
train_utils/logger.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Logger:
|
| 14 |
+
|
| 15 |
+
SUM_FREQ = 100
|
| 16 |
+
|
| 17 |
+
def __init__(self, model, scheduler, ckpt_path):
|
| 18 |
+
self.model = model
|
| 19 |
+
self.scheduler = scheduler
|
| 20 |
+
self.total_steps = 0
|
| 21 |
+
self.running_loss = {}
|
| 22 |
+
self.ckpt_path = ckpt_path
|
| 23 |
+
self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
|
| 24 |
+
|
| 25 |
+
def _print_training_status(self):
|
| 26 |
+
metrics_data = [
|
| 27 |
+
self.running_loss[k] / Logger.SUM_FREQ
|
| 28 |
+
for k in sorted(self.running_loss.keys())
|
| 29 |
+
]
|
| 30 |
+
training_str = "[{:6d}] ".format(self.total_steps + 1)
|
| 31 |
+
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
|
| 32 |
+
|
| 33 |
+
# print the training status
|
| 34 |
+
logging.info(
|
| 35 |
+
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if self.writer is None:
|
| 39 |
+
self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
|
| 40 |
+
for k in self.running_loss:
|
| 41 |
+
self.writer.add_scalar(
|
| 42 |
+
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
|
| 43 |
+
)
|
| 44 |
+
self.running_loss[k] = 0.0
|
| 45 |
+
|
| 46 |
+
def push(self, metrics, task):
|
| 47 |
+
for key in metrics:
|
| 48 |
+
task_key = str(key) + "_" + task
|
| 49 |
+
if task_key not in self.running_loss:
|
| 50 |
+
self.running_loss[task_key] = 0.0
|
| 51 |
+
self.running_loss[task_key] += metrics[key]
|
| 52 |
+
|
| 53 |
+
def update(self):
|
| 54 |
+
self.total_steps += 1
|
| 55 |
+
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
|
| 56 |
+
self._print_training_status()
|
| 57 |
+
self.running_loss = {}
|
| 58 |
+
|
| 59 |
+
def write_dict(self, results):
|
| 60 |
+
if self.writer is None:
|
| 61 |
+
self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
|
| 62 |
+
|
| 63 |
+
for key in results:
|
| 64 |
+
self.writer.add_scalar(key, results[key], self.total_steps)
|
| 65 |
+
|
| 66 |
+
def close(self):
|
| 67 |
+
self.writer.close()
|
train_utils/losses.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
# -- Added by Chu King on 23rd November 2025 to check for NaNs
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
def flow_to_rgb(flow):
|
| 16 |
+
# flow: [2, H, W]
|
| 17 |
+
u = flow[0]
|
| 18 |
+
v = flow[1]
|
| 19 |
+
rad = np.sqrt(u ** 2 + v ** 2)
|
| 20 |
+
ang = np.arctan2(v, u)
|
| 21 |
+
|
| 22 |
+
hsv = np.zeros((flow.shape[1], flow.shape[2], 3), dtype=np.float32)
|
| 23 |
+
hsv[..., 0] = (ang + np.pi) / (2 * np.pi)
|
| 24 |
+
hsv[..., 1] = 1.0
|
| 25 |
+
hsv[..., 2] = np.clip(rad / np.percentile(rad, 99), 0, 1)
|
| 26 |
+
|
| 27 |
+
rgb = plt.cm.hsv(hsv)
|
| 28 |
+
return rgb[..., :3]
|
| 29 |
+
|
| 30 |
+
def visualize_flow_debug(flow_pred, flow_gt, epe, step=0, save_path="debug"):
|
| 31 |
+
flow_pred_np = flow_pred.detach().cpu().numpy()
|
| 32 |
+
flow_gt_np = flow_gt.detach().cpu().numpy()
|
| 33 |
+
epe_np = epe
|
| 34 |
+
|
| 35 |
+
flow_pred0 = flow_pred_np[0, 0, :, :]
|
| 36 |
+
flow_gt0 = flow_gt_np[0, 0, :, :]
|
| 37 |
+
epe0 = epe_np
|
| 38 |
+
|
| 39 |
+
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
|
| 40 |
+
|
| 41 |
+
axs[0].imshow(flow_to_rgb(flow_pred0))
|
| 42 |
+
axs[0].set_title("Predicted Flow")
|
| 43 |
+
axs[0].axis("off")
|
| 44 |
+
|
| 45 |
+
axs[1].imshow(flow_to_rgb(flow_gt0))
|
| 46 |
+
axs[1].set_title("Ground Truth Flow")
|
| 47 |
+
axs[1].axis("off")
|
| 48 |
+
|
| 49 |
+
# -- axs[2].imshow(epe0, cmap="inferno")
|
| 50 |
+
# -- axs[2].set_title("EPE heatmap")
|
| 51 |
+
# -- axs[2].axis("off")
|
| 52 |
+
|
| 53 |
+
fig.suptitle(f"STEP = {step}")
|
| 54 |
+
|
| 55 |
+
plt.tight_layout()
|
| 56 |
+
plt.savefig(f"{save_path}/flow_debug_{step}.png")
|
| 57 |
+
plt.close()
|
| 58 |
+
|
| 59 |
+
def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9, max_flow=700):
|
| 60 |
+
"""Loss function defined over sequence of flow predictions"""
|
| 61 |
+
n_predictions = len(flow_preds)
|
| 62 |
+
assert n_predictions >= 1
|
| 63 |
+
flow_loss = 0.0
|
| 64 |
+
# exlude invalid pixels and extremely large diplacements
|
| 65 |
+
mag = torch.sum(flow_gt ** 2, dim=1).sqrt().unsqueeze(1)
|
| 66 |
+
|
| 67 |
+
if len(valid.shape) != len(flow_gt.shape):
|
| 68 |
+
valid = valid.unsqueeze(1)
|
| 69 |
+
|
| 70 |
+
valid = (valid >= 0.5) & (mag < max_flow)
|
| 71 |
+
|
| 72 |
+
if valid.shape != flow_gt.shape:
|
| 73 |
+
valid = torch.cat([valid, valid], dim=1)
|
| 74 |
+
assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]
|
| 75 |
+
assert not torch.isinf(flow_gt[valid.bool()]).any()
|
| 76 |
+
|
| 77 |
+
for i in range(n_predictions):
|
| 78 |
+
assert (
|
| 79 |
+
not torch.isnan(flow_preds[i]).any()
|
| 80 |
+
and not torch.isinf(flow_preds[i]).any()
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if n_predictions == 1:
|
| 84 |
+
i_weight = 1
|
| 85 |
+
else:
|
| 86 |
+
# We adjust the loss_gamma so it is consistent for any number of iterations
|
| 87 |
+
adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))
|
| 88 |
+
i_weight = adjusted_loss_gamma ** (n_predictions - i - 1)
|
| 89 |
+
|
| 90 |
+
flow_pred = flow_preds[i].clone()
|
| 91 |
+
if valid.shape[1] == 1 and flow_preds[i].shape[1] == 2:
|
| 92 |
+
flow_pred = flow_pred[:, :1]
|
| 93 |
+
|
| 94 |
+
i_loss = (flow_pred - flow_gt).abs()
|
| 95 |
+
|
| 96 |
+
assert i_loss.shape == valid.shape, [
|
| 97 |
+
i_loss.shape,
|
| 98 |
+
valid.shape,
|
| 99 |
+
flow_gt.shape,
|
| 100 |
+
flow_pred.shape,
|
| 101 |
+
]
|
| 102 |
+
flow_loss += i_weight * i_loss[valid.bool()].mean()
|
| 103 |
+
|
| 104 |
+
epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
|
| 105 |
+
|
| 106 |
+
valid = valid[:, 0]
|
| 107 |
+
epe = epe.view(-1)
|
| 108 |
+
epe = epe[valid.reshape(epe.shape)]
|
| 109 |
+
|
| 110 |
+
# -- Added by Chu King to deal with the case when there is no valid disparity.
|
| 111 |
+
if valid.sum().item() == 0:
|
| 112 |
+
metrics = {"epe": 0.0, "1px": 0.0, "3px": 0.0, "5px": 0.0}
|
| 113 |
+
else:
|
| 114 |
+
metrics = {
|
| 115 |
+
"epe": epe.mean().item(),
|
| 116 |
+
"1px": (epe < 1).float().mean().item(),
|
| 117 |
+
"3px": (epe < 3).float().mean().item(),
|
| 118 |
+
"5px": (epe < 5).float().mean().item(),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
for k, v in metrics.items():
|
| 122 |
+
if math.isnan(v):
|
| 123 |
+
print ("[ERROR] Nan detected for k: ", k)
|
| 124 |
+
if torch.isnan(flow_preds[-1]).any(): print("[WARNING] NaN in flow_preds")
|
| 125 |
+
if torch.isinf(flow_preds[-1]).any(): print("[WARNING] Inf in flow_preds")
|
| 126 |
+
if torch.isnan(flow_gt).any(): print("[WARNING] NaN in flow_gt")
|
| 127 |
+
if torch.isinf(flow_gt).any(): print("[WARNING] Inf in flow_gt")
|
| 128 |
+
|
| 129 |
+
raw_diff = flow_preds[-1] - flow_gt
|
| 130 |
+
if torch.isnan(raw_diff).any(): print("[WARNING] NaN in flow_diff")
|
| 131 |
+
|
| 132 |
+
sq = (raw_diff ** 2)
|
| 133 |
+
if torch.isnan(sq).any(): print("[WARNING] NaN in square")
|
| 134 |
+
|
| 135 |
+
sum_sq = torch.sum(sq, dim=1)
|
| 136 |
+
if torch.isnan(sum_sq).any(): print("[WARNING] NaN in sum")
|
| 137 |
+
|
| 138 |
+
epe = sum_sq.sqrt()
|
| 139 |
+
if torch.isnan(epe).any(): print("[WARNING] NaN in sqrt")
|
| 140 |
+
if torch.isinf(epe).any(): print("[WARNING] Inf in sqrt")
|
| 141 |
+
|
| 142 |
+
num_valid = valid.sum().item()
|
| 143 |
+
print("[INFO] Valid pixels:", num_valid)
|
| 144 |
+
if num_valid == 0:
|
| 145 |
+
print("[WARNING]: No valid pixels metrics will be NaN.")
|
| 146 |
+
|
| 147 |
+
if (epe > 1e6).any():
|
| 148 |
+
print("[INFP] Large EPE values detected:", epe.max().item())
|
| 149 |
+
|
| 150 |
+
print ("[INFO] Flow pred sample:", flow_preds[-1].view(-1)[:10])
|
| 151 |
+
print ("[INFO] Flow gt sample:", flow_gt.view(-1)[:10])
|
| 152 |
+
print ("[INFO] EPE sample:", epe.view(-1)[:10])
|
| 153 |
+
print ("[INFO] Valid sample:", valid.view(-1)[:10])
|
| 154 |
+
|
| 155 |
+
visualize_flow_debug(flow_preds[-1], flow_gt, v, step=0, save_path="debug")
|
| 156 |
+
raise SystemExit("Nan detected.")
|
| 157 |
+
|
| 158 |
+
return flow_loss, metrics
|
train_utils/utils.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import flow_vis
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
|
| 15 |
+
import datasets.dynamic_stereo_datasets as datasets
|
| 16 |
+
from evaluation.utils.utils import aggregate_and_print_results
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def count_parameters(model):
|
| 20 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def run_test_eval(ckpt_path, eval_type, evaluator, sci_enc_L, sci_enc_R, model, dataloaders, writer, step, resolution=[480, 640]):
|
| 24 |
+
|
| 25 |
+
# -- Evalution of real scenes disabled by Chu King on 16th November 2025 as depth data
|
| 26 |
+
# are not available.
|
| 27 |
+
# -- for real_sequence_name in ["teddy_static", "ignacio_waving", "nikita_reading"]:
|
| 28 |
+
# -- seq_len_real = 50
|
| 29 |
+
# -- ds_path = f"./dynamic_replica_data/real/{real_sequence_name}"
|
| 30 |
+
# -- real_dataset = datasets.DynamicReplicaDataset(
|
| 31 |
+
# -- split="test", root=ds_path, sample_len=seq_len_real, only_first_n_samples=1,
|
| 32 |
+
# -- VERBOSE=False # -- Added by Chu King on 16th November 2025 for debugging purposes
|
| 33 |
+
# -- )
|
| 34 |
+
|
| 35 |
+
# -- evaluator.evaluate_sequence(
|
| 36 |
+
# -- model=model.module.module,
|
| 37 |
+
# -- test_dataloader=real_dataset,
|
| 38 |
+
# -- writer=writer,
|
| 39 |
+
# -- step=step,
|
| 40 |
+
# -- train_mode=True,
|
| 41 |
+
# -- )
|
| 42 |
+
|
| 43 |
+
for ds_name, dataloader in dataloaders:
|
| 44 |
+
evaluator.visualize_interval = 1 if not "sintel" in ds_name else 0
|
| 45 |
+
|
| 46 |
+
evaluate_result = evaluator.evaluate_sequence(
|
| 47 |
+
sci_enc_L=sci_enc_L,
|
| 48 |
+
sci_enc_R=sci_enc_R,
|
| 49 |
+
model=model,
|
| 50 |
+
test_dataloader=dataloader,
|
| 51 |
+
writer=writer if not "sintel" in ds_name else None,
|
| 52 |
+
step=step,
|
| 53 |
+
train_mode=True,
|
| 54 |
+
resolution=resolution
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
aggregate_result = aggregate_and_print_results(
|
| 58 |
+
evaluate_result,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
save_metrics = [
|
| 62 |
+
"flow_mean_accuracy_5px",
|
| 63 |
+
"flow_mean_accuracy_3px",
|
| 64 |
+
"flow_mean_accuracy_1px",
|
| 65 |
+
"flow_epe_traj_mean",
|
| 66 |
+
]
|
| 67 |
+
for epe_name in ("epe", "temp_epe", "temp_epe_r"):
|
| 68 |
+
for m in [
|
| 69 |
+
f"disp_{epe_name}_bad_0.5px",
|
| 70 |
+
f"disp_{epe_name}_bad_1px",
|
| 71 |
+
f"disp_{epe_name}_bad_2px",
|
| 72 |
+
f"disp_{epe_name}_bad_3px",
|
| 73 |
+
f"disp_{epe_name}_mean",
|
| 74 |
+
]:
|
| 75 |
+
save_metrics.append(m)
|
| 76 |
+
|
| 77 |
+
for k, v in aggregate_result.items():
|
| 78 |
+
if k in save_metrics:
|
| 79 |
+
writer.add_scalars(
|
| 80 |
+
f"{ds_name}_{k.rsplit('_', 1)[0]}",
|
| 81 |
+
{f"{ds_name}_{k}": v},
|
| 82 |
+
step,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
result_file = os.path.join(
|
| 86 |
+
ckpt_path,
|
| 87 |
+
f"result_{ds_name}_{eval_type}_{step}_mimo.json",
|
| 88 |
+
)
|
| 89 |
+
print(f"Dumping {eval_type} results to {result_file}.")
|
| 90 |
+
with open(result_file, "w") as f:
|
| 91 |
+
json.dump(aggregate_result, f)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def fig2data(fig):
|
| 95 |
+
"""
|
| 96 |
+
fig = plt.figure()
|
| 97 |
+
image = fig2data(fig)
|
| 98 |
+
@brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
|
| 99 |
+
@param fig a matplotlib figure
|
| 100 |
+
@return a numpy 3D array of RGBA values
|
| 101 |
+
"""
|
| 102 |
+
import PIL.Image as Image
|
| 103 |
+
|
| 104 |
+
# draw the renderer
|
| 105 |
+
fig.canvas.draw()
|
| 106 |
+
|
| 107 |
+
# Get the RGBA buffer from the figure
|
| 108 |
+
w, h = fig.canvas.get_width_height()
|
| 109 |
+
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 110 |
+
buf.shape = (w, h, 3)
|
| 111 |
+
|
| 112 |
+
image = Image.frombytes("RGB", (w, h), buf.tobytes())
|
| 113 |
+
image = np.asarray(image)
|
| 114 |
+
return image
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def save_ims_to_tb(writer, batch, output, total_steps):
|
| 118 |
+
writer.add_image(
|
| 119 |
+
"train_im",
|
| 120 |
+
torch.cat([torch.cat([im[0], im[1]], dim=-1) for im in batch["img"][0]], dim=-2)
|
| 121 |
+
/ 255.0,
|
| 122 |
+
total_steps,
|
| 123 |
+
dataformats="CHW",
|
| 124 |
+
)
|
| 125 |
+
if "disp" in batch and len(batch["disp"]) > 0:
|
| 126 |
+
disp_im = [
|
| 127 |
+
(torch.cat([im[0], im[1]], dim=-1) * torch.cat([val[0], val[1]], dim=-1))
|
| 128 |
+
for im, val in zip(batch["disp"][0], batch["valid_disp"][0])
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
disp_im = torch.cat(disp_im, dim=1)
|
| 132 |
+
|
| 133 |
+
figure = plt.figure()
|
| 134 |
+
plt.imshow(disp_im.cpu()[0])
|
| 135 |
+
disp_im = fig2data(figure).copy()
|
| 136 |
+
|
| 137 |
+
writer.add_image(
|
| 138 |
+
"train_disp",
|
| 139 |
+
disp_im,
|
| 140 |
+
total_steps,
|
| 141 |
+
dataformats="HWC",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
for k, v in output.items():
|
| 145 |
+
if "predictions" in v:
|
| 146 |
+
pred = v["predictions"]
|
| 147 |
+
if k == "disparity":
|
| 148 |
+
figure = plt.figure()
|
| 149 |
+
plt.imshow(pred.cpu()[0])
|
| 150 |
+
pred = fig2data(figure).copy()
|
| 151 |
+
dataformat = "HWC"
|
| 152 |
+
else:
|
| 153 |
+
pred = torch.tensor(
|
| 154 |
+
flow_vis.flow_to_color(
|
| 155 |
+
pred.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False
|
| 156 |
+
)
|
| 157 |
+
/ 255.0
|
| 158 |
+
)
|
| 159 |
+
dataformat = "HWC"
|
| 160 |
+
writer.add_image(
|
| 161 |
+
f"pred_{k}",
|
| 162 |
+
pred,
|
| 163 |
+
total_steps,
|
| 164 |
+
dataformats=dataformat,
|
| 165 |
+
)
|
| 166 |
+
if "gt" in v:
|
| 167 |
+
gt = v["gt"]
|
| 168 |
+
gt = torch.tensor(
|
| 169 |
+
flow_vis.flow_to_color(
|
| 170 |
+
gt.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False
|
| 171 |
+
)
|
| 172 |
+
/ 255.0
|
| 173 |
+
)
|
| 174 |
+
dataformat = "HWC"
|
| 175 |
+
writer.add_image(
|
| 176 |
+
f"gt_{k}",
|
| 177 |
+
gt,
|
| 178 |
+
total_steps,
|
| 179 |
+
dataformats=dataformat,
|
| 180 |
+
)
|