SAIL-Recon / eval /utils /geometry.py
hengli
first
b7f83b0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# geometry utilitary functions
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from plyfile import PlyData, PlyElement
from scipy.spatial import cKDTree as KDTree
from tqdm import tqdm
from eval.utils.device import to_numpy
from eval.utils.misc import invalid_to_nans, invalid_to_zeros
def xy_grid(
W,
H,
device=None,
origin=(0, 0),
unsqueeze=None,
cat_dim=-1,
homogeneous=False,
**arange_kw,
):
"""Output a (H,W,2) array of int32
with output[j,i,0] = i + origin[0]
output[j,i,1] = j + origin[1]
"""
if device is None:
# numpy
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
else:
# torch
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
meshgrid, stack = torch.meshgrid, torch.stack
ones = lambda *a: torch.ones(*a, device=device)
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
grid = meshgrid(tw, th, indexing="xy")
if homogeneous:
grid = grid + (ones((H, W)),)
if unsqueeze is not None:
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
if cat_dim is not None:
grid = stack(grid, cat_dim)
return grid
def geotrf(Trf, pts, ncol=None, norm=False):
"""Apply a geometric transformation to a list of 3-D points.
H: 3x3 or 4x4 projection matrix (typically a Homography)
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
ncol: int. number of columns of the result (2 or 3)
norm: float. if != 0, the resut is projected on the z=norm plane.
Returns an array of projected 2d points.
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# optimized code
if (
isinstance(Trf, torch.Tensor)
and isinstance(pts, torch.Tensor)
and Trf.ndim == 3
and pts.ndim == 4
):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d + 1:
pts = (
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
+ Trf[:, None, None, :d, d]
)
else:
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
else:
if Trf.ndim >= 3:
n = Trf.ndim - 2
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1] + 1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
def inv(mat):
"""Invert a torch or numpy matrix"""
if isinstance(mat, torch.Tensor):
return torch.linalg.inv(mat)
if isinstance(mat, np.ndarray):
return np.linalg.inv(mat)
raise ValueError(f"bad matrix type = {type(mat)}")
def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
"""
Args:
- depthmap (BxHxW array):
- pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
Returns:
pointmap of absolute coordinates (BxHxWx3 array)
"""
if len(depth.shape) == 4:
B, H, W, n = depth.shape
else:
B, H, W = depth.shape
n = None
if len(pseudo_focal.shape) == 3: # [B,H,W]
pseudo_focalx = pseudo_focaly = pseudo_focal
elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
pseudo_focalx = pseudo_focal[:, 0]
if pseudo_focal.shape[1] == 2:
pseudo_focaly = pseudo_focal[:, 1]
else:
pseudo_focaly = pseudo_focalx
else:
raise NotImplementedError("Error, unknown input focal shape format.")
assert pseudo_focalx.shape == depth.shape[:3]
assert pseudo_focaly.shape == depth.shape[:3]
grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
# set principal point
if pp is None:
grid_x = grid_x - (W - 1) / 2
grid_y = grid_y - (H - 1) / 2
else:
grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
if n is None:
pts3d = torch.empty((B, H, W, 3), device=depth.device)
pts3d[..., 0] = depth * grid_x / pseudo_focalx
pts3d[..., 1] = depth * grid_y / pseudo_focaly
pts3d[..., 2] = depth
else:
pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
pts3d[..., 2, :] = depth
return pts3d
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
"""
Args:
- depthmap (HxW array):
- camera_intrinsics: a 3x3 matrix
Returns:
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
"""
camera_intrinsics = np.float32(camera_intrinsics)
H, W = depthmap.shape
# Compute 3D ray associated with each pixel
# Strong assumption: there are no skew terms
assert camera_intrinsics[0, 1] == 0.0
assert camera_intrinsics[1, 0] == 0.0
if pseudo_focal is None:
fu = camera_intrinsics[0, 0]
fv = camera_intrinsics[1, 1]
else:
assert pseudo_focal.shape == (H, W)
fu = fv = pseudo_focal
cu = camera_intrinsics[0, 2]
cv = camera_intrinsics[1, 2]
u, v = np.meshgrid(np.arange(W), np.arange(H))
z_cam = depthmap
x_cam = (u - cu) * z_cam / fu
y_cam = (v - cv) * z_cam / fv
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
# Mask for valid coordinates
valid_mask = depthmap > 0.0
return X_cam, valid_mask
def depthmap_to_absolute_camera_coordinates(
depthmap, camera_intrinsics, camera_pose, **kw
):
"""
Args:
- depthmap (HxW array):
- camera_intrinsics: a 3x3 matrix
- camera_pose: a 4x3 or 4x4 cam2world matrix
Returns:
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
"""
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
# R_cam2world = np.float32(camera_params["R_cam2world"])
# t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
R_cam2world = camera_pose[:3, :3]
t_cam2world = camera_pose[:3, 3]
# Express in absolute coordinates (invalid depth values)
X_world = (
np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
)
return X_world, valid_mask
def colmap_to_opencv_intrinsics(K):
"""
Modify camera intrinsics to follow a different convention.
Coordinates of the center of the top-left pixels are by default:
- (0.5, 0.5) in Colmap
- (0,0) in OpenCV
"""
K = K.copy()
K[0, 2] -= 0.5
K[1, 2] -= 0.5
return K
def opencv_to_colmap_intrinsics(K):
"""
Modify camera intrinsics to follow a different convention.
Coordinates of the center of the top-left pixels are by default:
- (0.5, 0.5) in Colmap
- (0,0) in OpenCV
"""
K = K.copy()
K[0, 2] += 0.5
K[1, 2] += 0.5
return K
def normalize_pointcloud(pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None):
"""renorm pointmaps pts1, pts2 with norm_mode"""
assert pts1.ndim >= 3 and pts1.shape[-1] == 3
assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
norm_mode, dis_mode = norm_mode.split("_")
if norm_mode == "avg":
# gather all points together (joint normalization)
nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
nan_pts2, nnz2 = (
invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
)
all_pts = (
torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
)
# compute distance to origin
all_dis = all_pts.norm(dim=-1)
if dis_mode == "dis":
pass # do nothing
elif dis_mode == "log1p":
all_dis = torch.log1p(all_dis)
elif dis_mode == "warp-log1p":
# actually warp input points before normalizing them
log_dis = torch.log1p(all_dis)
warp_factor = log_dis / all_dis.clip(min=1e-8)
H1, W1 = pts1.shape[1:-1]
pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1)
if pts2 is not None:
H2, W2 = pts2.shape[1:-1]
pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1)
all_dis = log_dis # this is their true distance afterwards
else:
raise ValueError(f"bad {dis_mode=}")
norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
else:
# gather all points together (joint normalization)
nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
all_pts = (
torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
)
# compute distance to origin
all_dis = all_pts.norm(dim=-1)
if norm_mode == "avg":
norm_factor = all_dis.nanmean(dim=1)
elif norm_mode == "median":
norm_factor = all_dis.nanmedian(dim=1).values.detach()
elif norm_mode == "sqrt":
norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
else:
raise ValueError(f"bad {norm_mode=}")
norm_factor = norm_factor.clip(min=1e-8)
while norm_factor.ndim < pts1.ndim:
norm_factor.unsqueeze_(-1)
res = pts1 / norm_factor
if pts2 is not None:
res = (res, pts2 / norm_factor)
return res
@torch.no_grad()
def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
# set invalid points to NaN
_z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
_z2 = (
invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1)
if z2 is not None
else None
)
_z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
# compute median depth overall (ignoring nans)
if quantile == 0.5:
shift_z = torch.nanmedian(_z, dim=-1).values
else:
shift_z = torch.nanquantile(_z, quantile, dim=-1)
return shift_z # (B,)
@torch.no_grad()
def get_joint_pointcloud_center_scale(
pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True
):
# set invalid points to NaN
_pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
_pts2 = (
invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3)
if pts2 is not None
else None
)
_pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
# compute median center
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
if z_only:
_center[..., :2] = 0 # do not center X and Y
# compute median norm
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
scale = torch.nanmedian(_norm, dim=1).values
return _center[:, None, :, :], scale[:, None, None, None]
def find_reciprocal_matches(P1, P2):
"""
returns 3 values:
1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
3 - reciprocal_in_P2.sum(): the number of matches
"""
tree1 = KDTree(P1)
tree2 = KDTree(P2)
_, nn1_in_P2 = tree2.query(P1, workers=8)
_, nn2_in_P1 = tree1.query(P2, workers=8)
reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))
reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))
assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
def get_med_dist_between_poses(poses):
from scipy.spatial.distance import pdist
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
def save_pointcloud_with_plyfile(result, filename="output.ply", downsample_ratio=10):
all_points = []
all_colors = []
for view in result:
pts = view["point_map_by_unprojection"] # (1, H, W, 3)
rgbs = view["rgbs"] # (1, 3, H, W)
dpt_cnf = view["dpt_cnf"]
# Remove batch dimension
pts = pts.squeeze(0) # (H, W, 3)
rgbs = rgbs.squeeze(0).permute(1, 2, 0) # (3, H, W) -> (H, W, 3)
# Flatten
pts = pts.reshape(-1, 3) # (N, 3)
rgbs = rgbs.reshape(-1, 3) # (N, 3)
# Remove invalid points
valid = torch.isfinite(pts).all(dim=1) & (pts.norm(dim=1) > 0)
valid = valid & (dpt_cnf > torch.quantile(view["dpt_cnf"], 0.5)).flatten()
pts = pts[valid]
rgbs = rgbs[valid]
# Downsample this view
N = pts.shape[0]
if downsample_ratio > 1 and N >= downsample_ratio:
idx = torch.randperm(N)[: N // downsample_ratio]
pts = pts[idx]
rgbs = rgbs[idx]
all_points.append(pts)
all_colors.append(rgbs)
# Merge all views
all_points = torch.cat(all_points, dim=0).cpu().numpy()
all_colors = torch.cat(all_colors, dim=0).cpu().numpy()
# Normalize color
if all_colors.max() <= 1.0:
all_colors = (all_colors * 255).astype(np.uint8)
else:
all_colors = all_colors.astype(np.uint8)
# Build structured array
vertex_data = np.empty(
len(all_points),
dtype=[
("x", "f4"),
("y", "f4"),
("z", "f4"),
("red", "u1"),
("green", "u1"),
("blue", "u1"),
],
)
vertex_data["x"] = all_points[:, 0]
vertex_data["y"] = all_points[:, 1]
vertex_data["z"] = all_points[:, 2]
vertex_data["red"] = all_colors[:, 0]
vertex_data["green"] = all_colors[:, 1]
vertex_data["blue"] = all_colors[:, 2]
# Save with plyfile
el = PlyElement.describe(vertex_data, "vertex")
PlyData([el], text=False).write(filename)
print(f"[PLY] Saved {len(all_points)} points to {filename}")
def save_pointcloud_with_plyfile_each_frame(
result, filename="output.ply", downsample_ratio=10
):
for frame_number, view in enumerate(result):
all_points = []
all_colors = []
pts = view["point_map_by_unprojection"] # (1, H, W, 3)
rgbs = view["rgbs"] # (1, 3, H, W)
dpt_cnf = view["dpt_cnf"]
# Remove batch dimension
pts = pts.squeeze(0) # (H, W, 3)
rgbs = rgbs.squeeze(0).permute(1, 2, 0) # (3, H, W) -> (H, W, 3)
# Flatten
pts = pts.reshape(-1, 3) # (N, 3)
rgbs = rgbs.reshape(-1, 3) # (N, 3)
# Remove invalid points
valid = torch.isfinite(pts).all(dim=1) & (pts.norm(dim=1) > 0)
valid = valid & (dpt_cnf > torch.quantile(view["dpt_cnf"], 0.5)).flatten()
pts = pts[valid]
rgbs = rgbs[valid]
# Downsample this view
N = pts.shape[0]
if downsample_ratio > 1 and N >= downsample_ratio:
idx = torch.randperm(N)[: N // downsample_ratio]
pts = pts[idx]
rgbs = rgbs[idx]
all_points.append(pts)
all_colors.append(rgbs)
# Merge all views
all_points = torch.cat(all_points, dim=0).cpu().numpy()
all_colors = torch.cat(all_colors, dim=0).cpu().numpy()
# Normalize color
if all_colors.max() <= 1.0:
all_colors = (all_colors * 255).astype(np.uint8)
else:
all_colors = all_colors.astype(np.uint8)
# Build structured array
vertex_data = np.empty(
len(all_points),
dtype=[
("x", "f4"),
("y", "f4"),
("z", "f4"),
("red", "u1"),
("green", "u1"),
("blue", "u1"),
],
)
vertex_data["x"] = all_points[:, 0]
vertex_data["y"] = all_points[:, 1]
vertex_data["z"] = all_points[:, 2]
vertex_data["red"] = all_colors[:, 0]
vertex_data["green"] = all_colors[:, 1]
vertex_data["blue"] = all_colors[:, 2]
# Save with plyfile
el = PlyElement.describe(vertex_data, "vertex")
PlyData([el], text=False).write(
filename.split(".ply")[0] + f"_{frame_number:05d}.ply"
)
print(
f"[PLY] Saved {len(all_points)} points to {filename} idx {frame_number:05d}"
)
def save_concatenated_images(samples, save_path):
imgs = []
for sample in samples:
img = sample["img"] # (1, C, H, W)
img = F.interpolate(
img, scale_factor=0.25, mode="bilinear", align_corners=False
)
imgs.append(img)
imgs = torch.cat(imgs, dim=0).cpu() # (N, C, H, W)
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
vutils.save_image(imgs, save_path, normalize=True)
print(f"[Image] Saved concatenated image to {save_path}")