turiya-ai commited on
Commit
0290efc
Β·
verified Β·
1 Parent(s): ec88f90

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +407 -0
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.patches import Rectangle, FancyBboxPatch
8
+ import h5py
9
+ import requests
10
+ from tqdm import tqdm
11
+ import json
12
+ from pathlib import Path
13
+
14
+ # Add current directory to path for imports
15
+ sys.path.insert(0, os.path.dirname(__file__))
16
+
17
+ # Import DeepCAD modules (assuming they're in the repository)
18
+ try:
19
+ from config import ConfigAE
20
+ from trainer import TrainerAE
21
+ from dataset import CADDataset
22
+ except ImportError:
23
+ print("Warning: Could not import DeepCAD modules. Creating minimal config...")
24
+ # Create minimal config class if import fails
25
+ class ConfigAE:
26
+ def __init__(self):
27
+ self.exp_name = "pretrained"
28
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ self.n_commands = 60
30
+ self.n_args = 256
31
+ self.dim = 256
32
+ self.n_layers = 4
33
+
34
+ # Constants
35
+ MODEL_URL = "http://www.cs.columbia.edu/cg/deepcad/pretrained.tar"
36
+ CHECKPOINT_DIR = "pretrained_model"
37
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ def download_pretrained_model():
40
+ """Download the pretrained DeepCAD model if not already present."""
41
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, "ckpt_1000.pt")
42
+
43
+ if os.path.exists(checkpoint_path):
44
+ print(f"βœ“ Pretrained model found at {checkpoint_path}")
45
+ return checkpoint_path
46
+
47
+ print("Downloading pretrained model...")
48
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
49
+
50
+ try:
51
+ response = requests.get(MODEL_URL, stream=True)
52
+ response.raise_for_status()
53
+
54
+ tar_path = os.path.join(CHECKPOINT_DIR, "pretrained.tar")
55
+ total_size = int(response.headers.get('content-length', 0))
56
+
57
+ with open(tar_path, 'wb') as f, tqdm(
58
+ total=total_size,
59
+ unit='B',
60
+ unit_scale=True,
61
+ desc="Downloading"
62
+ ) as pbar:
63
+ for chunk in response.iter_content(chunk_size=8192):
64
+ if chunk:
65
+ f.write(chunk)
66
+ pbar.update(len(chunk))
67
+
68
+ # Extract tar file
69
+ import tarfile
70
+ with tarfile.open(tar_path, 'r') as tar:
71
+ tar.extractall(CHECKPOINT_DIR)
72
+
73
+ os.remove(tar_path)
74
+ print("βœ“ Model downloaded and extracted successfully!")
75
+ return checkpoint_path
76
+
77
+ except Exception as e:
78
+ print(f"Error downloading model: {e}")
79
+ return None
80
+
81
+ class SimpleCADDecoder(torch.nn.Module):
82
+ """Simplified CAD decoder for inference."""
83
+
84
+ def __init__(self, dim=256, n_commands=60, n_args=256):
85
+ super().__init__()
86
+ self.dim = dim
87
+ self.n_commands = n_commands
88
+ self.n_args = n_args
89
+
90
+ # Simple decoder architecture
91
+ self.fc_layers = torch.nn.Sequential(
92
+ torch.nn.Linear(dim, dim * 2),
93
+ torch.nn.ReLU(),
94
+ torch.nn.Linear(dim * 2, dim * 2),
95
+ torch.nn.ReLU(),
96
+ torch.nn.Linear(dim * 2, n_commands * n_args)
97
+ )
98
+
99
+ def forward(self, z):
100
+ """Decode latent vector to CAD sequence."""
101
+ batch_size = z.size(0)
102
+ out = self.fc_layers(z)
103
+ out = out.view(batch_size, self.n_commands, self.n_args)
104
+ return out
105
+
106
+ class DeepCADModel:
107
+ """Wrapper for DeepCAD model with inference capabilities."""
108
+
109
+ def __init__(self, checkpoint_path=None):
110
+ self.device = DEVICE
111
+ self.dim = 256
112
+ self.n_commands = 60
113
+ self.n_args = 256
114
+
115
+ # Initialize model
116
+ self.model = SimpleCADDecoder(
117
+ dim=self.dim,
118
+ n_commands=self.n_commands,
119
+ n_args=self.n_args
120
+ ).to(self.device)
121
+
122
+ # Load checkpoint if provided
123
+ if checkpoint_path and os.path.exists(checkpoint_path):
124
+ try:
125
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
126
+ # Try to load decoder weights
127
+ if 'decoder' in checkpoint:
128
+ self.model.load_state_dict(checkpoint['decoder'])
129
+ elif 'model' in checkpoint:
130
+ self.model.load_state_dict(checkpoint['model'])
131
+ print(f"βœ“ Loaded checkpoint from {checkpoint_path}")
132
+ except Exception as e:
133
+ print(f"Warning: Could not load checkpoint: {e}")
134
+ print("Using randomly initialized model...")
135
+
136
+ self.model.eval()
137
+
138
+ def generate_from_latent(self, z):
139
+ """Generate CAD sequence from latent vector."""
140
+ with torch.no_grad():
141
+ if isinstance(z, np.ndarray):
142
+ z = torch.from_numpy(z).float()
143
+ z = z.to(self.device)
144
+ if len(z.shape) == 1:
145
+ z = z.unsqueeze(0)
146
+
147
+ output = self.model(z)
148
+ return output.cpu().numpy()
149
+
150
+ def random_generation(self, seed=None):
151
+ """Generate a random CAD sequence."""
152
+ if seed is not None:
153
+ np.random.seed(seed)
154
+ torch.manual_seed(seed)
155
+
156
+ # Sample random latent vector from normal distribution
157
+ z = torch.randn(1, self.dim).to(self.device)
158
+ return self.generate_from_latent(z)
159
+
160
+ def visualize_cad_sequence(cad_sequence, title="Generated CAD Model"):
161
+ """
162
+ Visualize CAD sequence as a 2D projection.
163
+ Since we can't use pythonocc-core, we'll create a simplified visualization.
164
+ """
165
+ fig = plt.figure(figsize=(12, 8))
166
+
167
+ # Main plot: 2D projection of CAD operations
168
+ ax1 = plt.subplot(2, 2, (1, 3))
169
+ ax1.set_title(title, fontsize=14, fontweight='bold')
170
+ ax1.set_xlim(-5, 5)
171
+ ax1.set_ylim(-5, 5)
172
+ ax1.set_aspect('equal')
173
+ ax1.grid(True, alpha=0.3)
174
+ ax1.set_xlabel('X')
175
+ ax1.set_ylabel('Y')
176
+
177
+ # Parse and visualize the sequence
178
+ sequence = cad_sequence[0] if len(cad_sequence.shape) == 3 else cad_sequence
179
+
180
+ # Extract meaningful features from the sequence
181
+ # Each command has multiple arguments representing geometric operations
182
+ colors = plt.cm.viridis(np.linspace(0, 1, len(sequence)))
183
+
184
+ for i, command in enumerate(sequence):
185
+ # Interpret command parameters as geometric primitives
186
+ # This is a simplified interpretation
187
+ if np.abs(command).max() > 0.01: # Skip near-zero commands
188
+ # Extract position and size parameters
189
+ x = command[0] * 4 # Scale to viewport
190
+ y = command[1] * 4
191
+ width = np.abs(command[2]) * 2 + 0.3
192
+ height = np.abs(command[3]) * 2 + 0.3
193
+
194
+ # Draw a rectangle representing this operation
195
+ rect = FancyBboxPatch(
196
+ (x - width/2, y - height/2),
197
+ width, height,
198
+ boxstyle="round,pad=0.05",
199
+ edgecolor=colors[i],
200
+ facecolor=colors[i],
201
+ alpha=0.3,
202
+ linewidth=2
203
+ )
204
+ ax1.add_patch(rect)
205
+
206
+ # Add operation number
207
+ if i % 5 == 0: # Label every 5th operation
208
+ ax1.text(x, y, str(i), ha='center', va='center',
209
+ fontsize=8, color='black', fontweight='bold')
210
+
211
+ # Command histogram
212
+ ax2 = plt.subplot(2, 2, 2)
213
+ ax2.set_title('Command Distribution', fontsize=12)
214
+ command_magnitudes = np.linalg.norm(sequence, axis=1)
215
+ ax2.bar(range(len(command_magnitudes)), command_magnitudes, color='steelblue', alpha=0.7)
216
+ ax2.set_xlabel('Command Index')
217
+ ax2.set_ylabel('Magnitude')
218
+ ax2.grid(True, alpha=0.3)
219
+
220
+ # Parameter statistics
221
+ ax3 = plt.subplot(2, 2, 4)
222
+ ax3.set_title('Parameter Statistics', fontsize=12)
223
+ param_stats = {
224
+ 'Mean': np.mean(np.abs(sequence)),
225
+ 'Std': np.std(sequence),
226
+ 'Max': np.max(np.abs(sequence)),
227
+ 'Non-zero': np.sum(np.abs(sequence) > 0.01) / sequence.size
228
+ }
229
+ bars = ax3.bar(param_stats.keys(), param_stats.values(), color='coral', alpha=0.7)
230
+ ax3.set_ylabel('Value')
231
+ ax3.grid(True, alpha=0.3)
232
+
233
+ # Add value labels on bars
234
+ for bar in bars:
235
+ height = bar.get_height()
236
+ ax3.text(bar.get_x() + bar.get_width()/2., height,
237
+ f'{height:.3f}',
238
+ ha='center', va='bottom', fontsize=9)
239
+
240
+ plt.tight_layout()
241
+ return fig
242
+
243
+ def save_cad_sequence(cad_sequence, filename="generated_cad.h5"):
244
+ """Save CAD sequence to H5 file."""
245
+ with h5py.File(filename, 'w') as f:
246
+ f.create_dataset('cad_sequence', data=cad_sequence)
247
+ f.attrs['format'] = 'DeepCAD vectorized representation'
248
+ return filename
249
+
250
+ # Initialize model globally
251
+ print("Initializing DeepCAD model...")
252
+ checkpoint_path = download_pretrained_model()
253
+ model = DeepCADModel(checkpoint_path)
254
+ print(f"βœ“ Model initialized on {DEVICE}")
255
+
256
+ def generate_cad(seed, temperature):
257
+ """Generate CAD model from seed and temperature."""
258
+ try:
259
+ # Set random seed for reproducibility
260
+ if seed >= 0:
261
+ np.random.seed(seed)
262
+ torch.manual_seed(seed)
263
+
264
+ # Generate random latent vector with temperature scaling
265
+ z = torch.randn(1, model.dim) * temperature
266
+
267
+ # Generate CAD sequence
268
+ cad_sequence = model.generate_from_latent(z)
269
+
270
+ # Create visualization
271
+ fig = visualize_cad_sequence(
272
+ cad_sequence,
273
+ title=f"Generated CAD Model (seed={seed}, temp={temperature:.2f})"
274
+ )
275
+
276
+ # Save to H5 file
277
+ h5_filename = f"generated_cad_seed{seed}.h5"
278
+ save_cad_sequence(cad_sequence, h5_filename)
279
+
280
+ # Create info text
281
+ info_text = f"""
282
+ **Generation Info:**
283
+ - Seed: {seed}
284
+ - Temperature: {temperature:.2f}
285
+ - Device: {DEVICE}
286
+ - Sequence shape: {cad_sequence.shape}
287
+ - Non-zero commands: {np.sum(np.abs(cad_sequence) > 0.01)}
288
+
289
+ **Note:** The visualization shows a 2D projection of the CAD operations.
290
+ Download the H5 file to use with full DeepCAD evaluation tools.
291
+ """
292
+
293
+ return fig, h5_filename, info_text
294
+
295
+ except Exception as e:
296
+ import traceback
297
+ error_msg = f"Error during generation:\n{str(e)}\n\n{traceback.format_exc()}"
298
+ print(error_msg)
299
+ # Return empty plot and error message
300
+ fig = plt.figure(figsize=(8, 6))
301
+ plt.text(0.5, 0.5, "Generation Failed\nSee console for details",
302
+ ha='center', va='center', fontsize=14, color='red')
303
+ plt.axis('off')
304
+ return fig, None, error_msg
305
+
306
+ # Create Gradio interface
307
+ with gr.Blocks(title="DeepCAD: CAD Model Generation", theme=gr.themes.Soft()) as demo:
308
+ gr.Markdown("""
309
+ # πŸ”§ DeepCAD: Deep Generative Network for CAD Models
310
+
311
+ Generate Computer-Aided Design (CAD) models using deep learning! This demo uses the DeepCAD model
312
+ from the ICCV 2021 paper by Wu, Xiao, and Zheng.
313
+
314
+ **How it works:**
315
+ 1. Adjust the seed for different random generations
316
+ 2. Control temperature to adjust variation (higher = more creative, lower = more conservative)
317
+ 3. Click Generate to create a new CAD model
318
+ 4. Download the H5 file for use with full DeepCAD tools
319
+
320
+ **Note:** Visualization is simplified (2D projection). For full 3D STEP export, use the downloaded
321
+ H5 file with the original DeepCAD repository tools.
322
+ """)
323
+
324
+ with gr.Row():
325
+ with gr.Column(scale=1):
326
+ gr.Markdown("### πŸŽ›οΈ Generation Parameters")
327
+
328
+ seed_input = gr.Slider(
329
+ minimum=0,
330
+ maximum=10000,
331
+ value=42,
332
+ step=1,
333
+ label="Random Seed",
334
+ info="Set seed for reproducible generation"
335
+ )
336
+
337
+ temperature_input = gr.Slider(
338
+ minimum=0.1,
339
+ maximum=2.0,
340
+ value=1.0,
341
+ step=0.1,
342
+ label="Temperature",
343
+ info="Controls generation diversity"
344
+ )
345
+
346
+ generate_btn = gr.Button("πŸš€ Generate CAD Model", variant="primary", size="lg")
347
+
348
+ gr.Markdown("### πŸ“Š Quick Stats")
349
+ info_output = gr.Markdown()
350
+
351
+ gr.Markdown("### πŸ’Ύ Download")
352
+ h5_output = gr.File(label="Download H5 File")
353
+
354
+ with gr.Column(scale=2):
355
+ gr.Markdown("### 🎨 Visualization")
356
+ plot_output = gr.Plot(label="CAD Model Visualization")
357
+
358
+ gr.Markdown("""
359
+ ---
360
+ ### πŸ“š References
361
+ - **Paper:** [DeepCAD: A Deep Generative Network for Computer-Aided Design Models](https://arxiv.org/abs/2105.09492)
362
+ - **Authors:** Rundi Wu, Chang Xiao, Changxi Zheng (Columbia University)
363
+ - **Conference:** ICCV 2021
364
+ - **Code:** [GitHub Repository](https://github.com/ChrisWu1997/DeepCAD)
365
+
366
+ ### ℹ️ About
367
+ This is a simplified deployment for demonstration. For full functionality including:
368
+ - 3D STEP file export
369
+ - Complete evaluation metrics
370
+ - Training your own models
371
+
372
+ Please visit the [official GitHub repository](https://github.com/ChrisWu1997/DeepCAD).
373
+ """)
374
+
375
+ # Connect the generate button
376
+ generate_btn.click(
377
+ fn=generate_cad,
378
+ inputs=[seed_input, temperature_input],
379
+ outputs=[plot_output, h5_output, info_output]
380
+ )
381
+
382
+ # Add examples
383
+ gr.Examples(
384
+ examples=[
385
+ [42, 1.0],
386
+ [123, 0.8],
387
+ [999, 1.2],
388
+ [2024, 1.5],
389
+ ],
390
+ inputs=[seed_input, temperature_input],
391
+ outputs=[plot_output, h5_output, info_output],
392
+ fn=generate_cad,
393
+ cache_examples=False,
394
+ )
395
+
396
+ # Launch the app
397
+ if __name__ == "__main__":
398
+ print("\n" + "="*50)
399
+ print("πŸš€ Starting DeepCAD Gradio Interface")
400
+ print(f"πŸ“ Device: {DEVICE}")
401
+ print("="*50 + "\n")
402
+
403
+ demo.launch(
404
+ server_name="0.0.0.0",
405
+ server_port=7860,
406
+ share=False
407
+ )