| | |
| | import spaces |
| | import gradio as gr |
| | from gradio import update |
| | from functools import lru_cache |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| | from opencc import OpenCC |
| |
|
| | |
| | cc = OpenCC('s2t') |
| |
|
| | |
| | MODEL_LIST = [ |
| | "liswei/Taiwan-ELM-270M", |
| | "Mxode/SmolLM-Chinese-180M", |
| | "flyingfishinwater/chinese-baby-llama2", |
| | "unsloth/gemma-3-1b-pt", |
| | "ckiplab/gpt2-tiny-chinese", |
| | "ckiplab/gpt2-base-chinese", |
| | "liswei/Taiwan-ELM-1_1B", |
| | "benchang1110/Qwen2.5-Taiwan-1.5B-Instruct", |
| | "benchang1110/Taiwan-tinyllama-v1.0-base", |
| | "lianghsun/Llama-3.2-Taiwan-3B", |
| | "twinkle-ai/Llama-3.2-3B-F1-Instruct", |
| | "Epiculous/Violet_Twilight-v0.2", |
| | ] |
| |
|
| | def merge_common_prefixes(suggestions, min_len=2): |
| | """ |
| | 合併具有共同前綴的建議: |
| | - 找出所有長度 ≥ min_len 的共同前綴 |
| | - 將這些前綴作為新建議,移除原有被合併的項目 |
| | """ |
| | prefixes = [] |
| | to_remove = set() |
| |
|
| | for i in range(len(suggestions)): |
| | for j in range(i+1, len(suggestions)): |
| | s1, s2 = suggestions[i], suggestions[j] |
| | |
| | common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2) |
| | if len(common) >= min_len: |
| | prefixes.append(common) |
| | to_remove.update([s1, s2]) |
| |
|
| | |
| | unique_prefixes = [] |
| | for p in prefixes: |
| | if p not in unique_prefixes: |
| | unique_prefixes.append(p) |
| |
|
| | |
| | remainder = [s for s in suggestions if s not in to_remove] |
| | return unique_prefixes + remainder |
| |
|
| | @lru_cache(maxsize=8) |
| | def get_pipeline(model_name): |
| | tok = AutoTokenizer.from_pretrained(model_name) |
| | mdl = AutoModelForCausalLM.from_pretrained( |
| | model_name, weights_only=False, trust_remote_code=True |
| | ) |
| | mdl.to("cuda") |
| | return pipeline("text-generation", model=mdl, tokenizer=tok, device=0) |
| |
|
| | @spaces.GPU |
| | def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty): |
| | """ |
| | 使用 Diverse Beam Search 產生 m 條候選: |
| | - num_beams = m |
| | - num_beam_groups, diversity_penalty 可調整多樣性 |
| | 之後轉繁體、去重、合併共同前綴後回傳。 |
| | """ |
| | gen_pipe = get_pipeline(model_name) |
| | outs = gen_pipe( |
| | text, |
| | max_new_tokens=k, |
| | num_beams=m, |
| | num_beam_groups=num_beam_groups, |
| | diversity_penalty=diversity_penalty, |
| | num_return_sequences=m, |
| | do_sample=False, |
| | early_stopping=True |
| | ) |
| | |
| | suggestions = [out["generated_text"][len(text):].strip() for out in outs] |
| | suggestions = [s for s in suggestions if s] |
| | suggestions = [cc.convert(s) for s in suggestions] |
| | |
| | unique_suggestions = [] |
| | for s in suggestions: |
| | if s not in unique_suggestions: |
| | unique_suggestions.append(s) |
| |
|
| | |
| | final_suggestions = merge_common_prefixes(unique_suggestions, min_len=2) |
| |
|
| | return update(choices=final_suggestions, value=None) |
| |
|
| |
|
| | def append_suggestion(current, choice): |
| | if choice is None: |
| | return current |
| | |
| | return current + choice |
| |
|
| | |
| | custom_css = """ |
| | #suggestions-bar { |
| | width: 100%; |
| | margin-bottom: 8px; |
| | } |
| | #suggestions-bar .candidate-list { |
| | display: flex; |
| | gap: 8px; |
| | background: #fff; |
| | border: 1px solid #999; |
| | border-radius: 4px; |
| | padding: 6px; |
| | overflow-x: auto; |
| | white-space: nowrap; |
| | } |
| | #suggestions-bar .candidate-list label { |
| | cursor: pointer; |
| | padding: 6px 10px; |
| | font-size: 16px; |
| | } |
| | #suggestions-bar .candidate-list label:hover { |
| | background: #f5f5f5; |
| | } |
| | #suggestions-bar .candidate-list input[type=radio]:checked + label { |
| | background: #e6f7ff; |
| | border: 1px solid #1890ff; |
| | } |
| | #input-box textarea { |
| | width: 100%; |
| | font-size: 16px; |
| | padding: 6px; |
| | box-sizing: border-box; |
| | overflow: hidden; |
| | resize: none; |
| | } |
| | #predict-button { |
| | margin-top: 8px; |
| | width: 100%; |
| | } |
| | /* 手機響應式 */ |
| | @media only screen and (max-width: 600px) { |
| | #suggestions-bar .candidate-list label { |
| | padding: 8px; |
| | font-size: 18px; |
| | } |
| | #predict-button { |
| | font-size: 18px; |
| | } |
| | } |
| | """ |
| |
|
| | |
| | auto_height_js = """ |
| | <script> |
| | window.addEventListener('load', () => { |
| | const textarea = document.querySelector('#input-box textarea'); |
| | if (!textarea) return; |
| | textarea.style.height = 'auto'; |
| | textarea.addEventListener('input', function() { |
| | this.style.height = 'auto'; |
| | this.style.height = this.scrollHeight + 'px'; |
| | }); |
| | }); |
| | </script> |
| | """ |
| |
|
| | with gr.Blocks(css=custom_css) as demo: |
| | gr.HTML(auto_height_js) |
| | gr.Markdown( |
| | "## 🇹🇼 繁體中文 IME 加速器 \ |
| | " |
| | "結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。" |
| | ) |
| |
|
| | with gr.Column(): |
| | suggestions = gr.Radio( |
| | [], label="", interactive=True, type="value", |
| | elem_id="suggestions-bar", elem_classes="candidate-list" |
| | ) |
| | input_text = gr.Textbox( |
| | label="", placeholder="請輸入拼音或文字…", |
| | lines=1, max_lines=20, elem_id="input-box" |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | auto_predict = gr.Checkbox( |
| | value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict" |
| | ) |
| | predict_button = gr.Button( |
| | "預測", elem_id="predict-button" |
| | ) |
| |
|
| | with gr.Accordion("進階設定", open=False): |
| | model_selector = gr.Dropdown( |
| | MODEL_LIST, value=MODEL_LIST[0], label="模型" |
| | ) |
| | k_slider = gr.Slider( |
| | minimum=1, maximum=50, step=1, value=10, label="K(最大新詞元數)" |
| | ) |
| | m_slider = gr.Slider( |
| | minimum=1, maximum=30, step=1, value=30, label="M(建議數/Beam 數)" |
| | ) |
| | group_slider = gr.Slider( |
| | minimum=1, maximum=30, step=1, value=30, |
| | label="Beam 群組數 (num_beam_groups)" |
| | ) |
| | diversity_penalty_slider = gr.Slider( |
| | minimum=0.0, maximum=2.0, step=0.1, value=1.0, |
| | label="多樣性懲罰 (diversity_penalty)" |
| | ) |
| |
|
| | |
| | predict_button.click( |
| | fn=suggest_next, |
| | inputs=[ |
| | input_text, |
| | model_selector, |
| | k_slider, |
| | m_slider, |
| | group_slider, |
| | diversity_penalty_slider |
| | ], |
| | outputs=suggestions, |
| | ) |
| | input_text.change( |
| | fn=lambda txt, mdl, k, m, g, d, auto: ( |
| | suggest_next(txt, mdl, k, m, g, d) |
| | if auto else update(choices=[], value=None) |
| | ), |
| | inputs=[ |
| | input_text, |
| | model_selector, |
| | k_slider, |
| | m_slider, |
| | group_slider, |
| | diversity_penalty_slider, |
| | auto_predict |
| | ], |
| | outputs=suggestions, |
| | ) |
| | suggestions.change( |
| | fn=append_suggestion, |
| | inputs=[input_text, suggestions], |
| | outputs=input_text, |
| | ) |
| |
|
| | demo.launch() |