diff --git a/spark/prep.py b/spark/prep.py new file mode 100644 index 0000000000000000000000000000000000000000..4553baaf905be7e802c2d441286636f92947dcad --- /dev/null +++ b/spark/prep.py @@ -0,0 +1,34 @@ +from pyspark.sql import SparkSession + +spark = SparkSession.builder \ + .appName("Preprocessing") \ + .getOrCreate() + +# cleaned data +df = spark.read.csv("../data/clean.csv", header=True, inferSchema=True) + +# df.printSchema() +# df.count() + +# split data +data_X = df.drop('churn') # Features +data_y = df.select('churn') # Target + +data_train, data_test = df.randomSplit([0.7, 0.3], seed=1) + +#print(f"Training data count: {data_train.count()}") +#print(f"Testing data count: {data_test.count()}") + + +# encode +from pyspark.ml.feature import StringIndexer + +indexer = StringIndexer(inputCol="churn", outputCol="churn_encoded") + +indexer_model = indexer.fit(data_train) +data_train_encoded = indexer_model.transform(data_train) + +data_test_encoded = indexer_model.transform(data_test) + +#data_train_encoded.select("churn", "churn_encoded").show(5) +#data_test_encoded.select("churn", "churn_encoded").show(5)