DeepCAD / config /configAE.py
turiya-ai's picture
Upload 51 files
4d588ce verified
import os
from utils import ensure_dirs
import argparse
import json
import shutil
from cadlib.macro import *
class ConfigAE(object):
def __init__(self, phase):
self.is_train = phase == "train"
self.set_configuration()
# init hyperparameters and parse from command-line
parser, args = self.parse()
# set as attributes
print("----Experiment Configuration-----")
for k, v in args.__dict__.items():
print("{0:20}".format(k), v)
self.__setattr__(k, v)
# experiment paths
self.exp_dir = os.path.join(self.proj_dir, self.exp_name)
if phase == "train" and args.cont is not True and os.path.exists(self.exp_dir):
response = input('Experiment log/model already exists, overwrite? (y/n) ')
if response != 'y':
exit()
shutil.rmtree(self.exp_dir)
self.log_dir = os.path.join(self.exp_dir, 'log')
self.model_dir = os.path.join(self.exp_dir, 'model')
ensure_dirs([self.log_dir, self.model_dir])
# GPU usage
if args.gpu_ids is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids)
# create soft link to experiment log directory
# if not os.path.exists('train_log'):
# os.symlink(self.exp_dir, 'train_log')
# save this configuration
if self.is_train:
with open('{}/config.txt'.format(self.exp_dir), 'w') as f:
json.dump(args.__dict__, f, indent=2)
def set_configuration(self):
self.args_dim = ARGS_DIM # 256
self.n_args = N_ARGS
self.n_commands = len(ALL_COMMANDS) # line, arc, circle, EOS, SOS
self.n_layers = 4 # Number of Encoder blocks
self.n_layers_decode = 4 # Number of Decoder blocks
self.n_heads = 8 # Transformer config: number of heads
self.dim_feedforward = 512 # Transformer config: FF dimensionality
self.d_model = 256 # Transformer config: model dimensionality
self.dropout = 0.1 # Dropout rate used in basic layers and Transformers
self.dim_z = 256 # Latent vector dimensionality
self.use_group_emb = True
self.max_n_ext = MAX_N_EXT
self.max_n_loops = MAX_N_LOOPS
self.max_n_curves = MAX_N_CURVES
self.max_num_groups = 30
self.max_total_len = MAX_TOTAL_LEN
self.loss_weights = {
"loss_cmd_weight": 1.0,
"loss_args_weight": 2.0
}
def parse(self):
"""initiaize argument parser. Define default hyperparameters and collect from command-line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('--proj_dir', type=str, default="proj_log", help="path to project folder where models and logs will be saved")
parser.add_argument('--data_root', type=str, default="data", help="path to source data folder")
parser.add_argument('--exp_name', type=str, default=os.getcwd().split('/')[-1], help="name of this experiment")
parser.add_argument('-g', '--gpu_ids', type=str, default='0', help="gpu to use, e.g. 0 0,1,2. CPU not supported.")
parser.add_argument('--batch_size', type=int, default=512, help="batch size")
parser.add_argument('--num_workers', type=int, default=8, help="number of workers for data loading")
parser.add_argument('--nr_epochs', type=int, default=1000, help="total number of epochs to train")
parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
parser.add_argument('--grad_clip', type=float, default=1.0, help="initial learning rate")
parser.add_argument('--warmup_step', type=int, default=2000, help="step size for learning rate warm up")
parser.add_argument('--continue', dest='cont', action='store_true', help="continue training from checkpoint")
parser.add_argument('--ckpt', type=str, default='latest', required=False, help="desired checkpoint to restore")
parser.add_argument('--vis', action='store_true', default=False, help="visualize output in training")
parser.add_argument('--save_frequency', type=int, default=500, help="save models every x epochs")
parser.add_argument('--val_frequency', type=int, default=10, help="run validation every x iterations")
parser.add_argument('--vis_frequency', type=int, default=2000, help="visualize output every x iterations")
parser.add_argument('--augment', action='store_true', help="use random data augmentation")
if not self.is_train:
parser.add_argument('-m', '--mode', type=str, choices=['rec', 'enc', 'dec'])
parser.add_argument('-o', '--outputs', type=str, default=None)
parser.add_argument('--z_path', type=str, default=None)
args = parser.parse_args()
return parser, args