Spaces:
Running on Zero
Running on Zero
| import os | |
| import sys | |
| import re | |
| import json | |
| import random | |
| import logging | |
| import warnings | |
| import traceback | |
| import threading | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| # ==================== spaces 兼容处理 ==================== | |
| # 在 HuggingFace Spaces 上会有 spaces 包; | |
| # 本地运行时如果没有 spaces,也不会直接崩溃。 | |
| try: | |
| import spaces # type: ignore | |
| except Exception: | |
| class _SpacesFallback: | |
| def GPU(fn=None, **kwargs): | |
| if fn is None: | |
| return lambda f: f | |
| return fn | |
| def aoti_blocks_load(*args, **kwargs): | |
| raise RuntimeError("spaces.aoti_blocks_load is unavailable outside HuggingFace Spaces.") | |
| spaces = _SpacesFallback() # type: ignore | |
| from diffusers import ( | |
| AutoencoderKL, | |
| DiffusionPipeline, | |
| FlowMatchEulerDiscreteScheduler, | |
| ) | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # ------------------------- 可选依赖:Prompt Enhancer 模板 ------------------------- | |
| # 如果你的工程里有 pe.py,会自动使用; | |
| # 没有也不会报错,Prompt Enhance 默认关闭。 | |
| try: | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from pe import prompt_template # type: ignore | |
| except Exception: | |
| prompt_template = ( | |
| "You are a helpful prompt engineer. Expand the user prompt into a richer, detailed prompt. " | |
| "Return JSON with key revised_prompt." | |
| ) | |
| # ==================== Environment Variables ==================== | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") | |
| # 关键修复: | |
| # 1. 默认关闭 compile,避免首轮加载超时、编译失败、ZeroGPU 兼容问题。 | |
| # 2. 如确认环境稳定,可在 Space Variables 中设置 ENABLE_COMPILE=true。 | |
| ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true" | |
| # 关键修复: | |
| # 默认关闭 warmup。原代码会遍历大量分辨率进行预热,非常容易导致启动失败。 | |
| ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() == "true" | |
| # 默认 native 最稳。若你的环境确认支持 flash_3,可设置 ATTENTION_BACKEND=flash_3。 | |
| ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "native") | |
| # ZeroGPU AoTI:默认尝试启用,但失败不会影响主流程。 | |
| ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "true").lower() == "true" | |
| # Safety checker 会额外占用内存,默认关闭,防止把主模型加载拖死。 | |
| # 如需要可设置 ENABLE_SAFETY_CHECKER=true。 | |
| ENABLE_SAFETY_CHECKER = os.environ.get("ENABLE_SAFETY_CHECKER", "false").lower() == "true" | |
| # 优先使用 DiffusionPipeline 加载;失败后再回退到手动组件加载。 | |
| USE_DIFFUSION_PIPELINE = os.environ.get("USE_DIFFUSION_PIPELINE", "true").lower() == "true" | |
| # 生成历史图片数量,避免 Gallery 越堆越多占内存。 | |
| MAX_GALLERY_HISTORY = int(os.environ.get("MAX_GALLERY_HISTORY", "8")) | |
| DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # =============================================================== | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
| pipe = None | |
| prompt_expander = None | |
| model_lock = threading.Lock() | |
| MODEL_LOAD_ERROR = "" | |
| RES_CHOICES = { | |
| "1024": [ | |
| "1024x1024 ( 1:1 )", | |
| "1152x896 ( 9:7 )", | |
| "896x1152 ( 7:9 )", | |
| "1152x864 ( 4:3 )", | |
| "864x1152 ( 3:4 )", | |
| "1248x832 ( 3:2 )", | |
| "832x1248 ( 2:3 )", | |
| "1280x720 ( 16:9 )", | |
| "720x1280 ( 9:16 )", | |
| "1344x576 ( 21:9 )", | |
| "576x1344 ( 9:21 )", | |
| ], | |
| "1280": [ | |
| "1280x1280 ( 1:1 )", | |
| "1440x1120 ( 9:7 )", | |
| "1120x1440 ( 7:9 )", | |
| "1472x1104 ( 4:3 )", | |
| "1104x1472 ( 3:4 )", | |
| "1536x1024 ( 3:2 )", | |
| "1024x1536 ( 2:3 )", | |
| "1536x864 ( 16:9 )", | |
| "864x1536 ( 9:16 )", | |
| "1680x720 ( 21:9 )", | |
| "720x1680 ( 9:21 )", | |
| ], | |
| "1536": [ | |
| "1536x1536 ( 1:1 )", | |
| "1728x1344 ( 9:7 )", | |
| "1344x1728 ( 7:9 )", | |
| "1728x1296 ( 4:3 )", | |
| "1296x1728 ( 3:4 )", | |
| "1872x1248 ( 3:2 )", | |
| "1248x1872 ( 2:3 )", | |
| "2048x1152 ( 16:9 )", | |
| "1152x2048 ( 9:16 )", | |
| "2016x864 ( 21:9 )", | |
| "864x2016 ( 9:21 )", | |
| ], | |
| } | |
| RESOLUTION_SET: List[str] = [] | |
| for _k, _items in RES_CHOICES.items(): | |
| RESOLUTION_SET.extend(_items) | |
| EXAMPLE_PROMPTS = [ | |
| ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"], | |
| ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"], | |
| ] | |
| def refresh_runtime_device() -> Tuple[str, torch.dtype]: | |
| """ | |
| 关键修复: | |
| 在 ZeroGPU 环境中,应用启动阶段可能没有 CUDA; | |
| 只有进入 @spaces.GPU 函数后,CUDA 才可能可见。 | |
| 因此必须在生成函数内部重新判断 DEVICE / DTYPE。 | |
| """ | |
| global DEVICE, DTYPE | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
| print(f"[Runtime] DEVICE={DEVICE}, DTYPE={DTYPE}") | |
| return DEVICE, DTYPE | |
| def cuda_cleanup(): | |
| """ | |
| 出错后尽量释放 CUDA 缓存,避免后续请求继续失败。 | |
| """ | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| except Exception: | |
| pass | |
| def is_local_model_path(model_path: str) -> bool: | |
| return os.path.isdir(model_path) | |
| def hf_token_candidates() -> List[Dict[str, str]]: | |
| """ | |
| 兼容不同版本的 transformers / diffusers。 | |
| 新版本一般使用 token; | |
| 旧版本可能使用 use_auth_token。 | |
| 不要同时传 token 和 use_auth_token,否则部分环境会报错。 | |
| """ | |
| if not HF_TOKEN: | |
| return [{}] | |
| return [{"token": HF_TOKEN}, {"use_auth_token": HF_TOKEN}, {}] | |
| def get_resolution(resolution: str) -> Tuple[int, int]: | |
| match = re.search(r"(\d+)\s*[×x]\s*(\d+)", str(resolution)) | |
| if match: | |
| return int(match.group(1)), int(match.group(2)) | |
| return 1024, 1024 | |
| def _make_blocked_image(width=1024, height=1024, text="Blocked by Safety Checker") -> Image.Image: | |
| img = Image.new("RGB", (width, height), (20, 20, 20)) | |
| draw = ImageDraw.Draw(img) | |
| try: | |
| font = ImageFont.load_default() | |
| except Exception: | |
| font = None | |
| draw.rectangle([0, 0, width, 90], fill=(160, 0, 0)) | |
| draw.text((20, 30), text, fill=(255, 255, 255), font=font) | |
| return img | |
| def _load_nsfw_placeholder(width=1024, height=1024) -> Image.Image: | |
| """ | |
| 命中 NSFW 时优先加载工作目录的 nsfw.png; | |
| 不存在就生成一张占位图,避免文件缺失导致再次报错。 | |
| """ | |
| if os.path.exists("nsfw.png"): | |
| try: | |
| return Image.open("nsfw.png").convert("RGB") | |
| except Exception: | |
| pass | |
| return _make_blocked_image(width, height, "NSFW blocked") | |
| def _move_pipeline_to_device(p) -> Any: | |
| """ | |
| 兼容不同 diffusers 版本的 .to() 调用方式。 | |
| """ | |
| if p is None: | |
| return p | |
| # 如果使用 device_map 加载,通常不要再强行 .to() | |
| if getattr(p, "hf_device_map", None): | |
| print(f"[Init] Pipeline already has hf_device_map: {getattr(p, 'hf_device_map', None)}") | |
| return p | |
| if DEVICE == "cuda": | |
| attempts = [ | |
| lambda: p.to("cuda"), | |
| lambda: p.to(torch_dtype=DTYPE), | |
| lambda: p.to(device="cuda"), | |
| lambda: p.to("cuda", torch_dtype=DTYPE), | |
| ] | |
| else: | |
| attempts = [ | |
| lambda: p.to("cpu"), | |
| lambda: p.to(torch_dtype=torch.float32), | |
| lambda: p.to(device="cpu"), | |
| ] | |
| last_error = None | |
| for fn in attempts: | |
| try: | |
| p = fn() | |
| return p | |
| except Exception as e: | |
| last_error = e | |
| print(f"[Init] Warning: pipeline.to(...) failed, continue anyway. Error: {last_error}") | |
| return p | |
| def _set_attention_backend_if_possible(p, backend: str) -> None: | |
| """ | |
| attention backend 不是所有环境都支持。 | |
| 失败时自动回退 native,仍失败也不阻塞主流程。 | |
| """ | |
| if not p: | |
| return | |
| transformer = getattr(p, "transformer", None) | |
| if transformer is None: | |
| return | |
| if not hasattr(transformer, "set_attention_backend"): | |
| print("[Init] Transformer has no set_attention_backend method, skip.") | |
| return | |
| try: | |
| transformer.set_attention_backend(backend) | |
| print(f"[Init] Attention backend set to: {backend}") | |
| return | |
| except Exception as e: | |
| print(f"[Init] set_attention_backend('{backend}') failed: {e}") | |
| try: | |
| transformer.set_attention_backend("native") | |
| print("[Init] Attention backend fallback to: native") | |
| except Exception as e: | |
| print(f"[Init] set_attention_backend('native') also failed, ignored: {e}") | |
| def _compile_transformer_if_possible(p) -> Any: | |
| """ | |
| torch.compile 可能加速,也可能导致首轮非常慢或直接失败。 | |
| 因此默认关闭,且失败时不影响主流程。 | |
| """ | |
| if not ENABLE_COMPILE: | |
| return p | |
| if DEVICE != "cuda": | |
| print("[Init] ENABLE_COMPILE=true but DEVICE is not cuda, skip compile.") | |
| return p | |
| transformer = getattr(p, "transformer", None) | |
| if transformer is None: | |
| print("[Init] No transformer found, skip compile.") | |
| return p | |
| try: | |
| print("[Init] Enabling torch.compile optimizations...") | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.max_autotune_gemm = True | |
| torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" | |
| torch._inductor.config.triton.cudagraphs = False | |
| p.transformer = torch.compile( | |
| transformer, | |
| mode="max-autotune-no-cudagraphs", | |
| fullgraph=False, | |
| ) | |
| print("[Init] Transformer compiled.") | |
| except Exception: | |
| print("[Init] torch.compile failed, continue without compile:") | |
| traceback.print_exc() | |
| return p | |
| def try_enable_aoti(p) -> None: | |
| """ | |
| AoTI / ZeroGPU 加速。 | |
| 可用则启用,不可用则跳过。 | |
| """ | |
| if not ENABLE_AOTI: | |
| print("[Init] ENABLE_AOTI=false, skip AoTI.") | |
| return | |
| if p is None: | |
| return | |
| try: | |
| transformer = getattr(p, "transformer", None) | |
| if transformer is None: | |
| print("[Init] No transformer found, skip AoTI.") | |
| return | |
| target = None | |
| if hasattr(transformer, "layers"): | |
| target = transformer.layers | |
| if hasattr(target, "_repeated_blocks"): | |
| target._repeated_blocks = ["ZImageTransformerBlock"] | |
| else: | |
| target = transformer | |
| if hasattr(target, "_repeated_blocks"): | |
| target._repeated_blocks = ["ZImageTransformerBlock"] | |
| if target is not None: | |
| spaces.aoti_blocks_load(target, "zerogpu-aoti/Z-Image", variant="fa3") | |
| print("[Init] AoTI blocks loaded.") | |
| except Exception: | |
| print("[Init] AoTI not enabled, safe to ignore:") | |
| traceback.print_exc() | |
| def _load_safety_checker_if_enabled(p) -> Any: | |
| """ | |
| Safety checker 默认关闭,因为它会额外占用内存。 | |
| 即使开启,加载失败也不影响主模型。 | |
| """ | |
| if not ENABLE_SAFETY_CHECKER: | |
| print("[Init] ENABLE_SAFETY_CHECKER=false, skip safety checker.") | |
| try: | |
| p.safety_feature_extractor = None | |
| p.safety_checker = None | |
| except Exception: | |
| pass | |
| return p | |
| try: | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| try: | |
| from transformers import CLIPImageProcessor as _CLIPProcessor | |
| except Exception: | |
| from transformers import CLIPFeatureExtractor as _CLIPProcessor # type: ignore | |
| safety_model_id = "CompVis/stable-diffusion-safety-checker" | |
| last_error = None | |
| for token_kwargs in hf_token_candidates(): | |
| try: | |
| safety_feature_extractor = _CLIPProcessor.from_pretrained( | |
| safety_model_id, | |
| **token_kwargs, | |
| ) | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
| safety_model_id, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| **token_kwargs, | |
| ) | |
| safety_checker = safety_checker.to(DEVICE) | |
| p.safety_feature_extractor = safety_feature_extractor | |
| p.safety_checker = safety_checker | |
| print("[Init] Safety checker loaded.") | |
| return p | |
| except Exception as e: | |
| last_error = e | |
| raise RuntimeError(f"Safety checker load failed: {last_error}") | |
| except Exception: | |
| print("[Init] Safety checker init failed. NSFW filtering will be skipped:") | |
| traceback.print_exc() | |
| try: | |
| p.safety_feature_extractor = None | |
| p.safety_checker = None | |
| except Exception: | |
| pass | |
| return p | |
| def _load_with_diffusion_pipeline(model_path: str) -> Any: | |
| """ | |
| 优先使用官方推荐的 DiffusionPipeline 加载方式。 | |
| 做多组参数尝试,以兼容 diffusers 新旧版本。 | |
| """ | |
| print("[Init] Trying DiffusionPipeline loading strategy...") | |
| local = is_local_model_path(model_path) | |
| token_candidates = [{}] if local else hf_token_candidates() | |
| dtype_candidates: List[Dict[str, Any]] = [] | |
| if DEVICE == "cuda": | |
| dtype_candidates.extend([ | |
| # 新版 diffusers 某些文档使用 dtype | |
| {"dtype": DTYPE, "device_map": "cuda"}, | |
| {"dtype": DTYPE}, | |
| # 旧版常用 torch_dtype | |
| {"torch_dtype": DTYPE, "device_map": "cuda"}, | |
| {"torch_dtype": DTYPE}, | |
| # 某些 Z-Image 示例需要 low_cpu_mem_usage=False | |
| {"torch_dtype": DTYPE, "low_cpu_mem_usage": False}, | |
| {"dtype": DTYPE, "low_cpu_mem_usage": False}, | |
| # 兼容 custom pipeline / older discussions | |
| {"torch_dtype": DTYPE, "trust_remote_code": True}, | |
| {"dtype": DTYPE, "trust_remote_code": True}, | |
| ]) | |
| else: | |
| dtype_candidates.extend([ | |
| {"torch_dtype": torch.float32}, | |
| {"dtype": torch.float32}, | |
| {"torch_dtype": torch.float32, "low_cpu_mem_usage": False}, | |
| {}, | |
| ]) | |
| errors: List[str] = [] | |
| for token_kwargs in token_candidates: | |
| for extra_kwargs in dtype_candidates: | |
| kwargs: Dict[str, Any] = {} | |
| kwargs.update(token_kwargs) | |
| kwargs.update(extra_kwargs) | |
| try: | |
| print(f"[Init] DiffusionPipeline.from_pretrained kwargs={list(kwargs.keys())}") | |
| p = DiffusionPipeline.from_pretrained(model_path, **kwargs) | |
| print("[Init] DiffusionPipeline loaded.") | |
| p = _move_pipeline_to_device(p) | |
| return p | |
| except Exception as e: | |
| err = f"kwargs={kwargs} -> {type(e).__name__}: {e}" | |
| print(f"[Init] DiffusionPipeline attempt failed: {err}") | |
| errors.append(err) | |
| raise RuntimeError( | |
| "All DiffusionPipeline loading attempts failed.\n" | |
| + "\n".join(errors[-8:]) | |
| ) | |
| def _from_pretrained_component(cls, path_or_repo: str, subfolder: Optional[str], torch_dtype: Optional[torch.dtype]) -> Any: | |
| """ | |
| 手动组件加载的兼容封装。 | |
| """ | |
| local = is_local_model_path(path_or_repo) | |
| if local: | |
| load_path = os.path.join(path_or_repo, subfolder) if subfolder else path_or_repo | |
| kwargs: Dict[str, Any] = {} | |
| if torch_dtype is not None: | |
| kwargs["torch_dtype"] = torch_dtype | |
| return cls.from_pretrained(load_path, **kwargs) | |
| errors = [] | |
| for token_kwargs in hf_token_candidates(): | |
| kwargs = {} | |
| if subfolder: | |
| kwargs["subfolder"] = subfolder | |
| if torch_dtype is not None: | |
| kwargs["torch_dtype"] = torch_dtype | |
| kwargs.update(token_kwargs) | |
| try: | |
| return cls.from_pretrained(path_or_repo, **kwargs) | |
| except Exception as e: | |
| errors.append(f"{type(e).__name__}: {e}") | |
| raise RuntimeError( | |
| f"Failed to load component {cls} subfolder={subfolder}. " | |
| + " | ".join(errors[-4:]) | |
| ) | |
| def _load_with_manual_components(model_path: str) -> Any: | |
| """ | |
| 回退方案:按你原来的方式手动加载 VAE、text_encoder、tokenizer、transformer。 | |
| 如果 diffusers 环境里没有 ZImagePipeline / ZImageTransformer2DModel,会在这里给出明确错误。 | |
| """ | |
| print("[Init] Trying manual component loading strategy...") | |
| try: | |
| from diffusers import ZImagePipeline # type: ignore | |
| from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel # type: ignore | |
| except Exception as e: | |
| raise RuntimeError( | |
| "Current diffusers does not provide ZImagePipeline / ZImageTransformer2DModel. " | |
| "Please upgrade diffusers, for example: pip install -U diffusers transformers accelerate" | |
| ) from e | |
| model_dtype = DTYPE if DEVICE == "cuda" else torch.float32 | |
| vae = _from_pretrained_component( | |
| AutoencoderKL, | |
| model_path, | |
| "vae", | |
| model_dtype, | |
| ) | |
| text_encoder = _from_pretrained_component( | |
| AutoModelForCausalLM, | |
| model_path, | |
| "text_encoder", | |
| model_dtype, | |
| ).eval() | |
| tokenizer = _from_pretrained_component( | |
| AutoTokenizer, | |
| model_path, | |
| "tokenizer", | |
| None, | |
| ) | |
| tokenizer.padding_side = "left" | |
| p = ZImagePipeline( | |
| scheduler=None, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| transformer=None, | |
| ) | |
| transformer = _from_pretrained_component( | |
| ZImageTransformer2DModel, | |
| model_path, | |
| "transformer", | |
| None, | |
| ) | |
| transformer = transformer.to(DEVICE, DTYPE) | |
| p.transformer = transformer | |
| p = _move_pipeline_to_device(p) | |
| print("[Init] Manual component loading finished.") | |
| return p | |
| def load_models(model_path: str) -> Any: | |
| """ | |
| 统一模型加载入口。 | |
| 先尝试 DiffusionPipeline,失败后回退手动组件加载。 | |
| """ | |
| print("=" * 80) | |
| print(f"[Init] Loading model from: {model_path}") | |
| print(f"[Init] DEVICE={DEVICE}, DTYPE={DTYPE}") | |
| print(f"[Init] USE_DIFFUSION_PIPELINE={USE_DIFFUSION_PIPELINE}") | |
| print(f"[Init] ENABLE_COMPILE={ENABLE_COMPILE}") | |
| print(f"[Init] ENABLE_WARMUP={ENABLE_WARMUP}") | |
| print(f"[Init] ATTENTION_BACKEND={ATTENTION_BACKEND}") | |
| print("=" * 80) | |
| last_error = None | |
| if USE_DIFFUSION_PIPELINE: | |
| try: | |
| p = _load_with_diffusion_pipeline(model_path) | |
| _set_attention_backend_if_possible(p, ATTENTION_BACKEND) | |
| p = _compile_transformer_if_possible(p) | |
| p = _load_safety_checker_if_enabled(p) | |
| return p | |
| except Exception as e: | |
| last_error = e | |
| print("[Init] DiffusionPipeline strategy failed:") | |
| traceback.print_exc() | |
| cuda_cleanup() | |
| try: | |
| p = _load_with_manual_components(model_path) | |
| _set_attention_backend_if_possible(p, ATTENTION_BACKEND) | |
| p = _compile_transformer_if_possible(p) | |
| p = _load_safety_checker_if_enabled(p) | |
| return p | |
| except Exception as e: | |
| print("[Init] Manual component strategy failed:") | |
| traceback.print_exc() | |
| cuda_cleanup() | |
| raise RuntimeError( | |
| "Model loading failed in all strategies. " | |
| f"First error: {last_error}. " | |
| f"Second error: {e}" | |
| ) from e | |
| def generate_image( | |
| p, | |
| prompt: str, | |
| resolution: str = "1024x1024", | |
| seed: int = 42, | |
| guidance_scale: float = 0.0, | |
| num_inference_steps: int = 9, | |
| shift: float = 3.0, | |
| max_sequence_length: int = 512, | |
| ) -> Image.Image: | |
| """ | |
| 单张图片生成。 | |
| """ | |
| width, height = get_resolution(resolution) | |
| if DEVICE == "cuda": | |
| generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
| else: | |
| generator = torch.Generator().manual_seed(int(seed)) | |
| # Z-Image-Turbo 常用 FlowMatchEulerDiscreteScheduler | |
| try: | |
| p.scheduler = FlowMatchEulerDiscreteScheduler( | |
| num_train_timesteps=1000, | |
| shift=float(shift), | |
| ) | |
| except Exception: | |
| print("[Generate] Failed to assign scheduler, continue with existing scheduler:") | |
| traceback.print_exc() | |
| call_kwargs = dict( | |
| prompt=prompt, | |
| height=int(height), | |
| width=int(width), | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(num_inference_steps), | |
| generator=generator, | |
| max_sequence_length=int(max_sequence_length), | |
| ) | |
| # 不同 pipeline 版本参数支持可能不同,失败后去掉 max_sequence_length 再试。 | |
| try: | |
| out = p(**call_kwargs) | |
| except TypeError: | |
| call_kwargs.pop("max_sequence_length", None) | |
| out = p(**call_kwargs) | |
| image = out.images[0] | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| return image.convert("RGB") | |
| def warmup_model(p) -> None: | |
| """ | |
| 极简 warmup。 | |
| 原代码遍历全部分辨率,每个分辨率生成两张,风险很高。 | |
| 这里仅在用户显式开启 ENABLE_WARMUP=true 时,对 1024x1024 跑一次短步数。 | |
| """ | |
| if not ENABLE_WARMUP: | |
| return | |
| try: | |
| print("[Warmup] Starting minimal warmup...") | |
| generate_image( | |
| p, | |
| prompt="warmup", | |
| resolution="1024x1024", | |
| num_inference_steps=2, | |
| guidance_scale=0.0, | |
| seed=42, | |
| ) | |
| print("[Warmup] Completed.") | |
| except Exception: | |
| print("[Warmup] Failed, ignored:") | |
| traceback.print_exc() | |
| cuda_cleanup() | |
| # ==================== Prompt Expander ==================== | |
| class PromptOutput: | |
| status: bool | |
| prompt: str | |
| seed: int | |
| system_prompt: str | |
| message: str | |
| class PromptExpander: | |
| def __init__(self, backend="api", **kwargs): | |
| self.backend = backend | |
| def decide_system_prompt(self, template_name=None): | |
| return prompt_template | |
| class APIPromptExpander(PromptExpander): | |
| def __init__(self, api_config=None, **kwargs): | |
| super().__init__(backend="api", **kwargs) | |
| self.api_config = api_config or {} | |
| self.client = self._init_api_client() | |
| def _init_api_client(self): | |
| try: | |
| from openai import OpenAI | |
| api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY | |
| base_url = self.api_config.get( | |
| "base_url", | |
| "https://dashscope.aliyuncs.com/compatible-mode/v1", | |
| ) | |
| if not api_key: | |
| print("[PE] Warning: DASHSCOPE_API_KEY not found. Prompt enhance unavailable.") | |
| return None | |
| return OpenAI(api_key=api_key, base_url=base_url) | |
| except ImportError: | |
| print("[PE] openai package not installed. Prompt enhance unavailable.") | |
| return None | |
| except Exception: | |
| print("[PE] Failed to initialize API client:") | |
| traceback.print_exc() | |
| return None | |
| def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs): | |
| return self.extend(prompt, system_prompt, seed, **kwargs) | |
| def extend(self, prompt, system_prompt=None, seed=-1, **kwargs): | |
| if self.client is None: | |
| return PromptOutput(False, "", seed, system_prompt or "", "API client not initialized") | |
| if system_prompt is None: | |
| system_prompt = self.decide_system_prompt() | |
| if "{prompt}" in system_prompt: | |
| system_prompt = system_prompt.format(prompt=prompt) | |
| prompt = " " | |
| try: | |
| model = self.api_config.get("model", "qwen3-max-preview") | |
| response = self.client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.7, | |
| top_p=0.8, | |
| ) | |
| content = response.choices[0].message.content or "" | |
| expanded_prompt = content | |
| json_start = content.find("```json") | |
| if json_start != -1: | |
| json_end = content.find("```", json_start + 7) | |
| if json_end != -1: | |
| json_str = content[json_start + 7: json_end].strip() | |
| try: | |
| data = json.loads(json_str) | |
| expanded_prompt = data.get("revised_prompt", content) | |
| except Exception: | |
| expanded_prompt = content | |
| return PromptOutput(True, expanded_prompt, seed, system_prompt, content) | |
| except Exception as e: | |
| return PromptOutput(False, "", seed, system_prompt, str(e)) | |
| def create_prompt_expander(backend="api", **kwargs): | |
| if backend == "api": | |
| return APIPromptExpander(**kwargs) | |
| raise ValueError("Only 'api' backend is supported.") | |
| def get_or_create_prompt_expander(): | |
| """ | |
| Prompt enhancer 懒加载,避免启动阶段因 openai / key 问题影响主程序。 | |
| """ | |
| global prompt_expander | |
| if prompt_expander is not None: | |
| return prompt_expander | |
| try: | |
| prompt_expander = create_prompt_expander( | |
| backend="api", | |
| api_config={"model": "qwen3-max-preview"}, | |
| ) | |
| print("[PE] Prompt expander ready.") | |
| except Exception: | |
| print("[PE] Prompt expander init failed:") | |
| traceback.print_exc() | |
| prompt_expander = None | |
| return prompt_expander | |
| def prompt_enhance(prompt: str, enable_enhance: bool) -> Tuple[str, str]: | |
| if not enable_enhance: | |
| return prompt, "Enhancement disabled." | |
| expander = get_or_create_prompt_expander() | |
| if not expander: | |
| return prompt, "Prompt expander unavailable." | |
| if not prompt.strip(): | |
| return "", "Please enter a prompt." | |
| try: | |
| result = expander(prompt) | |
| if result.status: | |
| return result.prompt, result.message | |
| return prompt, f"Enhancement failed: {result.message}" | |
| except Exception as e: | |
| return prompt, f"Error: {str(e)}" | |
| def get_or_load_pipe(): | |
| """ | |
| 关键修复: | |
| 不在应用启动阶段加载大模型; | |
| 只在 Generate 点击后进入 @spaces.GPU 环境时懒加载。 | |
| """ | |
| global pipe, MODEL_LOAD_ERROR | |
| refresh_runtime_device() | |
| if pipe is not None: | |
| return pipe | |
| with model_lock: | |
| if pipe is not None: | |
| return pipe | |
| try: | |
| MODEL_LOAD_ERROR = "" | |
| loaded_pipe = load_models(MODEL_PATH) | |
| try_enable_aoti(loaded_pipe) | |
| warmup_model(loaded_pipe) | |
| pipe = loaded_pipe | |
| print("[Init] Model loaded successfully.") | |
| return pipe | |
| except Exception: | |
| MODEL_LOAD_ERROR = traceback.format_exc() | |
| print("[Init] Model loading failed with full traceback:") | |
| print(MODEL_LOAD_ERROR) | |
| pipe = None | |
| cuda_cleanup() | |
| raise gr.Error( | |
| "Model loading failed. Please open the Space logs to view the full traceback. " | |
| "Common fixes: upgrade diffusers/transformers/accelerate, disable compile/warmup, " | |
| "or check MODEL_PATH / HF_TOKEN." | |
| ) | |
| def normalize_gallery_items(gallery_images) -> List[Any]: | |
| """ | |
| 兼容 Gradio Gallery 在不同版本下返回的格式。 | |
| """ | |
| if not gallery_images: | |
| return [] | |
| result = [] | |
| for item in list(gallery_images): | |
| try: | |
| if isinstance(item, Image.Image): | |
| result.append(item) | |
| elif isinstance(item, str) and os.path.exists(item): | |
| result.append(Image.open(item).convert("RGB")) | |
| elif isinstance(item, (tuple, list)) and len(item) > 0: | |
| first = item[0] | |
| if isinstance(first, Image.Image): | |
| result.append(first) | |
| elif isinstance(first, str) and os.path.exists(first): | |
| result.append(Image.open(first).convert("RGB")) | |
| elif isinstance(item, dict): | |
| img_obj = item.get("image") or item.get("path") or item.get("name") | |
| if isinstance(img_obj, Image.Image): | |
| result.append(img_obj) | |
| elif isinstance(img_obj, str) and os.path.exists(img_obj): | |
| result.append(Image.open(img_obj).convert("RGB")) | |
| except Exception: | |
| continue | |
| return result[: max(0, MAX_GALLERY_HISTORY - 1)] | |
| def run_safety_check_if_available(p, image: Image.Image, width: int, height: int) -> Image.Image: | |
| """ | |
| 生成后安全检查。 | |
| 默认不会启用,因为 ENABLE_SAFETY_CHECKER=false。 | |
| """ | |
| try: | |
| if getattr(p, "safety_feature_extractor", None) is None: | |
| return image | |
| if getattr(p, "safety_checker", None) is None: | |
| return image | |
| import numpy as np | |
| clip_inputs = p.safety_feature_extractor([image], return_tensors="pt") | |
| clip_input = clip_inputs.pixel_values.to(DEVICE) | |
| img_np = np.array(image).astype("float32") / 255.0 | |
| img_np = img_np[None, ...] | |
| _checked_images, has_nsfw = p.safety_checker( | |
| images=img_np, | |
| clip_input=clip_input, | |
| ) | |
| if isinstance(has_nsfw, (list, tuple)) and len(has_nsfw) > 0 and bool(has_nsfw[0]): | |
| return _load_nsfw_placeholder(width, height) | |
| return image | |
| except Exception: | |
| print("[Safety] Check failed, ignored:") | |
| traceback.print_exc() | |
| return image | |
| def generate( | |
| prompt, | |
| resolution="1024x1024 ( 1:1 )", | |
| seed=42, | |
| steps=8, | |
| shift=3.0, | |
| random_seed=True, | |
| gallery_images=None, | |
| enhance=False, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """ | |
| Gradio 生成入口。 | |
| 这个函数在 ZeroGPU 环境中会触发动态 GPU 分配。 | |
| """ | |
| try: | |
| if not str(prompt or "").strip(): | |
| raise gr.Error("Please enter a prompt.") | |
| current_pipe = get_or_load_pipe() | |
| if random_seed: | |
| new_seed = random.randint(1, 1_000_000) | |
| else: | |
| try: | |
| new_seed = int(seed) | |
| except Exception: | |
| new_seed = 42 | |
| if new_seed == -1: | |
| new_seed = random.randint(1, 1_000_000) | |
| final_prompt = str(prompt or "").strip() | |
| pe_msg = "" | |
| if enhance: | |
| final_prompt, pe_msg = prompt_enhance(final_prompt, True) | |
| print(f"[PE] Enhanced prompt: {final_prompt}") | |
| print(f"[PE] Message: {pe_msg}") | |
| try: | |
| resolution_str = str(resolution).split(" ")[0] | |
| except Exception: | |
| resolution_str = "1024x1024" | |
| width, height = get_resolution(resolution_str) | |
| # Z-Image-Turbo 通常 8 steps 左右即可。 | |
| safe_steps = max(1, min(int(steps), 100)) | |
| image = generate_image( | |
| p=current_pipe, | |
| prompt=final_prompt, | |
| resolution=resolution_str, | |
| seed=new_seed, | |
| guidance_scale=0.0, | |
| num_inference_steps=safe_steps + 1, | |
| shift=float(shift), | |
| ) | |
| image = run_safety_check_if_available(current_pipe, image, width, height) | |
| old_images = normalize_gallery_items(gallery_images) | |
| gallery = [image] + old_images | |
| gallery = gallery[:MAX_GALLERY_HISTORY] | |
| status = ( | |
| f"Done. DEVICE={DEVICE}, resolution={resolution_str}, " | |
| f"steps={safe_steps + 1}, seed={new_seed}" | |
| ) | |
| if pe_msg: | |
| status += f"\nPrompt Enhance: {pe_msg[:300]}" | |
| return gallery, str(new_seed), int(new_seed), status | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| print("[Generate] Failed:") | |
| traceback.print_exc() | |
| cuda_cleanup() | |
| raise gr.Error(f"Generation failed: {type(e).__name__}: {e}") | |
| def update_res_choices(_res_cat): | |
| if str(_res_cat) in RES_CHOICES: | |
| res_choices = RES_CHOICES[str(_res_cat)] | |
| else: | |
| res_choices = RES_CHOICES["1024"] | |
| return gr.update(value=res_choices[0], choices=res_choices) | |
| def get_model_status(): | |
| """ | |
| 页面按钮:检查当前模型状态。 | |
| """ | |
| if pipe is not None: | |
| return f"Model loaded. DEVICE={DEVICE}, DTYPE={DTYPE}, MODEL_PATH={MODEL_PATH}" | |
| if MODEL_LOAD_ERROR: | |
| return "Model not loaded. Last loading error:\n" + MODEL_LOAD_ERROR[-4000:] | |
| return ( | |
| "Model not loaded yet. This is normal. " | |
| "The model will be loaded when you click Generate." | |
| ) | |
| css = """ | |
| .fillable { | |
| max-width: 1230px !important; | |
| } | |
| .gradio-container { | |
| max-width: 1280px !important; | |
| } | |
| """ | |
| # ==================== Gradio UI ==================== | |
| with gr.Blocks(title="Z-Image Demo") as demo: | |
| gr.Markdown( | |
| """<div align="center"> | |
| # Z-Image Generation Demo | |
| *ZeroGPU friendly lazy-loading version* | |
| </div>""" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| lines=4, | |
| placeholder="Enter your prompt here...", | |
| ) | |
| with gr.Row(): | |
| choices = [int(k) for k in RES_CHOICES.keys()] | |
| res_cat = gr.Dropdown( | |
| value=1024, | |
| choices=choices, | |
| label="Resolution Category", | |
| ) | |
| initial_res_choices = RES_CHOICES["1024"] | |
| resolution = gr.Dropdown( | |
| value=initial_res_choices[0], | |
| choices=initial_res_choices, | |
| label="Width x Height (Ratio)", | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number( | |
| label="Seed", | |
| value=42, | |
| precision=0, | |
| ) | |
| random_seed = gr.Checkbox( | |
| label="Random Seed", | |
| value=True, | |
| ) | |
| with gr.Row(): | |
| steps = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=100, | |
| value=8, | |
| step=1, | |
| interactive=True, | |
| ) | |
| shift = gr.Slider( | |
| label="Time Shift", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=3.0, | |
| step=0.1, | |
| interactive=True, | |
| ) | |
| enhance = gr.Checkbox( | |
| label="Enhance Prompt with DashScope", | |
| value=False, | |
| info="Requires DASHSCOPE_API_KEY and openai package. Keep disabled if not needed.", | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| status_btn = gr.Button("Model Status") | |
| status_box = gr.Textbox( | |
| label="Status / Logs", | |
| lines=6, | |
| interactive=False, | |
| ) | |
| gr.Markdown("### 📝 Example Prompts") | |
| gr.Examples( | |
| examples=EXAMPLE_PROMPTS, | |
| inputs=prompt_input, | |
| label=None, | |
| ) | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery( | |
| label="Generated Images", | |
| columns=2, | |
| rows=2, | |
| height=600, | |
| object_fit="contain", | |
| format="png", | |
| interactive=False, | |
| ) | |
| used_seed = gr.Textbox( | |
| label="Seed Used", | |
| interactive=False, | |
| ) | |
| res_cat.change( | |
| update_res_choices, | |
| inputs=res_cat, | |
| outputs=resolution, | |
| ) | |
| generate_btn.click( | |
| generate, | |
| inputs=[ | |
| prompt_input, | |
| resolution, | |
| seed, | |
| steps, | |
| shift, | |
| random_seed, | |
| output_gallery, | |
| enhance, | |
| ], | |
| outputs=[ | |
| output_gallery, | |
| used_seed, | |
| seed, | |
| status_box, | |
| ], | |
| ) | |
| status_btn.click( | |
| get_model_status, | |
| inputs=[], | |
| outputs=[status_box], | |
| ) | |
| if __name__ == "__main__": | |
| # 兼容不同 Gradio 版本。 | |
| # 新版本支持 mcp_server,旧版本不支持时自动降级。 | |
| try: | |
| demo.launch(mcp_server=True) | |
| except TypeError: | |
| demo.launch() |