Aashir92's picture
Upload app.py with huggingface_hub
6af65f0 verified
from pathlib import Path
from typing import Dict, List
import gradio as gr
import joblib
import pandas as pd
BASE_DIR = Path(__file__).resolve().parent
MODEL_PATH = BASE_DIR / "churn_model_v1.pkl"
OUTPUT_BATCH_PATH = BASE_DIR / "batch_predictions.csv"
FEATURE_COLUMNS: List[str] = [
"gender",
"SeniorCitizen",
"Partner",
"Dependents",
"tenure",
"PhoneService",
"MultipleLines",
"InternetService",
"OnlineSecurity",
"OnlineBackup",
"DeviceProtection",
"TechSupport",
"StreamingTV",
"StreamingMovies",
"Contract",
"PaperlessBilling",
"PaymentMethod",
"MonthlyCharges",
"TotalCharges",
]
CATEGORICAL_OPTIONS: Dict[str, List[str]] = {
"gender": ["Female", "Male"],
"SeniorCitizen": ["0", "1"],
"Partner": ["No", "Yes"],
"Dependents": ["No", "Yes"],
"PhoneService": ["No", "Yes"],
"MultipleLines": ["No", "No phone service", "Yes"],
"InternetService": ["DSL", "Fiber optic", "No"],
"OnlineSecurity": ["No", "No internet service", "Yes"],
"OnlineBackup": ["No", "No internet service", "Yes"],
"DeviceProtection": ["No", "No internet service", "Yes"],
"TechSupport": ["No", "No internet service", "Yes"],
"StreamingTV": ["No", "No internet service", "Yes"],
"StreamingMovies": ["No", "No internet service", "Yes"],
"Contract": ["Month-to-month", "One year", "Two year"],
"PaperlessBilling": ["No", "Yes"],
"PaymentMethod": [
"Bank transfer (automatic)",
"Credit card (automatic)",
"Electronic check",
"Mailed check",
],
}
NUMERIC_COLUMNS = ["tenure", "MonthlyCharges", "TotalCharges"]
def load_model():
if not MODEL_PATH.exists():
raise FileNotFoundError(
f"Model file not found at {MODEL_PATH}. "
"Upload churn_model_v1.pkl to the Space root."
)
return joblib.load(MODEL_PATH)
MODEL = load_model()
def _to_binary_label(pred_value: int) -> str:
return "Churn" if int(pred_value) == 1 else "No Churn"
def _validate_and_cast(df: pd.DataFrame) -> pd.DataFrame:
work_df = df.copy()
if "customerID" in work_df.columns:
work_df = work_df.drop(columns=["customerID"])
missing = [col for col in FEATURE_COLUMNS if col not in work_df.columns]
if missing:
raise ValueError(f"Missing required columns: {missing}")
work_df = work_df[FEATURE_COLUMNS].copy()
for col in NUMERIC_COLUMNS:
work_df[col] = pd.to_numeric(work_df[col], errors="coerce")
invalid_rows = work_df[NUMERIC_COLUMNS].isna().any(axis=1).sum()
if invalid_rows > 0:
raise ValueError(
f"Found {invalid_rows} rows with invalid numeric values in "
f"{NUMERIC_COLUMNS}. Please fix and re-upload."
)
for col, options in CATEGORICAL_OPTIONS.items():
invalid_values = ~work_df[col].astype(str).isin(options)
if invalid_values.any():
bad = sorted(work_df.loc[invalid_values, col].astype(str).unique().tolist())
raise ValueError(f"Invalid value(s) in '{col}': {bad}")
work_df["SeniorCitizen"] = work_df["SeniorCitizen"].astype(str)
return work_df
def predict_single(
gender,
senior_citizen,
partner,
dependents,
tenure,
phone_service,
multiple_lines,
internet_service,
online_security,
online_backup,
device_protection,
tech_support,
streaming_tv,
streaming_movies,
contract,
paperless_billing,
payment_method,
monthly_charges,
total_charges,
):
input_row = pd.DataFrame(
[
{
"gender": gender,
"SeniorCitizen": str(senior_citizen),
"Partner": partner,
"Dependents": dependents,
"tenure": tenure,
"PhoneService": phone_service,
"MultipleLines": multiple_lines,
"InternetService": internet_service,
"OnlineSecurity": online_security,
"OnlineBackup": online_backup,
"DeviceProtection": device_protection,
"TechSupport": tech_support,
"StreamingTV": streaming_tv,
"StreamingMovies": streaming_movies,
"Contract": contract,
"PaperlessBilling": paperless_billing,
"PaymentMethod": payment_method,
"MonthlyCharges": monthly_charges,
"TotalCharges": total_charges,
}
]
)
validated = _validate_and_cast(input_row)
pred = MODEL.predict(validated)[0]
prob = MODEL.predict_proba(validated)[0][1]
return _to_binary_label(pred), float(prob)
def predict_batch(file_obj):
if file_obj is None:
raise gr.Error("Please upload a CSV file.")
input_path = Path(file_obj.name)
if input_path.suffix.lower() != ".csv":
raise gr.Error("Only CSV files are supported.")
df = pd.read_csv(input_path)
validated = _validate_and_cast(df)
pred = MODEL.predict(validated)
prob = MODEL.predict_proba(validated)[:, 1]
output_df = validated.copy()
output_df["predicted_churn_label"] = ["Churn" if p == 1 else "No Churn" for p in pred]
output_df["predicted_churn_probability"] = prob
output_df.to_csv(OUTPUT_BATCH_PATH, index=False)
return str(OUTPUT_BATCH_PATH), f"Processed {len(output_df)} rows successfully."
with gr.Blocks(title="Customer Churn Predictor") as app:
gr.Markdown(
"""
# Customer Churn Prediction
Upload customer attributes to predict churn risk using a trained Scikit-learn pipeline.
"""
)
with gr.Tab("Single Prediction"):
with gr.Row():
with gr.Column():
gender_in = gr.Dropdown(CATEGORICAL_OPTIONS["gender"], value="Female", label="gender")
senior_in = gr.Dropdown(CATEGORICAL_OPTIONS["SeniorCitizen"], value="0", label="SeniorCitizen")
partner_in = gr.Dropdown(CATEGORICAL_OPTIONS["Partner"], value="No", label="Partner")
dependents_in = gr.Dropdown(CATEGORICAL_OPTIONS["Dependents"], value="No", label="Dependents")
tenure_in = gr.Number(value=1, label="tenure")
phone_in = gr.Dropdown(CATEGORICAL_OPTIONS["PhoneService"], value="Yes", label="PhoneService")
multiple_in = gr.Dropdown(CATEGORICAL_OPTIONS["MultipleLines"], value="No", label="MultipleLines")
internet_in = gr.Dropdown(CATEGORICAL_OPTIONS["InternetService"], value="DSL", label="InternetService")
sec_in = gr.Dropdown(CATEGORICAL_OPTIONS["OnlineSecurity"], value="No", label="OnlineSecurity")
backup_in = gr.Dropdown(CATEGORICAL_OPTIONS["OnlineBackup"], value="No", label="OnlineBackup")
with gr.Column():
device_in = gr.Dropdown(CATEGORICAL_OPTIONS["DeviceProtection"], value="No", label="DeviceProtection")
tech_in = gr.Dropdown(CATEGORICAL_OPTIONS["TechSupport"], value="No", label="TechSupport")
tv_in = gr.Dropdown(CATEGORICAL_OPTIONS["StreamingTV"], value="No", label="StreamingTV")
movies_in = gr.Dropdown(CATEGORICAL_OPTIONS["StreamingMovies"], value="No", label="StreamingMovies")
contract_in = gr.Dropdown(CATEGORICAL_OPTIONS["Contract"], value="Month-to-month", label="Contract")
paperless_in = gr.Dropdown(CATEGORICAL_OPTIONS["PaperlessBilling"], value="Yes", label="PaperlessBilling")
payment_in = gr.Dropdown(CATEGORICAL_OPTIONS["PaymentMethod"], value="Electronic check", label="PaymentMethod")
monthly_in = gr.Number(value=29.85, label="MonthlyCharges")
total_in = gr.Number(value=29.85, label="TotalCharges")
predict_btn = gr.Button("Predict")
label_out = gr.Textbox(label="Prediction")
prob_out = gr.Number(label="Churn Probability")
predict_btn.click(
fn=predict_single,
inputs=[
gender_in,
senior_in,
partner_in,
dependents_in,
tenure_in,
phone_in,
multiple_in,
internet_in,
sec_in,
backup_in,
device_in,
tech_in,
tv_in,
movies_in,
contract_in,
paperless_in,
payment_in,
monthly_in,
total_in,
],
outputs=[label_out, prob_out],
)
with gr.Tab("Batch Prediction (CSV)"):
gr.Markdown(
"Upload CSV with Telco feature columns. Optional column `customerID` is ignored if present."
)
file_in = gr.File(label="Upload CSV", file_types=[".csv"])
batch_btn = gr.Button("Run Batch Prediction")
out_file = gr.File(label="Download Predictions CSV")
out_msg = gr.Textbox(label="Status")
batch_btn.click(fn=predict_batch, inputs=[file_in], outputs=[out_file, out_msg])
if __name__ == "__main__":
app.launch()