mahesh1209 commited on
Commit
6b88f90
·
verified ·
1 Parent(s): ddfa3e7

Create train_once.py

Browse files
Files changed (1) hide show
  1. app/train_once.py +41 -0
app/train_once.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, random, torch, torch.nn.functional as F
2
+ from slm_qa import TinyTransformer, encode, wrap_bos_eos, itos, PAD, DATA_QA
3
+
4
+ def make_sequences():
5
+ return [wrap_bos_eos(encode("q: "+q) + encode("a: "+a)) for q,a in DATA_QA]
6
+
7
+ def pad_batches(seqs, batch_size=4, device=torch.device("cpu")):
8
+ random.shuffle(seqs)
9
+ batches=[]
10
+ for i in range(0,len(seqs),batch_size):
11
+ chunk=seqs[i:i+batch_size]; T=max(len(s) for s in chunk)
12
+ x=torch.full((len(chunk),T-1),PAD); y=torch.full((len(chunk),T-1),PAD)
13
+ for bi,s in enumerate(chunk):
14
+ s_pad=s+[PAD]*(T-len(s))
15
+ x[bi]=torch.tensor(s_pad[:-1]); y[bi]=torch.tensor(s_pad[1:])
16
+ batches.append((x.to(device),y.to(device)))
17
+ return batches
18
+
19
+ def main():
20
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model=TinyTransformer(vocab_size=len(itos),max_len=128).to(device)
22
+ opt=torch.optim.AdamW(model.parameters(),lr=3e-4)
23
+ batches=pad_batches(make_sequences(),device=device)
24
+ best=1e9
25
+ for ep in range(60): # more epochs
26
+ loss_sum=0
27
+ for x,y in batches:
28
+ logits=model(x)
29
+ B,T,V=logits.size()
30
+ loss=F.cross_entropy(logits.view(B*T,V),y.view(B*T),ignore_index=PAD)
31
+ opt.zero_grad(); loss.backward()
32
+ torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
33
+ opt.step(); loss_sum+=loss.item()
34
+ avg=loss_sum/len(batches)
35
+ print(f"Epoch {ep+1:02d} loss {avg:.4f}")
36
+ if avg<best:
37
+ best=avg
38
+ os.makedirs("app/models",exist_ok=True)
39
+ torch.save(model.state_dict(),"app/models/slm_qa_best.pt")
40
+
41
+ if __name__=="__main__": main()