import gradio as gr import numpy as np from PIL import Image import torch import gc import os import warnings # Suppress specific warnings warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', message='.*torch_dtype.*deprecated.*') warnings.filterwarnings('ignore', message='.*CLIPFeatureExtractor.*deprecated.*') # Performance optimizations if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if torch.cuda.is_available() else torch.float32 print(f"🖥️ Device: {device} | dtype: {dtype}") # Lazy import from diffusers import ( StableDiffusionControlNetPipeline, ControlNetModel ) from diffusers import EulerAncestralDiscreteScheduler from controlnet_aux import ( LineartDetector, LineartAnimeDetector, OpenposeDetector, MidasDetector, CannyDetector, MLSDdetector, HEDdetector, PidiNetDetector, NormalBaeDetector, ZoeDetector, MediapipeFaceDetector ) # Memory optimization if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.set_per_process_memory_fraction(0.95) print(f"🔥 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") else: print("⚠️ Running on CPU - Image generation will be significantly slower") # ===== Model & Config ===== CURRENT_CONTROLNET_PIPE = None CURRENT_CONTROLNET_KEY = None # SD1.5 Models Only SD15_MODELS = [ "digiplay/ChikMix_V3", "digiplay/chilloutmix_NiPrunedFp16Fix", "gsdf/Counterfeit-V2.5", "stablediffusionapi/anything-v5", "runwayml/stable-diffusion-v1-5", "stablediffusionapi/realistic-vision-v51", "stablediffusionapi/dreamshaper-v8", "stablediffusionapi/henmix-real-v11", "stablediffusionapi/rev-animated-v122", "stablediffusionapi/cyberrealistic-v33", "stablediffusionapi/meinamix-meina-v11", "prompthero/openjourney-v4", "wavymulder/Analog-Diffusion", "dreamlike-art/dreamlike-photoreal-2.0", "SG161222/Realistic_Vision_V5.1_noVAE", "Lykon/dreamshaper-8", "hakurei/waifu-diffusion", "andite/anything-v4.0" ] # ControlNet models CONTROLNET_MODELS = { "lineart": "lllyasviel/control_v11p_sd15_lineart", "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime", "canny": "lllyasviel/control_v11p_sd15_canny", "depth": "lllyasviel/control_v11p_sd15_depth", "normal": "lllyasviel/control_v11p_sd15_normalbae", "openpose": "lllyasviel/control_v11p_sd15_openpose", "softedge": "lllyasviel/control_v11p_sd15_softedge", "scribble": "lllyasviel/control_v11p_sd15_scribble", "tile": "lllyasviel/control_v11f1e_sd15_tile" } # Detector instances DETECTORS = {} def load_detector(detector_type: str): """Lazy load detector""" global DETECTORS if detector_type in DETECTORS: return DETECTORS[detector_type] print(f"📥 Loading {detector_type} detector...") try: if detector_type == "lineart": DETECTORS[detector_type] = LineartDetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "lineart_anime": DETECTORS[detector_type] = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "openpose": DETECTORS[detector_type] = OpenposeDetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "depth": DETECTORS[detector_type] = MidasDetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "canny": DETECTORS[detector_type] = CannyDetector() elif detector_type == "normal": DETECTORS[detector_type] = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "hed": DETECTORS[detector_type] = HEDdetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "pidi": DETECTORS[detector_type] = PidiNetDetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "mlsd": DETECTORS[detector_type] = MLSDdetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "zoe": DETECTORS[detector_type] = ZoeDetector.from_pretrained("lllyasviel/Annotators") elif detector_type == "face": DETECTORS[detector_type] = MediapipeFaceDetector() else: raise ValueError(f"Unknown detector type: {detector_type}") return DETECTORS[detector_type] except Exception as e: print(f"❌ Error loading {detector_type} detector: {e}") return None def get_controlnet_model(controlnet_type: str): """Get ControlNet model based on type""" if controlnet_type in CONTROLNET_MODELS: return CONTROLNET_MODELS[controlnet_type] else: raise ValueError(f"Unknown ControlNet type: {controlnet_type}") def prepare_condition_image(image, controlnet_type): """Prepare condition image for ControlNet""" if controlnet_type in ["lineart", "lineart_anime"]: detector = load_detector("lineart_anime" if controlnet_type == "lineart_anime" else "lineart") if detector: result = detector(image, detect_resolution=512, image_resolution=512) return Image.fromarray(result) if isinstance(result, np.ndarray) else result elif controlnet_type == "canny": detector = load_detector("canny") if detector: result = detector(image, detect_resolution=512, image_resolution=512) return Image.fromarray(result) if isinstance(result, np.ndarray) else result elif controlnet_type == "depth": detector = load_detector("depth") if detector: result = detector(image, detect_resolution=512, image_resolution=512) return Image.fromarray(result) if isinstance(result, np.ndarray) else result elif controlnet_type == "normal": detector = load_detector("normal") if detector: result = detector(image, detect_resolution=512, image_resolution=512) return Image.fromarray(result) if isinstance(result, np.ndarray) else result elif controlnet_type == "openpose": detector = load_detector("openpose") if detector: result = detector(image, detect_resolution=512, image_resolution=512) return Image.fromarray(result) if isinstance(result, np.ndarray) else result return image def get_pipeline(model_name: str, controlnet_type: str = "lineart"): """Get or create a ControlNet pipeline""" global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY key = (model_name, controlnet_type) if CURRENT_CONTROLNET_KEY == key and CURRENT_CONTROLNET_PIPE is not None: print(f"✅ Reusing existing ControlNet pipeline: {model_name}, type: {controlnet_type}") return CURRENT_CONTROLNET_PIPE if CURRENT_CONTROLNET_PIPE is not None: print(f"🗑️ Unloading old ControlNet pipeline: {CURRENT_CONTROLNET_KEY}") del CURRENT_CONTROLNET_PIPE CURRENT_CONTROLNET_PIPE = None CURRENT_CONTROLNET_KEY = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"📥 Loading ControlNet pipeline for model: {model_name}, type: {controlnet_type}") try: controlnet_model_name = get_controlnet_model(controlnet_type) controlnet = ControlNetModel.from_pretrained( controlnet_model_name, torch_dtype=dtype ).to(device) pipe = StableDiffusionControlNetPipeline.from_pretrained( model_name, controlnet=controlnet, torch_dtype=dtype, safety_checker=None, requires_safety_checker=False, use_safetensors=True, variant="fp16" if dtype == torch.float16 else None ).to(device) # Optimizations pipe.enable_attention_slicing(slice_size="max") if hasattr(pipe, 'vae') and hasattr(pipe.vae, 'enable_slicing'): pipe.vae.enable_slicing() else: try: pipe.enable_vae_slicing() except: pass if device.type == "cuda": try: pipe.enable_xformers_memory_efficient_attention() print("✅ xFormers enabled") except: pass pipe.enable_model_cpu_offload() try: pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) print("✅ Using Euler Ancestral scheduler") except: pass CURRENT_CONTROLNET_PIPE = pipe CURRENT_CONTROLNET_KEY = key return pipe except Exception as e: print(f"❌ Error loading ControlNet pipeline: {e}") CURRENT_CONTROLNET_PIPE = None CURRENT_CONTROLNET_KEY = None raise def colorize(sketch, base_model, controlnet_type, prompt, negative_prompt, seed, steps, scale, cn_weight): try: pipe = get_pipeline(base_model, controlnet_type) status_msg = f"🎨 Using: {base_model} + {controlnet_type}" print(status_msg) condition_img = prepare_condition_image(sketch, controlnet_type) gen = torch.Generator(device=device).manual_seed(int(seed)) with torch.inference_mode(): out = pipe( prompt, negative_prompt=negative_prompt, image=condition_img, num_inference_steps=int(steps), guidance_scale=float(scale), controlnet_conditioning_scale=float(cn_weight), generator=gen, height=512, width=512 ).images[0] if device.type == "cuda": torch.cuda.empty_cache() return out, condition_img except Exception as e: print(f"❌ Error in colorize: {e}") error_img = Image.new('RGB', (512, 512), color='red') return error_img, Image.new('RGB', (512, 512), color='gray') def unload_all_models(): global CURRENT_CONTROLNET_PIPE, CURRENT_CONTROLNET_KEY global DETECTORS print("🗑️ Unloading all models from memory...") try: if CURRENT_CONTROLNET_PIPE is not None: del CURRENT_CONTROLNET_PIPE CURRENT_CONTROLNET_PIPE = None except: pass CURRENT_CONTROLNET_KEY = None for detector_type in list(DETECTORS.keys()): try: del DETECTORS[detector_type] except: pass DETECTORS.clear() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() allocated = torch.cuda.memory_allocated() / 1024**3 reserved = torch.cuda.memory_reserved() / 1024**3 print(f"💾 GPU memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB") return "✅ All models unloaded from memory!" # ===== Gradio UI ===== with gr.Blocks(title="🎨 AI Image Generator", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 AI Image Generator - ControlNet Edition") gr.Markdown("### Transform sketches/images into detailed artwork") if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 gr.Markdown(f"**GPU:** {gpu_name} ({gpu_memory:.1f} GB)") else: gr.Markdown("**⚠️ Running on CPU** - Generation will be slower") with gr.Row(): unload_btn = gr.Button("🗑️ Unload All Models", variant="stop", scale=1) status_text = gr.Textbox(label="Status", interactive=False, scale=3) unload_btn.click(unload_all_models, outputs=status_text) gr.Markdown("### Transform sketches/images using ControlNet") with gr.Row(): with gr.Column(scale=1): inp = gr.Image(label="Input Sketch/Image", type="pil") gr.Markdown("### Model Settings") base_model = gr.Dropdown( choices=SD15_MODELS, value="digiplay/ChikMix_V3", label="Base Model" ) controlnet_type = gr.Dropdown( choices=list(CONTROLNET_MODELS.keys()), value="lineart_anime", label="ControlNet Type" ) with gr.Column(scale=1): out = gr.Image(label="Generated Output") condition_out = gr.Image(label="Processed Condition", type="pil") gr.Markdown("### Generation Parameters") with gr.Row(): prompt = gr.Textbox( label="Prompt", placeholder="masterpiece, best quality, 1girl, beautiful detailed eyes, long hair", lines=3 ) negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="lowres, bad anatomy, bad hands, text, error, missing fingers", lines=3 ) with gr.Row(): seed = gr.Number(value=-1, label="Seed (-1 for random)") steps = gr.Slider(10, 150, 30, step=1, label="Steps") scale = gr.Slider(1, 30, 7.5, step=0.5, label="CFG Scale") cn_weight = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="ControlNet Weight") run = gr.Button("🎨 Generate", variant="primary", size="lg") run.click( colorize, [inp, base_model, controlnet_type, prompt, negative_prompt, seed, steps, scale, cn_weight], [out, condition_out] ) gr.Markdown(""" ### Tips for Better Results: - Use detailed prompts for better control - Adjust ControlNet weight to balance between condition and creativity - Try different ControlNet types for different inputs - Higher steps = better quality but slower generation - For line drawings: use **lineart** or **lineart_anime** - For photos: use **canny** or **depth** - For human poses: use **openpose** """) try: demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, quiet=False ) except Exception as e: print(f"❌ Error launching Gradio app: {e}")