Spaces:
Sleeping
Sleeping
Update app/app.py
Browse files- app/app.py +23 -59
app/app.py
CHANGED
|
@@ -1,101 +1,65 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import re
|
| 3 |
-
import torch
|
| 4 |
from flask import Flask, request, jsonify, send_from_directory
|
| 5 |
-
|
| 6 |
-
# Import local SLM components
|
| 7 |
from slm_qa import TinyTransformer, encode, wrap_bos_eos, itos, PAD, BOS, EOS
|
| 8 |
|
| 9 |
-
# Resolve absolute paths
|
| 10 |
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
| 12 |
MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
|
| 13 |
|
| 14 |
app = Flask(__name__, static_folder=STATIC_DIR)
|
| 15 |
-
|
| 16 |
-
# Moderation rule: ban exact word "sex" (case-insensitive)
|
| 17 |
BAN_REGEX = re.compile(r"(?i)\bsex\b")
|
| 18 |
|
| 19 |
-
# Model setup
|
| 20 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
VOCAB_SIZE = len(itos)
|
| 22 |
MAX_LEN = 128
|
| 23 |
|
| 24 |
-
model = TinyTransformer(
|
| 25 |
-
vocab_size=VOCAB_SIZE,
|
| 26 |
-
d_model=128,
|
| 27 |
-
n_heads=4,
|
| 28 |
-
n_layers=2,
|
| 29 |
-
d_ff=256,
|
| 30 |
-
dropout=0.1,
|
| 31 |
-
max_len=MAX_LEN,
|
| 32 |
-
).to(DEVICE)
|
| 33 |
-
|
| 34 |
-
# Load checkpoint if present
|
| 35 |
ckpt_path = os.path.join(MODELS_DIR, "slm_qa_best.pt")
|
| 36 |
if os.path.exists(ckpt_path):
|
| 37 |
model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
|
| 38 |
model.eval()
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def generate_answer(question: str, max_new_tokens: int = 40) -> str:
|
| 41 |
q_ids = encode("q: " + question)
|
| 42 |
a_prefix = encode("a:")
|
| 43 |
-
tokens = wrap_bos_eos(q_ids + a_prefix)[:-1]
|
| 44 |
x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
for _ in range(max_new_tokens):
|
| 48 |
-
if x.size(1) >= MAX_LEN:
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
next_id
|
| 52 |
-
|
| 53 |
-
break
|
| 54 |
-
x = torch.cat(
|
| 55 |
-
[x, torch.tensor([[next_id]], dtype=torch.long, device=DEVICE)],
|
| 56 |
-
dim=1,
|
| 57 |
-
)
|
| 58 |
|
| 59 |
gen_ids = x.squeeze(0).tolist()
|
| 60 |
-
prefix_len = 1 + len(q_ids) + len(a_prefix)
|
| 61 |
answer_ids = gen_ids[prefix_len:]
|
| 62 |
out = " ".join(itos[i] for i in answer_ids if i not in (PAD, BOS, EOS)).strip()
|
| 63 |
return out if out else "..."
|
| 64 |
|
| 65 |
@app.route("/")
|
| 66 |
-
def index():
|
| 67 |
-
return send_from_directory(STATIC_DIR, "index.html")
|
| 68 |
-
|
| 69 |
-
@app.route("/static/<path:filename>")
|
| 70 |
-
def static_files(filename):
|
| 71 |
-
return send_from_directory(STATIC_DIR, filename)
|
| 72 |
-
|
| 73 |
-
@app.route("/health")
|
| 74 |
-
def health():
|
| 75 |
-
return jsonify({"ok": True})
|
| 76 |
|
| 77 |
@app.route("/api/moderate", methods=["POST"])
|
| 78 |
def moderate():
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
banned = bool(BAN_REGEX.search(text))
|
| 82 |
-
return jsonify({"banned": banned})
|
| 83 |
|
| 84 |
@app.route("/api/answer", methods=["POST"])
|
| 85 |
def answer():
|
| 86 |
-
|
| 87 |
-
question
|
| 88 |
-
if
|
| 89 |
-
|
| 90 |
-
if BAN_REGEX.search(question):
|
| 91 |
-
return jsonify({"ok": False, "answer": "", "error": "banned"}), 403
|
| 92 |
-
|
| 93 |
-
try:
|
| 94 |
-
ans = generate_answer(question)
|
| 95 |
-
return jsonify({"ok": True, "answer": ans})
|
| 96 |
-
except Exception:
|
| 97 |
-
# Avoid leaking stack traces in production
|
| 98 |
-
return jsonify({"ok": False, "answer": "", "error": "server_error"}), 500
|
| 99 |
|
| 100 |
if __name__ == "__main__":
|
| 101 |
port = int(os.environ.get("PORT", "7860"))
|
|
|
|
| 1 |
+
import os, re, torch
|
|
|
|
|
|
|
| 2 |
from flask import Flask, request, jsonify, send_from_directory
|
|
|
|
|
|
|
| 3 |
from slm_qa import TinyTransformer, encode, wrap_bos_eos, itos, PAD, BOS, EOS
|
| 4 |
|
|
|
|
| 5 |
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
| 7 |
MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
|
| 8 |
|
| 9 |
app = Flask(__name__, static_folder=STATIC_DIR)
|
|
|
|
|
|
|
| 10 |
BAN_REGEX = re.compile(r"(?i)\bsex\b")
|
| 11 |
|
|
|
|
| 12 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
VOCAB_SIZE = len(itos)
|
| 14 |
MAX_LEN = 128
|
| 15 |
|
| 16 |
+
model = TinyTransformer(vocab_size=VOCAB_SIZE, max_len=MAX_LEN).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
ckpt_path = os.path.join(MODELS_DIR, "slm_qa_best.pt")
|
| 18 |
if os.path.exists(ckpt_path):
|
| 19 |
model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
|
| 20 |
model.eval()
|
| 21 |
|
| 22 |
+
def sample_next(logits, top_k=5):
|
| 23 |
+
"""Top-k sampling instead of greedy argmax."""
|
| 24 |
+
probs = torch.softmax(logits, dim=-1)
|
| 25 |
+
topk_probs, topk_idx = torch.topk(probs, k=top_k)
|
| 26 |
+
idx = torch.multinomial(topk_probs, 1)
|
| 27 |
+
return topk_idx[idx].item()
|
| 28 |
+
|
| 29 |
def generate_answer(question: str, max_new_tokens: int = 40) -> str:
|
| 30 |
q_ids = encode("q: " + question)
|
| 31 |
a_prefix = encode("a:")
|
| 32 |
+
tokens = wrap_bos_eos(q_ids + a_prefix)[:-1]
|
| 33 |
x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
| 34 |
|
| 35 |
with torch.no_grad():
|
| 36 |
for _ in range(max_new_tokens):
|
| 37 |
+
if x.size(1) >= MAX_LEN: break
|
| 38 |
+
logits = model(x)
|
| 39 |
+
next_id = sample_next(logits[:, -1, :])
|
| 40 |
+
if next_id == EOS: break
|
| 41 |
+
x = torch.cat([x, torch.tensor([[next_id]], device=DEVICE)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
gen_ids = x.squeeze(0).tolist()
|
| 44 |
+
prefix_len = 1 + len(q_ids) + len(a_prefix)
|
| 45 |
answer_ids = gen_ids[prefix_len:]
|
| 46 |
out = " ".join(itos[i] for i in answer_ids if i not in (PAD, BOS, EOS)).strip()
|
| 47 |
return out if out else "..."
|
| 48 |
|
| 49 |
@app.route("/")
|
| 50 |
+
def index(): return send_from_directory(STATIC_DIR, "index.html")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
@app.route("/api/moderate", methods=["POST"])
|
| 53 |
def moderate():
|
| 54 |
+
text = (request.json.get("text") or "").strip()
|
| 55 |
+
return jsonify({"banned": bool(BAN_REGEX.search(text))})
|
|
|
|
|
|
|
| 56 |
|
| 57 |
@app.route("/api/answer", methods=["POST"])
|
| 58 |
def answer():
|
| 59 |
+
question = (request.json.get("question") or "").strip()
|
| 60 |
+
if not question: return jsonify({"ok": False, "answer": "", "error": "Empty"}), 400
|
| 61 |
+
if BAN_REGEX.search(question): return jsonify({"ok": False, "answer": "", "error": "banned"}), 403
|
| 62 |
+
return jsonify({"ok": True, "answer": generate_answer(question)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
if __name__ == "__main__":
|
| 65 |
port = int(os.environ.get("PORT", "7860"))
|