File size: 2,855 Bytes
c4908ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# ๐น Load pretrained classification model
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# ๐น Sample Telecom Dataset
sample_data = pd.DataFrame([
{
"gender": "Female", "SeniorCitizen": 0, "Partner": "Yes", "Dependents": "No",
"tenure": 1, "PhoneService": "No", "InternetService": "DSL",
"Contract": "Month-to-month", "MonthlyCharges": 29.85
},
{
"gender": "Male", "SeniorCitizen": 1, "Partner": "No", "Dependents": "No",
"tenure": 34, "PhoneService": "Yes", "InternetService": "Fiber optic",
"Contract": "One year", "MonthlyCharges": 56.95
},
{
"gender": "Female", "SeniorCitizen": 0, "Partner": "Yes", "Dependents": "Yes",
"tenure": 2, "PhoneService": "Yes", "InternetService": "DSL",
"Contract": "Month-to-month", "MonthlyCharges": 53.85
}
])
# ๐น Convert structured row to prompt
def preprocess(row):
return (
f"Customer: {row['gender']}, senior: {row['SeniorCitizen']}, partner: {row['Partner']}, "
f"dependents: {row['Dependents']}, tenure: {row['tenure']}, phone: {row['PhoneService']}, "
f"internet: {row['InternetService']}, contract: {row['Contract']}, charges: {row['MonthlyCharges']}. "
f"Will they churn?"
)
# ๐น Prediction logic
def predict_churn(input_mode, table=None):
df = sample_data.copy() if input_mode == "Use Sample Data" else pd.DataFrame(table or [])
if df.empty:
return pd.DataFrame()
results = []
for _, row in df.iterrows():
try:
prompt = preprocess(row)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
results.append("Churn" if pred == 1 else "No Churn")
except Exception:
results.append("Error")
df["Prediction"] = results
return df
# ๐น Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## ๐ Telecom Customer Churn Predictor")
input_mode = gr.Radio(["Use Sample Data", "Manual Entry"], label="Choose Input Mode", value="Use Sample Data")
table_input = gr.Dataframe(
headers=["gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "InternetService", "Contract", "MonthlyCharges"],
label="Enter Customer Data",
row_count=3
)
output = gr.Dataframe(label="Prediction Results")
run_btn = gr.Button("Predict")
run_btn.click(fn=predict_churn, inputs=[input_mode, table_input], outputs=output)
demo.launch()
|