from transformers import pipeline, set_seed import gradio as gr import torch import re MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall" generator = pipeline( "text-generation", model=MODEL_NAME, device=-1, torch_dtype=torch.float32 ) def is_mostly_chinese(text, threshold=0.7): if not text.strip(): return False total = len(text) chinese_count = sum( 1 for c in text if '\u4e00' <= c <= '\u9fff' or c in ',。!?;:「」()《》…\n\t ' ) return (chinese_count / total) >= threshold if total > 0 else False def generate_story(starter, max_length=100, temperature=0.9): if not starter.strip(): return "請輸入一個故事開頭喔!" # ✅ 要求更長、更豐富的故事 prompt = f"""請用現代繁體中文寫一個迪士尼風格的童話故事。 開頭:「{starter}」 要求:150字以內,要有魔法、動物朋友、情節轉折、溫馨結局。 故事:""" try: set_seed(42) outputs = generator( prompt, max_new_tokens=int(max_length), # ✅ 完全跟隨滑桿值 temperature=float(temperature), do_sample=True, top_k=60, top_p=0.9, repetition_penalty=1.2, pad_token_id=generator.tokenizer.eos_token_id ) full_text = outputs[0]['generated_text'] if "故事:" in full_text: story = full_text.split("故事:", 1)[-1].strip() else: story = full_text.replace(prompt, "").strip() story = re.sub(r'\s+', ' ', story).strip() story = re.sub(r'[「」]+', '', story) if len(story) < 20 or not is_mostly_chinese(story): # ✅ 放寬最小字數 return "AI 今天文思泉湧... 試試調高『創意溫度』或再按一次!" return story except Exception as e: return f"❌ 錯誤:{str(e)}\n建議降低長度或重試" demo = gr.Interface( fn=generate_story, inputs=[ gr.Textbox( lines=2, placeholder="例如:小矮人發現了一張藏寶圖,指向雲端的糖果城堡...", label="📖 輸入你的故事開頭" ), gr.Slider(50, 200, value=120, step=10, label="📏 故事長度(token數)"), # ✅ 最高 200 gr.Slider(0.8, 1.3, value=1.0, step=0.1, label="🎨 創意溫度") ], outputs=gr.Textbox( label="📜 AI 生成的童話故事", lines=12, # ✅ 增加顯示行數 max_lines=30, # ✅ 增加最大行數 placeholder="長篇魔法故事即將誕生...", autoscroll=True ), title="🐉 長篇中文 AI 說書人", description="✨ 支援更長故事!輸入開頭,AI 幫你寫出情節豐富的奇幻冒險~", examples=[ ["小矮人發現了一張藏寶圖,指向雲端的糖果城堡...", 150, 1.1], ["會說話的雲朵邀請我去參加天空動物園的開幕典禮...", 180, 1.2], ["我的鬧鐘其實是時光機,每天早上帶我去不同年代...", 160, 1.0] ], theme=gr.themes.Soft() ) demo.css = """ .output-text { font-size: 18px !important; line-height: 1.8; } """ if __name__ == "__main__": demo.launch()