diff --git a/jobs/python/data_drift.py b/jobs/python/data_drift.py
index 8adcdbe0cea7146f53bbbda85481e3e20bb2e78a..bd6a7a0569f9bb10d493dbf26498d8e7faef87d4 100644
--- a/jobs/python/data_drift.py
+++ b/jobs/python/data_drift.py
@@ -1,90 +1,106 @@
-from pyspark.sql import SparkSession      
-from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType      
-from pyspark.sql.functions import when, col, isnull      
-from pyspark.ml.classification import LogisticRegressionModel      
-from pyspark.ml.evaluation import MulticlassClassificationEvaluator      
-from pyspark.ml.feature import VectorAssembler, StringIndexer      
-  
-spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate()      
-  
-schema = StructType([      
-    StructField("customerID", StringType(), True),      
-    StructField("gender", StringType(), True),      
-    StructField("SeniorCitizen", StringType(), True),      
-    StructField("Partner", StringType(), True),      
-    StructField("Dependents", StringType(), True),      
-    StructField("tenure", DoubleType(), True),      
-    StructField("PhoneService", StringType(), True),      
-    StructField("MultipleLines", StringType(), True),      
-    StructField("InternetService", StringType(), True),      
-    StructField("OnlineSecurity", StringType(), True),      
-    StructField("OnlineBackup", StringType(), True),      
-    StructField("DeviceProtection", StringType(), True),      
-    StructField("TechSupport", StringType(), True),      
-    StructField("StreamingTV", StringType(), True),      
-    StructField("StreamingMovies", StringType(), True),      
-    StructField("Contract", StringType(), True),      
-    StructField("PaperlessBilling", StringType(), True),      
-    StructField("PaymentMethod", StringType(), True),      
-    StructField("MonthlyCharges", DoubleType(), True),      
-    StructField("TotalCharges", DoubleType(), True),      
-    StructField("Churn", StringType(), True)      
-])      
-  
-original_data_path = "dataset/cleaned_data.csv"      
-original_data_df = spark.read.csv(original_data_path, header=True, schema=schema)      
-  
-for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']:      
-    if original_data_df.filter(isnull(col(column))).count() > 0:      
-        print(f"Null values found in {column}")      
-  
-original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))      
-  
-data_drift_df = original_data_df.sample(fraction=0.2, seed=42)      
-  
-if data_drift_df.count() == 0:      
-    print("data_drift_df is empty after sampling.")      
-  
-data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType()))      
-  
-model_path = 'logistic_regression_model'      
-loaded_model = LogisticRegressionModel.load(model_path)      
-  
-categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']      
+from pyspark.sql import SparkSession        
+from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType        
+from pyspark.sql.functions import when, col, isnull, monotonically_increasing_id, rand        
+from pyspark.ml.classification import LogisticRegressionModel        
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator        
+from pyspark.ml.feature import VectorAssembler, StringIndexer        
+  
+spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate()        
+  
+schema = StructType([        
+    StructField("customerID", StringType(), True),        
+    StructField("gender", StringType(), True),        
+    StructField("SeniorCitizen", StringType(), True),        
+    StructField("Partner", StringType(), True),        
+    StructField("Dependents", StringType(), True),        
+    StructField("tenure", DoubleType(), True),        
+    StructField("PhoneService", StringType(), True),        
+    StructField("MultipleLines", StringType(), True),        
+    StructField("InternetService", StringType(), True),        
+    StructField("OnlineSecurity", StringType(), True),        
+    StructField("OnlineBackup", StringType(), True),        
+    StructField("DeviceProtection", StringType(), True),        
+    StructField("TechSupport", StringType(), True),        
+    StructField("StreamingTV", StringType(), True),        
+    StructField("StreamingMovies", StringType(), True),        
+    StructField("Contract", StringType(), True),        
+    StructField("PaperlessBilling", StringType(), True),        
+    StructField("PaymentMethod", StringType(), True),        
+    StructField("MonthlyCharges", DoubleType(), True),        
+    StructField("TotalCharges", DoubleType(), True),        
+    StructField("Churn", StringType(), True)        
+])        
+  
+original_data_path = "dataset/cleaned_data.csv"        
+original_data_df = spark.read.csv(original_data_path, header=True, schema=schema)        
+  
+for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']:        
+    if original_data_df.filter(isnull(col(column))).count() > 0:        
+        print(f"Null values found in {column}")        
+  
+original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))        
+  
+data_drift_df = original_data_df.sample(fraction=0.2, seed=42)        
+  
+if data_drift_df.count() == 0:        
+    print("data_drift_df is empty after sampling.")        
+  
+data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType()))        
+  
+data_drift_df = data_drift_df.withColumn("customerID", monotonically_increasing_id().cast(StringType()))    
+  
+# Introduce significant changes to the new data to drop accuracy  
+data_drift_df = data_drift_df.withColumn("MonthlyCharges", col("MonthlyCharges") * (1 + rand() * 0.5))    
+data_drift_df = data_drift_df.withColumn("TotalCharges", col("TotalCharges") * (1 + rand() * 0.5))    
+data_drift_df = data_drift_df.withColumn("Churn", when(rand() > 0.5, 1).otherwise(0))    
+  
+model_path = 'logistic_regression_model'        
+loaded_model = LogisticRegressionModel.load(model_path)        
+  
+categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']        
+  
+for column in categorical_columns:        
+    indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep")        
+    data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df)        
   
 for column in categorical_columns:      
-    indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep")      
-    data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df)      
+    if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0:      
+        print(f"Null values found in {column + '_index'}")      
   
-for column in categorical_columns:    
-    if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0:    
-        print(f"Null values found in {column + '_index'}")    
+feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges']        
+assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')        
   
-feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges']      
-assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')      
+data = assembler.transform(data_drift_df)        
   
-data = assembler.transform(data_drift_df)      
+predictions = loaded_model.transform(data)        
   
-predictions = loaded_model.transform(data)      
+evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')        
+accuracy = evaluator.evaluate(predictions)        
   
-evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')      
-accuracy = evaluator.evaluate(predictions)      
+print(f"Accuracy after data drift: {accuracy * 100:.2f}%")        
   
-print(f"Accuracy after data drift: {accuracy * 100:.2f}%")      
+# Select only the original columns for merging  
+data_drift_df = data_drift_df.select("customerID", "gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", "MonthlyCharges", "TotalCharges", "Churn")  
   
-if accuracy < 0.70:      
-    print("Accuracy is below 70%. Triggering retraining of the model.")      
-    from airflow.models import DagBag      
-    dag_bag = DagBag()      
-    sparking_flow_dag = dag_bag.get_dag('sparking_flow')      
-          
-    if sparking_flow_dag:      
-        sparking_flow_dag.clear()   
-        sparking_flow_dag.run()      
-        print("Successfully triggered sparking_flow DAG.")      
-    else:      
-        print("sparking_flow DAG not found.")      
-else:      
-    print("Accuracy is acceptable.")      
+# Merge old and new data  
+merged_data_df = original_data_df.union(data_drift_df)  
   
-spark.stop()      
+output_path = "dataset/data_drifted_customers.csv"    
+merged_data_df.write.csv(output_path, header=True, mode='overwrite')    
+  
+if accuracy < 0.70:        
+    print("Accuracy is below 70%. Triggering retraining of the model.")        
+    from airflow.models import DagBag        
+    dag_bag = DagBag()        
+    sparking_flow_dag = dag_bag.get_dag('sparking_flow')        
+            
+    if sparking_flow_dag:        
+        sparking_flow_dag.clear()     
+        sparking_flow_dag.run()        
+        print("Successfully triggered sparking_flow DAG.")        
+    else:        
+        print("sparking_flow DAG not found.")        
+else:        
+    print("Accuracy is acceptable.")        
+  
+spark.stop()