diff --git a/spark/clean.py b/spark/clean.py index 49b0c5605dffc4e9e04acf0e1813e3156c23c6f6..463099adf2dd047dc07506f6a6f75f18f547b7c2 100644 --- a/spark/clean.py +++ b/spark/clean.py @@ -1,4 +1,5 @@ from pyspark.sql import SparkSession +from pyspark.sql import functions as F spark = SparkSession.builder \ .appName("TelcoCustomerChurn") \ @@ -18,7 +19,15 @@ df = df.withColumn("TotalCharges", col("TotalCharges").cast("double")) # tenure_zero_rows = df.filter(col("tenure") == 0) # tenure_zero_rows.show() -df = df.fillna({"TotalCharges": 0}) +df = df.withColumn( + "TotalCharges", + F.when(F.col("tenure") == 0, 0).otherwise(F.col("TotalCharges")) +) + +percentile = df.filter(F.col("tenure") != 0).approxQuantile("TotalCharges", [0.5], 0.01) +median_total_charges = percentile[0] + +df = df.fillna({"TotalCharges": median_total_charges}) # missing_counts = df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns]) # missing_counts.show()