Spaces:
Build error
Build error
| import os | |
| from utils import ensure_dirs | |
| import argparse | |
| import json | |
| import shutil | |
| from cadlib.macro import * | |
| class ConfigAE(object): | |
| def __init__(self, phase): | |
| self.is_train = phase == "train" | |
| self.set_configuration() | |
| # init hyperparameters and parse from command-line | |
| parser, args = self.parse() | |
| # set as attributes | |
| print("----Experiment Configuration-----") | |
| for k, v in args.__dict__.items(): | |
| print("{0:20}".format(k), v) | |
| self.__setattr__(k, v) | |
| # experiment paths | |
| self.exp_dir = os.path.join(self.proj_dir, self.exp_name) | |
| if phase == "train" and args.cont is not True and os.path.exists(self.exp_dir): | |
| response = input('Experiment log/model already exists, overwrite? (y/n) ') | |
| if response != 'y': | |
| exit() | |
| shutil.rmtree(self.exp_dir) | |
| self.log_dir = os.path.join(self.exp_dir, 'log') | |
| self.model_dir = os.path.join(self.exp_dir, 'model') | |
| ensure_dirs([self.log_dir, self.model_dir]) | |
| # GPU usage | |
| if args.gpu_ids is not None: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids) | |
| # create soft link to experiment log directory | |
| # if not os.path.exists('train_log'): | |
| # os.symlink(self.exp_dir, 'train_log') | |
| # save this configuration | |
| if self.is_train: | |
| with open('{}/config.txt'.format(self.exp_dir), 'w') as f: | |
| json.dump(args.__dict__, f, indent=2) | |
| def set_configuration(self): | |
| self.args_dim = ARGS_DIM # 256 | |
| self.n_args = N_ARGS | |
| self.n_commands = len(ALL_COMMANDS) # line, arc, circle, EOS, SOS | |
| self.n_layers = 4 # Number of Encoder blocks | |
| self.n_layers_decode = 4 # Number of Decoder blocks | |
| self.n_heads = 8 # Transformer config: number of heads | |
| self.dim_feedforward = 512 # Transformer config: FF dimensionality | |
| self.d_model = 256 # Transformer config: model dimensionality | |
| self.dropout = 0.1 # Dropout rate used in basic layers and Transformers | |
| self.dim_z = 256 # Latent vector dimensionality | |
| self.use_group_emb = True | |
| self.max_n_ext = MAX_N_EXT | |
| self.max_n_loops = MAX_N_LOOPS | |
| self.max_n_curves = MAX_N_CURVES | |
| self.max_num_groups = 30 | |
| self.max_total_len = MAX_TOTAL_LEN | |
| self.loss_weights = { | |
| "loss_cmd_weight": 1.0, | |
| "loss_args_weight": 2.0 | |
| } | |
| def parse(self): | |
| """initiaize argument parser. Define default hyperparameters and collect from command-line arguments.""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--proj_dir', type=str, default="proj_log", help="path to project folder where models and logs will be saved") | |
| parser.add_argument('--data_root', type=str, default="data", help="path to source data folder") | |
| parser.add_argument('--exp_name', type=str, default=os.getcwd().split('/')[-1], help="name of this experiment") | |
| parser.add_argument('-g', '--gpu_ids', type=str, default='0', help="gpu to use, e.g. 0 0,1,2. CPU not supported.") | |
| parser.add_argument('--batch_size', type=int, default=512, help="batch size") | |
| parser.add_argument('--num_workers', type=int, default=8, help="number of workers for data loading") | |
| parser.add_argument('--nr_epochs', type=int, default=1000, help="total number of epochs to train") | |
| parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate") | |
| parser.add_argument('--grad_clip', type=float, default=1.0, help="initial learning rate") | |
| parser.add_argument('--warmup_step', type=int, default=2000, help="step size for learning rate warm up") | |
| parser.add_argument('--continue', dest='cont', action='store_true', help="continue training from checkpoint") | |
| parser.add_argument('--ckpt', type=str, default='latest', required=False, help="desired checkpoint to restore") | |
| parser.add_argument('--vis', action='store_true', default=False, help="visualize output in training") | |
| parser.add_argument('--save_frequency', type=int, default=500, help="save models every x epochs") | |
| parser.add_argument('--val_frequency', type=int, default=10, help="run validation every x iterations") | |
| parser.add_argument('--vis_frequency', type=int, default=2000, help="visualize output every x iterations") | |
| parser.add_argument('--augment', action='store_true', help="use random data augmentation") | |
| if not self.is_train: | |
| parser.add_argument('-m', '--mode', type=str, choices=['rec', 'enc', 'dec']) | |
| parser.add_argument('-o', '--outputs', type=str, default=None) | |
| parser.add_argument('--z_path', type=str, default=None) | |
| args = parser.parse_args() | |
| return parser, args | |