mahesh1209's picture
Update app.py
88de55d verified
import torch, torch.nn as nn, torch.nn.functional as F
import gradio as gr, random
# ๐Ÿ“ Device
device = 'cpu'
# ๐Ÿ“š Q&A Pairs
qa_pairs = [
("What is AI?", "Artificial Intelligence simulates human cognition using machines."),
("What is ML?", "Machine Learning enables systems to learn from data."),
("What is deep learning?", "Deep learning uses neural networks with many layers to learn complex patterns."),
("What is supervised learning?", "Supervised learning trains models on labeled data to make predictions."),
("What is unsupervised learning?", "Unsupervised learning finds patterns in unlabeled data.")
]
# ๐Ÿ”ค Tokenizer (word-level)
vocab = sorted(set(" ".join(q + " " + a for q, a in qa_pairs).split() + ["<END>"]))
stoi = {w:i for i,w in enumerate(vocab)}
itos = {i:w for w,i in stoi.items()}
encode = lambda s: [stoi[w] for w in s.split() if w in stoi]
decode = lambda l: ' '.join([itos[i] for i in l])
# ๐Ÿงฑ Dataset
data = []
for q, a in qa_pairs:
full = encode("Q: " + q + " A: " + a + " <END>")
x = torch.tensor(full[:-1], dtype=torch.long)
y = torch.tensor(full[1:], dtype=torch.long)
data.append((x, y))
# ๐Ÿง  Model
class TinySLM(nn.Module):
def __init__(self, vocab_size, n_embed=64):
super().__init__()
self.embed = nn.Embedding(vocab_size, n_embed)
self.rnn = nn.GRU(n_embed, n_embed, batch_first=True)
self.lm_head = nn.Linear(n_embed, vocab_size)
def forward(self, idx):
x = self.embed(idx)
out, _ = self.rnn(x)
return self.lm_head(out)
model = TinySLM(len(vocab)).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# ๐Ÿ‹๏ธ Train
for step in range(1000):
x, y = random.choice(data)
logits = model(x.unsqueeze(0))
loss = F.cross_entropy(logits.squeeze(0), y)
optimizer.zero_grad(); loss.backward(); optimizer.step()
# ๐Ÿ”ฎ Generate
def generate_answer(question, max_new=50):
model.eval()
idx = torch.tensor(encode("Q: " + question + " A:"), dtype=torch.long).unsqueeze(0)
answer_tokens = []
for _ in range(max_new):
logits = model(idx)
next_id = torch.argmax(logits[:, -1], dim=-1).item()
if itos[next_id] == "<END>": break
answer_tokens.append(next_id)
idx = torch.cat([idx, torch.tensor([[next_id]])], dim=1)
return decode(answer_tokens)
# ๐ŸŽจ Gradio UI
def chat_interface(user_input):
return generate_answer(user_input)
demo = gr.Interface(
fn=chat_interface,
inputs="text",
outputs="text",
title="SLM Chatbot PYTORCH DEEPIKA",
description="Ask a question about AI or ML"
)
demo.launch()