From b0f95b17c17bef0647f99ba810dd3673d626eb16 Mon Sep 17 00:00:00 2001 From: Hidayatullah Wildan Ghaly Buchary <wildanghaly1@gmail.com> Date: Thu, 9 Jan 2025 22:07:13 +0700 Subject: [PATCH] chore: data drift can finally trigger the sparking flow BUT MULTIPLE TIMES AT ONCE --- jobs/python/data_drift.py | 176 +++++++++++++++++++++----------------- 1 file changed, 96 insertions(+), 80 deletions(-) diff --git a/jobs/python/data_drift.py b/jobs/python/data_drift.py index 8adcdbe..bd6a7a0 100644 --- a/jobs/python/data_drift.py +++ b/jobs/python/data_drift.py @@ -1,90 +1,106 @@ -from pyspark.sql import SparkSession -from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType -from pyspark.sql.functions import when, col, isnull -from pyspark.ml.classification import LogisticRegressionModel -from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.ml.feature import VectorAssembler, StringIndexer - -spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate() - -schema = StructType([ - StructField("customerID", StringType(), True), - StructField("gender", StringType(), True), - StructField("SeniorCitizen", StringType(), True), - StructField("Partner", StringType(), True), - StructField("Dependents", StringType(), True), - StructField("tenure", DoubleType(), True), - StructField("PhoneService", StringType(), True), - StructField("MultipleLines", StringType(), True), - StructField("InternetService", StringType(), True), - StructField("OnlineSecurity", StringType(), True), - StructField("OnlineBackup", StringType(), True), - StructField("DeviceProtection", StringType(), True), - StructField("TechSupport", StringType(), True), - StructField("StreamingTV", StringType(), True), - StructField("StreamingMovies", StringType(), True), - StructField("Contract", StringType(), True), - StructField("PaperlessBilling", StringType(), True), - StructField("PaymentMethod", StringType(), True), - StructField("MonthlyCharges", DoubleType(), True), - StructField("TotalCharges", DoubleType(), True), - StructField("Churn", StringType(), True) -]) - -original_data_path = "dataset/cleaned_data.csv" -original_data_df = spark.read.csv(original_data_path, header=True, schema=schema) - -for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']: - if original_data_df.filter(isnull(col(column))).count() > 0: - print(f"Null values found in {column}") - -original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType())) - -data_drift_df = original_data_df.sample(fraction=0.2, seed=42) - -if data_drift_df.count() == 0: - print("data_drift_df is empty after sampling.") - -data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType())) - -model_path = 'logistic_regression_model' -loaded_model = LogisticRegressionModel.load(model_path) - -categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod'] +from pyspark.sql import SparkSession +from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType +from pyspark.sql.functions import when, col, isnull, monotonically_increasing_id, rand +from pyspark.ml.classification import LogisticRegressionModel +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.ml.feature import VectorAssembler, StringIndexer + +spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate() + +schema = StructType([ + StructField("customerID", StringType(), True), + StructField("gender", StringType(), True), + StructField("SeniorCitizen", StringType(), True), + StructField("Partner", StringType(), True), + StructField("Dependents", StringType(), True), + StructField("tenure", DoubleType(), True), + StructField("PhoneService", StringType(), True), + StructField("MultipleLines", StringType(), True), + StructField("InternetService", StringType(), True), + StructField("OnlineSecurity", StringType(), True), + StructField("OnlineBackup", StringType(), True), + StructField("DeviceProtection", StringType(), True), + StructField("TechSupport", StringType(), True), + StructField("StreamingTV", StringType(), True), + StructField("StreamingMovies", StringType(), True), + StructField("Contract", StringType(), True), + StructField("PaperlessBilling", StringType(), True), + StructField("PaymentMethod", StringType(), True), + StructField("MonthlyCharges", DoubleType(), True), + StructField("TotalCharges", DoubleType(), True), + StructField("Churn", StringType(), True) +]) + +original_data_path = "dataset/cleaned_data.csv" +original_data_df = spark.read.csv(original_data_path, header=True, schema=schema) + +for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']: + if original_data_df.filter(isnull(col(column))).count() > 0: + print(f"Null values found in {column}") + +original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType())) + +data_drift_df = original_data_df.sample(fraction=0.2, seed=42) + +if data_drift_df.count() == 0: + print("data_drift_df is empty after sampling.") + +data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType())) + +data_drift_df = data_drift_df.withColumn("customerID", monotonically_increasing_id().cast(StringType())) + +# Introduce significant changes to the new data to drop accuracy +data_drift_df = data_drift_df.withColumn("MonthlyCharges", col("MonthlyCharges") * (1 + rand() * 0.5)) +data_drift_df = data_drift_df.withColumn("TotalCharges", col("TotalCharges") * (1 + rand() * 0.5)) +data_drift_df = data_drift_df.withColumn("Churn", when(rand() > 0.5, 1).otherwise(0)) + +model_path = 'logistic_regression_model' +loaded_model = LogisticRegressionModel.load(model_path) + +categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod'] + +for column in categorical_columns: + indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep") + data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df) for column in categorical_columns: - indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep") - data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df) + if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0: + print(f"Null values found in {column + '_index'}") -for column in categorical_columns: - if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0: - print(f"Null values found in {column + '_index'}") +feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges'] +assembler = VectorAssembler(inputCols=feature_columns, outputCol='features') -feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges'] -assembler = VectorAssembler(inputCols=feature_columns, outputCol='features') +data = assembler.transform(data_drift_df) -data = assembler.transform(data_drift_df) +predictions = loaded_model.transform(data) -predictions = loaded_model.transform(data) +evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy') +accuracy = evaluator.evaluate(predictions) -evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy') -accuracy = evaluator.evaluate(predictions) +print(f"Accuracy after data drift: {accuracy * 100:.2f}%") -print(f"Accuracy after data drift: {accuracy * 100:.2f}%") +# Select only the original columns for merging +data_drift_df = data_drift_df.select("customerID", "gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", "MonthlyCharges", "TotalCharges", "Churn") -if accuracy < 0.70: - print("Accuracy is below 70%. Triggering retraining of the model.") - from airflow.models import DagBag - dag_bag = DagBag() - sparking_flow_dag = dag_bag.get_dag('sparking_flow') - - if sparking_flow_dag: - sparking_flow_dag.clear() - sparking_flow_dag.run() - print("Successfully triggered sparking_flow DAG.") - else: - print("sparking_flow DAG not found.") -else: - print("Accuracy is acceptable.") +# Merge old and new data +merged_data_df = original_data_df.union(data_drift_df) -spark.stop() +output_path = "dataset/data_drifted_customers.csv" +merged_data_df.write.csv(output_path, header=True, mode='overwrite') + +if accuracy < 0.70: + print("Accuracy is below 70%. Triggering retraining of the model.") + from airflow.models import DagBag + dag_bag = DagBag() + sparking_flow_dag = dag_bag.get_dag('sparking_flow') + + if sparking_flow_dag: + sparking_flow_dag.clear() + sparking_flow_dag.run() + print("Successfully triggered sparking_flow DAG.") + else: + print("sparking_flow DAG not found.") +else: + print("Accuracy is acceptable.") + +spark.stop() -- GitLab