mahesh1209 commited on
Commit
db06e07
·
verified ·
1 Parent(s): 57be96b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +467 -0
app.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Single-page Gradio app for Hugging Face Spaces
3
+ # - Trains MiniGPT and classifier on startup (tiny datasets, short epochs by default)
4
+ # - Large, centered UI with three panels:
5
+ # 1) Instruction -> Response
6
+ # 2) Sentiment Classification
7
+ # 3) Next word + dataset sentence completion (prefix of two words)
8
+ # - Instant input moderation: banned words trigger immediate error and block
9
+ # - Greedy decoding for stable minimal outputs
10
+
11
+ import math, re, os, torch, torch.nn as nn
12
+ from torch.utils.data import Dataset, DataLoader
13
+ import gradio as gr
14
+
15
+ # ----------------------------
16
+ # 1) Data preparation
17
+ # ----------------------------
18
+ lm_corpus = [
19
+ "the cat sits on the mat",
20
+ "the dog chases the ball",
21
+ "a small model can learn patterns",
22
+ "language models predict next tokens",
23
+ "transformers use attention mechanism",
24
+ "training on tiny data is limited",
25
+ "we build a model from scratch",
26
+ "this is a minimal example",
27
+ "positional embeddings encode order",
28
+ "causal masking prevents peeking ahead",
29
+ ]
30
+
31
+ cls_data = [
32
+ ("this is bad", 0),
33
+ ("i dislike this", 0),
34
+ ("terrible and awful", 0),
35
+ ("this is good", 1),
36
+ ("i like this", 1),
37
+ ("wonderful and great", 1),
38
+ ]
39
+
40
+ inst_data_base = [
41
+ ("<INSTR> write a short greeting <ENDINSTR>", "<RESP> hello! <ENDRESP>"),
42
+ ("<INSTR> answer briefly what is a cat <ENDINSTR>", "<RESP> a small animal. <ENDRESP>"),
43
+ ("<INSTR> continue the sun is <ENDINSTR>", "<RESP> bright. <ENDRESP>"),
44
+ ]
45
+ inst_data = inst_data_base * 64 # stabilize tiny-data learning
46
+
47
+ # ----------------------------
48
+ # Tokenization (word-level)
49
+ # ----------------------------
50
+ def normalize_text(s):
51
+ s = s.lower().strip()
52
+ s = re.sub(r'([.!?,:;])', r' \1 ', s)
53
+ s = re.sub(r'\s+', ' ', s)
54
+ return s
55
+
56
+ def build_vocab(texts):
57
+ tokens = set()
58
+ specials = ["<pad>", "<bos>", "<eos>"]
59
+ for t in texts:
60
+ t = normalize_text(t)
61
+ for tok in t.split():
62
+ tokens.add(tok)
63
+ vocab = specials + sorted(list(tokens))
64
+ stoi = {s: i for i, s in enumerate(vocab)}
65
+ itos = {i: s for s, i in stoi.items()}
66
+ return vocab, stoi, itos
67
+
68
+ all_texts = lm_corpus + [x for x,_ in cls_data] + [a for a,_ in inst_data_base] + [b for _,b in inst_data_base]
69
+ vocab, stoi, itos = build_vocab(all_texts)
70
+ PAD, BOS, EOS = stoi["<pad>"], stoi["<bos>"], stoi["<eos>"]
71
+ vocab_size = len(vocab)
72
+
73
+ def encode(text, max_len=None, add_special=True):
74
+ text = normalize_text(text)
75
+ toks = text.split()
76
+ ids = ([BOS] if add_special else []) + [stoi.get(tok, PAD) for tok in toks] + ([EOS] if add_special else [])
77
+ if max_len is not None:
78
+ ids = ids[:max_len]
79
+ if len(ids) < max_len:
80
+ ids = ids + [PAD] * (max_len - len(ids))
81
+ return torch.tensor(ids, dtype=torch.long)
82
+
83
+ def decode(ids):
84
+ toks = [itos.get(i, "") for i in ids]
85
+ toks = [t for t in toks if t not in ("<pad>", "<bos>", "<eos>")]
86
+ out = " ".join(toks)
87
+ out = re.sub(r'\s+([.!?,:;])', r'\1', out)
88
+ return out.strip()
89
+
90
+ # ----------------------------
91
+ # Datasets
92
+ # ----------------------------
93
+ class LMPretrainDataset(Dataset):
94
+ def __init__(self, texts, block_size=64):
95
+ self.samples = []
96
+ for t in texts:
97
+ ids = encode(t, max_len=block_size, add_special=True)
98
+ self.samples.append((ids[:-1], ids[1:]))
99
+ def __len__(self): return len(self.samples)
100
+ def __getitem__(self, idx): return self.samples[idx]
101
+
102
+ class ClassificationDataset(Dataset):
103
+ def __init__(self, pairs, block_size=64):
104
+ self.samples = []
105
+ for text, label in pairs:
106
+ ids = encode(text, max_len=block_size, add_special=True)
107
+ self.samples.append((ids, torch.tensor(label, dtype=torch.long)))
108
+ def __len__(self): return len(self.samples)
109
+ def __getitem__(self, idx): return self.samples[idx]
110
+
111
+ class InstructionDataset(Dataset):
112
+ def __init__(self, pairs, block_size=64):
113
+ self.samples = []
114
+ for instr, resp in pairs:
115
+ instr_ids = encode(instr, add_special=False).tolist()
116
+ resp_ids = encode(resp, add_special=False).tolist()
117
+ seq = [BOS] + instr_ids + [EOS] + [BOS] + resp_ids + [EOS]
118
+ seq = seq[:block_size]
119
+ if len(seq) < block_size: seq += [PAD] * (block_size - len(seq))
120
+ ids = torch.tensor(seq, dtype=torch.long)
121
+ self.samples.append((ids[:-1], ids[1:]))
122
+ def __len__(self): return len(self.samples)
123
+ def __getitem__(self, idx): return self.samples[idx]
124
+
125
+ # ----------------------------
126
+ # 2) Model architecture (GPT-style)
127
+ # ----------------------------
128
+ class CausalSelfAttention(nn.Module):
129
+ def __init__(self, n_embed, n_head, dropout=0.1):
130
+ super().__init__()
131
+ assert n_embed % n_head == 0
132
+ self.n_head = n_head
133
+ self.head_dim = n_embed // n_head
134
+ self.qkv = nn.Linear(n_embed, 3 * n_embed)
135
+ self.proj = nn.Linear(n_embed, n_embed)
136
+ self.attn_drop = nn.Dropout(dropout)
137
+ self.resid_drop = nn.Dropout(dropout)
138
+ self.register_buffer("mask", None)
139
+
140
+ def forward(self, x):
141
+ B, T, C = x.size()
142
+ qkv = self.qkv(x)
143
+ q, k, v = qkv.chunk(3, dim=-1)
144
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
145
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
146
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
147
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
148
+ if (self.mask is None) or (self.mask.size(-1) != T):
149
+ self.mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
150
+ att = att.masked_fill(self.mask == 0, float('-inf'))
151
+ att = torch.softmax(att, dim=-1)
152
+ att = self.attn_drop(att)
153
+ y = att @ v
154
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
155
+ y = self.proj(y)
156
+ y = self.resid_drop(y)
157
+ return y
158
+
159
+ class TransformerBlock(nn.Module):
160
+ def __init__(self, n_embed, n_head, mlp_mult=4, dropout=0.1):
161
+ super().__init__()
162
+ self.ln1 = nn.LayerNorm(n_embed)
163
+ self.attn = CausalSelfAttention(n_embed, n_head, dropout)
164
+ self.ln2 = nn.LayerNorm(n_embed)
165
+ self.mlp = nn.Sequential(
166
+ nn.Linear(n_embed, mlp_mult * n_embed),
167
+ nn.GELU(),
168
+ nn.Dropout(dropout),
169
+ nn.Linear(mlp_mult * n_embed, n_embed),
170
+ nn.Dropout(dropout),
171
+ )
172
+ def forward(self, x):
173
+ x = x + self.attn(self.ln1(x))
174
+ x = x + self.mlp(self.ln2(x))
175
+ return x
176
+
177
+ class MiniGPT(nn.Module):
178
+ def __init__(self, vocab_size, n_embed=192, n_head=6, n_layer=4, block_size=64, dropout=0.1):
179
+ super().__init__()
180
+ self.block_size = block_size
181
+ self.tok_emb = nn.Embedding(vocab_size, n_embed)
182
+ self.pos_emb = nn.Embedding(block_size, n_embed)
183
+ self.drop = nn.Dropout(dropout)
184
+ self.blocks = nn.ModuleList([TransformerBlock(n_embed, n_head, 4, dropout) for _ in range(n_layer)])
185
+ self.ln_f = nn.LayerNorm(n_embed)
186
+ self.head = nn.Linear(n_embed, vocab_size, bias=False)
187
+ self.apply(self._init_weights)
188
+ def _init_weights(self, m):
189
+ if isinstance(m, (nn.Linear, nn.Embedding)):
190
+ nn.init.normal_(m.weight, 0.0, 0.02)
191
+ if isinstance(m, nn.Linear) and m.bias is not None:
192
+ nn.init.zeros_(m.bias)
193
+ def forward(self, idx):
194
+ B, T = idx.size()
195
+ tok = self.tok_emb(idx)
196
+ pos = self.pos_emb(torch.arange(T, device=idx.device))
197
+ x = self.drop(tok + pos)
198
+ for blk in self.blocks: x = blk(x)
199
+ x = self.ln_f(x)
200
+ return self.head(x)
201
+ @torch.no_grad()
202
+ def generate_greedy(self, idx, max_new_tokens=20):
203
+ for _ in range(max_new_tokens):
204
+ idx_cond = idx[:, -self.block_size:]
205
+ logits = self(idx_cond)
206
+ next_id = logits[:, -1, :].argmax(dim=-1, keepdim=True)
207
+ idx = torch.cat([idx, next_id], dim=1)
208
+ if next_id.item() == EOS:
209
+ break
210
+ return idx
211
+
212
+ # ----------------------------
213
+ # 3) Training pipeline
214
+ # ----------------------------
215
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
216
+ block_size = 64
217
+
218
+ lm_dl = DataLoader(LMPretrainDataset(lm_corpus, block_size), batch_size=16, shuffle=True)
219
+ cls_dl = DataLoader(ClassificationDataset(cls_data, block_size), batch_size=6, shuffle=True)
220
+ inst_dl = DataLoader(InstructionDataset(inst_data, block_size), batch_size=32, shuffle=True)
221
+
222
+ model = MiniGPT(vocab_size=vocab_size, n_embed=192, n_head=6, n_layer=4, block_size=block_size, dropout=0.1).to(device)
223
+
224
+ def pretrain(model, dataloader, epochs=8, lr=3e-4, grad_clip=1.0):
225
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9,0.95), weight_decay=0.01)
226
+ loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)
227
+ model.train()
228
+ for _ in range(epochs):
229
+ for inp, tgt in dataloader:
230
+ inp, tgt = inp.to(device), tgt.to(device)
231
+ logits = model(inp)
232
+ B, T, V = logits.size()
233
+ loss = loss_fn(logits.view(B*T, V), tgt.view(B*T))
234
+ opt.zero_grad(set_to_none=True)
235
+ loss.backward()
236
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
237
+ opt.step()
238
+
239
+ class ClassificationHead(nn.Module):
240
+ def __init__(self, backbone: MiniGPT, n_classes=2, freeze_backbone=False):
241
+ super().__init__()
242
+ self.backbone = backbone
243
+ if freeze_backbone:
244
+ for p in self.backbone.parameters(): p.requires_grad = False
245
+ n_embed = backbone.head.in_features
246
+ self.classifier = nn.Sequential(nn.LayerNorm(n_embed), nn.Linear(n_embed, n_classes))
247
+ def forward(self, idx):
248
+ B, T = idx.size()
249
+ tok = self.backbone.tok_emb(idx)
250
+ pos = self.backbone.pos_emb(torch.arange(T, device=idx.device))
251
+ x = self.backbone.drop(tok + pos)
252
+ for blk in self.backbone.blocks: x = blk(x)
253
+ x = self.backbone.ln_f(x)
254
+ eos_mask = (idx == EOS)
255
+ last_idx = torch.where(
256
+ eos_mask.any(dim=1),
257
+ eos_mask.float().argmax(dim=1),
258
+ torch.full((B,), T-1, device=idx.device)
259
+ )
260
+ pooled = x[torch.arange(B, device=idx.device), last_idx]
261
+ return self.classifier(pooled)
262
+
263
+ clf = ClassificationHead(model, n_classes=2, freeze_backbone=False).to(device)
264
+
265
+ def finetune_classification(clf, dataloader, epochs=6, lr=8e-4):
266
+ opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, clf.parameters()), lr=lr)
267
+ loss_fn = nn.CrossEntropyLoss()
268
+ clf.train()
269
+ for _ in range(epochs):
270
+ for x,y in dataloader:
271
+ x,y = x.to(device), y.to(device)
272
+ logits = clf(x)
273
+ loss = loss_fn(logits, y)
274
+ opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
275
+
276
+ def finetune_instruction(model, dataloader, epochs=50, lr=1.5e-4, grad_clip=1.0):
277
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9,0.95), weight_decay=0.01)
278
+ loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)
279
+ model.train()
280
+ for _ in range(epochs):
281
+ for inp, tgt in dataloader:
282
+ inp, tgt = inp.to(device), tgt.to(device)
283
+ logits = model(inp)
284
+ B,T,V = logits.size()
285
+ loss = loss_fn(logits.view(B*T, V), tgt.view(B*T))
286
+ opt.zero_grad(set_to_none=True)
287
+ loss.backward()
288
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
289
+ opt.step()
290
+
291
+ # ----------------------------
292
+ # 4) Inference helpers
293
+ # ----------------------------
294
+ @torch.no_grad()
295
+ def classify_text(text):
296
+ ids = encode(text, max_len=block_size, add_special=True).unsqueeze(0).to(device)
297
+ logits = clf(ids)
298
+ pred = logits.argmax(dim=-1).item()
299
+ return "positive" if pred==1 else "negative"
300
+
301
+ @torch.no_grad()
302
+ def generate_response(instruction, max_new_tokens=12):
303
+ instr = f"<INSTR> {instruction} <ENDINSTR>"
304
+ resp_start = "<RESP>"
305
+ prefix_ids = encode(instr, add_special=False).tolist()
306
+ resp_start_ids = encode(resp_start, add_special=False).tolist()
307
+ seq = [BOS] + prefix_ids + [EOS] + resp_start_ids
308
+ idx = torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0)
309
+ out = model.generate_greedy(idx, max_new_tokens=max_new_tokens)
310
+ gen = out[0].tolist()
311
+ toks = [itos[i] for i in gen]
312
+ try:
313
+ resp_pos = toks.index("<resp>")
314
+ except ValueError:
315
+ resp_pos = len(toks)-1
316
+ resp_toks = toks[resp_pos+1:]
317
+ if "<endresp>" in resp_toks:
318
+ end_idx = resp_toks.index("<endresp>")
319
+ resp_toks = resp_toks[:end_idx]
320
+ elif "<eos>" in resp_toks:
321
+ end_idx = resp_toks.index("<eos>")
322
+ resp_toks = resp_toks[:end_idx]
323
+ text = " ".join(resp_toks)
324
+ text = re.sub(r'\s+([.!?,:;])', r'\1', text).strip()
325
+ return text
326
+
327
+ # --- Next word + dataset sentence completion ---
328
+ @torch.no_grad()
329
+ def predict_next_word_and_complete(prefix_two_words, max_new_tokens=16):
330
+ # Normalize and validate
331
+ s = normalize_text(prefix_two_words)
332
+ toks = s.split()
333
+ if len(toks) < 2:
334
+ return "(need at least two words)", "(no match)", "(no generation)"
335
+ # Moderation handled separately at UI entry
336
+
337
+ # Next-word prediction via LM
338
+ ids = encode(" ".join(toks), add_special=True).unsqueeze(0).to(device)
339
+ logits = model(ids)
340
+ next_id = logits[:, -1, :].argmax(dim=-1).item()
341
+ next_word = itos.get(next_id, "")
342
+
343
+ # Dataset sentence completion: exact prefix match
344
+ prefix = " ".join(toks[:2]) # strictly first two words
345
+ matches = [sent for sent in lm_corpus if normalize_text(sent).startswith(prefix + " ")]
346
+ matched = "; ".join(matches) if matches else "(no exact dataset sentence starts with those two words)"
347
+
348
+ # Fallback generation to complete a sentence-like output
349
+ gen_ids = model.generate_greedy(ids, max_new_tokens=max_new_tokens)
350
+ gen_text = decode(gen_ids[0].tolist())
351
+
352
+ return next_word, matched, gen_text
353
+
354
+ # ----------------------------
355
+ # 5) Moderation (instant lockout)
356
+ # ----------------------------
357
+ BANNED = {"hate", "kill", "self-harm", "suicide", "violence"} # extend as needed
358
+
359
+ def check_banned(s: str):
360
+ s_norm = normalize_text(s)
361
+ toks = set(s_norm.split())
362
+ bad = toks.intersection(BANNED)
363
+ if bad:
364
+ raise gr.Error(f"Input contains prohibited words: {', '.join(sorted(bad))}. Submission blocked.")
365
+
366
+ # ----------------------------
367
+ # 6) Train-on-start (short epochs by default)
368
+ # Use env FAST_TRAIN=1 on Spaces for snappy startup
369
+ # ----------------------------
370
+ FAST = os.getenv("FAST_TRAIN", "1") == "1"
371
+ PRE_EPOCHS = 2 if FAST else 8
372
+ CLS_EPOCHS = 2 if FAST else 6
373
+ INST_EPOCHS = 6 if FAST else 50
374
+
375
+ def bootstrap():
376
+ pretrain(model, lm_dl, epochs=PRE_EPOCHS, lr=3e-4)
377
+ finetune_classification(clf, cls_dl, epochs=CLS_EPOCHS, lr=8e-4)
378
+ finetune_instruction(model, inst_dl, epochs=INST_EPOCHS, lr=1.5e-4)
379
+
380
+ bootstrap()
381
+
382
+ # ----------------------------
383
+ # 7) Gradio UI (large, centered)
384
+ # ----------------------------
385
+ def ui_generate(instruction, max_tokens):
386
+ check_banned(instruction)
387
+ resp = generate_response(instruction, max_new_tokens=max_tokens)
388
+ return resp if resp.strip() else "(no response)"
389
+
390
+ def ui_classify(text):
391
+ check_banned(text)
392
+ return classify_text(text)
393
+
394
+ def ui_next_word(prefix_two_words, max_tokens):
395
+ check_banned(prefix_two_words)
396
+ next_word, matched, gen_text = predict_next_word_and_complete(prefix_two_words, max_new_tokens=max_tokens)
397
+ return next_word, matched, gen_text
398
+
399
+ with gr.Blocks(title="Minimal GPT-style LLM (word-level, greedy)") as demo:
400
+ gr.HTML(
401
+ """
402
+ <div style="text-align:center; max-width: 880px; margin:auto;">
403
+ <h1 style="font-size: 32px; margin-bottom: 10px;">Minimal GPT-style LLM</h1>
404
+ <p style="font-size: 16px;">
405
+ Word-level tokenizer • Tiny transformer • Greedy decoding • Instruction fine-tuning • Sentiment classification • Next-word prediction
406
+ </p>
407
+ </div>
408
+ """
409
+ )
410
+ with gr.Row():
411
+ with gr.Column(scale=1):
412
+ gr.Markdown("### Instruction to response")
413
+ instr = gr.Textbox(
414
+ label="Instruction",
415
+ placeholder="e.g., write a short greeting",
416
+ lines=2,
417
+ elem_id="instr_box"
418
+ )
419
+ max_toks = gr.Slider(4, 32, value=12, step=1, label="Max new tokens")
420
+ gen_btn = gr.Button("Generate response", variant="primary", elem_id="gen_btn")
421
+ resp = gr.Textbox(label="Model response", lines=4, interactive=False)
422
+ gen_btn.click(fn=ui_generate, inputs=[instr, max_toks], outputs=resp)
423
+
424
+ with gr.Column(scale=1):
425
+ gr.Markdown("### Sentiment classification")
426
+ cls_in = gr.Textbox(
427
+ label="Text",
428
+ placeholder="e.g., i like this",
429
+ lines=2,
430
+ elem_id="cls_box"
431
+ )
432
+ cls_btn = gr.Button("Classify sentiment", variant="primary", elem_id="cls_btn")
433
+ cls_out = gr.Textbox(label="Prediction", lines=1, interactive=False)
434
+ cls_btn.click(fn=ui_classify, inputs=cls_in, outputs=cls_out)
435
+
436
+ with gr.Row():
437
+ with gr.Column(scale=2):
438
+ gr.Markdown("### Next word + dataset sentence completion")
439
+ two_words = gr.Textbox(
440
+ label="Enter at least two words (prefix)",
441
+ placeholder="e.g., the cat",
442
+ lines=1,
443
+ elem_id="nw_box"
444
+ )
445
+ max_toks_nw = gr.Slider(4, 32, value=16, step=1, label="Max new tokens for generation")
446
+ nw_btn = gr.Button("Predict next word & complete", variant="primary", elem_id="nw_btn")
447
+ next_word_out = gr.Textbox(label="Next word (LM greedy)", lines=1, interactive=False)
448
+ matched_out = gr.Textbox(label="Dataset sentence match (exact prefix)", lines=2, interactive=False)
449
+ gen_out = gr.Textbox(label="Generated completion (fallback)", lines=3, interactive=False)
450
+ nw_btn.click(fn=ui_next_word, inputs=[two_words, max_toks_nw], outputs=[next_word_out, matched_out, gen_out])
451
+
452
+ gr.HTML(
453
+ """
454
+ <style>
455
+ #instr_box textarea, #cls_box textarea, #nw_box textarea {
456
+ font-size: 18px; text-align: center;
457
+ }
458
+ #gen_btn, #cls_btn, #nw_btn {
459
+ font-size: 18px; width: 100%; height: 52px;
460
+ }
461
+ .gradio-container { max-width: 980px !important; margin: auto !important; }
462
+ </style>
463
+ """
464
+ )
465
+
466
+ if __name__ == "__main__":
467
+ demo.launch()