Spaces:
Sleeping
Sleeping
Create train_once.py
Browse files- app/train_once.py +41 -0
app/train_once.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, random, torch, torch.nn.functional as F
|
| 2 |
+
from slm_qa import TinyTransformer, encode, wrap_bos_eos, itos, PAD, DATA_QA
|
| 3 |
+
|
| 4 |
+
def make_sequences():
|
| 5 |
+
return [wrap_bos_eos(encode("q: "+q) + encode("a: "+a)) for q,a in DATA_QA]
|
| 6 |
+
|
| 7 |
+
def pad_batches(seqs, batch_size=4, device=torch.device("cpu")):
|
| 8 |
+
random.shuffle(seqs)
|
| 9 |
+
batches=[]
|
| 10 |
+
for i in range(0,len(seqs),batch_size):
|
| 11 |
+
chunk=seqs[i:i+batch_size]; T=max(len(s) for s in chunk)
|
| 12 |
+
x=torch.full((len(chunk),T-1),PAD); y=torch.full((len(chunk),T-1),PAD)
|
| 13 |
+
for bi,s in enumerate(chunk):
|
| 14 |
+
s_pad=s+[PAD]*(T-len(s))
|
| 15 |
+
x[bi]=torch.tensor(s_pad[:-1]); y[bi]=torch.tensor(s_pad[1:])
|
| 16 |
+
batches.append((x.to(device),y.to(device)))
|
| 17 |
+
return batches
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
model=TinyTransformer(vocab_size=len(itos),max_len=128).to(device)
|
| 22 |
+
opt=torch.optim.AdamW(model.parameters(),lr=3e-4)
|
| 23 |
+
batches=pad_batches(make_sequences(),device=device)
|
| 24 |
+
best=1e9
|
| 25 |
+
for ep in range(60): # more epochs
|
| 26 |
+
loss_sum=0
|
| 27 |
+
for x,y in batches:
|
| 28 |
+
logits=model(x)
|
| 29 |
+
B,T,V=logits.size()
|
| 30 |
+
loss=F.cross_entropy(logits.view(B*T,V),y.view(B*T),ignore_index=PAD)
|
| 31 |
+
opt.zero_grad(); loss.backward()
|
| 32 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
|
| 33 |
+
opt.step(); loss_sum+=loss.item()
|
| 34 |
+
avg=loss_sum/len(batches)
|
| 35 |
+
print(f"Epoch {ep+1:02d} loss {avg:.4f}")
|
| 36 |
+
if avg<best:
|
| 37 |
+
best=avg
|
| 38 |
+
os.makedirs("app/models",exist_ok=True)
|
| 39 |
+
torch.save(model.state_dict(),"app/models/slm_qa_best.pt")
|
| 40 |
+
|
| 41 |
+
if __name__=="__main__": main()
|