# Tiny CoT SLM + Canonical Verifier (single-page, CPU-only, Gradio UI) # Hugging Face Spaces-ready (fast startup on CPU) import math, random, time, re, torch import torch.nn as nn import torch.nn.functional as F import numpy as np import gradio as gr # ----------------------------- # Config # ----------------------------- SEED = 42 random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) DEVICE = torch.device("cpu") MAX_LEN = 256 EMBED_DIM = 128 HIDDEN_DIM = 256 NUM_LAYERS = 2 DROPOUT = 0.1 LR = 2e-3 # Reduced for faster Space startup EPOCHS = 3 BATCH_SIZE = 16 TRAIN_STEPS_PER_EPOCH = 100 TEMPERATURE = 0.7 TOP_K = 8 # ----------------------------- # Synthetic CoT dataset (Step/Final format) # ----------------------------- BASE = [ ("Q: What is 13 + 27?\nA:", "Step: 13 + 27 = 40\nFinal: 40\n"), ("Q: If x = 5, what is 3x + 2?\nA:", "Step: 3*5 = 15\nStep: 15 + 2 = 17\nFinal: 17\n"), ("Q: What is 12 * 4?\nA:", "Step: 12 * 4 = 48\nFinal: 48\n"), ] def gen_add(): a,b = random.randint(10,99), random.randint(10,99) return (f"Q: What is {a} + {b}?\nA:", f"Step: {a} + {b} = {a+b}\nFinal: {a+b}\n") def gen_sub(): a = random.randint(20,99); b = random.randint(1,a-1) return (f"Q: What is {a} - {b}?\nA:", f"Step: {a} - {b} = {a-b}\nFinal: {a-b}\n") def gen_mul(): a,b = random.randint(2,12), random.randint(2,12) return (f"Q: What is {a} * {b}?\nA:", f"Step: {a} * {b} = {a*b}\nFinal: {a*b}\n") def build_dataset(extra=800): pairs = BASE[:] for _ in range(extra): pairs.append(random.choice([gen_add, gen_sub, gen_mul])()) return [inp + tgt for inp,tgt in pairs] SAMPLES = build_dataset() # ----------------------------- # Char-level tokenizer (direct, no external libs) # ----------------------------- def build_vocab(texts): chars=set() for t in texts: chars.update(list(t)) vocab=['','','']+sorted(chars) stoi={c:i for i,c in enumerate(vocab)} itos={i:c for c,i in stoi.items()} return vocab,stoi,itos VOCAB,STOI,ITOS=build_vocab(SAMPLES) PAD_ID,BOS_ID,EOS_ID=STOI[''],STOI[''],STOI[''] VOCAB_SIZE=len(VOCAB) def encode(s,max_len=MAX_LEN,add_bos=True,add_eos=True): ids=[BOS_ID] if add_bos else [] unk = STOI.get('?') ids += [STOI.get(ch, unk) for ch in s][:max_len-2 if add_eos else max_len-1] if add_eos: ids.append(EOS_ID) ids = ids[:max_len] ids += [PAD_ID]*(max_len-len(ids)) return torch.tensor(ids,dtype=torch.long) def decode(ids): out=[] for i in ids: ii=int(i) if ii in (PAD_ID,BOS_ID,EOS_ID): continue out.append(ITOS[ii]) return ''.join(out) # ----------------------------- # Dataset # ----------------------------- class CoTDataset(torch.utils.data.Dataset): def __init__(self,samples): self.samples=samples def __len__(self): return len(self.samples) def __getitem__(self,idx): x=encode(self.samples[idx],MAX_LEN) return x[:-1],x[1:] train_ds=CoTDataset(SAMPLES) train_loader=torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True,drop_last=True) # ----------------------------- # Model (tiny GRU LM + LayerNorm head) # ----------------------------- class TinyGRULM(nn.Module): def __init__(self,vocab_size,embed_dim,hidden_dim,num_layers,dropout): super().__init__() self.embed=nn.Embedding(vocab_size,embed_dim) self.gru=nn.GRU(embed_dim,hidden_dim,num_layers=num_layers,batch_first=True,dropout=dropout) self.norm=nn.LayerNorm(hidden_dim) self.head=nn.Linear(hidden_dim,vocab_size) def forward(self,x,h=None): emb=self.embed(x); out,h=self.gru(emb,h) out=self.norm(out); logits=self.head(out) return logits,h model=TinyGRULM(VOCAB_SIZE,EMBED_DIM,HIDDEN_DIM,NUM_LAYERS,DROPOUT).to(DEVICE) opt=torch.optim.AdamW(model.parameters(),lr=LR) # ----------------------------- # Train (quick warm-up for Spaces) # ----------------------------- for epoch in range(EPOCHS): model.train() for step,(x,y) in enumerate(train_loader): if step>=TRAIN_STEPS_PER_EPOCH: break x,y=x.to(DEVICE),y.to(DEVICE) logits,_=model(x) loss=F.cross_entropy(logits.view(-1,VOCAB_SIZE),y.view(-1),ignore_index=PAD_ID) opt.zero_grad(); loss.backward(); opt.step() print(f"Epoch {epoch+1} done, loss={loss.item():.3f}") # ----------------------------- # Sampling # ----------------------------- def sample_next(probs,top_k=8,temp=0.7): logits=torch.log(probs+1e-9)/max(temp,1e-6) if top_k>0: k=min(top_k,logits.shape[-1]) v,idx=torch.topk(logits,k) mask=torch.full_like(logits,float('-inf')) mask.scatter_(0,idx,v) logits=mask return torch.multinomial(F.softmax(logits,dim=-1),1).item() def generate_raw(prompt,max_new_tokens=160): model.eval() with torch.no_grad(): x=encode(prompt,MAX_LEN,add_bos=True,add_eos=False).unsqueeze(0).to(DEVICE) h=None; out_ids=x.squeeze(0).tolist() for _ in range(max_new_tokens): logits,h=model(x,h) probs=F.softmax(logits[0,-1],dim=-1) nid=sample_next(probs,TOP_K,TEMPERATURE) out_ids.append(nid) if nid==EOS_ID: break x=torch.tensor([out_ids[-(MAX_LEN-1):]],dtype=torch.long).to(DEVICE) gen=decode(out_ids) if "A:" in gen: gen = gen.split("A:",1)[1].strip() return gen # ----------------------------- # Canonical verifier (accurate Step/Final for +, -, *) # ----------------------------- def compute_step_final(q): m=re.search(r'(-?\d+)\s*([+\-*])\s*(-?\d+)',q) if not m: return None a,op,b=int(m.group(1)),m.group(2),int(m.group(3)) if op=='+': return f"Step: {a} + {b} = {a+b}", f"Final: {a+b}" if op=='-': return f"Step: {a} - {b} = {a-b}", f"Final: {a-b}" if op=='*': return f"Step: {a} * {b} = {a*b}", f"Final: {a*b}" return None def normalize_generation(gen): lines=[ln.strip() for ln in gen.splitlines() if ln.strip()] return lines def verify_and_format(question, gen): sf = compute_step_final(question) lines = normalize_generation(gen) step_lines = [ln for ln in lines if ln.lower().startswith("step:")] final_lines = [ln for ln in lines if ln.lower().startswith("final:")] if sf is None: out=[] if step_lines: out.append(step_lines[0]) if final_lines: out.append(final_lines[0]) if not out: out = lines[:4] return "\n".join(out) canonical_step, canonical_final = sf return "\n".join([canonical_step, canonical_final]) # ----------------------------- # Quick validation on held-out samples (sanity check) # ----------------------------- def validate(n=20): tests=[] for _ in range(n): kind=random.choice([gen_add,gen_sub,gen_mul]) inp,tgt=kind() q = inp[:-2] tests.append(q) ok=0 for q in tests: ans = verify_and_format(q, generate_raw(q+"\nA:",160)) m=re.search(r'Final:\s*(-?\d+)\s*$',ans) target = compute_step_final(q) if m and target and int(m.group(1))==int(target[1].split(":")[1].strip()): ok+=1 return ok, n acc, total = validate(30) print(f"Verifier accuracy on random held-out arithmetic: {acc}/{total} correct") # ----------------------------- # Inference # ----------------------------- def answer_fn(question): raw=generate_raw(question+"\nA:",max_new_tokens=160) return verify_and_format(question, raw) # ----------------------------- # Gradio UI # ----------------------------- demo=gr.Interface( fn=answer_fn, inputs=gr.Textbox(lines=2, placeholder="Ask a math question (e.g., What is 12 * 4?)"), outputs="text", title="Tiny CoT SLM (CPU) with canonical Steps and Final", description="Guarantees accurate 'Step:' and 'Final:' for +, -, * via a safe verifier. Model adds natural CoT flavor." ) if __name__ == "__main__": demo.launch()