From a752cd0eebcf7d8d68fe391ce9c803373aeb950e Mon Sep 17 00:00:00 2001
From: Hidayatullah Wildan Ghaly Buchary <wildanghaly1@gmail.com>
Date: Thu, 9 Jan 2025 21:37:50 +0700
Subject: [PATCH] feat: add data drift simulation test (basic)

---
 dags/data_drift_simulation.py | 40 ++++++++++++++++
 jobs/python/data_drift.py     | 90 +++++++++++++++++++++++++++++++++++
 2 files changed, 130 insertions(+)
 create mode 100644 dags/data_drift_simulation.py
 create mode 100644 jobs/python/data_drift.py

diff --git a/dags/data_drift_simulation.py b/dags/data_drift_simulation.py
new file mode 100644
index 0000000..2491c37
--- /dev/null
+++ b/dags/data_drift_simulation.py
@@ -0,0 +1,40 @@
+import airflow  
+from airflow import DAG  
+from airflow.operators.python import PythonOperator  
+from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator  
+
+dag = DAG(  
+    dag_id="data_drift_simulation",  
+    default_args={  
+        "owner": "admin",  
+        "start_date": airflow.utils.dates.days_ago(1)  
+    },  
+    schedule_interval="@daily"  
+)  
+
+start = PythonOperator(  
+    task_id="start",  
+    python_callable=lambda: print("Data Drift Simulation started"),  
+    dag=dag  
+)  
+
+data_drift_job = SparkSubmitOperator(  
+    task_id="data_drift",  
+    conn_id="spark_conn",  
+    application="jobs/python/data_drift.py",  
+    dag=dag,  
+)  
+
+trigger_sparking_flow = PythonOperator(  
+    task_id="trigger_sparking_flow",  
+    python_callable=lambda: print("Triggering sparking_flow DAG"),  
+    dag=dag,  
+)  
+
+end = PythonOperator(  
+    task_id="end",  
+    python_callable=lambda: print("Data Drift Simulation completed successfully"),  
+    dag=dag  
+)  
+
+start >> data_drift_job >> trigger_sparking_flow >> end  
diff --git a/jobs/python/data_drift.py b/jobs/python/data_drift.py
new file mode 100644
index 0000000..8adcdbe
--- /dev/null
+++ b/jobs/python/data_drift.py
@@ -0,0 +1,90 @@
+from pyspark.sql import SparkSession      
+from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType      
+from pyspark.sql.functions import when, col, isnull      
+from pyspark.ml.classification import LogisticRegressionModel      
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator      
+from pyspark.ml.feature import VectorAssembler, StringIndexer      
+  
+spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate()      
+  
+schema = StructType([      
+    StructField("customerID", StringType(), True),      
+    StructField("gender", StringType(), True),      
+    StructField("SeniorCitizen", StringType(), True),      
+    StructField("Partner", StringType(), True),      
+    StructField("Dependents", StringType(), True),      
+    StructField("tenure", DoubleType(), True),      
+    StructField("PhoneService", StringType(), True),      
+    StructField("MultipleLines", StringType(), True),      
+    StructField("InternetService", StringType(), True),      
+    StructField("OnlineSecurity", StringType(), True),      
+    StructField("OnlineBackup", StringType(), True),      
+    StructField("DeviceProtection", StringType(), True),      
+    StructField("TechSupport", StringType(), True),      
+    StructField("StreamingTV", StringType(), True),      
+    StructField("StreamingMovies", StringType(), True),      
+    StructField("Contract", StringType(), True),      
+    StructField("PaperlessBilling", StringType(), True),      
+    StructField("PaymentMethod", StringType(), True),      
+    StructField("MonthlyCharges", DoubleType(), True),      
+    StructField("TotalCharges", DoubleType(), True),      
+    StructField("Churn", StringType(), True)      
+])      
+  
+original_data_path = "dataset/cleaned_data.csv"      
+original_data_df = spark.read.csv(original_data_path, header=True, schema=schema)      
+  
+for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']:      
+    if original_data_df.filter(isnull(col(column))).count() > 0:      
+        print(f"Null values found in {column}")      
+  
+original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))      
+  
+data_drift_df = original_data_df.sample(fraction=0.2, seed=42)      
+  
+if data_drift_df.count() == 0:      
+    print("data_drift_df is empty after sampling.")      
+  
+data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType()))      
+  
+model_path = 'logistic_regression_model'      
+loaded_model = LogisticRegressionModel.load(model_path)      
+  
+categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']      
+  
+for column in categorical_columns:      
+    indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep")      
+    data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df)      
+  
+for column in categorical_columns:    
+    if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0:    
+        print(f"Null values found in {column + '_index'}")    
+  
+feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges']      
+assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')      
+  
+data = assembler.transform(data_drift_df)      
+  
+predictions = loaded_model.transform(data)      
+  
+evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')      
+accuracy = evaluator.evaluate(predictions)      
+  
+print(f"Accuracy after data drift: {accuracy * 100:.2f}%")      
+  
+if accuracy < 0.70:      
+    print("Accuracy is below 70%. Triggering retraining of the model.")      
+    from airflow.models import DagBag      
+    dag_bag = DagBag()      
+    sparking_flow_dag = dag_bag.get_dag('sparking_flow')      
+          
+    if sparking_flow_dag:      
+        sparking_flow_dag.clear()   
+        sparking_flow_dag.run()      
+        print("Successfully triggered sparking_flow DAG.")      
+    else:      
+        print("sparking_flow DAG not found.")      
+else:      
+    print("Accuracy is acceptable.")      
+  
+spark.stop()      
-- 
GitLab