Skip to content
Snippets Groups Projects
Commit 656c921f authored by mikeleo03's avatar mikeleo03
Browse files
parents c0906d30 27ae315f
Branches
No related merge requests found
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
# Initialize Spark session
spark = SparkSession.builder.appName("ChurnPrediction").getOrCreate()
# Load the dataset
df = spark.read.csv("telco_customer_churn.csv", header=True, inferSchema=True)
# Drop irrelevant or redundant columns
df = df.drop("customerID")
# Convert TotalCharges to numerical and handle invalid entries
df = df.withColumn("TotalCharges", when(col("TotalCharges") == " ", None).otherwise(col("TotalCharges").cast("double")))
# Fill missing or invalid values
df = df.fillna({"TotalCharges": 0, "MonthlyCharges": 0, "tenure": 0})
df = df.fillna("Unknown", subset=["gender", "Partner", "Dependents", "PhoneService", "InternetService"])
# Convert the Churn column into a binary format (Yes -> 1, No -> 0).
df = df.withColumn("Churn", when(col("Churn") == "Yes", 1).otherwise(0))
# Index and encode categorical columns
categorical_cols = ["gender", "Partner", "Dependents", "InternetService", "Contract", "PaymentMethod"]
indexers = [StringIndexer(inputCol=col, outputCol=col + "_idx") for col in categorical_cols]
encoded_cols = [col + "_ohe" for col in categorical_cols]
for indexer in indexers:
df = indexer.fit(df).transform(df)
encoder = OneHotEncoder(inputCols=[col + "_idx" for col in categorical_cols],
outputCols=encoded_cols)
df = encoder.fit(df).transform(df)
# Combine numerical and encoded categorical columns into a single feature vector
assembler = VectorAssembler(inputCols=["tenure", "MonthlyCharges", "TotalCharges"] + encoded_cols,
outputCol="features")
df = assembler.transform(df)
print(df.select(["features", "Churn"]).take(5))
# Save the preprocessed data
df.select(["features", "Churn"]).write.mode("overwrite").parquet("preprocessed_data")
# Stop the Spark session
spark.stop()
\ No newline at end of file
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder.appName("ChurnPrediction").getOrCreate()
# Load the dataset
df = spark.read.csv("telco_customer_churn.csv", header=True, inferSchema=True)
# Show the DataFrame
df.show()
# Stop the Spark session
spark.stop()
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment