| import json |
| import os |
| import subprocess |
| import time |
| import uuid |
| import zipfile |
| from dataclasses import fields |
| from urllib.request import urlretrieve |
|
|
| import gradio as gr |
| import torch.multiprocessing as mp |
| import transformers |
| from legogpt.models import LegoGPT, LegoGPTConfig |
|
|
|
|
| def setup(): |
| |
| licence_filename = 'gurobi.lic' |
| licence_lines = [] |
| for secret_name in ['WLSACCESSID', 'WLSSECRET', 'LICENSEID']: |
| secret = os.environ.get(secret_name) |
| if not secret: |
| raise ValueError(f'Env variable {secret_name} not found. Please set it in the Hugging Face Space settings.') |
| licence_lines.append(f'{secret_name}={secret}\n') |
| with open(licence_filename, 'w') as f: |
| f.writelines(licence_lines) |
| os.environ['GRB_LICENSE_FILE'] = os.path.abspath(licence_filename) |
|
|
| |
| ldraw_zip_url = 'https://library.ldraw.org/library/updates/complete.zip' |
| ldraw_zip_filename = 'complete.zip' |
| urlretrieve(ldraw_zip_url, ldraw_zip_filename) |
| with zipfile.ZipFile(ldraw_zip_filename) as zip_ref: |
| zip_ref.extractall() |
| os.environ['LDRAW_LIBRARY_PATH'] = os.path.abspath('ldraw') |
|
|
|
|
| def main(): |
| if os.environ.get('IS_HF_SPACE') == '1': |
| print('Running in Hugging Face Space, setting up environment...') |
| setup() |
|
|
| model_cfg = LegoGPTConfig(max_regenerations=5) |
| generator = LegoGenerator(LegoGPT(model_cfg)) |
|
|
| |
| in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a LEGO model.') |
| in_temperature = gr.Slider(0.01, 2.0, value=model_cfg.temperature, step=0.01, |
| label='Temperature', info=get_help_string('temperature')) |
| in_seed = gr.Number(value=42, label='Seed', info='Random seed for generation.', precision=0, step=1) |
| in_bricks = gr.Number(value=model_cfg.max_bricks, label='Max bricks', info=get_help_string('max_bricks'), |
| precision=0, minimum=1, step=1) |
| in_rejections = gr.Number(value=model_cfg.max_brick_rejections, label='Max brick rejections', |
| info=get_help_string('max_brick_rejections'), precision=0, minimum=0, step=1) |
| in_regenerations = gr.Number(value=model_cfg.max_regenerations, label='Max regenerations', |
| info=get_help_string('max_regenerations'), precision=0, minimum=0, step=1) |
| out_img = gr.Image(label='Output image', format='png') |
| out_txt = gr.Textbox(label='Output LEGO bricks', lines=5, max_lines=5, show_copy_button=True, |
| info='The LEGO structure in text format. Each line of the form "hxw (x,y,z)" represents a ' |
| '1-unit-tall rectangular brick with dimensions hxw placed at coordinates (x,y,z).') |
|
|
| |
| demo = gr.Interface( |
| fn=generator.generate_lego_subprocess, |
| title='LegoGPT Demo', |
| description='Official demo for [LegoGPT](https://avalovelace1.github.io/LegoGPT/), the first approach for generating physically stable LEGO brick models from text prompts.\n\n' |
| 'The model is restricted to creating structures made of 1-unit-tall cuboid bricks on a 20x20x20 grid. It was trained on a dataset of 21 object categories: ' |
| '*basket, bed, bench, birdhouse, bookshelf, bottle, bowl, bus, camera, car, chair, guitar, jar, mug, piano, pot, sofa, table, tower, train, vessel.* ' |
| 'Performance on prompts from outside these categories may be limited. This demo does not include texturing or coloring.', |
| inputs=[in_prompt], |
| additional_inputs=[in_temperature, in_seed, in_bricks, in_rejections, in_regenerations], |
| outputs=[out_img, out_txt], |
| flagging_mode='never', |
| ) |
| with demo: |
| with gr.Row(): |
| examples = get_examples() |
| dummy_name = gr.Textbox(visible=False, label='Name') |
| dummy_out_img = gr.Image(visible=False, label='Result') |
| gr.Examples( |
| examples=[[name, example['prompt'], example['temperature'], example['seed'], example['output_img']] |
| for name, example in examples.items()], |
| inputs=[dummy_name, in_prompt, in_temperature, in_seed, dummy_out_img], |
| outputs=[out_img, out_txt], |
| fn=lambda *args: (args[-1], examples[args[0]]['output_txt']), |
| run_on_click=True, |
| ) |
|
|
| concurrency_limit = 2 if os.environ.get('CONCURRENCY_LIMIT') is None else int(os.environ.get('CONCURRENCY_LIMIT')) |
| demo.queue(default_concurrency_limit=concurrency_limit) |
| demo.launch(share=True) |
|
|
|
|
| class LegoGenerator: |
| def __init__(self, model: LegoGPT): |
| self.model = model |
| self.ctx = mp.get_context('spawn') |
|
|
| def generate_lego( |
| self, |
| prompt: str, |
| temperature: float | None, |
| seed: int | None, |
| max_bricks: int | None, |
| max_brick_rejections: int | None, |
| max_regenerations: int | None, |
| ): |
| |
| if temperature is not None: self.model.temperature = temperature |
| if max_bricks is not None: self.model.max_bricks = max_bricks |
| if max_brick_rejections is not None: self.model.max_brick_rejections = max_brick_rejections |
| if max_regenerations is not None: self.model.max_regenerations = max_regenerations |
| if seed is not None: transformers.set_seed(seed) |
|
|
| |
| print(f'Generating LEGO for prompt: "{prompt}"') |
| start_time = time.time() |
| output = self.model(prompt) |
|
|
| |
| output_dir = os.path.abspath('out') |
| output_uuid = str(uuid.uuid4()) |
| os.makedirs(output_dir, exist_ok=True) |
| ldr_filename = os.path.join(output_dir, f'{output_uuid}.ldr') |
| with open(ldr_filename, 'w') as f: |
| f.write(output['lego'].to_ldr()) |
| print(f'Finished generation in {time.time() - start_time:.1f}s!') |
|
|
| |
| print('Rendering image...') |
| start_time = time.time() |
| img_filename = os.path.join(output_dir, f'{output_uuid}.png') |
| subprocess.run(['python', 'render_lego.py', '--in_file', ldr_filename, '--out_file', img_filename], |
| check=True) |
| print(f'Finished rendering in {time.time() - start_time:.1f}s!') |
|
|
| return img_filename, output['lego'] |
|
|
| def generate_lego_subprocess(self, *args): |
| """ |
| Run generation as a subprocess so that multiple requests can be handled concurrently. |
| """ |
| with self.ctx.Pool(1) as pool: |
| return pool.starmap(self.generate_lego, [args])[0] |
|
|
|
|
| def get_help_string(field_name: str) -> str: |
| """ |
| :param field_name: Name of a field in LegoGPTConfig. |
| :return: Help string for the field. |
| """ |
| data_fields = fields(LegoGPTConfig) |
| name_field = next(f for f in data_fields if f.name == field_name) |
| return name_field.metadata['help'] |
|
|
|
|
| def get_examples(example_dir: str = os.path.abspath('examples')) -> dict[str, dict[str, str]]: |
| examples_file = os.path.join(example_dir, 'examples.json') |
| with open(examples_file) as f: |
| examples = json.load(f) |
|
|
| for example in examples.values(): |
| example['output_img'] = os.path.join(example_dir, example['output_img']) |
| return examples |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|