import os import sys import re import json import random import logging import warnings import traceback import threading from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import gradio as gr import torch from PIL import Image, ImageDraw, ImageFont # ==================== spaces 兼容处理 ==================== # 在 HuggingFace Spaces 上会有 spaces 包; # 本地运行时如果没有 spaces,也不会直接崩溃。 try: import spaces # type: ignore except Exception: class _SpacesFallback: @staticmethod def GPU(fn=None, **kwargs): if fn is None: return lambda f: f return fn @staticmethod def aoti_blocks_load(*args, **kwargs): raise RuntimeError("spaces.aoti_blocks_load is unavailable outside HuggingFace Spaces.") spaces = _SpacesFallback() # type: ignore from diffusers import ( AutoencoderKL, DiffusionPipeline, FlowMatchEulerDiscreteScheduler, ) from transformers import AutoModelForCausalLM, AutoTokenizer # ------------------------- 可选依赖:Prompt Enhancer 模板 ------------------------- # 如果你的工程里有 pe.py,会自动使用; # 没有也不会报错,Prompt Enhance 默认关闭。 try: sys.path.append(os.path.dirname(os.path.abspath(__file__))) from pe import prompt_template # type: ignore except Exception: prompt_template = ( "You are a helpful prompt engineer. Expand the user prompt into a richer, detailed prompt. " "Return JSON with key revised_prompt." ) # ==================== Environment Variables ==================== MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") # 关键修复: # 1. 默认关闭 compile,避免首轮加载超时、编译失败、ZeroGPU 兼容问题。 # 2. 如确认环境稳定,可在 Space Variables 中设置 ENABLE_COMPILE=true。 ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true" # 关键修复: # 默认关闭 warmup。原代码会遍历大量分辨率进行预热,非常容易导致启动失败。 ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() == "true" # 默认 native 最稳。若你的环境确认支持 flash_3,可设置 ATTENTION_BACKEND=flash_3。 ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "native") # ZeroGPU AoTI:默认尝试启用,但失败不会影响主流程。 ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "true").lower() == "true" # Safety checker 会额外占用内存,默认关闭,防止把主模型加载拖死。 # 如需要可设置 ENABLE_SAFETY_CHECKER=true。 ENABLE_SAFETY_CHECKER = os.environ.get("ENABLE_SAFETY_CHECKER", "false").lower() == "true" # 优先使用 DiffusionPipeline 加载;失败后再回退到手动组件加载。 USE_DIFFUSION_PIPELINE = os.environ.get("USE_DIFFUSION_PIPELINE", "true").lower() == "true" # 生成历史图片数量,避免 Gallery 越堆越多占内存。 MAX_GALLERY_HISTORY = int(os.environ.get("MAX_GALLERY_HISTORY", "8")) DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") HF_TOKEN = os.environ.get("HF_TOKEN") # =============================================================== os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") logging.getLogger("transformers").setLevel(logging.ERROR) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 pipe = None prompt_expander = None model_lock = threading.Lock() MODEL_LOAD_ERROR = "" RES_CHOICES = { "1024": [ "1024x1024 ( 1:1 )", "1152x896 ( 9:7 )", "896x1152 ( 7:9 )", "1152x864 ( 4:3 )", "864x1152 ( 3:4 )", "1248x832 ( 3:2 )", "832x1248 ( 2:3 )", "1280x720 ( 16:9 )", "720x1280 ( 9:16 )", "1344x576 ( 21:9 )", "576x1344 ( 9:21 )", ], "1280": [ "1280x1280 ( 1:1 )", "1440x1120 ( 9:7 )", "1120x1440 ( 7:9 )", "1472x1104 ( 4:3 )", "1104x1472 ( 3:4 )", "1536x1024 ( 3:2 )", "1024x1536 ( 2:3 )", "1536x864 ( 16:9 )", "864x1536 ( 9:16 )", "1680x720 ( 21:9 )", "720x1680 ( 9:21 )", ], "1536": [ "1536x1536 ( 1:1 )", "1728x1344 ( 9:7 )", "1344x1728 ( 7:9 )", "1728x1296 ( 4:3 )", "1296x1728 ( 3:4 )", "1872x1248 ( 3:2 )", "1248x1872 ( 2:3 )", "2048x1152 ( 16:9 )", "1152x2048 ( 9:16 )", "2016x864 ( 21:9 )", "864x2016 ( 9:21 )", ], } RESOLUTION_SET: List[str] = [] for _k, _items in RES_CHOICES.items(): RESOLUTION_SET.extend(_items) EXAMPLE_PROMPTS = [ ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"], ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"], ] def refresh_runtime_device() -> Tuple[str, torch.dtype]: """ 关键修复: 在 ZeroGPU 环境中,应用启动阶段可能没有 CUDA; 只有进入 @spaces.GPU 函数后,CUDA 才可能可见。 因此必须在生成函数内部重新判断 DEVICE / DTYPE。 """ global DEVICE, DTYPE DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 print(f"[Runtime] DEVICE={DEVICE}, DTYPE={DTYPE}") return DEVICE, DTYPE def cuda_cleanup(): """ 出错后尽量释放 CUDA 缓存,避免后续请求继续失败。 """ try: if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() except Exception: pass def is_local_model_path(model_path: str) -> bool: return os.path.isdir(model_path) def hf_token_candidates() -> List[Dict[str, str]]: """ 兼容不同版本的 transformers / diffusers。 新版本一般使用 token; 旧版本可能使用 use_auth_token。 不要同时传 token 和 use_auth_token,否则部分环境会报错。 """ if not HF_TOKEN: return [{}] return [{"token": HF_TOKEN}, {"use_auth_token": HF_TOKEN}, {}] def get_resolution(resolution: str) -> Tuple[int, int]: match = re.search(r"(\d+)\s*[×x]\s*(\d+)", str(resolution)) if match: return int(match.group(1)), int(match.group(2)) return 1024, 1024 def _make_blocked_image(width=1024, height=1024, text="Blocked by Safety Checker") -> Image.Image: img = Image.new("RGB", (width, height), (20, 20, 20)) draw = ImageDraw.Draw(img) try: font = ImageFont.load_default() except Exception: font = None draw.rectangle([0, 0, width, 90], fill=(160, 0, 0)) draw.text((20, 30), text, fill=(255, 255, 255), font=font) return img def _load_nsfw_placeholder(width=1024, height=1024) -> Image.Image: """ 命中 NSFW 时优先加载工作目录的 nsfw.png; 不存在就生成一张占位图,避免文件缺失导致再次报错。 """ if os.path.exists("nsfw.png"): try: return Image.open("nsfw.png").convert("RGB") except Exception: pass return _make_blocked_image(width, height, "NSFW blocked") def _move_pipeline_to_device(p) -> Any: """ 兼容不同 diffusers 版本的 .to() 调用方式。 """ if p is None: return p # 如果使用 device_map 加载,通常不要再强行 .to() if getattr(p, "hf_device_map", None): print(f"[Init] Pipeline already has hf_device_map: {getattr(p, 'hf_device_map', None)}") return p if DEVICE == "cuda": attempts = [ lambda: p.to("cuda"), lambda: p.to(torch_dtype=DTYPE), lambda: p.to(device="cuda"), lambda: p.to("cuda", torch_dtype=DTYPE), ] else: attempts = [ lambda: p.to("cpu"), lambda: p.to(torch_dtype=torch.float32), lambda: p.to(device="cpu"), ] last_error = None for fn in attempts: try: p = fn() return p except Exception as e: last_error = e print(f"[Init] Warning: pipeline.to(...) failed, continue anyway. Error: {last_error}") return p def _set_attention_backend_if_possible(p, backend: str) -> None: """ attention backend 不是所有环境都支持。 失败时自动回退 native,仍失败也不阻塞主流程。 """ if not p: return transformer = getattr(p, "transformer", None) if transformer is None: return if not hasattr(transformer, "set_attention_backend"): print("[Init] Transformer has no set_attention_backend method, skip.") return try: transformer.set_attention_backend(backend) print(f"[Init] Attention backend set to: {backend}") return except Exception as e: print(f"[Init] set_attention_backend('{backend}') failed: {e}") try: transformer.set_attention_backend("native") print("[Init] Attention backend fallback to: native") except Exception as e: print(f"[Init] set_attention_backend('native') also failed, ignored: {e}") def _compile_transformer_if_possible(p) -> Any: """ torch.compile 可能加速,也可能导致首轮非常慢或直接失败。 因此默认关闭,且失败时不影响主流程。 """ if not ENABLE_COMPILE: return p if DEVICE != "cuda": print("[Init] ENABLE_COMPILE=true but DEVICE is not cuda, skip compile.") return p transformer = getattr(p, "transformer", None) if transformer is None: print("[Init] No transformer found, skip compile.") return p try: print("[Init] Enabling torch.compile optimizations...") torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.max_autotune_gemm = True torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" torch._inductor.config.triton.cudagraphs = False p.transformer = torch.compile( transformer, mode="max-autotune-no-cudagraphs", fullgraph=False, ) print("[Init] Transformer compiled.") except Exception: print("[Init] torch.compile failed, continue without compile:") traceback.print_exc() return p def try_enable_aoti(p) -> None: """ AoTI / ZeroGPU 加速。 可用则启用,不可用则跳过。 """ if not ENABLE_AOTI: print("[Init] ENABLE_AOTI=false, skip AoTI.") return if p is None: return try: transformer = getattr(p, "transformer", None) if transformer is None: print("[Init] No transformer found, skip AoTI.") return target = None if hasattr(transformer, "layers"): target = transformer.layers if hasattr(target, "_repeated_blocks"): target._repeated_blocks = ["ZImageTransformerBlock"] else: target = transformer if hasattr(target, "_repeated_blocks"): target._repeated_blocks = ["ZImageTransformerBlock"] if target is not None: spaces.aoti_blocks_load(target, "zerogpu-aoti/Z-Image", variant="fa3") print("[Init] AoTI blocks loaded.") except Exception: print("[Init] AoTI not enabled, safe to ignore:") traceback.print_exc() def _load_safety_checker_if_enabled(p) -> Any: """ Safety checker 默认关闭,因为它会额外占用内存。 即使开启,加载失败也不影响主模型。 """ if not ENABLE_SAFETY_CHECKER: print("[Init] ENABLE_SAFETY_CHECKER=false, skip safety checker.") try: p.safety_feature_extractor = None p.safety_checker = None except Exception: pass return p try: from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker try: from transformers import CLIPImageProcessor as _CLIPProcessor except Exception: from transformers import CLIPFeatureExtractor as _CLIPProcessor # type: ignore safety_model_id = "CompVis/stable-diffusion-safety-checker" last_error = None for token_kwargs in hf_token_candidates(): try: safety_feature_extractor = _CLIPProcessor.from_pretrained( safety_model_id, **token_kwargs, ) safety_checker = StableDiffusionSafetyChecker.from_pretrained( safety_model_id, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, **token_kwargs, ) safety_checker = safety_checker.to(DEVICE) p.safety_feature_extractor = safety_feature_extractor p.safety_checker = safety_checker print("[Init] Safety checker loaded.") return p except Exception as e: last_error = e raise RuntimeError(f"Safety checker load failed: {last_error}") except Exception: print("[Init] Safety checker init failed. NSFW filtering will be skipped:") traceback.print_exc() try: p.safety_feature_extractor = None p.safety_checker = None except Exception: pass return p def _load_with_diffusion_pipeline(model_path: str) -> Any: """ 优先使用官方推荐的 DiffusionPipeline 加载方式。 做多组参数尝试,以兼容 diffusers 新旧版本。 """ print("[Init] Trying DiffusionPipeline loading strategy...") local = is_local_model_path(model_path) token_candidates = [{}] if local else hf_token_candidates() dtype_candidates: List[Dict[str, Any]] = [] if DEVICE == "cuda": dtype_candidates.extend([ # 新版 diffusers 某些文档使用 dtype {"dtype": DTYPE, "device_map": "cuda"}, {"dtype": DTYPE}, # 旧版常用 torch_dtype {"torch_dtype": DTYPE, "device_map": "cuda"}, {"torch_dtype": DTYPE}, # 某些 Z-Image 示例需要 low_cpu_mem_usage=False {"torch_dtype": DTYPE, "low_cpu_mem_usage": False}, {"dtype": DTYPE, "low_cpu_mem_usage": False}, # 兼容 custom pipeline / older discussions {"torch_dtype": DTYPE, "trust_remote_code": True}, {"dtype": DTYPE, "trust_remote_code": True}, ]) else: dtype_candidates.extend([ {"torch_dtype": torch.float32}, {"dtype": torch.float32}, {"torch_dtype": torch.float32, "low_cpu_mem_usage": False}, {}, ]) errors: List[str] = [] for token_kwargs in token_candidates: for extra_kwargs in dtype_candidates: kwargs: Dict[str, Any] = {} kwargs.update(token_kwargs) kwargs.update(extra_kwargs) try: print(f"[Init] DiffusionPipeline.from_pretrained kwargs={list(kwargs.keys())}") p = DiffusionPipeline.from_pretrained(model_path, **kwargs) print("[Init] DiffusionPipeline loaded.") p = _move_pipeline_to_device(p) return p except Exception as e: err = f"kwargs={kwargs} -> {type(e).__name__}: {e}" print(f"[Init] DiffusionPipeline attempt failed: {err}") errors.append(err) raise RuntimeError( "All DiffusionPipeline loading attempts failed.\n" + "\n".join(errors[-8:]) ) def _from_pretrained_component(cls, path_or_repo: str, subfolder: Optional[str], torch_dtype: Optional[torch.dtype]) -> Any: """ 手动组件加载的兼容封装。 """ local = is_local_model_path(path_or_repo) if local: load_path = os.path.join(path_or_repo, subfolder) if subfolder else path_or_repo kwargs: Dict[str, Any] = {} if torch_dtype is not None: kwargs["torch_dtype"] = torch_dtype return cls.from_pretrained(load_path, **kwargs) errors = [] for token_kwargs in hf_token_candidates(): kwargs = {} if subfolder: kwargs["subfolder"] = subfolder if torch_dtype is not None: kwargs["torch_dtype"] = torch_dtype kwargs.update(token_kwargs) try: return cls.from_pretrained(path_or_repo, **kwargs) except Exception as e: errors.append(f"{type(e).__name__}: {e}") raise RuntimeError( f"Failed to load component {cls} subfolder={subfolder}. " + " | ".join(errors[-4:]) ) def _load_with_manual_components(model_path: str) -> Any: """ 回退方案:按你原来的方式手动加载 VAE、text_encoder、tokenizer、transformer。 如果 diffusers 环境里没有 ZImagePipeline / ZImageTransformer2DModel,会在这里给出明确错误。 """ print("[Init] Trying manual component loading strategy...") try: from diffusers import ZImagePipeline # type: ignore from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel # type: ignore except Exception as e: raise RuntimeError( "Current diffusers does not provide ZImagePipeline / ZImageTransformer2DModel. " "Please upgrade diffusers, for example: pip install -U diffusers transformers accelerate" ) from e model_dtype = DTYPE if DEVICE == "cuda" else torch.float32 vae = _from_pretrained_component( AutoencoderKL, model_path, "vae", model_dtype, ) text_encoder = _from_pretrained_component( AutoModelForCausalLM, model_path, "text_encoder", model_dtype, ).eval() tokenizer = _from_pretrained_component( AutoTokenizer, model_path, "tokenizer", None, ) tokenizer.padding_side = "left" p = ZImagePipeline( scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, ) transformer = _from_pretrained_component( ZImageTransformer2DModel, model_path, "transformer", None, ) transformer = transformer.to(DEVICE, DTYPE) p.transformer = transformer p = _move_pipeline_to_device(p) print("[Init] Manual component loading finished.") return p def load_models(model_path: str) -> Any: """ 统一模型加载入口。 先尝试 DiffusionPipeline,失败后回退手动组件加载。 """ print("=" * 80) print(f"[Init] Loading model from: {model_path}") print(f"[Init] DEVICE={DEVICE}, DTYPE={DTYPE}") print(f"[Init] USE_DIFFUSION_PIPELINE={USE_DIFFUSION_PIPELINE}") print(f"[Init] ENABLE_COMPILE={ENABLE_COMPILE}") print(f"[Init] ENABLE_WARMUP={ENABLE_WARMUP}") print(f"[Init] ATTENTION_BACKEND={ATTENTION_BACKEND}") print("=" * 80) last_error = None if USE_DIFFUSION_PIPELINE: try: p = _load_with_diffusion_pipeline(model_path) _set_attention_backend_if_possible(p, ATTENTION_BACKEND) p = _compile_transformer_if_possible(p) p = _load_safety_checker_if_enabled(p) return p except Exception as e: last_error = e print("[Init] DiffusionPipeline strategy failed:") traceback.print_exc() cuda_cleanup() try: p = _load_with_manual_components(model_path) _set_attention_backend_if_possible(p, ATTENTION_BACKEND) p = _compile_transformer_if_possible(p) p = _load_safety_checker_if_enabled(p) return p except Exception as e: print("[Init] Manual component strategy failed:") traceback.print_exc() cuda_cleanup() raise RuntimeError( "Model loading failed in all strategies. " f"First error: {last_error}. " f"Second error: {e}" ) from e def generate_image( p, prompt: str, resolution: str = "1024x1024", seed: int = 42, guidance_scale: float = 0.0, num_inference_steps: int = 9, shift: float = 3.0, max_sequence_length: int = 512, ) -> Image.Image: """ 单张图片生成。 """ width, height = get_resolution(resolution) if DEVICE == "cuda": generator = torch.Generator(device="cuda").manual_seed(int(seed)) else: generator = torch.Generator().manual_seed(int(seed)) # Z-Image-Turbo 常用 FlowMatchEulerDiscreteScheduler try: p.scheduler = FlowMatchEulerDiscreteScheduler( num_train_timesteps=1000, shift=float(shift), ) except Exception: print("[Generate] Failed to assign scheduler, continue with existing scheduler:") traceback.print_exc() call_kwargs = dict( prompt=prompt, height=int(height), width=int(width), guidance_scale=float(guidance_scale), num_inference_steps=int(num_inference_steps), generator=generator, max_sequence_length=int(max_sequence_length), ) # 不同 pipeline 版本参数支持可能不同,失败后去掉 max_sequence_length 再试。 try: out = p(**call_kwargs) except TypeError: call_kwargs.pop("max_sequence_length", None) out = p(**call_kwargs) image = out.images[0] if not isinstance(image, Image.Image): image = Image.fromarray(image) return image.convert("RGB") def warmup_model(p) -> None: """ 极简 warmup。 原代码遍历全部分辨率,每个分辨率生成两张,风险很高。 这里仅在用户显式开启 ENABLE_WARMUP=true 时,对 1024x1024 跑一次短步数。 """ if not ENABLE_WARMUP: return try: print("[Warmup] Starting minimal warmup...") generate_image( p, prompt="warmup", resolution="1024x1024", num_inference_steps=2, guidance_scale=0.0, seed=42, ) print("[Warmup] Completed.") except Exception: print("[Warmup] Failed, ignored:") traceback.print_exc() cuda_cleanup() # ==================== Prompt Expander ==================== @dataclass class PromptOutput: status: bool prompt: str seed: int system_prompt: str message: str class PromptExpander: def __init__(self, backend="api", **kwargs): self.backend = backend def decide_system_prompt(self, template_name=None): return prompt_template class APIPromptExpander(PromptExpander): def __init__(self, api_config=None, **kwargs): super().__init__(backend="api", **kwargs) self.api_config = api_config or {} self.client = self._init_api_client() def _init_api_client(self): try: from openai import OpenAI api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY base_url = self.api_config.get( "base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1", ) if not api_key: print("[PE] Warning: DASHSCOPE_API_KEY not found. Prompt enhance unavailable.") return None return OpenAI(api_key=api_key, base_url=base_url) except ImportError: print("[PE] openai package not installed. Prompt enhance unavailable.") return None except Exception: print("[PE] Failed to initialize API client:") traceback.print_exc() return None def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs): return self.extend(prompt, system_prompt, seed, **kwargs) def extend(self, prompt, system_prompt=None, seed=-1, **kwargs): if self.client is None: return PromptOutput(False, "", seed, system_prompt or "", "API client not initialized") if system_prompt is None: system_prompt = self.decide_system_prompt() if "{prompt}" in system_prompt: system_prompt = system_prompt.format(prompt=prompt) prompt = " " try: model = self.api_config.get("model", "qwen3-max-preview") response = self.client.chat.completions.create( model=model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], temperature=0.7, top_p=0.8, ) content = response.choices[0].message.content or "" expanded_prompt = content json_start = content.find("```json") if json_start != -1: json_end = content.find("```", json_start + 7) if json_end != -1: json_str = content[json_start + 7: json_end].strip() try: data = json.loads(json_str) expanded_prompt = data.get("revised_prompt", content) except Exception: expanded_prompt = content return PromptOutput(True, expanded_prompt, seed, system_prompt, content) except Exception as e: return PromptOutput(False, "", seed, system_prompt, str(e)) def create_prompt_expander(backend="api", **kwargs): if backend == "api": return APIPromptExpander(**kwargs) raise ValueError("Only 'api' backend is supported.") def get_or_create_prompt_expander(): """ Prompt enhancer 懒加载,避免启动阶段因 openai / key 问题影响主程序。 """ global prompt_expander if prompt_expander is not None: return prompt_expander try: prompt_expander = create_prompt_expander( backend="api", api_config={"model": "qwen3-max-preview"}, ) print("[PE] Prompt expander ready.") except Exception: print("[PE] Prompt expander init failed:") traceback.print_exc() prompt_expander = None return prompt_expander def prompt_enhance(prompt: str, enable_enhance: bool) -> Tuple[str, str]: if not enable_enhance: return prompt, "Enhancement disabled." expander = get_or_create_prompt_expander() if not expander: return prompt, "Prompt expander unavailable." if not prompt.strip(): return "", "Please enter a prompt." try: result = expander(prompt) if result.status: return result.prompt, result.message return prompt, f"Enhancement failed: {result.message}" except Exception as e: return prompt, f"Error: {str(e)}" def get_or_load_pipe(): """ 关键修复: 不在应用启动阶段加载大模型; 只在 Generate 点击后进入 @spaces.GPU 环境时懒加载。 """ global pipe, MODEL_LOAD_ERROR refresh_runtime_device() if pipe is not None: return pipe with model_lock: if pipe is not None: return pipe try: MODEL_LOAD_ERROR = "" loaded_pipe = load_models(MODEL_PATH) try_enable_aoti(loaded_pipe) warmup_model(loaded_pipe) pipe = loaded_pipe print("[Init] Model loaded successfully.") return pipe except Exception: MODEL_LOAD_ERROR = traceback.format_exc() print("[Init] Model loading failed with full traceback:") print(MODEL_LOAD_ERROR) pipe = None cuda_cleanup() raise gr.Error( "Model loading failed. Please open the Space logs to view the full traceback. " "Common fixes: upgrade diffusers/transformers/accelerate, disable compile/warmup, " "or check MODEL_PATH / HF_TOKEN." ) def normalize_gallery_items(gallery_images) -> List[Any]: """ 兼容 Gradio Gallery 在不同版本下返回的格式。 """ if not gallery_images: return [] result = [] for item in list(gallery_images): try: if isinstance(item, Image.Image): result.append(item) elif isinstance(item, str) and os.path.exists(item): result.append(Image.open(item).convert("RGB")) elif isinstance(item, (tuple, list)) and len(item) > 0: first = item[0] if isinstance(first, Image.Image): result.append(first) elif isinstance(first, str) and os.path.exists(first): result.append(Image.open(first).convert("RGB")) elif isinstance(item, dict): img_obj = item.get("image") or item.get("path") or item.get("name") if isinstance(img_obj, Image.Image): result.append(img_obj) elif isinstance(img_obj, str) and os.path.exists(img_obj): result.append(Image.open(img_obj).convert("RGB")) except Exception: continue return result[: max(0, MAX_GALLERY_HISTORY - 1)] def run_safety_check_if_available(p, image: Image.Image, width: int, height: int) -> Image.Image: """ 生成后安全检查。 默认不会启用,因为 ENABLE_SAFETY_CHECKER=false。 """ try: if getattr(p, "safety_feature_extractor", None) is None: return image if getattr(p, "safety_checker", None) is None: return image import numpy as np clip_inputs = p.safety_feature_extractor([image], return_tensors="pt") clip_input = clip_inputs.pixel_values.to(DEVICE) img_np = np.array(image).astype("float32") / 255.0 img_np = img_np[None, ...] _checked_images, has_nsfw = p.safety_checker( images=img_np, clip_input=clip_input, ) if isinstance(has_nsfw, (list, tuple)) and len(has_nsfw) > 0 and bool(has_nsfw[0]): return _load_nsfw_placeholder(width, height) return image except Exception: print("[Safety] Check failed, ignored:") traceback.print_exc() return image @spaces.GPU def generate( prompt, resolution="1024x1024 ( 1:1 )", seed=42, steps=8, shift=3.0, random_seed=True, gallery_images=None, enhance=False, progress=gr.Progress(track_tqdm=True), ): """ Gradio 生成入口。 这个函数在 ZeroGPU 环境中会触发动态 GPU 分配。 """ try: if not str(prompt or "").strip(): raise gr.Error("Please enter a prompt.") current_pipe = get_or_load_pipe() if random_seed: new_seed = random.randint(1, 1_000_000) else: try: new_seed = int(seed) except Exception: new_seed = 42 if new_seed == -1: new_seed = random.randint(1, 1_000_000) final_prompt = str(prompt or "").strip() pe_msg = "" if enhance: final_prompt, pe_msg = prompt_enhance(final_prompt, True) print(f"[PE] Enhanced prompt: {final_prompt}") print(f"[PE] Message: {pe_msg}") try: resolution_str = str(resolution).split(" ")[0] except Exception: resolution_str = "1024x1024" width, height = get_resolution(resolution_str) # Z-Image-Turbo 通常 8 steps 左右即可。 safe_steps = max(1, min(int(steps), 100)) image = generate_image( p=current_pipe, prompt=final_prompt, resolution=resolution_str, seed=new_seed, guidance_scale=0.0, num_inference_steps=safe_steps + 1, shift=float(shift), ) image = run_safety_check_if_available(current_pipe, image, width, height) old_images = normalize_gallery_items(gallery_images) gallery = [image] + old_images gallery = gallery[:MAX_GALLERY_HISTORY] status = ( f"Done. DEVICE={DEVICE}, resolution={resolution_str}, " f"steps={safe_steps + 1}, seed={new_seed}" ) if pe_msg: status += f"\nPrompt Enhance: {pe_msg[:300]}" return gallery, str(new_seed), int(new_seed), status except gr.Error: raise except Exception as e: print("[Generate] Failed:") traceback.print_exc() cuda_cleanup() raise gr.Error(f"Generation failed: {type(e).__name__}: {e}") def update_res_choices(_res_cat): if str(_res_cat) in RES_CHOICES: res_choices = RES_CHOICES[str(_res_cat)] else: res_choices = RES_CHOICES["1024"] return gr.update(value=res_choices[0], choices=res_choices) def get_model_status(): """ 页面按钮:检查当前模型状态。 """ if pipe is not None: return f"Model loaded. DEVICE={DEVICE}, DTYPE={DTYPE}, MODEL_PATH={MODEL_PATH}" if MODEL_LOAD_ERROR: return "Model not loaded. Last loading error:\n" + MODEL_LOAD_ERROR[-4000:] return ( "Model not loaded yet. This is normal. " "The model will be loaded when you click Generate." ) css = """ .fillable { max-width: 1230px !important; } .gradio-container { max-width: 1280px !important; } """ # ==================== Gradio UI ==================== with gr.Blocks(title="Z-Image Demo") as demo: gr.Markdown( """