|
|
|
|
|
|
|
|
|
|
|
import math, random, time, re, torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEED = 42 |
|
|
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) |
|
|
DEVICE = torch.device("cpu") |
|
|
MAX_LEN = 256 |
|
|
EMBED_DIM = 128 |
|
|
HIDDEN_DIM = 256 |
|
|
NUM_LAYERS = 2 |
|
|
DROPOUT = 0.1 |
|
|
LR = 2e-3 |
|
|
|
|
|
EPOCHS = 3 |
|
|
BATCH_SIZE = 16 |
|
|
TRAIN_STEPS_PER_EPOCH = 100 |
|
|
TEMPERATURE = 0.7 |
|
|
TOP_K = 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE = [ |
|
|
("Q: What is 13 + 27?\nA:", "Step: 13 + 27 = 40\nFinal: 40\n"), |
|
|
("Q: If x = 5, what is 3x + 2?\nA:", "Step: 3*5 = 15\nStep: 15 + 2 = 17\nFinal: 17\n"), |
|
|
("Q: What is 12 * 4?\nA:", "Step: 12 * 4 = 48\nFinal: 48\n"), |
|
|
] |
|
|
|
|
|
def gen_add(): |
|
|
a,b = random.randint(10,99), random.randint(10,99) |
|
|
return (f"Q: What is {a} + {b}?\nA:", f"Step: {a} + {b} = {a+b}\nFinal: {a+b}\n") |
|
|
|
|
|
def gen_sub(): |
|
|
a = random.randint(20,99); b = random.randint(1,a-1) |
|
|
return (f"Q: What is {a} - {b}?\nA:", f"Step: {a} - {b} = {a-b}\nFinal: {a-b}\n") |
|
|
|
|
|
def gen_mul(): |
|
|
a,b = random.randint(2,12), random.randint(2,12) |
|
|
return (f"Q: What is {a} * {b}?\nA:", f"Step: {a} * {b} = {a*b}\nFinal: {a*b}\n") |
|
|
|
|
|
def build_dataset(extra=800): |
|
|
pairs = BASE[:] |
|
|
for _ in range(extra): |
|
|
pairs.append(random.choice([gen_add, gen_sub, gen_mul])()) |
|
|
return [inp + tgt for inp,tgt in pairs] |
|
|
|
|
|
SAMPLES = build_dataset() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_vocab(texts): |
|
|
chars=set() |
|
|
for t in texts: chars.update(list(t)) |
|
|
vocab=['<PAD>','<BOS>','<EOS>']+sorted(chars) |
|
|
stoi={c:i for i,c in enumerate(vocab)} |
|
|
itos={i:c for c,i in stoi.items()} |
|
|
return vocab,stoi,itos |
|
|
|
|
|
VOCAB,STOI,ITOS=build_vocab(SAMPLES) |
|
|
PAD_ID,BOS_ID,EOS_ID=STOI['<PAD>'],STOI['<BOS>'],STOI['<EOS>'] |
|
|
VOCAB_SIZE=len(VOCAB) |
|
|
|
|
|
def encode(s,max_len=MAX_LEN,add_bos=True,add_eos=True): |
|
|
ids=[BOS_ID] if add_bos else [] |
|
|
unk = STOI.get('?') |
|
|
ids += [STOI.get(ch, unk) for ch in s][:max_len-2 if add_eos else max_len-1] |
|
|
if add_eos: ids.append(EOS_ID) |
|
|
ids = ids[:max_len] |
|
|
ids += [PAD_ID]*(max_len-len(ids)) |
|
|
return torch.tensor(ids,dtype=torch.long) |
|
|
|
|
|
def decode(ids): |
|
|
out=[] |
|
|
for i in ids: |
|
|
ii=int(i) |
|
|
if ii in (PAD_ID,BOS_ID,EOS_ID): continue |
|
|
out.append(ITOS[ii]) |
|
|
return ''.join(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CoTDataset(torch.utils.data.Dataset): |
|
|
def __init__(self,samples): self.samples=samples |
|
|
def __len__(self): return len(self.samples) |
|
|
def __getitem__(self,idx): |
|
|
x=encode(self.samples[idx],MAX_LEN) |
|
|
return x[:-1],x[1:] |
|
|
|
|
|
train_ds=CoTDataset(SAMPLES) |
|
|
train_loader=torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True,drop_last=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TinyGRULM(nn.Module): |
|
|
def __init__(self,vocab_size,embed_dim,hidden_dim,num_layers,dropout): |
|
|
super().__init__() |
|
|
self.embed=nn.Embedding(vocab_size,embed_dim) |
|
|
self.gru=nn.GRU(embed_dim,hidden_dim,num_layers=num_layers,batch_first=True,dropout=dropout) |
|
|
self.norm=nn.LayerNorm(hidden_dim) |
|
|
self.head=nn.Linear(hidden_dim,vocab_size) |
|
|
def forward(self,x,h=None): |
|
|
emb=self.embed(x); out,h=self.gru(emb,h) |
|
|
out=self.norm(out); logits=self.head(out) |
|
|
return logits,h |
|
|
|
|
|
model=TinyGRULM(VOCAB_SIZE,EMBED_DIM,HIDDEN_DIM,NUM_LAYERS,DROPOUT).to(DEVICE) |
|
|
opt=torch.optim.AdamW(model.parameters(),lr=LR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(EPOCHS): |
|
|
model.train() |
|
|
for step,(x,y) in enumerate(train_loader): |
|
|
if step>=TRAIN_STEPS_PER_EPOCH: break |
|
|
x,y=x.to(DEVICE),y.to(DEVICE) |
|
|
logits,_=model(x) |
|
|
loss=F.cross_entropy(logits.view(-1,VOCAB_SIZE),y.view(-1),ignore_index=PAD_ID) |
|
|
opt.zero_grad(); loss.backward(); opt.step() |
|
|
print(f"Epoch {epoch+1} done, loss={loss.item():.3f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_next(probs,top_k=8,temp=0.7): |
|
|
logits=torch.log(probs+1e-9)/max(temp,1e-6) |
|
|
if top_k>0: |
|
|
k=min(top_k,logits.shape[-1]) |
|
|
v,idx=torch.topk(logits,k) |
|
|
mask=torch.full_like(logits,float('-inf')) |
|
|
mask.scatter_(0,idx,v) |
|
|
logits=mask |
|
|
return torch.multinomial(F.softmax(logits,dim=-1),1).item() |
|
|
|
|
|
def generate_raw(prompt,max_new_tokens=160): |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
x=encode(prompt,MAX_LEN,add_bos=True,add_eos=False).unsqueeze(0).to(DEVICE) |
|
|
h=None; out_ids=x.squeeze(0).tolist() |
|
|
for _ in range(max_new_tokens): |
|
|
logits,h=model(x,h) |
|
|
probs=F.softmax(logits[0,-1],dim=-1) |
|
|
nid=sample_next(probs,TOP_K,TEMPERATURE) |
|
|
out_ids.append(nid) |
|
|
if nid==EOS_ID: break |
|
|
x=torch.tensor([out_ids[-(MAX_LEN-1):]],dtype=torch.long).to(DEVICE) |
|
|
gen=decode(out_ids) |
|
|
if "A:" in gen: |
|
|
gen = gen.split("A:",1)[1].strip() |
|
|
return gen |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_step_final(q): |
|
|
m=re.search(r'(-?\d+)\s*([+\-*])\s*(-?\d+)',q) |
|
|
if not m: return None |
|
|
a,op,b=int(m.group(1)),m.group(2),int(m.group(3)) |
|
|
if op=='+': |
|
|
return f"Step: {a} + {b} = {a+b}", f"Final: {a+b}" |
|
|
if op=='-': |
|
|
return f"Step: {a} - {b} = {a-b}", f"Final: {a-b}" |
|
|
if op=='*': |
|
|
return f"Step: {a} * {b} = {a*b}", f"Final: {a*b}" |
|
|
return None |
|
|
|
|
|
def normalize_generation(gen): |
|
|
lines=[ln.strip() for ln in gen.splitlines() if ln.strip()] |
|
|
return lines |
|
|
|
|
|
def verify_and_format(question, gen): |
|
|
sf = compute_step_final(question) |
|
|
lines = normalize_generation(gen) |
|
|
step_lines = [ln for ln in lines if ln.lower().startswith("step:")] |
|
|
final_lines = [ln for ln in lines if ln.lower().startswith("final:")] |
|
|
|
|
|
if sf is None: |
|
|
out=[] |
|
|
if step_lines: out.append(step_lines[0]) |
|
|
if final_lines: out.append(final_lines[0]) |
|
|
if not out: out = lines[:4] |
|
|
return "\n".join(out) |
|
|
|
|
|
canonical_step, canonical_final = sf |
|
|
return "\n".join([canonical_step, canonical_final]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate(n=20): |
|
|
tests=[] |
|
|
for _ in range(n): |
|
|
kind=random.choice([gen_add,gen_sub,gen_mul]) |
|
|
inp,tgt=kind() |
|
|
q = inp[:-2] |
|
|
tests.append(q) |
|
|
ok=0 |
|
|
for q in tests: |
|
|
ans = verify_and_format(q, generate_raw(q+"\nA:",160)) |
|
|
m=re.search(r'Final:\s*(-?\d+)\s*$',ans) |
|
|
target = compute_step_final(q) |
|
|
if m and target and int(m.group(1))==int(target[1].split(":")[1].strip()): |
|
|
ok+=1 |
|
|
return ok, n |
|
|
|
|
|
acc, total = validate(30) |
|
|
print(f"Verifier accuracy on random held-out arithmetic: {acc}/{total} correct") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def answer_fn(question): |
|
|
raw=generate_raw(question+"\nA:",max_new_tokens=160) |
|
|
return verify_and_format(question, raw) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo=gr.Interface( |
|
|
fn=answer_fn, |
|
|
inputs=gr.Textbox(lines=2, placeholder="Ask a math question (e.g., What is 12 * 4?)"), |
|
|
outputs="text", |
|
|
title="Tiny CoT SLM (CPU) with canonical Steps and Final", |
|
|
description="Guarantees accurate 'Step:' and 'Final:' for +, -, * via a safe verifier. Model adds natural CoT flavor." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|