Skip to content
Snippets Groups Projects
Commit 2c5514a2 authored by Margaretha Olivia's avatar Margaretha Olivia
Browse files

[Feat] data preprocessing code with Spark

parent a9b43b06
No related merge requests found
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
# Initialize Spark session
spark = SparkSession.builder.appName("ChurnPrediction").getOrCreate()
# Load the dataset
df = spark.read.csv("telco_customer_churn.csv", header=True, inferSchema=True)
# Drop irrelevant or redundant columns
df = df.drop("customerID")
# Convert TotalCharges to numerical and handle invalid entries
df = df.withColumn("TotalCharges", when(col("TotalCharges") == " ", None).otherwise(col("TotalCharges").cast("double")))
# Fill missing or invalid values
df = df.fillna({"TotalCharges": 0, "MonthlyCharges": 0, "tenure": 0})
df = df.fillna("Unknown", subset=["gender", "Partner", "Dependents", "PhoneService", "InternetService"])
# Convert the Churn column into a binary format (Yes -> 1, No -> 0).
df = df.withColumn("Churn", when(col("Churn") == "Yes", 1).otherwise(0))
# Index and encode categorical columns
categorical_cols = ["gender", "Partner", "Dependents", "InternetService", "Contract", "PaymentMethod"]
indexers = [StringIndexer(inputCol=col, outputCol=col + "_idx") for col in categorical_cols]
encoded_cols = [col + "_ohe" for col in categorical_cols]
for indexer in indexers:
df = indexer.fit(df).transform(df)
encoder = OneHotEncoder(inputCols=[col + "_idx" for col in categorical_cols],
outputCols=encoded_cols)
df = encoder.fit(df).transform(df)
# Combine numerical and encoded categorical columns into a single feature vector
assembler = VectorAssembler(inputCols=["tenure", "MonthlyCharges", "TotalCharges"] + encoded_cols,
outputCol="features")
df = assembler.transform(df)
print(df.select(["features", "Churn"]).take(5))
# Save the preprocessed data
df.select(["features", "Churn"]).write.mode("overwrite").parquet("preprocessed_data")
# Stop the Spark session
spark.stop()
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment