| from pathlib import Path | |
| from typing import Dict, List | |
| import gradio as gr | |
| import joblib | |
| import pandas as pd | |
| BASE_DIR = Path(__file__).resolve().parent | |
| MODEL_PATH = BASE_DIR / "churn_model_v1.pkl" | |
| OUTPUT_BATCH_PATH = BASE_DIR / "batch_predictions.csv" | |
| FEATURE_COLUMNS: List[str] = [ | |
| "gender", | |
| "SeniorCitizen", | |
| "Partner", | |
| "Dependents", | |
| "tenure", | |
| "PhoneService", | |
| "MultipleLines", | |
| "InternetService", | |
| "OnlineSecurity", | |
| "OnlineBackup", | |
| "DeviceProtection", | |
| "TechSupport", | |
| "StreamingTV", | |
| "StreamingMovies", | |
| "Contract", | |
| "PaperlessBilling", | |
| "PaymentMethod", | |
| "MonthlyCharges", | |
| "TotalCharges", | |
| ] | |
| CATEGORICAL_OPTIONS: Dict[str, List[str]] = { | |
| "gender": ["Female", "Male"], | |
| "SeniorCitizen": ["0", "1"], | |
| "Partner": ["No", "Yes"], | |
| "Dependents": ["No", "Yes"], | |
| "PhoneService": ["No", "Yes"], | |
| "MultipleLines": ["No", "No phone service", "Yes"], | |
| "InternetService": ["DSL", "Fiber optic", "No"], | |
| "OnlineSecurity": ["No", "No internet service", "Yes"], | |
| "OnlineBackup": ["No", "No internet service", "Yes"], | |
| "DeviceProtection": ["No", "No internet service", "Yes"], | |
| "TechSupport": ["No", "No internet service", "Yes"], | |
| "StreamingTV": ["No", "No internet service", "Yes"], | |
| "StreamingMovies": ["No", "No internet service", "Yes"], | |
| "Contract": ["Month-to-month", "One year", "Two year"], | |
| "PaperlessBilling": ["No", "Yes"], | |
| "PaymentMethod": [ | |
| "Bank transfer (automatic)", | |
| "Credit card (automatic)", | |
| "Electronic check", | |
| "Mailed check", | |
| ], | |
| } | |
| NUMERIC_COLUMNS = ["tenure", "MonthlyCharges", "TotalCharges"] | |
| def load_model(): | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"Model file not found at {MODEL_PATH}. " | |
| "Upload churn_model_v1.pkl to the Space root." | |
| ) | |
| return joblib.load(MODEL_PATH) | |
| MODEL = load_model() | |
| def _to_binary_label(pred_value: int) -> str: | |
| return "Churn" if int(pred_value) == 1 else "No Churn" | |
| def _validate_and_cast(df: pd.DataFrame) -> pd.DataFrame: | |
| work_df = df.copy() | |
| if "customerID" in work_df.columns: | |
| work_df = work_df.drop(columns=["customerID"]) | |
| missing = [col for col in FEATURE_COLUMNS if col not in work_df.columns] | |
| if missing: | |
| raise ValueError(f"Missing required columns: {missing}") | |
| work_df = work_df[FEATURE_COLUMNS].copy() | |
| for col in NUMERIC_COLUMNS: | |
| work_df[col] = pd.to_numeric(work_df[col], errors="coerce") | |
| invalid_rows = work_df[NUMERIC_COLUMNS].isna().any(axis=1).sum() | |
| if invalid_rows > 0: | |
| raise ValueError( | |
| f"Found {invalid_rows} rows with invalid numeric values in " | |
| f"{NUMERIC_COLUMNS}. Please fix and re-upload." | |
| ) | |
| for col, options in CATEGORICAL_OPTIONS.items(): | |
| invalid_values = ~work_df[col].astype(str).isin(options) | |
| if invalid_values.any(): | |
| bad = sorted(work_df.loc[invalid_values, col].astype(str).unique().tolist()) | |
| raise ValueError(f"Invalid value(s) in '{col}': {bad}") | |
| work_df["SeniorCitizen"] = work_df["SeniorCitizen"].astype(str) | |
| return work_df | |
| def predict_single( | |
| gender, | |
| senior_citizen, | |
| partner, | |
| dependents, | |
| tenure, | |
| phone_service, | |
| multiple_lines, | |
| internet_service, | |
| online_security, | |
| online_backup, | |
| device_protection, | |
| tech_support, | |
| streaming_tv, | |
| streaming_movies, | |
| contract, | |
| paperless_billing, | |
| payment_method, | |
| monthly_charges, | |
| total_charges, | |
| ): | |
| input_row = pd.DataFrame( | |
| [ | |
| { | |
| "gender": gender, | |
| "SeniorCitizen": str(senior_citizen), | |
| "Partner": partner, | |
| "Dependents": dependents, | |
| "tenure": tenure, | |
| "PhoneService": phone_service, | |
| "MultipleLines": multiple_lines, | |
| "InternetService": internet_service, | |
| "OnlineSecurity": online_security, | |
| "OnlineBackup": online_backup, | |
| "DeviceProtection": device_protection, | |
| "TechSupport": tech_support, | |
| "StreamingTV": streaming_tv, | |
| "StreamingMovies": streaming_movies, | |
| "Contract": contract, | |
| "PaperlessBilling": paperless_billing, | |
| "PaymentMethod": payment_method, | |
| "MonthlyCharges": monthly_charges, | |
| "TotalCharges": total_charges, | |
| } | |
| ] | |
| ) | |
| validated = _validate_and_cast(input_row) | |
| pred = MODEL.predict(validated)[0] | |
| prob = MODEL.predict_proba(validated)[0][1] | |
| return _to_binary_label(pred), float(prob) | |
| def predict_batch(file_obj): | |
| if file_obj is None: | |
| raise gr.Error("Please upload a CSV file.") | |
| input_path = Path(file_obj.name) | |
| if input_path.suffix.lower() != ".csv": | |
| raise gr.Error("Only CSV files are supported.") | |
| df = pd.read_csv(input_path) | |
| validated = _validate_and_cast(df) | |
| pred = MODEL.predict(validated) | |
| prob = MODEL.predict_proba(validated)[:, 1] | |
| output_df = validated.copy() | |
| output_df["predicted_churn_label"] = ["Churn" if p == 1 else "No Churn" for p in pred] | |
| output_df["predicted_churn_probability"] = prob | |
| output_df.to_csv(OUTPUT_BATCH_PATH, index=False) | |
| return str(OUTPUT_BATCH_PATH), f"Processed {len(output_df)} rows successfully." | |
| with gr.Blocks(title="Customer Churn Predictor") as app: | |
| gr.Markdown( | |
| """ | |
| # Customer Churn Prediction | |
| Upload customer attributes to predict churn risk using a trained Scikit-learn pipeline. | |
| """ | |
| ) | |
| with gr.Tab("Single Prediction"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gender_in = gr.Dropdown(CATEGORICAL_OPTIONS["gender"], value="Female", label="gender") | |
| senior_in = gr.Dropdown(CATEGORICAL_OPTIONS["SeniorCitizen"], value="0", label="SeniorCitizen") | |
| partner_in = gr.Dropdown(CATEGORICAL_OPTIONS["Partner"], value="No", label="Partner") | |
| dependents_in = gr.Dropdown(CATEGORICAL_OPTIONS["Dependents"], value="No", label="Dependents") | |
| tenure_in = gr.Number(value=1, label="tenure") | |
| phone_in = gr.Dropdown(CATEGORICAL_OPTIONS["PhoneService"], value="Yes", label="PhoneService") | |
| multiple_in = gr.Dropdown(CATEGORICAL_OPTIONS["MultipleLines"], value="No", label="MultipleLines") | |
| internet_in = gr.Dropdown(CATEGORICAL_OPTIONS["InternetService"], value="DSL", label="InternetService") | |
| sec_in = gr.Dropdown(CATEGORICAL_OPTIONS["OnlineSecurity"], value="No", label="OnlineSecurity") | |
| backup_in = gr.Dropdown(CATEGORICAL_OPTIONS["OnlineBackup"], value="No", label="OnlineBackup") | |
| with gr.Column(): | |
| device_in = gr.Dropdown(CATEGORICAL_OPTIONS["DeviceProtection"], value="No", label="DeviceProtection") | |
| tech_in = gr.Dropdown(CATEGORICAL_OPTIONS["TechSupport"], value="No", label="TechSupport") | |
| tv_in = gr.Dropdown(CATEGORICAL_OPTIONS["StreamingTV"], value="No", label="StreamingTV") | |
| movies_in = gr.Dropdown(CATEGORICAL_OPTIONS["StreamingMovies"], value="No", label="StreamingMovies") | |
| contract_in = gr.Dropdown(CATEGORICAL_OPTIONS["Contract"], value="Month-to-month", label="Contract") | |
| paperless_in = gr.Dropdown(CATEGORICAL_OPTIONS["PaperlessBilling"], value="Yes", label="PaperlessBilling") | |
| payment_in = gr.Dropdown(CATEGORICAL_OPTIONS["PaymentMethod"], value="Electronic check", label="PaymentMethod") | |
| monthly_in = gr.Number(value=29.85, label="MonthlyCharges") | |
| total_in = gr.Number(value=29.85, label="TotalCharges") | |
| predict_btn = gr.Button("Predict") | |
| label_out = gr.Textbox(label="Prediction") | |
| prob_out = gr.Number(label="Churn Probability") | |
| predict_btn.click( | |
| fn=predict_single, | |
| inputs=[ | |
| gender_in, | |
| senior_in, | |
| partner_in, | |
| dependents_in, | |
| tenure_in, | |
| phone_in, | |
| multiple_in, | |
| internet_in, | |
| sec_in, | |
| backup_in, | |
| device_in, | |
| tech_in, | |
| tv_in, | |
| movies_in, | |
| contract_in, | |
| paperless_in, | |
| payment_in, | |
| monthly_in, | |
| total_in, | |
| ], | |
| outputs=[label_out, prob_out], | |
| ) | |
| with gr.Tab("Batch Prediction (CSV)"): | |
| gr.Markdown( | |
| "Upload CSV with Telco feature columns. Optional column `customerID` is ignored if present." | |
| ) | |
| file_in = gr.File(label="Upload CSV", file_types=[".csv"]) | |
| batch_btn = gr.Button("Run Batch Prediction") | |
| out_file = gr.File(label="Download Predictions CSV") | |
| out_msg = gr.Textbox(label="Status") | |
| batch_btn.click(fn=predict_batch, inputs=[file_in], outputs=[out_file, out_msg]) | |
| if __name__ == "__main__": | |
| app.launch() | |