import torch.nn as nn import torch import numpy as np import os from torch.utils.data import Dataset, DataLoader from tqdm import tqdm from utils import TrainClock, cycle, ensure_dirs, ensure_dir import argparse import h5py import shutil import json import random from plyfile import PlyData, PlyElement import sys sys.path.append("..") from agent import BaseAgent from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule from plyfile import PlyData, PlyElement def write_ply(points, filename, text=False): """ input: Nx3, write points to filename as PLY format. """ points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) with open(filename, mode='wb') as f: PlyData([el], text=text).write(f) class Config(object): n_points = 2048 batch_size = 128 num_workers = 8 nr_epochs = 200 lr = 1e-4 lr_step_size = 50 # beta1 = 0.5 grad_clip = None noise = 0.02 save_frequency = 100 val_frequency = 10 def __init__(self, args): self.data_root = os.path.join(args.proj_dir, args.exp_name, "results/all_zs_ckpt{}.h5".format(args.ae_ckpt)) self.exp_dir = os.path.join(args.proj_dir, args.exp_name, "pc2cad_tune_noise{}_{}_new".format(self.n_points, self.noise)) print(self.exp_dir) self.log_dir = os.path.join(self.exp_dir, 'log') self.model_dir = os.path.join(self.exp_dir, 'model') self.gpu_ids = args.gpu_ids if (not args.test) and args.cont is not True and os.path.exists(self.exp_dir): response = input('Experiment log/model already exists, overwrite? (y/n) ') if response != 'y': exit() shutil.rmtree(self.exp_dir) ensure_dirs([self.log_dir, self.model_dir]) if not args.test: os.system("cp pc2cad.py {}".format(self.exp_dir)) with open('{}/config.txt'.format(self.exp_dir), 'w') as f: json.dump(self.__dict__, f, indent=2) class PointNet2(nn.Module): def __init__(self): super(PointNet2, self).__init__() self.use_xyz = True self._build_model() def _build_model(self): self.SA_modules = nn.ModuleList() self.SA_modules.append( PointnetSAModule( npoint=512, radius=0.1, nsample=64, mlp=[0, 32, 32, 64], # bn=False, use_xyz=self.use_xyz, ) ) self.SA_modules.append( PointnetSAModule( npoint=256, radius=0.2, nsample=64, mlp=[64, 64, 64, 128], # bn=False, use_xyz=self.use_xyz, ) ) self.SA_modules.append( PointnetSAModule( npoint=128, radius=0.4, nsample=64, mlp=[128, 128, 128, 256], # bn=False, use_xyz=self.use_xyz, ) ) self.SA_modules.append( PointnetSAModule( mlp=[256, 256, 512, 1024], # bn=False, use_xyz=self.use_xyz ) ) self.fc_layer = nn.Sequential( nn.Linear(1024, 512), nn.LeakyReLU(True), nn.Linear(512, 256), nn.LeakyReLU(True), nn.Linear(256, 256), nn.Tanh() ) def _break_up_pc(self, pc): xyz = pc[..., 0:3].contiguous() features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None return xyz, features def forward(self, pointcloud): r""" Forward pass of the network Parameters ---------- pointcloud: Variable(torch.cuda.FloatTensor) (B, N, 3 + input_channels) tensor Point cloud to run predicts on Each point in the point-cloud MUST be formated as (x, y, z, features...) """ xyz, features = self._break_up_pc(pointcloud) for module in self.SA_modules: xyz, features = module(xyz, features) return self.fc_layer(features.squeeze(-1)) class EncoderPointNet(nn.Module): def __init__(self, n_filters=(128, 256, 512, 1024), bn=True): super(EncoderPointNet, self).__init__() self.n_filters = list(n_filters) # + [latent_dim] # self.latent_dim = latent_dim model = [] prev_nf = 3 for idx, nf in enumerate(self.n_filters): conv_layer = nn.Conv1d(prev_nf, nf, kernel_size=1, stride=1) model.append(conv_layer) if bn: bn_layer = nn.BatchNorm1d(nf) model.append(bn_layer) act_layer = nn.LeakyReLU(inplace=True) model.append(act_layer) prev_nf = nf self.model = nn.Sequential(*model) self.fc_layer = nn.Sequential( nn.Linear(1024, 512), nn.LeakyReLU(True), nn.Linear(512, 256), nn.Tanh() ) def forward(self, x): x = x.permute(0, 2, 1) x = self.model(x) x = torch.mean(x, dim=2) x = self.fc_layer(x) return x class TrainAgent(BaseAgent): def build_net(self, config): net = PointNet2() if len(config.gpu_ids) > 1: net = nn.DataParallel(net) # net = EncoderPointNet() return net def set_loss_function(self): self.criterion = nn.MSELoss().cuda() def set_optimizer(self, config): """set optimizer and lr scheduler used in training""" self.optimizer = torch.optim.Adam(self.net.parameters(), config.lr) # , betas=(config.beta1, 0.9)) self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, config.lr_step_size) def forward(self, data): points = data["points"].cuda() code = data["code"].cuda() pred_code = self.net(points) loss = self.criterion(pred_code, code) return pred_code, {"mse": loss} def read_ply(path, with_normal=False): with open(path, 'rb') as f: plydata = PlyData.read(f) x = np.array(plydata['vertex']['x']) y = np.array(plydata['vertex']['y']) z = np.array(plydata['vertex']['z']) vertex = np.stack([x, y, z], axis=1) if with_normal: nx = np.array(plydata['vertex']['nx']) ny = np.array(plydata['vertex']['ny']) nz = np.array(plydata['vertex']['nz']) normals = np.stack([nx, ny, nz], axis=1) if with_normal: return np.concatenate([vertex, normals], axis=1) else: return vertex class ShapeCodesDataset(Dataset): def __init__(self, phase, config): super(ShapeCodesDataset, self).__init__() self.n_points = config.n_points self.data_root = config.data_root # self.abc_root = "/mnt/disk6/wurundi/abc" self.abc_root = "/home/rundi/data/abc" self.pc_root = self.abc_root + "/pc_v5a_processed_merge" self.path = os.path.join(self.abc_root, "cad_e10_l6_c15_len60_min0_t100.json") with open(self.path, "r") as fp: self.all_data = json.load(fp)[phase] with h5py.File(self.data_root, 'r') as fp: self.zs = fp["{}_zs".format(phase)][:] self.noise = config.noise def __getitem__(self, index): data_id = self.all_data[index] pc_path = os.path.join(self.pc_root, data_id + '.ply') if not os.path.exists(pc_path): return self.__getitem__(index + 1) pc_n = read_ply(pc_path, with_normal=True) pc = pc_n[:, :3] normal = pc_n[:, -3:] sample_idx = random.sample(list(range(pc.shape[0])), self.n_points) pc = pc[sample_idx] normal = normal[sample_idx] normal = normal / (np.linalg.norm(normal, axis=1, keepdims=True) + 1e-6) pc = pc + np.random.uniform(-self.noise, self.noise, (pc.shape[0], 1)) * normal pc = torch.tensor(pc, dtype=torch.float32) shape_code = torch.tensor(self.zs[index], dtype=torch.float32) return {"points": pc, "code": shape_code, "id": data_id} def __len__(self): return len(self.zs) def get_dataloader(phase, config, shuffle=None): is_shuffle = phase == 'train' if shuffle is None else shuffle dataset = ShapeCodesDataset(phase, config) dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=is_shuffle, num_workers=config.num_workers) return dataloader parser = argparse.ArgumentParser() # parser.add_argument('--proj_dir', type=str, default="/mnt/disk6/wurundi/cad_gen", # help="path to project folder where models and logs will be saved") parser.add_argument('--proj_dir', type=str, default="/home/rundi/project_log/cad_gen", help="path to project folder where models and logs will be saved") parser.add_argument('--exp_name', type=str, required=True, help="name of this experiment") parser.add_argument('--ae_ckpt', type=str, required=True, help="desired checkpoint to restore") parser.add_argument('--continue', dest='cont', action='store_true', help="continue training from checkpoint") parser.add_argument('--ckpt', type=str, default='latest', required=False, help="desired checkpoint to restore") parser.add_argument('--test',action='store_true', help="test mode") parser.add_argument('--n_samples', type=int, default=100, help="number of samples to generate when testing") parser.add_argument('-g', '--gpu_ids', type=str, default="0", help="gpu to use, e.g. 0 0,1,2. CPU not supported.") args = parser.parse_args() if args.gpu_ids is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids) cfg = Config(args) print("data path:", cfg.data_root) agent = TrainAgent(cfg) if not args.test: # load from checkpoint if provided if args.cont: agent.load_ckpt(args.ckpt) # for g in agent.optimizer.param_groups: # g['lr'] = 1e-5 # create dataloader train_loader = get_dataloader('train', cfg) val_loader = get_dataloader('validation', cfg) val_loader = cycle(val_loader) # start training clock = agent.clock for e in range(clock.epoch, cfg.nr_epochs): # begin iteration pbar = tqdm(train_loader) for b, data in enumerate(pbar): # train step outputs, losses = agent.train_func(data) pbar.set_description("EPOCH[{}][{}]".format(e, b)) pbar.set_postfix({k: v.item() for k, v in losses.items()}) # validation step if clock.step % cfg.val_frequency == 0: data = next(val_loader) outputs, losses = agent.val_func(data) clock.tick() clock.tock() if clock.epoch % cfg.save_frequency == 0: agent.save_ckpt() # if clock.epoch % 10 == 0: agent.save_ckpt('latest') else: # load trained weights agent.load_ckpt(args.ckpt) test_loader = get_dataloader('test', cfg) # save_dir = os.path.join(cfg.exp_dir, "results/fake_z_ckpt{}_num{}_pc".format(args.ckpt, args.n_samples)) save_dir = os.path.join(cfg.exp_dir, "results/pc2cad_ckpt{}_num{}".format(args.ckpt, args.n_samples)) if not os.path.exists(save_dir): os.makedirs(save_dir) all_zs = [] all_ids = [] pbar = tqdm(test_loader) cnt = 0 for i, data in enumerate(pbar): with torch.no_grad(): pred_z, _ = agent.forward(data) pred_z = pred_z.detach().cpu().numpy() # print(pred_z.shape) all_zs.append(pred_z) all_ids.extend(data['id']) pts = data['points'].detach().cpu().numpy() # for j in range(pred_z.shape[0]): # save_path = os.path.join(save_dir, "{}.ply".format(data['id'][j])) # write_ply(pts[j], save_path) # for j in range(pred_z.shape[0]): # save_path = os.path.join(save_dir, "{}.h5".format(data['id'][j])) # with h5py.File(save_path, 'w') as fp: # fp.create_dataset("zs", data=pred_z[j]) cnt += pred_z.shape[0] if cnt > args.n_samples: break all_zs = np.concatenate(all_zs, axis=0) # save generated z save_path = os.path.join(cfg.exp_dir, "results/pc2cad_z_ckpt{}_num{}.h5".format(args.ckpt, args.n_samples)) ensure_dir(os.path.dirname(save_path)) with h5py.File(save_path, 'w') as fp: fp.create_dataset("zs", shape=all_zs.shape, data=all_zs) save_path = os.path.join(cfg.exp_dir, "results/pc2cad_z_ckpt{}_num{}_ids.json".format(args.ckpt, args.n_samples)) with open(save_path, 'w') as fp: json.dump(all_ids, fp)