import torch.nn as nn import torch import numpy as np import os from torch.utils.data import Dataset, DataLoader from tqdm import tqdm import argparse import h5py import shutil import json import random import sys sys.path.append("..") from trainer.base import BaseTrainer from utils import cycle, ensure_dirs, ensure_dir, read_ply, write_ply try: from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule except Exception as e: print("need to install https://github.com/erikwijmans/Pointnet2_PyTorch") exit() class Config(object): n_points = 2048 batch_size = 128 num_workers = 4 nr_epochs = 200 lr = 1e-4 lr_step_size = 50 # beta1 = 0.5 grad_clip = None save_frequency = 100 val_frequency = 10 def __init__(self, args): self.data_root = os.path.join(args.proj_dir, args.exp_name, "results/all_zs_ckpt{}.h5".format(args.ae_ckpt)) self.pc_root = args.pc_root self.split_path = args.split_path self.exp_dir = os.path.join(args.proj_dir, args.exp_name, "pc2cad") self.log_dir = os.path.join(self.exp_dir, 'log') self.model_dir = os.path.join(self.exp_dir, 'model') self.gpu_ids = args.gpu_ids if (not args.test) and args.cont is not True and os.path.exists(self.exp_dir): response = input('Experiment log/model already exists, overwrite? (y/n) ') if response != 'y': exit() shutil.rmtree(self.exp_dir) ensure_dirs([self.log_dir, self.model_dir]) if not args.test: os.system("cp pc2cad.py {}".format(self.exp_dir)) with open('{}/config.txt'.format(self.exp_dir), 'w') as f: json.dump(self.__dict__, f, indent=2) class PointNet2(nn.Module): def __init__(self): super(PointNet2, self).__init__() self.use_xyz = True self._build_model() def _build_model(self): self.SA_modules = nn.ModuleList() self.SA_modules.append( PointnetSAModule( npoint=512, radius=0.1, nsample=64, mlp=[0, 32, 32, 64], # bn=False, use_xyz=self.use_xyz, ) ) self.SA_modules.append( PointnetSAModule( npoint=256, radius=0.2, nsample=64, mlp=[64, 64, 64, 128], # bn=False, use_xyz=self.use_xyz, ) ) self.SA_modules.append( PointnetSAModule( npoint=128, radius=0.4, nsample=64, mlp=[128, 128, 128, 256], # bn=False, use_xyz=self.use_xyz, ) ) self.SA_modules.append( PointnetSAModule( mlp=[256, 256, 512, 1024], # bn=False, use_xyz=self.use_xyz ) ) self.fc_layer = nn.Sequential( nn.Linear(1024, 512), nn.LeakyReLU(True), nn.Linear(512, 256), nn.LeakyReLU(True), nn.Linear(256, 256), nn.Tanh() ) def _break_up_pc(self, pc): xyz = pc[..., 0:3].contiguous() features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None return xyz, features def forward(self, pointcloud): r""" Forward pass of the network Parameters ---------- pointcloud: Variable(torch.cuda.FloatTensor) (B, N, 3 + input_channels) tensor Point cloud to run predicts on Each point in the point-cloud MUST be formated as (x, y, z, features...) """ xyz, features = self._break_up_pc(pointcloud) for module in self.SA_modules: xyz, features = module(xyz, features) return self.fc_layer(features.squeeze(-1)) class TrainAgent(BaseTrainer): def build_net(self, config): self.net = PointNet2().cuda() def set_loss_function(self): self.criterion = nn.MSELoss().cuda() def set_optimizer(self, config): """set optimizer and lr scheduler used in training""" self.optimizer = torch.optim.Adam(self.net.parameters(), config.lr) # , betas=(config.beta1, 0.9)) self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, config.lr_step_size) def forward(self, data): points = data["points"].cuda() code = data["code"].cuda() pred_code = self.net(points) loss = self.criterion(pred_code, code) return pred_code, {"mse": loss} class ShapeCodesDataset(Dataset): def __init__(self, phase, config): super(ShapeCodesDataset, self).__init__() self.n_points = config.n_points self.data_root = config.data_root self.pc_root = config.pc_root self.path = config.split_path with open(self.path, "r") as fp: self.all_data = json.load(fp)[phase] with h5py.File(self.data_root, 'r') as fp: self.zs = fp["{}_zs".format(phase)][:] def __getitem__(self, index): data_id = self.all_data[index] pc_path = os.path.join(self.pc_root, data_id + '.ply') if not os.path.exists(pc_path): return self.__getitem__(index + 1) pc = read_ply(pc_path) sample_idx = random.sample(list(range(pc.shape[0])), self.n_points) pc = pc[sample_idx] pc = torch.tensor(pc, dtype=torch.float32) shape_code = torch.tensor(self.zs[index], dtype=torch.float32) return {"points": pc, "code": shape_code, "id": data_id} def __len__(self): return len(self.zs) def get_dataloader(phase, config, shuffle=None): is_shuffle = phase == 'train' if shuffle is None else shuffle dataset = ShapeCodesDataset(phase, config) dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=is_shuffle, num_workers=config.num_workers) return dataloader parser = argparse.ArgumentParser() parser.add_argument('--proj_dir', type=str, default="proj_log", help="path to project folder where models and logs will be saved") parser.add_argument('--pc_root', type=str, default="path_to_pc_data", help="path to point clouds data folder") parser.add_argument('--split_path', type=str, default="data/train_val_test_split.json", help="path to train-val-test split") parser.add_argument('--exp_name', type=str, required=True, help="name of this experiment") parser.add_argument('--ae_ckpt', type=str, required=True, help="desired checkpoint to restore") parser.add_argument('--continue', dest='cont', action='store_true', help="continue training from checkpoint") parser.add_argument('--ckpt', type=str, default='latest', required=False, help="desired checkpoint to restore") parser.add_argument('--test',action='store_true', help="test mode") parser.add_argument('--n_samples', type=int, default=100, help="number of samples to generate when testing") parser.add_argument('-g', '--gpu_ids', type=str, default="0", help="gpu to use, e.g. 0 0,1,2. CPU not supported.") args = parser.parse_args() if args.gpu_ids is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids) cfg = Config(args) print("data path:", cfg.data_root) agent = TrainAgent(cfg) if not args.test: # load from checkpoint if provided if args.cont: agent.load_ckpt(args.ckpt) # create dataloader train_loader = get_dataloader('train', cfg) val_loader = get_dataloader('validation', cfg) val_loader = cycle(val_loader) # start training clock = agent.clock for e in range(clock.epoch, cfg.nr_epochs): # begin iteration pbar = tqdm(train_loader) for b, data in enumerate(pbar): # train step outputs, losses = agent.train_func(data) pbar.set_description("EPOCH[{}][{}]".format(e, b)) pbar.set_postfix({k: v.item() for k, v in losses.items()}) # validation step if clock.step % cfg.val_frequency == 0: data = next(val_loader) outputs, losses = agent.val_func(data) clock.tick() clock.tock() if clock.epoch % cfg.save_frequency == 0: agent.save_ckpt() # if clock.epoch % 10 == 0: agent.save_ckpt('latest') else: # load trained weights agent.load_ckpt(args.ckpt) test_loader = get_dataloader('test', cfg) save_dir = os.path.join(cfg.exp_dir, "results/fake_z_ckpt{}_num{}_pc".format(args.ckpt, args.n_samples)) if not os.path.exists(save_dir): os.makedirs(save_dir) all_zs = [] pbar = tqdm(test_loader) cnt = 0 for i, data in enumerate(pbar): with torch.no_grad(): pred_z, _ = agent.forward(data) pred_z = pred_z.detach().cpu().numpy() # print(pred_z.shape) all_zs.append(pred_z) pts = data['points'].detach().cpu().numpy() for j in range(pred_z.shape[0]): save_path = os.path.join(save_dir, "{}.ply".format(data['id'][j])) write_ply(pts[j], save_path) cnt += pred_z.shape[0] if cnt > args.n_samples: break all_zs = np.concatenate(all_zs, axis=0) # save generated z save_path = os.path.join(cfg.exp_dir, "results/fake_z_ckpt{}_num{}.h5".format(args.ckpt, args.n_samples)) ensure_dir(os.path.dirname(save_path)) with h5py.File(save_path, 'w') as fp: fp.create_dataset("zs", shape=all_zs.shape, data=all_zs)