Spaces:
Build error
Build error
| from .layers.transformer import * | |
| from .layers.improved_transformer import * | |
| from .layers.positional_encoding import * | |
| from .model_utils import _make_seq_first, _make_batch_first, \ | |
| _get_padding_mask, _get_key_padding_mask, _get_group_mask | |
| class CADEmbedding(nn.Module): | |
| """Embedding: positional embed + command embed + parameter embed + group embed (optional)""" | |
| def __init__(self, cfg, seq_len, use_group=False, group_len=None): | |
| super().__init__() | |
| self.command_embed = nn.Embedding(cfg.n_commands, cfg.d_model) | |
| args_dim = cfg.args_dim + 1 | |
| self.arg_embed = nn.Embedding(args_dim, 64, padding_idx=0) | |
| self.embed_fcn = nn.Linear(64 * cfg.n_args, cfg.d_model) | |
| # use_group: additional embedding for each sketch-extrusion pair | |
| self.use_group = use_group | |
| if use_group: | |
| if group_len is None: | |
| group_len = cfg.max_num_groups | |
| self.group_embed = nn.Embedding(group_len + 2, cfg.d_model) | |
| self.pos_encoding = PositionalEncodingLUT(cfg.d_model, max_len=seq_len+2) | |
| def forward(self, commands, args, groups=None): | |
| S, N = commands.shape | |
| src = self.command_embed(commands.long()) + \ | |
| self.embed_fcn(self.arg_embed((args + 1).long()).view(S, N, -1)) # shift due to -1 PAD_VAL | |
| if self.use_group: | |
| src = src + self.group_embed(groups.long()) | |
| src = self.pos_encoding(src) | |
| return src | |
| class ConstEmbedding(nn.Module): | |
| """learned constant embedding""" | |
| def __init__(self, cfg, seq_len): | |
| super().__init__() | |
| self.d_model = cfg.d_model | |
| self.seq_len = seq_len | |
| self.PE = PositionalEncodingLUT(cfg.d_model, max_len=seq_len) | |
| def forward(self, z): | |
| N = z.size(1) | |
| src = self.PE(z.new_zeros(self.seq_len, N, self.d_model)) | |
| return src | |
| class Encoder(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| seq_len = cfg.max_total_len | |
| self.use_group = cfg.use_group_emb | |
| self.embedding = CADEmbedding(cfg, seq_len, use_group=self.use_group) | |
| encoder_layer = TransformerEncoderLayerImproved(cfg.d_model, cfg.n_heads, cfg.dim_feedforward, cfg.dropout) | |
| encoder_norm = LayerNorm(cfg.d_model) | |
| self.encoder = TransformerEncoder(encoder_layer, cfg.n_layers, encoder_norm) | |
| def forward(self, commands, args): | |
| padding_mask, key_padding_mask = _get_padding_mask(commands, seq_dim=0), _get_key_padding_mask(commands, seq_dim=0) | |
| group_mask = _get_group_mask(commands, seq_dim=0) if self.use_group else None | |
| src = self.embedding(commands, args, group_mask) | |
| memory = self.encoder(src, mask=None, src_key_padding_mask=key_padding_mask) | |
| z = (memory * padding_mask).sum(dim=0, keepdim=True) / padding_mask.sum(dim=0, keepdim=True) # (1, N, dim_z) | |
| return z | |
| class FCN(nn.Module): | |
| def __init__(self, d_model, n_commands, n_args, args_dim=256): | |
| super().__init__() | |
| self.n_args = n_args | |
| self.args_dim = args_dim | |
| self.command_fcn = nn.Linear(d_model, n_commands) | |
| self.args_fcn = nn.Linear(d_model, n_args * args_dim) | |
| def forward(self, out): | |
| S, N, _ = out.shape | |
| command_logits = self.command_fcn(out) # Shape [S, N, n_commands] | |
| args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim] | |
| args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim] | |
| return command_logits, args_logits | |
| class Decoder(nn.Module): | |
| def __init__(self, cfg): | |
| super(Decoder, self).__init__() | |
| self.embedding = ConstEmbedding(cfg, cfg.max_total_len) | |
| decoder_layer = TransformerDecoderLayerGlobalImproved(cfg.d_model, cfg.dim_z, cfg.n_heads, cfg.dim_feedforward, cfg.dropout) | |
| decoder_norm = LayerNorm(cfg.d_model) | |
| self.decoder = TransformerDecoder(decoder_layer, cfg.n_layers_decode, decoder_norm) | |
| args_dim = cfg.args_dim + 1 | |
| self.fcn = FCN(cfg.d_model, cfg.n_commands, cfg.n_args, args_dim) | |
| def forward(self, z): | |
| src = self.embedding(z) | |
| out = self.decoder(src, z, tgt_mask=None, tgt_key_padding_mask=None) | |
| command_logits, args_logits = self.fcn(out) | |
| out_logits = (command_logits, args_logits) | |
| return out_logits | |
| class Bottleneck(nn.Module): | |
| def __init__(self, cfg): | |
| super(Bottleneck, self).__init__() | |
| self.bottleneck = nn.Sequential(nn.Linear(cfg.d_model, cfg.dim_z), | |
| nn.Tanh()) | |
| def forward(self, z): | |
| return self.bottleneck(z) | |
| class CADTransformer(nn.Module): | |
| def __init__(self, cfg): | |
| super(CADTransformer, self).__init__() | |
| self.args_dim = cfg.args_dim + 1 | |
| self.encoder = Encoder(cfg) | |
| self.bottleneck = Bottleneck(cfg) | |
| self.decoder = Decoder(cfg) | |
| def forward(self, commands_enc, args_enc, | |
| z=None, return_tgt=True, encode_mode=False): | |
| commands_enc_, args_enc_ = _make_seq_first(commands_enc, args_enc) # Possibly None, None | |
| if z is None: | |
| z = self.encoder(commands_enc_, args_enc_) | |
| z = self.bottleneck(z) | |
| else: | |
| z = _make_seq_first(z) | |
| if encode_mode: return _make_batch_first(z) | |
| out_logits = self.decoder(z) | |
| out_logits = _make_batch_first(*out_logits) | |
| res = { | |
| "command_logits": out_logits[0], | |
| "args_logits": out_logits[1] | |
| } | |
| if return_tgt: | |
| res["tgt_commands"] = commands_enc | |
| res["tgt_args"] = args_enc | |
| return res | |