diff --git a/spark/clean.py b/spark/clean.py
index 49b0c5605dffc4e9e04acf0e1813e3156c23c6f6..463099adf2dd047dc07506f6a6f75f18f547b7c2 100644
--- a/spark/clean.py
+++ b/spark/clean.py
@@ -1,4 +1,5 @@
 from pyspark.sql import SparkSession
+from pyspark.sql import functions as F
 
 spark = SparkSession.builder \
     .appName("TelcoCustomerChurn") \
@@ -18,7 +19,15 @@ df = df.withColumn("TotalCharges", col("TotalCharges").cast("double"))
 # tenure_zero_rows = df.filter(col("tenure") == 0)
 # tenure_zero_rows.show()
 
-df = df.fillna({"TotalCharges": 0})
+df = df.withColumn(
+    "TotalCharges",
+    F.when(F.col("tenure") == 0, 0).otherwise(F.col("TotalCharges"))
+)
+
+percentile = df.filter(F.col("tenure") != 0).approxQuantile("TotalCharges", [0.5], 0.01)
+median_total_charges = percentile[0]
+
+df = df.fillna({"TotalCharges": median_total_charges})
 
 # missing_counts = df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns])
 # missing_counts.show()