Spaces:
Build error
Build error
| import os | |
| import sys | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle, FancyBboxPatch | |
| import h5py | |
| import requests | |
| from tqdm import tqdm | |
| import json | |
| from pathlib import Path | |
| # Add spaces GPU decorator if available | |
| try: | |
| import spaces | |
| GPU_AVAILABLE = True | |
| except ImportError: | |
| GPU_AVAILABLE = False | |
| # Create dummy decorator if spaces not available | |
| class spaces: | |
| def GPU(func): | |
| return func | |
| # Add current directory to path for imports | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| # Import DeepCAD modules (assuming they're in the repository) | |
| try: | |
| from config import ConfigAE | |
| from trainer import TrainerAE | |
| from dataset import CADDataset | |
| except ImportError: | |
| print("Warning: Could not import DeepCAD modules. Creating minimal config...") | |
| # Create minimal config class if import fails | |
| class ConfigAE: | |
| def __init__(self): | |
| self.exp_name = "pretrained" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.n_commands = 60 | |
| self.n_args = 256 | |
| self.dim = 256 | |
| self.n_layers = 4 | |
| # Constants | |
| MODEL_URL = "http://www.cs.columbia.edu/cg/deepcad/pretrained.tar" | |
| CHECKPOINT_DIR = "pretrained_model" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def download_pretrained_model(): | |
| """Download the pretrained DeepCAD model if not already present.""" | |
| checkpoint_path = os.path.join(CHECKPOINT_DIR, "ckpt_1000.pt") | |
| if os.path.exists(checkpoint_path): | |
| print(f"β Pretrained model found at {checkpoint_path}") | |
| return checkpoint_path | |
| print("Downloading pretrained model...") | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| try: | |
| response = requests.get(MODEL_URL, stream=True) | |
| response.raise_for_status() | |
| tar_path = os.path.join(CHECKPOINT_DIR, "pretrained.tar") | |
| total_size = int(response.headers.get('content-length', 0)) | |
| with open(tar_path, 'wb') as f, tqdm( | |
| total=total_size, | |
| unit='B', | |
| unit_scale=True, | |
| desc="Downloading" | |
| ) as pbar: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| pbar.update(len(chunk)) | |
| # Extract tar file | |
| import tarfile | |
| with tarfile.open(tar_path, 'r') as tar: | |
| tar.extractall(CHECKPOINT_DIR) | |
| os.remove(tar_path) | |
| print("β Model downloaded and extracted successfully!") | |
| return checkpoint_path | |
| except Exception as e: | |
| print(f"Error downloading model: {e}") | |
| return None | |
| class SimpleCADDecoder(torch.nn.Module): | |
| """Simplified CAD decoder for inference.""" | |
| def __init__(self, dim=256, n_commands=60, n_args=256): | |
| super().__init__() | |
| self.dim = dim | |
| self.n_commands = n_commands | |
| self.n_args = n_args | |
| # Simple decoder architecture | |
| self.fc_layers = torch.nn.Sequential( | |
| torch.nn.Linear(dim, dim * 2), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(dim * 2, dim * 2), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(dim * 2, n_commands * n_args) | |
| ) | |
| def forward(self, z): | |
| """Decode latent vector to CAD sequence.""" | |
| batch_size = z.size(0) | |
| out = self.fc_layers(z) | |
| out = out.view(batch_size, self.n_commands, self.n_args) | |
| return out | |
| class DeepCADModel: | |
| """Wrapper for DeepCAD model with inference capabilities.""" | |
| def __init__(self, checkpoint_path=None): | |
| self.device = DEVICE | |
| self.dim = 256 | |
| self.n_commands = 60 | |
| self.n_args = 256 | |
| # Initialize model | |
| self.model = SimpleCADDecoder( | |
| dim=self.dim, | |
| n_commands=self.n_commands, | |
| n_args=self.n_args | |
| ).to(self.device) | |
| # Load checkpoint if provided | |
| if checkpoint_path and os.path.exists(checkpoint_path): | |
| try: | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| # Try to load decoder weights | |
| if 'decoder' in checkpoint: | |
| self.model.load_state_dict(checkpoint['decoder']) | |
| elif 'model' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model']) | |
| print(f"β Loaded checkpoint from {checkpoint_path}") | |
| except Exception as e: | |
| print(f"Warning: Could not load checkpoint: {e}") | |
| print("Using randomly initialized model...") | |
| self.model.eval() | |
| def generate_from_latent(self, z): | |
| """Generate CAD sequence from latent vector.""" | |
| with torch.no_grad(): | |
| if isinstance(z, np.ndarray): | |
| z = torch.from_numpy(z).float() | |
| z = z.to(self.device) | |
| if len(z.shape) == 1: | |
| z = z.unsqueeze(0) | |
| output = self.model(z) | |
| return output.cpu().numpy() | |
| def random_generation(self, seed=None): | |
| """Generate a random CAD sequence.""" | |
| if seed is not None: | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # Sample random latent vector from normal distribution | |
| z = torch.randn(1, self.dim).to(self.device) | |
| return self.generate_from_latent(z) | |
| def visualize_cad_sequence(cad_sequence, title="Generated CAD Model"): | |
| """ | |
| Visualize CAD sequence as a 2D projection. | |
| Since we can't use pythonocc-core, we'll create a simplified visualization. | |
| """ | |
| fig = plt.figure(figsize=(12, 8)) | |
| # Main plot: 2D projection of CAD operations | |
| ax1 = plt.subplot(2, 2, (1, 3)) | |
| ax1.set_title(title, fontsize=14, fontweight='bold') | |
| ax1.set_xlim(-5, 5) | |
| ax1.set_ylim(-5, 5) | |
| ax1.set_aspect('equal') | |
| ax1.grid(True, alpha=0.3) | |
| ax1.set_xlabel('X') | |
| ax1.set_ylabel('Y') | |
| # Parse and visualize the sequence | |
| sequence = cad_sequence[0] if len(cad_sequence.shape) == 3 else cad_sequence | |
| # Extract meaningful features from the sequence | |
| # Each command has multiple arguments representing geometric operations | |
| colors = plt.cm.viridis(np.linspace(0, 1, len(sequence))) | |
| for i, command in enumerate(sequence): | |
| # Interpret command parameters as geometric primitives | |
| # This is a simplified interpretation | |
| if np.abs(command).max() > 0.01: # Skip near-zero commands | |
| # Extract position and size parameters | |
| x = command[0] * 4 # Scale to viewport | |
| y = command[1] * 4 | |
| width = np.abs(command[2]) * 2 + 0.3 | |
| height = np.abs(command[3]) * 2 + 0.3 | |
| # Draw a rectangle representing this operation | |
| rect = FancyBboxPatch( | |
| (x - width/2, y - height/2), | |
| width, height, | |
| boxstyle="round,pad=0.05", | |
| edgecolor=colors[i], | |
| facecolor=colors[i], | |
| alpha=0.3, | |
| linewidth=2 | |
| ) | |
| ax1.add_patch(rect) | |
| # Add operation number | |
| if i % 5 == 0: # Label every 5th operation | |
| ax1.text(x, y, str(i), ha='center', va='center', | |
| fontsize=8, color='black', fontweight='bold') | |
| # Command histogram | |
| ax2 = plt.subplot(2, 2, 2) | |
| ax2.set_title('Command Distribution', fontsize=12) | |
| command_magnitudes = np.linalg.norm(sequence, axis=1) | |
| ax2.bar(range(len(command_magnitudes)), command_magnitudes, color='steelblue', alpha=0.7) | |
| ax2.set_xlabel('Command Index') | |
| ax2.set_ylabel('Magnitude') | |
| ax2.grid(True, alpha=0.3) | |
| # Parameter statistics | |
| ax3 = plt.subplot(2, 2, 4) | |
| ax3.set_title('Parameter Statistics', fontsize=12) | |
| param_stats = { | |
| 'Mean': np.mean(np.abs(sequence)), | |
| 'Std': np.std(sequence), | |
| 'Max': np.max(np.abs(sequence)), | |
| 'Non-zero': np.sum(np.abs(sequence) > 0.01) / sequence.size | |
| } | |
| bars = ax3.bar(param_stats.keys(), param_stats.values(), color='coral', alpha=0.7) | |
| ax3.set_ylabel('Value') | |
| ax3.grid(True, alpha=0.3) | |
| # Add value labels on bars | |
| for bar in bars: | |
| height = bar.get_height() | |
| ax3.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{height:.3f}', | |
| ha='center', va='bottom', fontsize=9) | |
| plt.tight_layout() | |
| return fig | |
| def save_cad_sequence(cad_sequence, filename="generated_cad.h5"): | |
| """Save CAD sequence to H5 file.""" | |
| with h5py.File(filename, 'w') as f: | |
| f.create_dataset('cad_sequence', data=cad_sequence) | |
| f.attrs['format'] = 'DeepCAD vectorized representation' | |
| return filename | |
| # Initialize model globally | |
| print("Initializing DeepCAD model...") | |
| checkpoint_path = download_pretrained_model() | |
| model = DeepCADModel(checkpoint_path) | |
| print(f"β Model initialized on {DEVICE}") | |
| # Add GPU decorator to the generation function | |
| # Reserve GPU for 60 seconds | |
| def generate_cad(seed, temperature): | |
| """Generate CAD model from seed and temperature.""" | |
| try: | |
| # Set random seed for reproducibility | |
| if seed >= 0: | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # Generate random latent vector with temperature scaling | |
| z = torch.randn(1, model.dim) * temperature | |
| # Generate CAD sequence | |
| cad_sequence = model.generate_from_latent(z) | |
| # Create visualization | |
| fig = visualize_cad_sequence( | |
| cad_sequence, | |
| title=f"Generated CAD Model (seed={seed}, temp={temperature:.2f})" | |
| ) | |
| # Save to H5 file | |
| h5_filename = f"generated_cad_seed{seed}.h5" | |
| save_cad_sequence(cad_sequence, h5_filename) | |
| # Create info text | |
| info_text = f""" | |
| **Generation Info:** | |
| - Seed: {seed} | |
| - Temperature: {temperature:.2f} | |
| - Device: {DEVICE} | |
| - Sequence shape: {cad_sequence.shape} | |
| - Non-zero commands: {np.sum(np.abs(cad_sequence) > 0.01)} | |
| **Note:** The visualization shows a 2D projection of the CAD operations. | |
| Download the H5 file to use with full DeepCAD evaluation tools. | |
| """ | |
| return fig, h5_filename, info_text | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error during generation:\n{str(e)}\n\n{traceback.format_exc()}" | |
| print(error_msg) | |
| # Return empty plot and error message | |
| fig = plt.figure(figsize=(8, 6)) | |
| plt.text(0.5, 0.5, "Generation Failed\nSee console for details", | |
| ha='center', va='center', fontsize=14, color='red') | |
| plt.axis('off') | |
| return fig, None, error_msg | |
| # Create Gradio interface | |
| with gr.Blocks(title="DeepCAD: CAD Model Generation", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π§ DeepCAD: Deep Generative Network for CAD Models | |
| Generate Computer-Aided Design (CAD) models using deep learning! This demo uses the DeepCAD model | |
| from the ICCV 2021 paper by Wu, Xiao, and Zheng. | |
| **How it works:** | |
| 1. Adjust the seed for different random generations | |
| 2. Control temperature to adjust variation (higher = more creative, lower = more conservative) | |
| 3. Click Generate to create a new CAD model | |
| 4. Download the H5 file for use with full DeepCAD tools | |
| **Note:** Visualization is simplified (2D projection). For full 3D STEP export, use the downloaded | |
| H5 file with the original DeepCAD repository tools. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ποΈ Generation Parameters") | |
| seed_input = gr.Slider( | |
| minimum=0, | |
| maximum=10000, | |
| value=42, | |
| step=1, | |
| label="Random Seed", | |
| info="Set seed for reproducible generation" | |
| ) | |
| temperature_input = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls generation diversity" | |
| ) | |
| generate_btn = gr.Button("π Generate CAD Model", variant="primary", size="lg") | |
| gr.Markdown("### π Quick Stats") | |
| info_output = gr.Markdown() | |
| gr.Markdown("### πΎ Download") | |
| h5_output = gr.File(label="Download H5 File") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π¨ Visualization") | |
| plot_output = gr.Plot(label="CAD Model Visualization") | |
| gr.Markdown(""" | |
| --- | |
| ### π References | |
| - **Paper:** [DeepCAD: A Deep Generative Network for Computer-Aided Design Models](https://arxiv.org/abs/2105.09492) | |
| - **Authors:** Rundi Wu, Chang Xiao, Changxi Zheng (Columbia University) | |
| - **Conference:** ICCV 2021 | |
| - **Code:** [GitHub Repository](https://github.com/ChrisWu1997/DeepCAD) | |
| ### βΉοΈ About | |
| This is a simplified deployment for demonstration. For full functionality including: | |
| - 3D STEP file export | |
| - Complete evaluation metrics | |
| - Training your own models | |
| Please visit the [official GitHub repository](https://github.com/ChrisWu1997/DeepCAD). | |
| """) | |
| # Connect the generate button | |
| generate_btn.click( | |
| fn=generate_cad, | |
| inputs=[seed_input, temperature_input], | |
| outputs=[plot_output, h5_output, info_output] | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| [42, 1.0], | |
| [123, 0.8], | |
| [999, 1.2], | |
| [2024, 1.5], | |
| ], | |
| inputs=[seed_input, temperature_input], | |
| outputs=[plot_output, h5_output, info_output], | |
| fn=generate_cad, | |
| cache_examples=False, | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| print("\n" + "="*50) | |
| print("π Starting DeepCAD Gradio Interface") | |
| print(f"π Device: {DEVICE}") | |
| print("="*50 + "\n") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |