From 7acca518a81dcc70bb29375b45d5deebe5eea69f Mon Sep 17 00:00:00 2001
From: Hidayatullah Wildan Ghaly Buchary <wildanghaly1@gmail.com>
Date: Thu, 9 Jan 2025 22:32:59 +0700
Subject: [PATCH] chore: add model version control and finally can trigger the
 sparking flow ONCE. But the path to re train is still wrong.

---
 dags/data_drift_simulation.py      |  89 ++++++++-----
 jobs/python/data_drift.py          | 201 ++++++++++++++---------------
 jobs/python/logistic_regression.py |   8 +-
 3 files changed, 155 insertions(+), 143 deletions(-)

diff --git a/dags/data_drift_simulation.py b/dags/data_drift_simulation.py
index 2491c37..8b6dc1c 100644
--- a/dags/data_drift_simulation.py
+++ b/dags/data_drift_simulation.py
@@ -1,40 +1,59 @@
-import airflow  
-from airflow import DAG  
-from airflow.operators.python import PythonOperator  
-from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator  
-
-dag = DAG(  
-    dag_id="data_drift_simulation",  
-    default_args={  
-        "owner": "admin",  
-        "start_date": airflow.utils.dates.days_ago(1)  
-    },  
-    schedule_interval="@daily"  
-)  
-
-start = PythonOperator(  
-    task_id="start",  
-    python_callable=lambda: print("Data Drift Simulation started"),  
-    dag=dag  
-)  
-
-data_drift_job = SparkSubmitOperator(  
-    task_id="data_drift",  
-    conn_id="spark_conn",  
-    application="jobs/python/data_drift.py",  
+import airflow    
+from airflow import DAG    
+from airflow.operators.python import PythonOperator    
+from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator    
+from airflow.operators.dagrun_operator import TriggerDagRunOperator  
+from airflow.models import Variable  
+  
+dag = DAG(    
+    dag_id="data_drift_simulation",    
+    default_args={    
+        "owner": "admin",    
+        "start_date": airflow.utils.dates.days_ago(1)    
+    },    
+    schedule_interval="@daily"    
+)    
+  
+start = PythonOperator(    
+    task_id="start",    
+    python_callable=lambda: print("Data Drift Simulation started"),    
+    dag=dag    
+)    
+  
+data_drift_job = SparkSubmitOperator(    
+    task_id="data_drift",    
+    conn_id="spark_conn",    
+    application="jobs/python/data_drift.py",    
+    dag=dag,    
+)    
+  
+def check_accuracy(**kwargs):  
+    accuracy_data = Variable.get("data_drift_accuracy", deserialize_json=True)  
+    accuracy = accuracy_data.get("accuracy", 0)  
+    if accuracy < 0.70:  
+        return 'trigger_sparking_flow'  
+    else:  
+        return 'end'  
+  
+branch_task = PythonOperator(  
+    task_id='check_accuracy',  
+    python_callable=check_accuracy,  
+    provide_context=True,  
     dag=dag,  
 )  
-
-trigger_sparking_flow = PythonOperator(  
+  
+trigger_sparking_flow = TriggerDagRunOperator(  
     task_id="trigger_sparking_flow",  
-    python_callable=lambda: print("Triggering sparking_flow DAG"),  
+    trigger_dag_id="sparking_flow",  
     dag=dag,  
 )  
-
-end = PythonOperator(  
-    task_id="end",  
-    python_callable=lambda: print("Data Drift Simulation completed successfully"),  
-    dag=dag  
-)  
-
-start >> data_drift_job >> trigger_sparking_flow >> end  
+  
+end = PythonOperator(    
+    task_id="end",    
+    python_callable=lambda: print("Data Drift Simulation completed successfully"),    
+    dag=dag    
+)    
+  
+start >> data_drift_job >> branch_task  
+branch_task >> trigger_sparking_flow >> end  
+branch_task >> end  
diff --git a/jobs/python/data_drift.py b/jobs/python/data_drift.py
index bd6a7a0..f9ada4b 100644
--- a/jobs/python/data_drift.py
+++ b/jobs/python/data_drift.py
@@ -1,106 +1,95 @@
-from pyspark.sql import SparkSession        
-from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType        
-from pyspark.sql.functions import when, col, isnull, monotonically_increasing_id, rand        
-from pyspark.ml.classification import LogisticRegressionModel        
-from pyspark.ml.evaluation import MulticlassClassificationEvaluator        
-from pyspark.ml.feature import VectorAssembler, StringIndexer        
-  
-spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate()        
-  
-schema = StructType([        
-    StructField("customerID", StringType(), True),        
-    StructField("gender", StringType(), True),        
-    StructField("SeniorCitizen", StringType(), True),        
-    StructField("Partner", StringType(), True),        
-    StructField("Dependents", StringType(), True),        
-    StructField("tenure", DoubleType(), True),        
-    StructField("PhoneService", StringType(), True),        
-    StructField("MultipleLines", StringType(), True),        
-    StructField("InternetService", StringType(), True),        
-    StructField("OnlineSecurity", StringType(), True),        
-    StructField("OnlineBackup", StringType(), True),        
-    StructField("DeviceProtection", StringType(), True),        
-    StructField("TechSupport", StringType(), True),        
-    StructField("StreamingTV", StringType(), True),        
-    StructField("StreamingMovies", StringType(), True),        
-    StructField("Contract", StringType(), True),        
-    StructField("PaperlessBilling", StringType(), True),        
-    StructField("PaymentMethod", StringType(), True),        
-    StructField("MonthlyCharges", DoubleType(), True),        
-    StructField("TotalCharges", DoubleType(), True),        
-    StructField("Churn", StringType(), True)        
-])        
-  
-original_data_path = "dataset/cleaned_data.csv"        
-original_data_df = spark.read.csv(original_data_path, header=True, schema=schema)        
-  
-for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']:        
-    if original_data_df.filter(isnull(col(column))).count() > 0:        
-        print(f"Null values found in {column}")        
-  
-original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))        
-  
-data_drift_df = original_data_df.sample(fraction=0.2, seed=42)        
-  
-if data_drift_df.count() == 0:        
-    print("data_drift_df is empty after sampling.")        
-  
-data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType()))        
-  
-data_drift_df = data_drift_df.withColumn("customerID", monotonically_increasing_id().cast(StringType()))    
-  
-# Introduce significant changes to the new data to drop accuracy  
-data_drift_df = data_drift_df.withColumn("MonthlyCharges", col("MonthlyCharges") * (1 + rand() * 0.5))    
-data_drift_df = data_drift_df.withColumn("TotalCharges", col("TotalCharges") * (1 + rand() * 0.5))    
-data_drift_df = data_drift_df.withColumn("Churn", when(rand() > 0.5, 1).otherwise(0))    
-  
-model_path = 'logistic_regression_model'        
-loaded_model = LogisticRegressionModel.load(model_path)        
-  
-categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']        
-  
-for column in categorical_columns:        
-    indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep")        
-    data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df)        
-  
-for column in categorical_columns:      
-    if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0:      
-        print(f"Null values found in {column + '_index'}")      
-  
-feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges']        
-assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')        
-  
-data = assembler.transform(data_drift_df)        
-  
-predictions = loaded_model.transform(data)        
-  
-evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')        
-accuracy = evaluator.evaluate(predictions)        
-  
-print(f"Accuracy after data drift: {accuracy * 100:.2f}%")        
-  
-# Select only the original columns for merging  
-data_drift_df = data_drift_df.select("customerID", "gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", "MonthlyCharges", "TotalCharges", "Churn")  
-  
-# Merge old and new data  
-merged_data_df = original_data_df.union(data_drift_df)  
-  
-output_path = "dataset/data_drifted_customers.csv"    
-merged_data_df.write.csv(output_path, header=True, mode='overwrite')    
-  
-if accuracy < 0.70:        
-    print("Accuracy is below 70%. Triggering retraining of the model.")        
-    from airflow.models import DagBag        
-    dag_bag = DagBag()        
-    sparking_flow_dag = dag_bag.get_dag('sparking_flow')        
-            
-    if sparking_flow_dag:        
-        sparking_flow_dag.clear()     
-        sparking_flow_dag.run()        
-        print("Successfully triggered sparking_flow DAG.")        
-    else:        
-        print("sparking_flow DAG not found.")        
-else:        
-    print("Accuracy is acceptable.")        
-  
-spark.stop()        
+from pyspark.sql import SparkSession              
+from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType              
+from pyspark.sql.functions import when, col, isnull, monotonically_increasing_id, rand              
+from pyspark.ml.classification import LogisticRegressionModel              
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator              
+from pyspark.ml.feature import VectorAssembler, StringIndexer              
+from airflow import DAG    
+from airflow.operators.python import PythonOperator    
+import json    
+import os  
+import datetime  
+  
+spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate()              
+  
+schema = StructType([              
+    StructField("customerID", StringType(), True),              
+    StructField("gender", StringType(), True),              
+    StructField("SeniorCitizen", StringType(), True),              
+    StructField("Partner", StringType(), True),              
+    StructField("Dependents", StringType(), True),              
+    StructField("tenure", DoubleType(), True),              
+    StructField("PhoneService", StringType(), True),              
+    StructField("MultipleLines", StringType(), True),              
+    StructField("InternetService", StringType(), True),              
+    StructField("OnlineSecurity", StringType(), True),              
+    StructField("OnlineBackup", StringType(), True),              
+    StructField("DeviceProtection", StringType(), True),              
+    StructField("TechSupport", StringType(), True),              
+    StructField("StreamingTV", StringType(), True),              
+    StructField("StreamingMovies", StringType(), True),              
+    StructField("Contract", StringType(), True),              
+    StructField("PaperlessBilling", StringType(), True),              
+    StructField("PaymentMethod", StringType(), True),              
+    StructField("MonthlyCharges", DoubleType(), True),              
+    StructField("TotalCharges", DoubleType(), True),              
+    StructField("Churn", StringType(), True)              
+])              
+  
+original_data_path = "dataset/cleaned_data.csv"              
+original_data_df = spark.read.csv(original_data_path, header=True, schema=schema)              
+  
+for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']:              
+    if original_data_df.filter(isnull(col(column))).count() > 0:              
+        print(f"Null values found in {column}")              
+  
+original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))              
+  
+data_drift_df = original_data_df.sample(fraction=0.2, seed=42)              
+  
+if data_drift_df.count() == 0:              
+    print("data_drift_df is empty after sampling.")              
+  
+data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType()))              
+data_drift_df = data_drift_df.withColumn("customerID", monotonically_increasing_id().cast(StringType()))          
+data_drift_df = data_drift_df.withColumn("MonthlyCharges", col("MonthlyCharges") * (1 + rand() * 0.5))          
+data_drift_df = data_drift_df.withColumn("TotalCharges", col("TotalCharges") * (1 + rand() * 0.5))          
+data_drift_df = data_drift_df.withColumn("Churn", when(rand() > 0.5, 1).otherwise(0))          
+  
+model_base_path = 'logistic_regression_model'  
+latest_model_path = sorted([os.path.join(model_base_path, d) for d in os.listdir(model_base_path)], key=os.path.getmtime)[-1]  
+loaded_model = LogisticRegressionModel.load(latest_model_path)              
+  
+categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']              
+  
+for column in categorical_columns:              
+    indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep")              
+    data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df)              
+  
+for column in categorical_columns:            
+    if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0:            
+        print(f"Null values found in {column + '_index'}")            
+      
+feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges']              
+assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')              
+data = assembler.transform(data_drift_df)              
+  
+predictions = loaded_model.transform(data)              
+  
+evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')              
+accuracy = evaluator.evaluate(predictions)              
+  
+print(f"Accuracy after data drift: {accuracy * 100:.2f}%")              
+  
+data_drift_df = data_drift_df.select("customerID", "gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", "MonthlyCharges", "TotalCharges", "Churn")        
+  
+merged_data_df = original_data_df.union(data_drift_df)        
+  
+output_path = "dataset/data_drifted_customers.csv"          
+merged_data_df.write.csv(output_path, header=True, mode='overwrite')          
+  
+print(f"Accuracy: {accuracy}")    
+from airflow.models import Variable    
+Variable.set("data_drift_accuracy", json.dumps({"accuracy": accuracy}))    
+  
+spark.stop()  
diff --git a/jobs/python/logistic_regression.py b/jobs/python/logistic_regression.py
index 697c65b..f2861be 100644
--- a/jobs/python/logistic_regression.py
+++ b/jobs/python/logistic_regression.py
@@ -5,6 +5,7 @@ from pyspark.ml.evaluation import MulticlassClassificationEvaluator
 from pyspark.ml.classification import LogisticRegressionModel
 import os
 import shutil
+import datetime
 
 spark = SparkSession.builder.master("local").appName("LogisticRegressionJob").getOrCreate()
 
@@ -30,14 +31,17 @@ train_data, test_data = data.randomSplit([0.80, 0.20], seed=135)
 lr = LogisticRegression(featuresCol='features', labelCol='Churn', maxIter=2000)
 model = lr.fit(train_data)
 
-model_path = 'logistic_regression_model'
+model_base_path = 'logistic_regression_model'  
+timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")  
+model_path = f"{model_base_path}/version_{timestamp}" 
 
 if os.path.exists(model_path):
     shutil.rmtree(model_path)
 
 model.save(model_path)
 
-loaded_model = LogisticRegressionModel.load(model_path)
+latest_model_path = sorted([os.path.join(model_base_path, d) for d in os.listdir(model_base_path)], key=os.path.getmtime)[-1]  
+loaded_model = LogisticRegressionModel.load(latest_model_path)
 
 predictions = loaded_model.transform(test_data)
 
-- 
GitLab