🏠 首页
⬇️ 下载
import torch
import os
import torch.multiprocessing as mp
from modelscope import ZImagePipeline
from datetime import datetime
from flask import Flask, render_template_string, send_from_directory, request, jsonify
import logging
import signal
import sys
import random
import time
import json

# ================= 配置区 =================
MODEL_PATH = "./Z-Image-Turbo"
OUTPUT_DIR = "./img"
HTTP_PORT = 5000
DEFAULT_NEG = "ugly, deformed, noisy, blurry, low contrast, text, watermark, bad anatomy, bad hands, low quality"

app = Flask(__name__)
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)

processes = []
shared_queue = None

# ================= 前端界面 (布局优化版) =================
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no, viewport-fit=cover">
    <title>Z-Image-Turbo</title>
    <style>
        :root { --primary: #00E5FF; --bg: #050505; --panel: #111; --border: #222; --text: #ddd; }
        * { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
        body { font-family: "PingFang SC", sans-serif; background: var(--bg); color: var(--text); margin: 0; display: flex; height: 100vh; overflow: hidden; }
        
        /* === 侧边栏优化 === */
        .sidebar { 
            width: 380px; background: var(--panel); padding: 20px; 
            border-right: 1px solid var(--border); display: flex; flex-direction: column; gap: 16px; 
            flex-shrink: 0; overflow-y: auto; z-index: 20;
        }
        
        /* 强制显示滚动条 (解决电脑端看不到底部的问题) */
        .sidebar::-webkit-scrollbar { width: 6px; }
        .sidebar::-webkit-scrollbar-track { background: #000; }
        .sidebar::-webkit-scrollbar-thumb { background: #333; border-radius: 3px; }
        .sidebar::-webkit-scrollbar-thumb:hover { background: var(--primary); }

        .header h2 { margin: 0; font-size: 20px; color: #fff; letter-spacing: 1.5px; font-weight: 800; font-style: italic; background: linear-gradient(90deg, #fff, var(--primary)); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
        .header .status { font-size: 12px; color: #666; margin-top: 6px; font-family: monospace; }
        .control-label { font-size: 12px; color: #888; font-weight: 600; margin-bottom: 5px; display: block; }
        
        /* 输入框 */
        textarea, input { width: 100%; background: #1a1a1a; color: #fff; border: 1px solid var(--border); border-radius: 8px; padding: 10px; font-size: 13px; transition: all 0.2s; }
        textarea:focus, input:focus { border-color: var(--primary); background: #222; }
        textarea { min-height: 90px; resize: vertical; }
        
        /* 分辨率按钮 (紧凑) */
        .res-presets { display: grid; grid-template-columns: repeat(2, 1fr); gap: 8px; margin-bottom: 8px; }
        .res-btn { background: #222; border: 1px solid #333; color: #ccc; padding: 10px; border-radius: 6px; font-size: 12px; cursor: pointer; display: flex; flex-direction: column; align-items: center; gap: 2px; transition: 0.2s; }
        .res-btn span { font-size: 10px; color: #666; font-family: monospace; }
        .res-btn.active { background: rgba(0, 229, 255, 0.15); border-color: var(--primary); color: #fff; }
        .res-btn.active span { color: var(--primary); }
        
        /* 自定义分辨率 */
        .custom-res { display: none; gap: 8px; align-items: center; animation: slideDown 0.3s; }
        .custom-res.show { display: flex; }
        .res-input-wrap { position: relative; flex: 1; }
        .res-input-wrap span { position: absolute; right: 10px; top: 50%; transform: translateY(-50%); font-size: 11px; color: #555; }
        @keyframes slideDown { from { opacity: 0; transform: translateY(-10px); } to { opacity: 1; transform: translateY(0); } }
        
        /* === 参数并排显示 (节省空间且直观) === */
        .params-row { display: grid; grid-template-columns: 1fr 1fr; gap: 12px; background: #161616; padding: 12px; border-radius: 8px; border: 1px solid var(--border); }
        .param-item { display: flex; flex-direction: column; gap: 5px; }
        .param-header { display: flex; justify-content: space-between; font-size: 11px; color: #888; }
        input[type=range] { -webkit-appearance: none; width: 100%; background: transparent; margin: 5px 0; }
        input[type=range]::-webkit-slider-thumb { -webkit-appearance: none; height: 12px; width: 12px; border-radius: 50%; background: var(--primary); margin-top: -4px; }
        input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 4px; background: #333; border-radius: 2px; }
        
        /* 按钮 */
        button#gen-btn { width: 100%; padding: 14px; border-radius: 8px; border: none; font-weight: 700; background: var(--primary); color: #000; cursor: pointer; margin-top: 5px; letter-spacing: 1px; font-size: 15px; }
        button:disabled { background: #333; color: #777; }
        
        /* 画廊 */
        .main { flex: 1; padding: 20px; overflow-y: auto; background: var(--bg); }
        .gallery { display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 15px; }
        .img-card { background: #111; border-radius: 12px; overflow: hidden; position: relative; cursor: zoom-in; aspect-ratio: 1; box-shadow: 0 4px 10px rgba(0,0,0,0.3); border: 1px solid transparent; transition: 0.2s; }
        .img-card:hover { border-color: #333; }
        .img-card img { width: 100%; height: 100%; object-fit: cover; display: block; opacity: 0; animation: fadeIn 0.5s forwards; }
        @keyframes fadeIn { to { opacity: 1; } }
        
        /* 灯箱 */
        .lightbox { display: none; position: fixed; top: 0; left: 0; width: 100%; height: 100%; background: rgba(0,0,0,0.95); backdrop-filter: blur(8px); z-index: 1000; flex-direction: column; }
        .lightbox.active { display: flex; }
        .lb-img-area { flex: 1; display: flex; align-items: center; justify-content: center; padding: 20px; overflow: hidden; }
        .lb-img-area img { max-width: 100%; max-height: 100%; box-shadow: 0 0 30px rgba(0,0,0,0.5); border-radius: 4px; cursor: zoom-out; }
        .lb-panel { background: #161616; border-top: 1px solid #333; padding: 20px; width: 100%; flex-shrink: 0; max-height: 40vh; overflow-y: auto; }
        .lb-meta-row { display: flex; gap: 15px; margin-bottom: 12px; font-size: 12px; color: var(--primary); font-family: monospace; }
        .lb-prompt-box { background: #222; padding: 12px; border-radius: 8px; font-size: 13px; line-height: 1.5; color: #ccc; white-space: pre-wrap; word-break: break-word; border: 1px solid #333; }
        .lb-label { font-size: 11px; color: #666; margin-bottom: 4px; display: block; font-weight: bold; }
        
        @media (max-width: 768px) { body { flex-direction: column; height: auto; overflow-y: auto; } .sidebar { width: 100%; border-right: none; padding: 16px; gap: 12px; background: #000; } .main { padding: 10px; } .gallery { grid-template-columns: repeat(2, 1fr); gap: 8px; } .lb-panel { padding: 15px; } }
    </style>
</head>
<body>
    <div class="sidebar">
        <div class="header"><h2>Z-Image-Turbo</h2><div class="status">队列: <span id="q-size">0</span> | 完成: <span id="f-count">0</span></div></div>
        
        <div><label class="control-label">提示词 (Prompt)</label><textarea id="prompt" placeholder="在此输入英文提示词...">Cyberpunk city, neon lights, 8k best quality</textarea></div>
        
        <div>
            <label class="control-label">尺寸选择</label>
            <div class="res-presets">
                <div class="res-btn" onclick="selectRes(512, 512, this)">极速<span>512x512</span></div>
                <div class="res-btn active" onclick="selectRes(1024, 1024, this)">标准<span>1024x1024</span></div>
                <div class="res-btn" onclick="selectRes(2048, 2048, this)">超清<span>2048x2048</span></div>
                <div class="res-btn" onclick="toggleCustom(this)">自定义<span>Custom</span></div>
            </div>
            <div class="custom-res" id="custom-inputs">
                <div class="res-input-wrap"><input type="number" id="custom-w" value="1024" step="8"><span>宽</span></div>
                <div style="color:#666">×</div>
                <div class="res-input-wrap"><input type="number" id="custom-h" value="1024" step="8"><span>高</span></div>
            </div>
        </div>

        <div><label class="control-label">种子 (Seed)</label><input type="number" id="seed" placeholder="留空为随机 (-1)" value=""></div>

        <div>
            <label class="control-label">高级参数</label>
            <div class="params-row">
                <div class="param-item">
                    <div class="param-header"><span>CFG 引导</span><span id="val-cfg" style="color:var(--primary)">0.0</span></div>
                    <input type="range" id="cfg" min="0" max="5" step="0.5" value="0" oninput="document.getElementById('val-cfg').innerText=this.value">
                </div>
                <div class="param-item">
                    <div class="param-header"><span>迭代步数</span><span id="val-steps" style="color:var(--primary)">4</span></div>
                    <input type="range" id="steps" min="1" max="12" value="4" oninput="document.getElementById('val-steps').innerText=this.value">
                </div>
            </div>
        </div>
        
        <button id="gen-btn" onclick="generate()">立即生成 (4张)</button>
    </div>
    
    <div class="main"><div class="gallery" id="gallery"></div></div>
    
    <div id="lightbox" class="lightbox" onclick="closeLightbox(event)">
        <div class="lb-img-area"><img id="lb-img" src=""></div>
        <div class="lb-panel" onclick="event.stopPropagation()">
            <div class="lb-meta-row"><span id="lb-res">尺寸: --</span><span id="lb-seed">种子: --</span><span id="lb-steps">步数: --</span></div>
            <label class="lb-label">提示词 (PROMPT)</label>
            <div class="lb-prompt-box" id="lb-prompt">加载中...</div>
            <div style="margin-top:10px; text-align:right;"><a id="lb-dl" href="#" target="_blank" style="color:#666; text-decoration:underline; font-size:12px;">查看原图</a></div>
        </div>
    </div>
    
    <script>
        let lastJson="", isGenerating=false, currentW=1024, currentH=1024, isCustom=false;
        function selectRes(w,h,btn){isCustom=false;currentW=w;currentH=h;document.querySelectorAll('.res-btn').forEach(b=>b.classList.remove('active'));btn.classList.add('active');document.getElementById('custom-inputs').classList.remove('show');}
        function toggleCustom(btn){isCustom=true;document.querySelectorAll('.res-btn').forEach(b=>b.classList.remove('active'));btn.classList.add('active');document.getElementById('custom-inputs').classList.add('show');}
        
        async function generate(){
            if(isGenerating)return;
            const btn=document.getElementById('gen-btn');
            
            let finalW=isCustom?parseInt(document.getElementById('custom-w').value):currentW;
            let finalH=isCustom?parseInt(document.getElementById('custom-h').value):currentH;
            finalW-=(finalW%8);finalH-=(finalH%8);
            
            let seedInput = document.getElementById('seed').value;
            let finalSeed = (seedInput === "" || seedInput === null) ? -1 : parseInt(seedInput);
            
            const data={
                prompt:document.getElementById('prompt').value,
                width:finalW,
                height:finalH,
                seed:finalSeed,
                steps:parseInt(document.getElementById('steps').value),
                cfg:parseFloat(document.getElementById('cfg').value)
            };
            
            if(!data.prompt)return alert("请输入提示词");
            isGenerating=true;btn.disabled=true;btn.innerText="生成中...";
            try{await fetch('/api/generate',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(data)});btn.innerText="已加入队列";setTimeout(()=>{btn.disabled=false;btn.innerText="立即生成 (4张)";isGenerating=false;},2000);}catch(e){alert("请求失败: "+e);btn.disabled=false;isGenerating=false;}
        }
        
        async function openLightbox(f){
            const lb=document.getElementById('lightbox');lb.classList.add('active');
            document.getElementById('lb-img').src=`/img/${f}`;document.getElementById('lb-dl').href=`/img/${f}`;
            document.getElementById('lb-prompt').innerText="正在读取信息...";
            try{
                const res=await fetch(`/api/meta/${f}`);
                const data=await res.json();
                document.getElementById('lb-prompt').innerText=data.prompt||"无数据";
                document.getElementById('lb-res').innerText=`尺寸: ${data.width}x${data.height}`;
                document.getElementById('lb-seed').innerText=`种子: ${data.seed}`;
                document.getElementById('lb-steps').innerText=`步数: ${data.steps}`;
            }catch(e){document.getElementById('lb-prompt').innerText="读取失败";}
        }
        function closeLightbox(e){document.getElementById('lightbox').classList.remove('active');}
        async function refresh(){try{const s=await(await fetch('/api/stats')).json();document.getElementById('q-size').innerText=s.queue;document.getElementById('f-count').innerText=s.total;const imgs=await(await fetch('/api/images')).json();const jsonStr=JSON.stringify(imgs);if(jsonStr!==lastJson){document.getElementById('gallery').innerHTML=imgs.map(f=>`<div class="img-card" onclick="openLightbox('${f}')"><img src="/img/${f}" loading="lazy"></div>`).join('');lastJson=jsonStr;}}catch(e){}}setInterval(refresh,2000);refresh();
    </script>
</body>
</html>
"""

# ================= 后端逻辑 =================

@app.route('/')
def index(): return render_template_string(HTML_TEMPLATE)

@app.route('/api/stats')
def stats():
    if not os.path.exists(OUTPUT_DIR): return jsonify({"queue": 0, "total": 0})
    count = len([n for n in os.listdir(OUTPUT_DIR) if n.endswith('.png')])
    q_size = shared_queue.qsize() if shared_queue else 0
    return jsonify({"queue": q_size, "total": count})

@app.route('/api/images')
def get_images():
    if not os.path.exists(OUTPUT_DIR): return jsonify([])
    files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith('.png') and not f.startswith('tmp_')]
    files.sort(key=lambda x: os.path.getmtime(os.path.join(OUTPUT_DIR, x)), reverse=True)
    return jsonify(files[:60])

@app.route('/api/meta/<path:filename>')
def get_meta(filename):
    json_path = os.path.join(OUTPUT_DIR, filename + ".json")
    if os.path.exists(json_path):
        try:
            with open(json_path, 'r') as f: return jsonify(json.load(f))
        except: pass
    parts = filename.split('_')
    return jsonify({
        "prompt": "元数据丢失",
        "width": parts[0].split('x')[0] if 'x' in parts[0] else "?",
        "height": parts[0].split('x')[1] if 'x' in parts[0] else "?",
        "seed": parts[1] if len(parts)>1 else "?",
        "steps": "?"
    })

@app.route('/img/<path:f>')
def img(f): return send_from_directory(OUTPUT_DIR, f)

@app.route('/api/generate', methods=['POST'])
def generate_api():
    d = request.json
    w = int(d.get('width', 1024))
    h = int(d.get('height', 1024))
    
    frontend_seed = int(d.get('seed', -1))
    final_seed = frontend_seed if frontend_seed != -1 else random.randint(0, 2**32 - 1)
    
    task = {
        'p': d['prompt'], 
        'w': w, 'h': h, 
        's': final_seed,
        'steps': int(d.get('steps', 4)), 
        'cfg': float(d.get('cfg', 0.0))
    }
    shared_queue.put(task)
    return jsonify({"status": "ok"})

def signal_handler(sig, frame):
    print("\n[系统] 正在关闭服务...")
    for p in processes:
        if p.is_alive():
            p.terminate()
            p.join(timeout=1)
    sys.exit(0)

# ================= GPU Worker =================

def gpu_worker(rank, model_path, output_dir, queue):
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    
    device = f"cuda:{rank}"
    try:
        torch.cuda.set_device(device)
        print(f"[GPU {rank}] 正在加载模型...")
        torch.cuda.empty_cache()
        pipe = ZImagePipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16, local_files_only=True).to(device)
        print(f"[GPU {rank}] 就绪")
    except Exception as e:
        print(f"[GPU {rank}] ❌ 初始化失败: {e}")
        return

    while True:
        try:
            task = queue.get()
            p, w, h, s = task['p'], task['w'], task['h'], task['s']
            steps, cfg = task['steps'], task['cfg']
            
            print(f"[GPU {rank}] 绘图: {w}x{h} | S:{s} | CFG:{cfg}")
            g = torch.Generator(device).manual_seed(s)
            
            with torch.inference_mode():
                imgs = pipe(
                    prompt=p,
                    negative_prompt=DEFAULT_NEG,
                    height=h, width=w,
                    num_inference_steps=steps, 
                    guidance_scale=cfg,
                    num_images_per_prompt=4,
                    generator=g
                ).images
                
                for i, img in enumerate(imgs):
                    ts = datetime.now().strftime("%H%M%S")
                    fn = f"{w}x{h}_{s}_{ts}_g{rank}_{i}.png"
                    img.save(os.path.join(output_dir, f"tmp_{fn}"))
                    meta = {
                        "prompt": p, "width": w, "height": h, "seed": s,
                        "steps": steps, "cfg": cfg, "timestamp": ts, "gpu": rank
                    }
                    with open(os.path.join(output_dir, f"tmp_{fn}.json"), 'w') as f:
                        json.dump(meta, f)
                    os.rename(os.path.join(output_dir, f"tmp_{fn}"), os.path.join(output_dir, fn))
                    os.rename(os.path.join(output_dir, f"tmp_{fn}.json"), os.path.join(output_dir, fn + ".json"))
            
            torch.cuda.empty_cache()

        except KeyboardInterrupt:
            break
        except Exception as e:
            print(f"[GPU {rank}] 错误: {e}")
            time.sleep(1)

if __name__ == "__main__":
    signal.signal(signal.SIGINT, signal_handler)
    mp.set_start_method('spawn', force=True)
    if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR)
    manager = mp.Manager()
    shared_queue = manager.Queue()
    print(f"[系统] 启动中... (检测到 {torch.cuda.device_count()} GPU)")
    for r in range(torch.cuda.device_count()):
        p = mp.Process(target=gpu_worker, args=(r, MODEL_PATH, OUTPUT_DIR, shared_queue))
        p.daemon = True 
        p.start()
        processes.append(p)
    app.run(host='0.0.0.0', port=HTTP_PORT, threaded=True, use_reloader=False)
链接已复制!
可在终端使用 wget "链接" 下载