mahesh1209's picture
Create app.py
c4908ec verified
import gradio as gr
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# πŸ”Ή Load pretrained classification model
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# πŸ”Ή Sample Telecom Dataset
sample_data = pd.DataFrame([
{
"gender": "Female", "SeniorCitizen": 0, "Partner": "Yes", "Dependents": "No",
"tenure": 1, "PhoneService": "No", "InternetService": "DSL",
"Contract": "Month-to-month", "MonthlyCharges": 29.85
},
{
"gender": "Male", "SeniorCitizen": 1, "Partner": "No", "Dependents": "No",
"tenure": 34, "PhoneService": "Yes", "InternetService": "Fiber optic",
"Contract": "One year", "MonthlyCharges": 56.95
},
{
"gender": "Female", "SeniorCitizen": 0, "Partner": "Yes", "Dependents": "Yes",
"tenure": 2, "PhoneService": "Yes", "InternetService": "DSL",
"Contract": "Month-to-month", "MonthlyCharges": 53.85
}
])
# πŸ”Ή Convert structured row to prompt
def preprocess(row):
return (
f"Customer: {row['gender']}, senior: {row['SeniorCitizen']}, partner: {row['Partner']}, "
f"dependents: {row['Dependents']}, tenure: {row['tenure']}, phone: {row['PhoneService']}, "
f"internet: {row['InternetService']}, contract: {row['Contract']}, charges: {row['MonthlyCharges']}. "
f"Will they churn?"
)
# πŸ”Ή Prediction logic
def predict_churn(input_mode, table=None):
df = sample_data.copy() if input_mode == "Use Sample Data" else pd.DataFrame(table or [])
if df.empty:
return pd.DataFrame()
results = []
for _, row in df.iterrows():
try:
prompt = preprocess(row)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
results.append("Churn" if pred == 1 else "No Churn")
except Exception:
results.append("Error")
df["Prediction"] = results
return df
# πŸ”Ή Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## πŸ“ž Telecom Customer Churn Predictor")
input_mode = gr.Radio(["Use Sample Data", "Manual Entry"], label="Choose Input Mode", value="Use Sample Data")
table_input = gr.Dataframe(
headers=["gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "InternetService", "Contract", "MonthlyCharges"],
label="Enter Customer Data",
row_count=3
)
output = gr.Dataframe(label="Prediction Results")
run_btn = gr.Button("Predict")
run_btn.click(fn=predict_churn, inputs=[input_mode, table_input], outputs=output)
demo.launch()