Spaces:
Build error
Build error
File size: 4,894 Bytes
4d588ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import glob
import h5py
import numpy as np
import argparse
from joblib import Parallel, delayed
import random
from scipy.spatial import cKDTree as KDTree
import time
import sys
sys.path.append("..")
from utils import read_ply
from cadlib.visualize import vec2CADsolid, CADsolid2pc
PC_ROOT = "../data/pc_cad"
# data that is unable to process
SKIP_DATA = [""]
def chamfer_dist(gt_points, gen_points, offset=0, scale=1):
gen_points = gen_points / scale - offset
# one direction
gen_points_kd_tree = KDTree(gen_points)
one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points)
gt_to_gen_chamfer = np.mean(np.square(one_distances))
# other direction
gt_points_kd_tree = KDTree(gt_points)
two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points)
gen_to_gt_chamfer = np.mean(np.square(two_distances))
return gt_to_gen_chamfer + gen_to_gt_chamfer
def normalize_pc(points):
scale = np.max(np.abs(points))
points = points / scale
return points
def process_one(path):
with h5py.File(path, 'r') as fp:
out_vec = fp["out_vec"][:].astype(np.float)
# gt_vec = fp["gt_vec"][:].astype(np.float)
data_id = path.split('/')[-1].split('.')[0][:8]
truck_id = data_id[:4]
gt_pc_path = os.path.join(PC_ROOT, truck_id, data_id + '.ply')
if not os.path.exists(gt_pc_path):
return None
try:
shape = vec2CADsolid(out_vec)
except Exception as e:
print("create_CAD failed", data_id)
return None
try:
out_pc = CADsolid2pc(shape, args.n_points, data_id)
except Exception as e:
print("convert pc failed:", data_id)
return None
if np.max(np.abs(out_pc)) > 2: # normalize out-of-bound data
out_pc = normalize_pc(out_pc)
gt_pc = read_ply(gt_pc_path)
sample_idx = random.sample(list(range(gt_pc.shape[0])), args.n_points)
gt_pc = gt_pc[sample_idx]
cd = chamfer_dist(gt_pc, out_pc)
return cd
def run(args):
filepaths = sorted(glob.glob(os.path.join(args.src, "*.h5")))
if args.num != -1:
filepaths = filepaths[:args.num]
save_path = args.src + '_pc_stat.txt'
record_res = None
if os.path.exists(save_path):
response = input(save_path + ' already exists, overwrite? (y/n) ')
if response == 'y':
os.system("rm {}".format(save_path))
record_res = None
else:
with open(save_path, 'r') as fp:
record_res = fp.readlines()
n_processed = len(record_res) - 3
if args.parallel:
dists = Parallel(n_jobs=8, verbose=2)(delayed(process_one)(x) for x in filepaths)
else:
dists = []
for i in range(len(filepaths)):
print("processing[{}] {}".format(i, filepaths[i]))
data_id = filepaths[i].split('/')[-1].split('.')[0]
if record_res is not None and i < n_processed:
record_dist = record_res[i].split('\t')[-1][:-1]
record_dist = None if record_dist == 'None' else eval(record_dist)
dists.append(record_dist)
continue
if data_id in SKIP_DATA:
print("skip {}".format(data_id))
res = None
else:
res = process_one(filepaths[i])
with open(save_path, 'a') as fp:
print("{}\t{}\t{}".format(i, data_id, res), file=fp)
dists.append(res)
valid_dists = [x for x in dists if x is not None]
valid_dists = sorted(valid_dists)
print("top 20 largest error:")
print(valid_dists[-20:][::-1])
n_valid = len(valid_dists)
n_invalid = len(dists) - n_valid
avg_dist = np.mean(valid_dists)
trim_avg_dist = np.mean(valid_dists[int(n_valid * 0.1):-int(n_valid * 0.1)])
med_dist = np.median(valid_dists)
print("#####" * 10)
print("total:", len(filepaths), "\t invalid:", n_invalid, "\t invalid ratio:", n_invalid / len(filepaths))
print("avg dist:", avg_dist, "trim_avg_dist:", trim_avg_dist, "med dist:", med_dist)
with open(save_path, "a") as fp:
print("#####" * 10, file=fp)
print("total:", len(filepaths), "\t invalid:", n_invalid, "\t invalid ratio:", n_invalid / len(filepaths),
file=fp)
print("avg dist:", avg_dist, "trim_avg_dist:", trim_avg_dist, "med dist:", med_dist,
file=fp)
parser = argparse.ArgumentParser()
parser.add_argument('--src', type=str, default=None, required=True)
parser.add_argument('--n_points', type=int, default=2000)
parser.add_argument('--num', type=int, default=-1)
parser.add_argument('--parallel', action='store_true', help="use parallelization")
args = parser.parse_args()
print(args.src)
print("SKIP DATA:", SKIP_DATA)
since = time.time()
run(args)
end = time.time()
print("running time: {}s".format(end - since))
|