Z-Image-Turbo / app.py
cpuai's picture
Update app.py
8616178 verified
import os
import sys
import re
import json
import random
import logging
import warnings
import traceback
import threading
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
# ==================== spaces 兼容处理 ====================
# 在 HuggingFace Spaces 上会有 spaces 包;
# 本地运行时如果没有 spaces,也不会直接崩溃。
try:
import spaces # type: ignore
except Exception:
class _SpacesFallback:
@staticmethod
def GPU(fn=None, **kwargs):
if fn is None:
return lambda f: f
return fn
@staticmethod
def aoti_blocks_load(*args, **kwargs):
raise RuntimeError("spaces.aoti_blocks_load is unavailable outside HuggingFace Spaces.")
spaces = _SpacesFallback() # type: ignore
from diffusers import (
AutoencoderKL,
DiffusionPipeline,
FlowMatchEulerDiscreteScheduler,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
# ------------------------- 可选依赖:Prompt Enhancer 模板 -------------------------
# 如果你的工程里有 pe.py,会自动使用;
# 没有也不会报错,Prompt Enhance 默认关闭。
try:
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from pe import prompt_template # type: ignore
except Exception:
prompt_template = (
"You are a helpful prompt engineer. Expand the user prompt into a richer, detailed prompt. "
"Return JSON with key revised_prompt."
)
# ==================== Environment Variables ====================
MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
# 关键修复:
# 1. 默认关闭 compile,避免首轮加载超时、编译失败、ZeroGPU 兼容问题。
# 2. 如确认环境稳定,可在 Space Variables 中设置 ENABLE_COMPILE=true。
ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
# 关键修复:
# 默认关闭 warmup。原代码会遍历大量分辨率进行预热,非常容易导致启动失败。
ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() == "true"
# 默认 native 最稳。若你的环境确认支持 flash_3,可设置 ATTENTION_BACKEND=flash_3。
ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "native")
# ZeroGPU AoTI:默认尝试启用,但失败不会影响主流程。
ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "true").lower() == "true"
# Safety checker 会额外占用内存,默认关闭,防止把主模型加载拖死。
# 如需要可设置 ENABLE_SAFETY_CHECKER=true。
ENABLE_SAFETY_CHECKER = os.environ.get("ENABLE_SAFETY_CHECKER", "false").lower() == "true"
# 优先使用 DiffusionPipeline 加载;失败后再回退到手动组件加载。
USE_DIFFUSION_PIPELINE = os.environ.get("USE_DIFFUSION_PIPELINE", "true").lower() == "true"
# 生成历史图片数量,避免 Gallery 越堆越多占内存。
MAX_GALLERY_HISTORY = int(os.environ.get("MAX_GALLERY_HISTORY", "8"))
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
# ===============================================================
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
pipe = None
prompt_expander = None
model_lock = threading.Lock()
MODEL_LOAD_ERROR = ""
RES_CHOICES = {
"1024": [
"1024x1024 ( 1:1 )",
"1152x896 ( 9:7 )",
"896x1152 ( 7:9 )",
"1152x864 ( 4:3 )",
"864x1152 ( 3:4 )",
"1248x832 ( 3:2 )",
"832x1248 ( 2:3 )",
"1280x720 ( 16:9 )",
"720x1280 ( 9:16 )",
"1344x576 ( 21:9 )",
"576x1344 ( 9:21 )",
],
"1280": [
"1280x1280 ( 1:1 )",
"1440x1120 ( 9:7 )",
"1120x1440 ( 7:9 )",
"1472x1104 ( 4:3 )",
"1104x1472 ( 3:4 )",
"1536x1024 ( 3:2 )",
"1024x1536 ( 2:3 )",
"1536x864 ( 16:9 )",
"864x1536 ( 9:16 )",
"1680x720 ( 21:9 )",
"720x1680 ( 9:21 )",
],
"1536": [
"1536x1536 ( 1:1 )",
"1728x1344 ( 9:7 )",
"1344x1728 ( 7:9 )",
"1728x1296 ( 4:3 )",
"1296x1728 ( 3:4 )",
"1872x1248 ( 3:2 )",
"1248x1872 ( 2:3 )",
"2048x1152 ( 16:9 )",
"1152x2048 ( 9:16 )",
"2016x864 ( 21:9 )",
"864x2016 ( 9:21 )",
],
}
RESOLUTION_SET: List[str] = []
for _k, _items in RES_CHOICES.items():
RESOLUTION_SET.extend(_items)
EXAMPLE_PROMPTS = [
["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"],
]
def refresh_runtime_device() -> Tuple[str, torch.dtype]:
"""
关键修复:
在 ZeroGPU 环境中,应用启动阶段可能没有 CUDA;
只有进入 @spaces.GPU 函数后,CUDA 才可能可见。
因此必须在生成函数内部重新判断 DEVICE / DTYPE。
"""
global DEVICE, DTYPE
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
print(f"[Runtime] DEVICE={DEVICE}, DTYPE={DTYPE}")
return DEVICE, DTYPE
def cuda_cleanup():
"""
出错后尽量释放 CUDA 缓存,避免后续请求继续失败。
"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except Exception:
pass
def is_local_model_path(model_path: str) -> bool:
return os.path.isdir(model_path)
def hf_token_candidates() -> List[Dict[str, str]]:
"""
兼容不同版本的 transformers / diffusers。
新版本一般使用 token;
旧版本可能使用 use_auth_token。
不要同时传 token 和 use_auth_token,否则部分环境会报错。
"""
if not HF_TOKEN:
return [{}]
return [{"token": HF_TOKEN}, {"use_auth_token": HF_TOKEN}, {}]
def get_resolution(resolution: str) -> Tuple[int, int]:
match = re.search(r"(\d+)\s*[×x]\s*(\d+)", str(resolution))
if match:
return int(match.group(1)), int(match.group(2))
return 1024, 1024
def _make_blocked_image(width=1024, height=1024, text="Blocked by Safety Checker") -> Image.Image:
img = Image.new("RGB", (width, height), (20, 20, 20))
draw = ImageDraw.Draw(img)
try:
font = ImageFont.load_default()
except Exception:
font = None
draw.rectangle([0, 0, width, 90], fill=(160, 0, 0))
draw.text((20, 30), text, fill=(255, 255, 255), font=font)
return img
def _load_nsfw_placeholder(width=1024, height=1024) -> Image.Image:
"""
命中 NSFW 时优先加载工作目录的 nsfw.png;
不存在就生成一张占位图,避免文件缺失导致再次报错。
"""
if os.path.exists("nsfw.png"):
try:
return Image.open("nsfw.png").convert("RGB")
except Exception:
pass
return _make_blocked_image(width, height, "NSFW blocked")
def _move_pipeline_to_device(p) -> Any:
"""
兼容不同 diffusers 版本的 .to() 调用方式。
"""
if p is None:
return p
# 如果使用 device_map 加载,通常不要再强行 .to()
if getattr(p, "hf_device_map", None):
print(f"[Init] Pipeline already has hf_device_map: {getattr(p, 'hf_device_map', None)}")
return p
if DEVICE == "cuda":
attempts = [
lambda: p.to("cuda"),
lambda: p.to(torch_dtype=DTYPE),
lambda: p.to(device="cuda"),
lambda: p.to("cuda", torch_dtype=DTYPE),
]
else:
attempts = [
lambda: p.to("cpu"),
lambda: p.to(torch_dtype=torch.float32),
lambda: p.to(device="cpu"),
]
last_error = None
for fn in attempts:
try:
p = fn()
return p
except Exception as e:
last_error = e
print(f"[Init] Warning: pipeline.to(...) failed, continue anyway. Error: {last_error}")
return p
def _set_attention_backend_if_possible(p, backend: str) -> None:
"""
attention backend 不是所有环境都支持。
失败时自动回退 native,仍失败也不阻塞主流程。
"""
if not p:
return
transformer = getattr(p, "transformer", None)
if transformer is None:
return
if not hasattr(transformer, "set_attention_backend"):
print("[Init] Transformer has no set_attention_backend method, skip.")
return
try:
transformer.set_attention_backend(backend)
print(f"[Init] Attention backend set to: {backend}")
return
except Exception as e:
print(f"[Init] set_attention_backend('{backend}') failed: {e}")
try:
transformer.set_attention_backend("native")
print("[Init] Attention backend fallback to: native")
except Exception as e:
print(f"[Init] set_attention_backend('native') also failed, ignored: {e}")
def _compile_transformer_if_possible(p) -> Any:
"""
torch.compile 可能加速,也可能导致首轮非常慢或直接失败。
因此默认关闭,且失败时不影响主流程。
"""
if not ENABLE_COMPILE:
return p
if DEVICE != "cuda":
print("[Init] ENABLE_COMPILE=true but DEVICE is not cuda, skip compile.")
return p
transformer = getattr(p, "transformer", None)
if transformer is None:
print("[Init] No transformer found, skip compile.")
return p
try:
print("[Init] Enabling torch.compile optimizations...")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.max_autotune_gemm = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
torch._inductor.config.triton.cudagraphs = False
p.transformer = torch.compile(
transformer,
mode="max-autotune-no-cudagraphs",
fullgraph=False,
)
print("[Init] Transformer compiled.")
except Exception:
print("[Init] torch.compile failed, continue without compile:")
traceback.print_exc()
return p
def try_enable_aoti(p) -> None:
"""
AoTI / ZeroGPU 加速。
可用则启用,不可用则跳过。
"""
if not ENABLE_AOTI:
print("[Init] ENABLE_AOTI=false, skip AoTI.")
return
if p is None:
return
try:
transformer = getattr(p, "transformer", None)
if transformer is None:
print("[Init] No transformer found, skip AoTI.")
return
target = None
if hasattr(transformer, "layers"):
target = transformer.layers
if hasattr(target, "_repeated_blocks"):
target._repeated_blocks = ["ZImageTransformerBlock"]
else:
target = transformer
if hasattr(target, "_repeated_blocks"):
target._repeated_blocks = ["ZImageTransformerBlock"]
if target is not None:
spaces.aoti_blocks_load(target, "zerogpu-aoti/Z-Image", variant="fa3")
print("[Init] AoTI blocks loaded.")
except Exception:
print("[Init] AoTI not enabled, safe to ignore:")
traceback.print_exc()
def _load_safety_checker_if_enabled(p) -> Any:
"""
Safety checker 默认关闭,因为它会额外占用内存。
即使开启,加载失败也不影响主模型。
"""
if not ENABLE_SAFETY_CHECKER:
print("[Init] ENABLE_SAFETY_CHECKER=false, skip safety checker.")
try:
p.safety_feature_extractor = None
p.safety_checker = None
except Exception:
pass
return p
try:
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
try:
from transformers import CLIPImageProcessor as _CLIPProcessor
except Exception:
from transformers import CLIPFeatureExtractor as _CLIPProcessor # type: ignore
safety_model_id = "CompVis/stable-diffusion-safety-checker"
last_error = None
for token_kwargs in hf_token_candidates():
try:
safety_feature_extractor = _CLIPProcessor.from_pretrained(
safety_model_id,
**token_kwargs,
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
**token_kwargs,
)
safety_checker = safety_checker.to(DEVICE)
p.safety_feature_extractor = safety_feature_extractor
p.safety_checker = safety_checker
print("[Init] Safety checker loaded.")
return p
except Exception as e:
last_error = e
raise RuntimeError(f"Safety checker load failed: {last_error}")
except Exception:
print("[Init] Safety checker init failed. NSFW filtering will be skipped:")
traceback.print_exc()
try:
p.safety_feature_extractor = None
p.safety_checker = None
except Exception:
pass
return p
def _load_with_diffusion_pipeline(model_path: str) -> Any:
"""
优先使用官方推荐的 DiffusionPipeline 加载方式。
做多组参数尝试,以兼容 diffusers 新旧版本。
"""
print("[Init] Trying DiffusionPipeline loading strategy...")
local = is_local_model_path(model_path)
token_candidates = [{}] if local else hf_token_candidates()
dtype_candidates: List[Dict[str, Any]] = []
if DEVICE == "cuda":
dtype_candidates.extend([
# 新版 diffusers 某些文档使用 dtype
{"dtype": DTYPE, "device_map": "cuda"},
{"dtype": DTYPE},
# 旧版常用 torch_dtype
{"torch_dtype": DTYPE, "device_map": "cuda"},
{"torch_dtype": DTYPE},
# 某些 Z-Image 示例需要 low_cpu_mem_usage=False
{"torch_dtype": DTYPE, "low_cpu_mem_usage": False},
{"dtype": DTYPE, "low_cpu_mem_usage": False},
# 兼容 custom pipeline / older discussions
{"torch_dtype": DTYPE, "trust_remote_code": True},
{"dtype": DTYPE, "trust_remote_code": True},
])
else:
dtype_candidates.extend([
{"torch_dtype": torch.float32},
{"dtype": torch.float32},
{"torch_dtype": torch.float32, "low_cpu_mem_usage": False},
{},
])
errors: List[str] = []
for token_kwargs in token_candidates:
for extra_kwargs in dtype_candidates:
kwargs: Dict[str, Any] = {}
kwargs.update(token_kwargs)
kwargs.update(extra_kwargs)
try:
print(f"[Init] DiffusionPipeline.from_pretrained kwargs={list(kwargs.keys())}")
p = DiffusionPipeline.from_pretrained(model_path, **kwargs)
print("[Init] DiffusionPipeline loaded.")
p = _move_pipeline_to_device(p)
return p
except Exception as e:
err = f"kwargs={kwargs} -> {type(e).__name__}: {e}"
print(f"[Init] DiffusionPipeline attempt failed: {err}")
errors.append(err)
raise RuntimeError(
"All DiffusionPipeline loading attempts failed.\n"
+ "\n".join(errors[-8:])
)
def _from_pretrained_component(cls, path_or_repo: str, subfolder: Optional[str], torch_dtype: Optional[torch.dtype]) -> Any:
"""
手动组件加载的兼容封装。
"""
local = is_local_model_path(path_or_repo)
if local:
load_path = os.path.join(path_or_repo, subfolder) if subfolder else path_or_repo
kwargs: Dict[str, Any] = {}
if torch_dtype is not None:
kwargs["torch_dtype"] = torch_dtype
return cls.from_pretrained(load_path, **kwargs)
errors = []
for token_kwargs in hf_token_candidates():
kwargs = {}
if subfolder:
kwargs["subfolder"] = subfolder
if torch_dtype is not None:
kwargs["torch_dtype"] = torch_dtype
kwargs.update(token_kwargs)
try:
return cls.from_pretrained(path_or_repo, **kwargs)
except Exception as e:
errors.append(f"{type(e).__name__}: {e}")
raise RuntimeError(
f"Failed to load component {cls} subfolder={subfolder}. "
+ " | ".join(errors[-4:])
)
def _load_with_manual_components(model_path: str) -> Any:
"""
回退方案:按你原来的方式手动加载 VAE、text_encoder、tokenizer、transformer。
如果 diffusers 环境里没有 ZImagePipeline / ZImageTransformer2DModel,会在这里给出明确错误。
"""
print("[Init] Trying manual component loading strategy...")
try:
from diffusers import ZImagePipeline # type: ignore
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel # type: ignore
except Exception as e:
raise RuntimeError(
"Current diffusers does not provide ZImagePipeline / ZImageTransformer2DModel. "
"Please upgrade diffusers, for example: pip install -U diffusers transformers accelerate"
) from e
model_dtype = DTYPE if DEVICE == "cuda" else torch.float32
vae = _from_pretrained_component(
AutoencoderKL,
model_path,
"vae",
model_dtype,
)
text_encoder = _from_pretrained_component(
AutoModelForCausalLM,
model_path,
"text_encoder",
model_dtype,
).eval()
tokenizer = _from_pretrained_component(
AutoTokenizer,
model_path,
"tokenizer",
None,
)
tokenizer.padding_side = "left"
p = ZImagePipeline(
scheduler=None,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=None,
)
transformer = _from_pretrained_component(
ZImageTransformer2DModel,
model_path,
"transformer",
None,
)
transformer = transformer.to(DEVICE, DTYPE)
p.transformer = transformer
p = _move_pipeline_to_device(p)
print("[Init] Manual component loading finished.")
return p
def load_models(model_path: str) -> Any:
"""
统一模型加载入口。
先尝试 DiffusionPipeline,失败后回退手动组件加载。
"""
print("=" * 80)
print(f"[Init] Loading model from: {model_path}")
print(f"[Init] DEVICE={DEVICE}, DTYPE={DTYPE}")
print(f"[Init] USE_DIFFUSION_PIPELINE={USE_DIFFUSION_PIPELINE}")
print(f"[Init] ENABLE_COMPILE={ENABLE_COMPILE}")
print(f"[Init] ENABLE_WARMUP={ENABLE_WARMUP}")
print(f"[Init] ATTENTION_BACKEND={ATTENTION_BACKEND}")
print("=" * 80)
last_error = None
if USE_DIFFUSION_PIPELINE:
try:
p = _load_with_diffusion_pipeline(model_path)
_set_attention_backend_if_possible(p, ATTENTION_BACKEND)
p = _compile_transformer_if_possible(p)
p = _load_safety_checker_if_enabled(p)
return p
except Exception as e:
last_error = e
print("[Init] DiffusionPipeline strategy failed:")
traceback.print_exc()
cuda_cleanup()
try:
p = _load_with_manual_components(model_path)
_set_attention_backend_if_possible(p, ATTENTION_BACKEND)
p = _compile_transformer_if_possible(p)
p = _load_safety_checker_if_enabled(p)
return p
except Exception as e:
print("[Init] Manual component strategy failed:")
traceback.print_exc()
cuda_cleanup()
raise RuntimeError(
"Model loading failed in all strategies. "
f"First error: {last_error}. "
f"Second error: {e}"
) from e
def generate_image(
p,
prompt: str,
resolution: str = "1024x1024",
seed: int = 42,
guidance_scale: float = 0.0,
num_inference_steps: int = 9,
shift: float = 3.0,
max_sequence_length: int = 512,
) -> Image.Image:
"""
单张图片生成。
"""
width, height = get_resolution(resolution)
if DEVICE == "cuda":
generator = torch.Generator(device="cuda").manual_seed(int(seed))
else:
generator = torch.Generator().manual_seed(int(seed))
# Z-Image-Turbo 常用 FlowMatchEulerDiscreteScheduler
try:
p.scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=float(shift),
)
except Exception:
print("[Generate] Failed to assign scheduler, continue with existing scheduler:")
traceback.print_exc()
call_kwargs = dict(
prompt=prompt,
height=int(height),
width=int(width),
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
max_sequence_length=int(max_sequence_length),
)
# 不同 pipeline 版本参数支持可能不同,失败后去掉 max_sequence_length 再试。
try:
out = p(**call_kwargs)
except TypeError:
call_kwargs.pop("max_sequence_length", None)
out = p(**call_kwargs)
image = out.images[0]
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
return image.convert("RGB")
def warmup_model(p) -> None:
"""
极简 warmup。
原代码遍历全部分辨率,每个分辨率生成两张,风险很高。
这里仅在用户显式开启 ENABLE_WARMUP=true 时,对 1024x1024 跑一次短步数。
"""
if not ENABLE_WARMUP:
return
try:
print("[Warmup] Starting minimal warmup...")
generate_image(
p,
prompt="warmup",
resolution="1024x1024",
num_inference_steps=2,
guidance_scale=0.0,
seed=42,
)
print("[Warmup] Completed.")
except Exception:
print("[Warmup] Failed, ignored:")
traceback.print_exc()
cuda_cleanup()
# ==================== Prompt Expander ====================
@dataclass
class PromptOutput:
status: bool
prompt: str
seed: int
system_prompt: str
message: str
class PromptExpander:
def __init__(self, backend="api", **kwargs):
self.backend = backend
def decide_system_prompt(self, template_name=None):
return prompt_template
class APIPromptExpander(PromptExpander):
def __init__(self, api_config=None, **kwargs):
super().__init__(backend="api", **kwargs)
self.api_config = api_config or {}
self.client = self._init_api_client()
def _init_api_client(self):
try:
from openai import OpenAI
api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY
base_url = self.api_config.get(
"base_url",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
)
if not api_key:
print("[PE] Warning: DASHSCOPE_API_KEY not found. Prompt enhance unavailable.")
return None
return OpenAI(api_key=api_key, base_url=base_url)
except ImportError:
print("[PE] openai package not installed. Prompt enhance unavailable.")
return None
except Exception:
print("[PE] Failed to initialize API client:")
traceback.print_exc()
return None
def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
return self.extend(prompt, system_prompt, seed, **kwargs)
def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
if self.client is None:
return PromptOutput(False, "", seed, system_prompt or "", "API client not initialized")
if system_prompt is None:
system_prompt = self.decide_system_prompt()
if "{prompt}" in system_prompt:
system_prompt = system_prompt.format(prompt=prompt)
prompt = " "
try:
model = self.api_config.get("model", "qwen3-max-preview")
response = self.client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
temperature=0.7,
top_p=0.8,
)
content = response.choices[0].message.content or ""
expanded_prompt = content
json_start = content.find("```json")
if json_start != -1:
json_end = content.find("```", json_start + 7)
if json_end != -1:
json_str = content[json_start + 7: json_end].strip()
try:
data = json.loads(json_str)
expanded_prompt = data.get("revised_prompt", content)
except Exception:
expanded_prompt = content
return PromptOutput(True, expanded_prompt, seed, system_prompt, content)
except Exception as e:
return PromptOutput(False, "", seed, system_prompt, str(e))
def create_prompt_expander(backend="api", **kwargs):
if backend == "api":
return APIPromptExpander(**kwargs)
raise ValueError("Only 'api' backend is supported.")
def get_or_create_prompt_expander():
"""
Prompt enhancer 懒加载,避免启动阶段因 openai / key 问题影响主程序。
"""
global prompt_expander
if prompt_expander is not None:
return prompt_expander
try:
prompt_expander = create_prompt_expander(
backend="api",
api_config={"model": "qwen3-max-preview"},
)
print("[PE] Prompt expander ready.")
except Exception:
print("[PE] Prompt expander init failed:")
traceback.print_exc()
prompt_expander = None
return prompt_expander
def prompt_enhance(prompt: str, enable_enhance: bool) -> Tuple[str, str]:
if not enable_enhance:
return prompt, "Enhancement disabled."
expander = get_or_create_prompt_expander()
if not expander:
return prompt, "Prompt expander unavailable."
if not prompt.strip():
return "", "Please enter a prompt."
try:
result = expander(prompt)
if result.status:
return result.prompt, result.message
return prompt, f"Enhancement failed: {result.message}"
except Exception as e:
return prompt, f"Error: {str(e)}"
def get_or_load_pipe():
"""
关键修复:
不在应用启动阶段加载大模型;
只在 Generate 点击后进入 @spaces.GPU 环境时懒加载。
"""
global pipe, MODEL_LOAD_ERROR
refresh_runtime_device()
if pipe is not None:
return pipe
with model_lock:
if pipe is not None:
return pipe
try:
MODEL_LOAD_ERROR = ""
loaded_pipe = load_models(MODEL_PATH)
try_enable_aoti(loaded_pipe)
warmup_model(loaded_pipe)
pipe = loaded_pipe
print("[Init] Model loaded successfully.")
return pipe
except Exception:
MODEL_LOAD_ERROR = traceback.format_exc()
print("[Init] Model loading failed with full traceback:")
print(MODEL_LOAD_ERROR)
pipe = None
cuda_cleanup()
raise gr.Error(
"Model loading failed. Please open the Space logs to view the full traceback. "
"Common fixes: upgrade diffusers/transformers/accelerate, disable compile/warmup, "
"or check MODEL_PATH / HF_TOKEN."
)
def normalize_gallery_items(gallery_images) -> List[Any]:
"""
兼容 Gradio Gallery 在不同版本下返回的格式。
"""
if not gallery_images:
return []
result = []
for item in list(gallery_images):
try:
if isinstance(item, Image.Image):
result.append(item)
elif isinstance(item, str) and os.path.exists(item):
result.append(Image.open(item).convert("RGB"))
elif isinstance(item, (tuple, list)) and len(item) > 0:
first = item[0]
if isinstance(first, Image.Image):
result.append(first)
elif isinstance(first, str) and os.path.exists(first):
result.append(Image.open(first).convert("RGB"))
elif isinstance(item, dict):
img_obj = item.get("image") or item.get("path") or item.get("name")
if isinstance(img_obj, Image.Image):
result.append(img_obj)
elif isinstance(img_obj, str) and os.path.exists(img_obj):
result.append(Image.open(img_obj).convert("RGB"))
except Exception:
continue
return result[: max(0, MAX_GALLERY_HISTORY - 1)]
def run_safety_check_if_available(p, image: Image.Image, width: int, height: int) -> Image.Image:
"""
生成后安全检查。
默认不会启用,因为 ENABLE_SAFETY_CHECKER=false。
"""
try:
if getattr(p, "safety_feature_extractor", None) is None:
return image
if getattr(p, "safety_checker", None) is None:
return image
import numpy as np
clip_inputs = p.safety_feature_extractor([image], return_tensors="pt")
clip_input = clip_inputs.pixel_values.to(DEVICE)
img_np = np.array(image).astype("float32") / 255.0
img_np = img_np[None, ...]
_checked_images, has_nsfw = p.safety_checker(
images=img_np,
clip_input=clip_input,
)
if isinstance(has_nsfw, (list, tuple)) and len(has_nsfw) > 0 and bool(has_nsfw[0]):
return _load_nsfw_placeholder(width, height)
return image
except Exception:
print("[Safety] Check failed, ignored:")
traceback.print_exc()
return image
@spaces.GPU
def generate(
prompt,
resolution="1024x1024 ( 1:1 )",
seed=42,
steps=8,
shift=3.0,
random_seed=True,
gallery_images=None,
enhance=False,
progress=gr.Progress(track_tqdm=True),
):
"""
Gradio 生成入口。
这个函数在 ZeroGPU 环境中会触发动态 GPU 分配。
"""
try:
if not str(prompt or "").strip():
raise gr.Error("Please enter a prompt.")
current_pipe = get_or_load_pipe()
if random_seed:
new_seed = random.randint(1, 1_000_000)
else:
try:
new_seed = int(seed)
except Exception:
new_seed = 42
if new_seed == -1:
new_seed = random.randint(1, 1_000_000)
final_prompt = str(prompt or "").strip()
pe_msg = ""
if enhance:
final_prompt, pe_msg = prompt_enhance(final_prompt, True)
print(f"[PE] Enhanced prompt: {final_prompt}")
print(f"[PE] Message: {pe_msg}")
try:
resolution_str = str(resolution).split(" ")[0]
except Exception:
resolution_str = "1024x1024"
width, height = get_resolution(resolution_str)
# Z-Image-Turbo 通常 8 steps 左右即可。
safe_steps = max(1, min(int(steps), 100))
image = generate_image(
p=current_pipe,
prompt=final_prompt,
resolution=resolution_str,
seed=new_seed,
guidance_scale=0.0,
num_inference_steps=safe_steps + 1,
shift=float(shift),
)
image = run_safety_check_if_available(current_pipe, image, width, height)
old_images = normalize_gallery_items(gallery_images)
gallery = [image] + old_images
gallery = gallery[:MAX_GALLERY_HISTORY]
status = (
f"Done. DEVICE={DEVICE}, resolution={resolution_str}, "
f"steps={safe_steps + 1}, seed={new_seed}"
)
if pe_msg:
status += f"\nPrompt Enhance: {pe_msg[:300]}"
return gallery, str(new_seed), int(new_seed), status
except gr.Error:
raise
except Exception as e:
print("[Generate] Failed:")
traceback.print_exc()
cuda_cleanup()
raise gr.Error(f"Generation failed: {type(e).__name__}: {e}")
def update_res_choices(_res_cat):
if str(_res_cat) in RES_CHOICES:
res_choices = RES_CHOICES[str(_res_cat)]
else:
res_choices = RES_CHOICES["1024"]
return gr.update(value=res_choices[0], choices=res_choices)
def get_model_status():
"""
页面按钮:检查当前模型状态。
"""
if pipe is not None:
return f"Model loaded. DEVICE={DEVICE}, DTYPE={DTYPE}, MODEL_PATH={MODEL_PATH}"
if MODEL_LOAD_ERROR:
return "Model not loaded. Last loading error:\n" + MODEL_LOAD_ERROR[-4000:]
return (
"Model not loaded yet. This is normal. "
"The model will be loaded when you click Generate."
)
css = """
.fillable {
max-width: 1230px !important;
}
.gradio-container {
max-width: 1280px !important;
}
"""
# ==================== Gradio UI ====================
with gr.Blocks(title="Z-Image Demo") as demo:
gr.Markdown(
"""<div align="center">
# Z-Image Generation Demo
*ZeroGPU friendly lazy-loading version*
</div>"""
)
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="Prompt",
lines=4,
placeholder="Enter your prompt here...",
)
with gr.Row():
choices = [int(k) for k in RES_CHOICES.keys()]
res_cat = gr.Dropdown(
value=1024,
choices=choices,
label="Resolution Category",
)
initial_res_choices = RES_CHOICES["1024"]
resolution = gr.Dropdown(
value=initial_res_choices[0],
choices=initial_res_choices,
label="Width x Height (Ratio)",
)
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
precision=0,
)
random_seed = gr.Checkbox(
label="Random Seed",
value=True,
)
with gr.Row():
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=100,
value=8,
step=1,
interactive=True,
)
shift = gr.Slider(
label="Time Shift",
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.1,
interactive=True,
)
enhance = gr.Checkbox(
label="Enhance Prompt with DashScope",
value=False,
info="Requires DASHSCOPE_API_KEY and openai package. Keep disabled if not needed.",
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
status_btn = gr.Button("Model Status")
status_box = gr.Textbox(
label="Status / Logs",
lines=6,
interactive=False,
)
gr.Markdown("### 📝 Example Prompts")
gr.Examples(
examples=EXAMPLE_PROMPTS,
inputs=prompt_input,
label=None,
)
with gr.Column(scale=1):
output_gallery = gr.Gallery(
label="Generated Images",
columns=2,
rows=2,
height=600,
object_fit="contain",
format="png",
interactive=False,
)
used_seed = gr.Textbox(
label="Seed Used",
interactive=False,
)
res_cat.change(
update_res_choices,
inputs=res_cat,
outputs=resolution,
)
generate_btn.click(
generate,
inputs=[
prompt_input,
resolution,
seed,
steps,
shift,
random_seed,
output_gallery,
enhance,
],
outputs=[
output_gallery,
used_seed,
seed,
status_box,
],
)
status_btn.click(
get_model_status,
inputs=[],
outputs=[status_box],
)
if __name__ == "__main__":
# 兼容不同 Gradio 版本。
# 新版本支持 mcp_server,旧版本不支持时自动降级。
try:
demo.launch(mcp_server=True)
except TypeError:
demo.launch()