import os import sys import torch import numpy as np import gradio as gr import matplotlib.pyplot as plt from matplotlib.patches import Rectangle, FancyBboxPatch import h5py import requests from tqdm import tqdm import json from pathlib import Path # Add spaces GPU decorator if available try: import spaces GPU_AVAILABLE = True except ImportError: GPU_AVAILABLE = False # Create dummy decorator if spaces not available class spaces: @staticmethod def GPU(func): return func # Add current directory to path for imports sys.path.insert(0, os.path.dirname(__file__)) # Import DeepCAD modules (assuming they're in the repository) try: from config import ConfigAE from trainer import TrainerAE from dataset import CADDataset except ImportError: print("Warning: Could not import DeepCAD modules. Creating minimal config...") # Create minimal config class if import fails class ConfigAE: def __init__(self): self.exp_name = "pretrained" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.n_commands = 60 self.n_args = 256 self.dim = 256 self.n_layers = 4 # Constants MODEL_URL = "http://www.cs.columbia.edu/cg/deepcad/pretrained.tar" CHECKPOINT_DIR = "pretrained_model" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def download_pretrained_model(): """Download the pretrained DeepCAD model if not already present.""" checkpoint_path = os.path.join(CHECKPOINT_DIR, "ckpt_1000.pt") if os.path.exists(checkpoint_path): print(f"✓ Pretrained model found at {checkpoint_path}") return checkpoint_path print("Downloading pretrained model...") os.makedirs(CHECKPOINT_DIR, exist_ok=True) try: response = requests.get(MODEL_URL, stream=True) response.raise_for_status() tar_path = os.path.join(CHECKPOINT_DIR, "pretrained.tar") total_size = int(response.headers.get('content-length', 0)) with open(tar_path, 'wb') as f, tqdm( total=total_size, unit='B', unit_scale=True, desc="Downloading" ) as pbar: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) pbar.update(len(chunk)) # Extract tar file import tarfile with tarfile.open(tar_path, 'r') as tar: tar.extractall(CHECKPOINT_DIR) os.remove(tar_path) print("✓ Model downloaded and extracted successfully!") return checkpoint_path except Exception as e: print(f"Error downloading model: {e}") return None class SimpleCADDecoder(torch.nn.Module): """Simplified CAD decoder for inference.""" def __init__(self, dim=256, n_commands=60, n_args=256): super().__init__() self.dim = dim self.n_commands = n_commands self.n_args = n_args # Simple decoder architecture self.fc_layers = torch.nn.Sequential( torch.nn.Linear(dim, dim * 2), torch.nn.ReLU(), torch.nn.Linear(dim * 2, dim * 2), torch.nn.ReLU(), torch.nn.Linear(dim * 2, n_commands * n_args) ) def forward(self, z): """Decode latent vector to CAD sequence.""" batch_size = z.size(0) out = self.fc_layers(z) out = out.view(batch_size, self.n_commands, self.n_args) return out class DeepCADModel: """Wrapper for DeepCAD model with inference capabilities.""" def __init__(self, checkpoint_path=None): self.device = DEVICE self.dim = 256 self.n_commands = 60 self.n_args = 256 # Initialize model self.model = SimpleCADDecoder( dim=self.dim, n_commands=self.n_commands, n_args=self.n_args ).to(self.device) # Load checkpoint if provided if checkpoint_path and os.path.exists(checkpoint_path): try: checkpoint = torch.load(checkpoint_path, map_location=self.device) # Try to load decoder weights if 'decoder' in checkpoint: self.model.load_state_dict(checkpoint['decoder']) elif 'model' in checkpoint: self.model.load_state_dict(checkpoint['model']) print(f"✓ Loaded checkpoint from {checkpoint_path}") except Exception as e: print(f"Warning: Could not load checkpoint: {e}") print("Using randomly initialized model...") self.model.eval() def generate_from_latent(self, z): """Generate CAD sequence from latent vector.""" with torch.no_grad(): if isinstance(z, np.ndarray): z = torch.from_numpy(z).float() z = z.to(self.device) if len(z.shape) == 1: z = z.unsqueeze(0) output = self.model(z) return output.cpu().numpy() def random_generation(self, seed=None): """Generate a random CAD sequence.""" if seed is not None: np.random.seed(seed) torch.manual_seed(seed) # Sample random latent vector from normal distribution z = torch.randn(1, self.dim).to(self.device) return self.generate_from_latent(z) def visualize_cad_sequence(cad_sequence, title="Generated CAD Model"): """ Visualize CAD sequence as a 2D projection. Since we can't use pythonocc-core, we'll create a simplified visualization. """ fig = plt.figure(figsize=(12, 8)) # Main plot: 2D projection of CAD operations ax1 = plt.subplot(2, 2, (1, 3)) ax1.set_title(title, fontsize=14, fontweight='bold') ax1.set_xlim(-5, 5) ax1.set_ylim(-5, 5) ax1.set_aspect('equal') ax1.grid(True, alpha=0.3) ax1.set_xlabel('X') ax1.set_ylabel('Y') # Parse and visualize the sequence sequence = cad_sequence[0] if len(cad_sequence.shape) == 3 else cad_sequence # Extract meaningful features from the sequence # Each command has multiple arguments representing geometric operations colors = plt.cm.viridis(np.linspace(0, 1, len(sequence))) for i, command in enumerate(sequence): # Interpret command parameters as geometric primitives # This is a simplified interpretation if np.abs(command).max() > 0.01: # Skip near-zero commands # Extract position and size parameters x = command[0] * 4 # Scale to viewport y = command[1] * 4 width = np.abs(command[2]) * 2 + 0.3 height = np.abs(command[3]) * 2 + 0.3 # Draw a rectangle representing this operation rect = FancyBboxPatch( (x - width/2, y - height/2), width, height, boxstyle="round,pad=0.05", edgecolor=colors[i], facecolor=colors[i], alpha=0.3, linewidth=2 ) ax1.add_patch(rect) # Add operation number if i % 5 == 0: # Label every 5th operation ax1.text(x, y, str(i), ha='center', va='center', fontsize=8, color='black', fontweight='bold') # Command histogram ax2 = plt.subplot(2, 2, 2) ax2.set_title('Command Distribution', fontsize=12) command_magnitudes = np.linalg.norm(sequence, axis=1) ax2.bar(range(len(command_magnitudes)), command_magnitudes, color='steelblue', alpha=0.7) ax2.set_xlabel('Command Index') ax2.set_ylabel('Magnitude') ax2.grid(True, alpha=0.3) # Parameter statistics ax3 = plt.subplot(2, 2, 4) ax3.set_title('Parameter Statistics', fontsize=12) param_stats = { 'Mean': np.mean(np.abs(sequence)), 'Std': np.std(sequence), 'Max': np.max(np.abs(sequence)), 'Non-zero': np.sum(np.abs(sequence) > 0.01) / sequence.size } bars = ax3.bar(param_stats.keys(), param_stats.values(), color='coral', alpha=0.7) ax3.set_ylabel('Value') ax3.grid(True, alpha=0.3) # Add value labels on bars for bar in bars: height = bar.get_height() ax3.text(bar.get_x() + bar.get_width()/2., height, f'{height:.3f}', ha='center', va='bottom', fontsize=9) plt.tight_layout() return fig def save_cad_sequence(cad_sequence, filename="generated_cad.h5"): """Save CAD sequence to H5 file.""" with h5py.File(filename, 'w') as f: f.create_dataset('cad_sequence', data=cad_sequence) f.attrs['format'] = 'DeepCAD vectorized representation' return filename # Initialize model globally print("Initializing DeepCAD model...") checkpoint_path = download_pretrained_model() model = DeepCADModel(checkpoint_path) print(f"✓ Model initialized on {DEVICE}") # Add GPU decorator to the generation function @spaces.GPU(duration=60) # Reserve GPU for 60 seconds def generate_cad(seed, temperature): """Generate CAD model from seed and temperature.""" try: # Set random seed for reproducibility if seed >= 0: np.random.seed(seed) torch.manual_seed(seed) # Generate random latent vector with temperature scaling z = torch.randn(1, model.dim) * temperature # Generate CAD sequence cad_sequence = model.generate_from_latent(z) # Create visualization fig = visualize_cad_sequence( cad_sequence, title=f"Generated CAD Model (seed={seed}, temp={temperature:.2f})" ) # Save to H5 file h5_filename = f"generated_cad_seed{seed}.h5" save_cad_sequence(cad_sequence, h5_filename) # Create info text info_text = f""" **Generation Info:** - Seed: {seed} - Temperature: {temperature:.2f} - Device: {DEVICE} - Sequence shape: {cad_sequence.shape} - Non-zero commands: {np.sum(np.abs(cad_sequence) > 0.01)} **Note:** The visualization shows a 2D projection of the CAD operations. Download the H5 file to use with full DeepCAD evaluation tools. """ return fig, h5_filename, info_text except Exception as e: import traceback error_msg = f"Error during generation:\n{str(e)}\n\n{traceback.format_exc()}" print(error_msg) # Return empty plot and error message fig = plt.figure(figsize=(8, 6)) plt.text(0.5, 0.5, "Generation Failed\nSee console for details", ha='center', va='center', fontsize=14, color='red') plt.axis('off') return fig, None, error_msg # Create Gradio interface with gr.Blocks(title="DeepCAD: CAD Model Generation", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🔧 DeepCAD: Deep Generative Network for CAD Models Generate Computer-Aided Design (CAD) models using deep learning! This demo uses the DeepCAD model from the ICCV 2021 paper by Wu, Xiao, and Zheng. **How it works:** 1. Adjust the seed for different random generations 2. Control temperature to adjust variation (higher = more creative, lower = more conservative) 3. Click Generate to create a new CAD model 4. Download the H5 file for use with full DeepCAD tools **Note:** Visualization is simplified (2D projection). For full 3D STEP export, use the downloaded H5 file with the original DeepCAD repository tools. """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### đŸŽ›ī¸ Generation Parameters") seed_input = gr.Slider( minimum=0, maximum=10000, value=42, step=1, label="Random Seed", info="Set seed for reproducible generation" ) temperature_input = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature", info="Controls generation diversity" ) generate_btn = gr.Button("🚀 Generate CAD Model", variant="primary", size="lg") gr.Markdown("### 📊 Quick Stats") info_output = gr.Markdown() gr.Markdown("### 💾 Download") h5_output = gr.File(label="Download H5 File") with gr.Column(scale=2): gr.Markdown("### 🎨 Visualization") plot_output = gr.Plot(label="CAD Model Visualization") gr.Markdown(""" --- ### 📚 References - **Paper:** [DeepCAD: A Deep Generative Network for Computer-Aided Design Models](https://arxiv.org/abs/2105.09492) - **Authors:** Rundi Wu, Chang Xiao, Changxi Zheng (Columbia University) - **Conference:** ICCV 2021 - **Code:** [GitHub Repository](https://github.com/ChrisWu1997/DeepCAD) ### â„šī¸ About This is a simplified deployment for demonstration. For full functionality including: - 3D STEP file export - Complete evaluation metrics - Training your own models Please visit the [official GitHub repository](https://github.com/ChrisWu1997/DeepCAD). """) # Connect the generate button generate_btn.click( fn=generate_cad, inputs=[seed_input, temperature_input], outputs=[plot_output, h5_output, info_output] ) # Add examples gr.Examples( examples=[ [42, 1.0], [123, 0.8], [999, 1.2], [2024, 1.5], ], inputs=[seed_input, temperature_input], outputs=[plot_output, h5_output, info_output], fn=generate_cad, cache_examples=False, ) # Launch the app if __name__ == "__main__": print("\n" + "="*50) print("🚀 Starting DeepCAD Gradio Interface") print(f"📍 Device: {DEVICE}") print("="*50 + "\n") demo.launch( server_name="0.0.0.0", server_port=7860, share=False )