DeepCAD / trainer /trainerAE.py
turiya-ai's picture
Upload 51 files
4d588ce verified
raw
history blame
4.43 kB
import torch
import torch.optim as optim
from tqdm import tqdm
from model import CADTransformer
from .base import BaseTrainer
from .loss import CADLoss
from .scheduler import GradualWarmupScheduler
from cadlib.macro import *
class TrainerAE(BaseTrainer):
def build_net(self, cfg):
self.net = CADTransformer(cfg).cuda()
def set_optimizer(self, cfg):
"""set optimizer and lr scheduler used in training"""
self.optimizer = optim.Adam(self.net.parameters(), cfg.lr)
self.scheduler = GradualWarmupScheduler(self.optimizer, 1.0, cfg.warmup_step)
def set_loss_function(self):
self.loss_func = CADLoss(self.cfg).cuda()
def forward(self, data):
commands = data['command'].cuda() # (N, S)
args = data['args'].cuda() # (N, S, N_ARGS)
outputs = self.net(commands, args)
loss_dict = self.loss_func(outputs)
return outputs, loss_dict
def encode(self, data, is_batch=False):
"""encode into latent vectors"""
commands = data['command'].cuda()
args = data['args'].cuda()
if not is_batch:
commands = commands.unsqueeze(0)
args = args.unsqueeze(0)
z = self.net(commands, args, encode_mode=True)
return z
def decode(self, z):
"""decode given latent vectors"""
outputs = self.net(None, None, z=z, return_tgt=False)
return outputs
def logits2vec(self, outputs, refill_pad=True, to_numpy=True):
"""network outputs (logits) to final CAD vector"""
out_command = torch.argmax(torch.softmax(outputs['command_logits'], dim=-1), dim=-1) # (N, S)
out_args = torch.argmax(torch.softmax(outputs['args_logits'], dim=-1), dim=-1) - 1 # (N, S, N_ARGS)
if refill_pad: # fill all unused element to -1
mask = ~torch.tensor(CMD_ARGS_MASK).bool().cuda()[out_command.long()]
out_args[mask] = -1
out_cad_vec = torch.cat([out_command.unsqueeze(-1), out_args], dim=-1)
if to_numpy:
out_cad_vec = out_cad_vec.detach().cpu().numpy()
return out_cad_vec
def evaluate(self, test_loader):
"""evaluatinon during training"""
self.net.eval()
pbar = tqdm(test_loader)
pbar.set_description("EVALUATE[{}]".format(self.clock.epoch))
all_ext_args_comp = []
all_line_args_comp = []
all_arc_args_comp = []
all_circle_args_comp = []
for i, data in enumerate(pbar):
with torch.no_grad():
commands = data['command'].cuda()
args = data['args'].cuda()
outputs = self.net(commands, args)
out_args = torch.argmax(torch.softmax(outputs['args_logits'], dim=-1), dim=-1) - 1
out_args = out_args.long().detach().cpu().numpy() # (N, S, n_args)
gt_commands = commands.squeeze(1).long().detach().cpu().numpy() # (N, S)
gt_args = args.squeeze(1).long().detach().cpu().numpy() # (N, S, n_args)
ext_pos = np.where(gt_commands == EXT_IDX)
line_pos = np.where(gt_commands == LINE_IDX)
arc_pos = np.where(gt_commands == ARC_IDX)
circle_pos = np.where(gt_commands == CIRCLE_IDX)
args_comp = (gt_args == out_args).astype(np.int)
all_ext_args_comp.append(args_comp[ext_pos][:, -N_ARGS_EXT:])
all_line_args_comp.append(args_comp[line_pos][:, :2])
all_arc_args_comp.append(args_comp[arc_pos][:, :4])
all_circle_args_comp.append(args_comp[circle_pos][:, [0, 1, 4]])
all_ext_args_comp = np.concatenate(all_ext_args_comp, axis=0)
sket_plane_acc = np.mean(all_ext_args_comp[:, :N_ARGS_PLANE])
sket_trans_acc = np.mean(all_ext_args_comp[:, N_ARGS_PLANE:N_ARGS_PLANE+N_ARGS_TRANS])
extent_one_acc = np.mean(all_ext_args_comp[:, -N_ARGS_EXT_PARAM])
line_acc = np.mean(np.concatenate(all_line_args_comp, axis=0))
arc_acc = np.mean(np.concatenate(all_arc_args_comp, axis=0))
circle_acc = np.mean(np.concatenate(all_circle_args_comp, axis=0))
self.val_tb.add_scalars("args_acc",
{"line": line_acc, "arc": arc_acc, "circle": circle_acc,
"plane": sket_plane_acc, "trans": sket_trans_acc, "extent": extent_one_acc},
global_step=self.clock.epoch)