mahesh1209 commited on
Commit
ddfa3e7
·
verified ·
1 Parent(s): 8685c87

Update app/app.py

Browse files
Files changed (1) hide show
  1. app/app.py +23 -59
app/app.py CHANGED
@@ -1,101 +1,65 @@
1
- import os
2
- import re
3
- import torch
4
  from flask import Flask, request, jsonify, send_from_directory
5
-
6
- # Import local SLM components
7
  from slm_qa import TinyTransformer, encode, wrap_bos_eos, itos, PAD, BOS, EOS
8
 
9
- # Resolve absolute paths
10
  BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
  STATIC_DIR = os.path.join(BASE_DIR, "static")
12
  MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
13
 
14
  app = Flask(__name__, static_folder=STATIC_DIR)
15
-
16
- # Moderation rule: ban exact word "sex" (case-insensitive)
17
  BAN_REGEX = re.compile(r"(?i)\bsex\b")
18
 
19
- # Model setup
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  VOCAB_SIZE = len(itos)
22
  MAX_LEN = 128
23
 
24
- model = TinyTransformer(
25
- vocab_size=VOCAB_SIZE,
26
- d_model=128,
27
- n_heads=4,
28
- n_layers=2,
29
- d_ff=256,
30
- dropout=0.1,
31
- max_len=MAX_LEN,
32
- ).to(DEVICE)
33
-
34
- # Load checkpoint if present
35
  ckpt_path = os.path.join(MODELS_DIR, "slm_qa_best.pt")
36
  if os.path.exists(ckpt_path):
37
  model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
38
  model.eval()
39
 
 
 
 
 
 
 
 
40
  def generate_answer(question: str, max_new_tokens: int = 40) -> str:
41
  q_ids = encode("q: " + question)
42
  a_prefix = encode("a:")
43
- tokens = wrap_bos_eos(q_ids + a_prefix)[:-1] # BOS + q + "a:" (no EOS)
44
  x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
45
 
46
  with torch.no_grad():
47
  for _ in range(max_new_tokens):
48
- if x.size(1) >= MAX_LEN:
49
- break
50
- logits = model(x) # [1, T, V]
51
- next_id = logits[:, -1, :].argmax(dim=-1).item()
52
- if next_id == EOS:
53
- break
54
- x = torch.cat(
55
- [x, torch.tensor([[next_id]], dtype=torch.long, device=DEVICE)],
56
- dim=1,
57
- )
58
 
59
  gen_ids = x.squeeze(0).tolist()
60
- prefix_len = 1 + len(q_ids) + len(a_prefix) # BOS + question + "a:"
61
  answer_ids = gen_ids[prefix_len:]
62
  out = " ".join(itos[i] for i in answer_ids if i not in (PAD, BOS, EOS)).strip()
63
  return out if out else "..."
64
 
65
  @app.route("/")
66
- def index():
67
- return send_from_directory(STATIC_DIR, "index.html")
68
-
69
- @app.route("/static/<path:filename>")
70
- def static_files(filename):
71
- return send_from_directory(STATIC_DIR, filename)
72
-
73
- @app.route("/health")
74
- def health():
75
- return jsonify({"ok": True})
76
 
77
  @app.route("/api/moderate", methods=["POST"])
78
  def moderate():
79
- data = request.get_json(force=True, silent=True) or {}
80
- text = (data.get("text") or "").strip()
81
- banned = bool(BAN_REGEX.search(text))
82
- return jsonify({"banned": banned})
83
 
84
  @app.route("/api/answer", methods=["POST"])
85
  def answer():
86
- data = request.get_json(force=True, silent=True) or {}
87
- question = (data.get("question") or "").strip()
88
- if not question:
89
- return jsonify({"ok": False, "answer": "", "error": "Empty question"}), 400
90
- if BAN_REGEX.search(question):
91
- return jsonify({"ok": False, "answer": "", "error": "banned"}), 403
92
-
93
- try:
94
- ans = generate_answer(question)
95
- return jsonify({"ok": True, "answer": ans})
96
- except Exception:
97
- # Avoid leaking stack traces in production
98
- return jsonify({"ok": False, "answer": "", "error": "server_error"}), 500
99
 
100
  if __name__ == "__main__":
101
  port = int(os.environ.get("PORT", "7860"))
 
1
+ import os, re, torch
 
 
2
  from flask import Flask, request, jsonify, send_from_directory
 
 
3
  from slm_qa import TinyTransformer, encode, wrap_bos_eos, itos, PAD, BOS, EOS
4
 
 
5
  BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
  STATIC_DIR = os.path.join(BASE_DIR, "static")
7
  MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
8
 
9
  app = Flask(__name__, static_folder=STATIC_DIR)
 
 
10
  BAN_REGEX = re.compile(r"(?i)\bsex\b")
11
 
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  VOCAB_SIZE = len(itos)
14
  MAX_LEN = 128
15
 
16
+ model = TinyTransformer(vocab_size=VOCAB_SIZE, max_len=MAX_LEN).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
17
  ckpt_path = os.path.join(MODELS_DIR, "slm_qa_best.pt")
18
  if os.path.exists(ckpt_path):
19
  model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
20
  model.eval()
21
 
22
+ def sample_next(logits, top_k=5):
23
+ """Top-k sampling instead of greedy argmax."""
24
+ probs = torch.softmax(logits, dim=-1)
25
+ topk_probs, topk_idx = torch.topk(probs, k=top_k)
26
+ idx = torch.multinomial(topk_probs, 1)
27
+ return topk_idx[idx].item()
28
+
29
  def generate_answer(question: str, max_new_tokens: int = 40) -> str:
30
  q_ids = encode("q: " + question)
31
  a_prefix = encode("a:")
32
+ tokens = wrap_bos_eos(q_ids + a_prefix)[:-1]
33
  x = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
34
 
35
  with torch.no_grad():
36
  for _ in range(max_new_tokens):
37
+ if x.size(1) >= MAX_LEN: break
38
+ logits = model(x)
39
+ next_id = sample_next(logits[:, -1, :])
40
+ if next_id == EOS: break
41
+ x = torch.cat([x, torch.tensor([[next_id]], device=DEVICE)], dim=1)
 
 
 
 
 
42
 
43
  gen_ids = x.squeeze(0).tolist()
44
+ prefix_len = 1 + len(q_ids) + len(a_prefix)
45
  answer_ids = gen_ids[prefix_len:]
46
  out = " ".join(itos[i] for i in answer_ids if i not in (PAD, BOS, EOS)).strip()
47
  return out if out else "..."
48
 
49
  @app.route("/")
50
+ def index(): return send_from_directory(STATIC_DIR, "index.html")
 
 
 
 
 
 
 
 
 
51
 
52
  @app.route("/api/moderate", methods=["POST"])
53
  def moderate():
54
+ text = (request.json.get("text") or "").strip()
55
+ return jsonify({"banned": bool(BAN_REGEX.search(text))})
 
 
56
 
57
  @app.route("/api/answer", methods=["POST"])
58
  def answer():
59
+ question = (request.json.get("question") or "").strip()
60
+ if not question: return jsonify({"ok": False, "answer": "", "error": "Empty"}), 400
61
+ if BAN_REGEX.search(question): return jsonify({"ok": False, "answer": "", "error": "banned"}), 403
62
+ return jsonify({"ok": True, "answer": generate_answer(question)})
 
 
 
 
 
 
 
 
 
63
 
64
  if __name__ == "__main__":
65
  port = int(os.environ.get("PORT", "7860"))