diff --git a/data_preprocessing.py b/data_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..86ef7fffad31b279a6928c9863dab2a9c8337c84 --- /dev/null +++ b/data_preprocessing.py @@ -0,0 +1,54 @@ +from pyspark.sql import SparkSession +from pyspark.sql.functions import col, when +from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler + +# Initialize Spark session +spark = SparkSession.builder.appName("ChurnPrediction").getOrCreate() + +# Load the dataset +df = spark.read.csv("telco_customer_churn.csv", header=True, inferSchema=True) + + +# Drop irrelevant or redundant columns +df = df.drop("customerID") + + +# Convert TotalCharges to numerical and handle invalid entries +df = df.withColumn("TotalCharges", when(col("TotalCharges") == " ", None).otherwise(col("TotalCharges").cast("double"))) + + +# Fill missing or invalid values +df = df.fillna({"TotalCharges": 0, "MonthlyCharges": 0, "tenure": 0}) +df = df.fillna("Unknown", subset=["gender", "Partner", "Dependents", "PhoneService", "InternetService"]) + + +# Convert the Churn column into a binary format (Yes -> 1, No -> 0). +df = df.withColumn("Churn", when(col("Churn") == "Yes", 1).otherwise(0)) + + +# Index and encode categorical columns +categorical_cols = ["gender", "Partner", "Dependents", "InternetService", "Contract", "PaymentMethod"] +indexers = [StringIndexer(inputCol=col, outputCol=col + "_idx") for col in categorical_cols] +encoded_cols = [col + "_ohe" for col in categorical_cols] + +for indexer in indexers: + df = indexer.fit(df).transform(df) + +encoder = OneHotEncoder(inputCols=[col + "_idx" for col in categorical_cols], + outputCols=encoded_cols) +df = encoder.fit(df).transform(df) + + +# Combine numerical and encoded categorical columns into a single feature vector +assembler = VectorAssembler(inputCols=["tenure", "MonthlyCharges", "TotalCharges"] + encoded_cols, + outputCol="features") +df = assembler.transform(df) + +print(df.select(["features", "Churn"]).take(5)) + +# Save the preprocessed data +df.select(["features", "Churn"]).write.mode("overwrite").parquet("preprocessed_data") + + +# Stop the Spark session +spark.stop() \ No newline at end of file diff --git a/spark_code.py b/spark_code.py deleted file mode 100644 index e4a9a0f0b3f8dd5c30ef78d6b19cff19a7e613ad..0000000000000000000000000000000000000000 --- a/spark_code.py +++ /dev/null @@ -1,13 +0,0 @@ -from pyspark.sql import SparkSession - -# Initialize Spark session -spark = SparkSession.builder.appName("ChurnPrediction").getOrCreate() - -# Load the dataset -df = spark.read.csv("telco_customer_churn.csv", header=True, inferSchema=True) - -# Show the DataFrame -df.show() - -# Stop the Spark session -spark.stop() \ No newline at end of file