SAIL-Recon / eval /utils /image.py
hengli
first
b7f83b0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# utilitary functions about images (loading/converting...)
# --------------------------------------------------------
import os
from typing import Dict, Optional
import numpy as np
import PIL.Image
import torch
import torchvision.transforms as tvf
from PIL.ImageOps import exif_transpose
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2
try:
from pillow_heif import register_heif_opener
register_heif_opener()
heif_support_enabled = True
except ImportError:
heif_support_enabled = False
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def imread_cv2(path, options=cv2.IMREAD_COLOR):
"""Open an image or a depthmap with opencv-python."""
if path.endswith((".exr", "EXR")):
options = cv2.IMREAD_ANYDEPTH
img = cv2.imread(path, options)
if img is None:
raise IOError(f"Could not load image={path} with {options=}")
if img.ndim == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def rgb(ftensor, true_shape=None):
if isinstance(ftensor, list):
return [rgb(x, true_shape=true_shape) for x in ftensor]
if isinstance(ftensor, torch.Tensor):
ftensor = ftensor.detach().cpu().numpy() # H,W,3
if ftensor.ndim == 3 and ftensor.shape[0] == 3:
ftensor = ftensor.transpose(1, 2, 0)
elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
ftensor = ftensor.transpose(0, 2, 3, 1)
if true_shape is not None:
H, W = true_shape
ftensor = ftensor[:H, :W]
if ftensor.dtype == np.uint8:
img = np.float32(ftensor) / 255
else:
img = (ftensor * 0.5) + 0.5
return img.clip(min=0, max=1)
def _resize_pil_image(img, long_edge_size):
S = max(img.size)
if S > long_edge_size:
interp = PIL.Image.LANCZOS
elif S <= long_edge_size:
interp = PIL.Image.BICUBIC
new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
return img.resize(new_size, interp)
def load_images(
folder_or_list,
size,
square_ok=False,
verbose=True,
rotate_clockwise_90=False,
crop_to_landscape=False,
):
"""open and convert all images in a list or folder to proper input format for DUSt3R"""
if isinstance(folder_or_list, str):
if verbose:
print(f">> Loading images from {folder_or_list}")
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
elif isinstance(folder_or_list, list):
if verbose:
print(f">> Loading a list of {len(folder_or_list)} images")
root, folder_content = "", folder_or_list
else:
raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
supported_images_extensions = [".jpg", ".jpeg", ".png"]
if heif_support_enabled:
supported_images_extensions += [".heic", ".heif"]
supported_images_extensions = tuple(supported_images_extensions)
imgs = []
for path in folder_content:
if not path.lower().endswith(supported_images_extensions):
continue
img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
if rotate_clockwise_90:
img = img.rotate(-90, expand=True)
if crop_to_landscape:
# Crop to a landscape aspect ratio (e.g., 16:9)
desired_aspect_ratio = 4 / 3
width, height = img.size
current_aspect_ratio = width / height
if current_aspect_ratio > desired_aspect_ratio:
# Wider than landscape: crop width
new_width = int(height * desired_aspect_ratio)
left = (width - new_width) // 2
right = left + new_width
top = 0
bottom = height
else:
# Taller than landscape: crop height
new_height = int(width / desired_aspect_ratio)
top = (height - new_height) // 2
bottom = top + new_height
left = 0
right = width
img = img.crop((left, top, right, bottom))
W1, H1 = img.size
if size == 224:
# resize short side to 224 (then crop)
img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
else:
# resize long side to 512
img = _resize_pil_image(img, size)
W, H = img.size
cx, cy = W // 2, H // 2
if size == 224:
half = min(cx, cy)
img = img.crop((cx - half, cy - half, cx + half, cy + half))
else:
halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
if not (square_ok) and W == H:
halfh = 3 * halfw / 4
img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
W2, H2 = img.size
if verbose:
print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
imgs.append(
dict(
img=ImgNorm(img)[None],
true_shape=np.int32([img.size[::-1]]),
idx=len(imgs),
instance=str(len(imgs)),
)
)
assert imgs, "no images foud at " + root
if verbose:
print(f" (Found {len(imgs)} images)")
return imgs
def get_image_vggt_augmentation(
color_jitter: Optional[Dict[str, float]] = None,
gray_scale: bool = True,
gau_blur: bool = False,
) -> Optional[tvf.Compose]:
"""Create a composition of image augmentations.
Args:
color_jitter: Dictionary containing color jitter parameters:
- brightness: float (default: 0.5)
- contrast: float (default: 0.5)
- saturation: float (default: 0.5)
- hue: float (default: 0.1)
- p: probability of applying (default: 0.9)
If None, uses default values
gray_scale: Whether to apply random grayscale (default: True)
gau_blur: Whether to apply gaussian blur (default: False)
Returns:
A Compose object of transforms or None if no transforms are added
"""
transform_list = []
default_jitter = {
"brightness": 0.5,
"contrast": 0.5,
"saturation": 0.5,
"hue": 0.1,
"p": 0.9,
}
# Handle color jitter
if color_jitter is not None:
if not isinstance(color_jitter, dict):
raise ValueError("color_jitter must be a dictionary or None")
# Merge with defaults for missing keys
effective_jitter = {**default_jitter, **color_jitter}
else:
effective_jitter = default_jitter
transform_list.append(
tvf.RandomApply(
[
tvf.ColorJitter(
brightness=effective_jitter["brightness"],
contrast=effective_jitter["contrast"],
saturation=effective_jitter["saturation"],
hue=effective_jitter["hue"],
)
],
p=effective_jitter["p"],
)
)
if gray_scale:
transform_list.append(tvf.RandomGrayscale(p=0.05))
if gau_blur:
transform_list.append(
tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05)
)
# transform_list.append(tvf.ToTensor())
return tvf.Compose(transform_list) if transform_list else None