diff --git a/dags/prediction.py b/dags/prediction.py
index 54a91e2d6399455d02643c3d68ba656bb316d16b..bf3b16e03d1f8e4dd7ad93c17ea913e258d444e0 100644
--- a/dags/prediction.py
+++ b/dags/prediction.py
@@ -2,92 +2,112 @@ from pyspark.sql import SparkSession
 from pyspark.ml import Pipeline
 from pyspark.sql.functions import col, when, isnan, trim, regexp_replace, lower
 from pyspark.sql.types import DoubleType, IntegerType, StringType
-from pyspark.sql.functions import col, udf
-from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder
-from pyspark.ml.linalg import SparseVector, DenseVector
+from pyspark.sql.functions import col, udf, log
+from pyspark.ml.feature import StringIndexer, VectorAssembler
 import mlflow.spark
 import shutil
 import os
 from data_cleanup import clean_data
+import math
 
 def init_spark():
     return SparkSession.builder.appName("Churn Prediction").getOrCreate()
 
+def calculate_psi(base_df, target_df, column, bins=10):
+    # Binning pada kolom
+    bin_edges = base_df.select(column).rdd.flatMap(lambda x: x).histogram(bins)[0]
+
+    # Hitung distribusi untuk data baseline
+    base_hist = base_df.select(column).rdd.flatMap(lambda x: x).histogram(bin_edges)
+    base_counts = base_hist[1]
+    base_percents = [count / sum(base_counts) for count in base_counts]
+
+    # Hitung distribusi untuk data baru
+    target_hist = target_df.select(column).rdd.flatMap(lambda x: x).histogram(bin_edges)
+    target_counts = target_hist[1]
+    target_percents = [count / sum(target_counts) for count in target_counts]
+
+    # Hindari pembagian nol
+    base_percents = [max(p, 1e-6) for p in base_percents]
+    target_percents = [max(p, 1e-6) for p in target_percents]
+
+    # Hitung PSI untuk setiap bin
+    psi = sum((bp - tp) * math.log(bp / tp) for bp, tp in zip(base_percents, target_percents))
+    return psi
+
+# Function to make predictions and calculate PSI
 def predict(path_cleaned_data):  
     # Set tracking URI to the MLflow server  
     mlflow.set_tracking_uri('http://localhost:5000')  
-  
+
     # Mulai pelacakan MLflow  
     mlflow.start_run()  
-  
+
     spark = SparkSession.builder.appName("CustomerChurn").getOrCreate()  
-      
-    # Baca data  
-    cleaned_data = spark.read.csv(path_cleaned_data, header=True, inferSchema=True)  
-  
-    # Menampilkan data untuk memastikan format  
-    cleaned_data.show()  
-  
-    # Tentukan kolom string dan kolom numerik  
+
+    # Baca data baru untuk prediksi
+    cleaned_data = spark.read.csv(path_cleaned_data, header=True, inferSchema=True)
+
+    # Load baseline data (data yang digunakan untuk training model)
+    baseline_data = spark.read.csv("telco_customer_churn_transformed.csv", header=True, inferSchema=True)
+
+    # Tentukan kolom string dan kolom numerik
     categorical_columns = ["gender", "Partner", "Dependents", "PhoneService", "MultipleLines",  
                            "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection",  
                            "TechSupport", "StreamingTV", "StreamingMovies", "Contract",  
                            "PaperlessBilling", "PaymentMethod"]  
     numeric_columns = ["SeniorCitizen", "tenure", "MonthlyCharges", "TotalCharges"]  
-  
-    # Konversi target label 'Churn' menjadi numerik  
+
+    # Konversi target label 'Churn' menjadi numerik
     label_indexer = StringIndexer(inputCol="Churn", outputCol="label")  
-    indexed_data = label_indexer.fit(cleaned_data).transform(cleaned_data)  
-  
-    # Konversi kolom string menjadi indeks numerik menggunakan StringIndexer  
+    indexed_data = label_indexer.fit(cleaned_data).transform(cleaned_data)
+    base_data = label_indexer.fit(baseline_data).transform(baseline_data)
+
+    # Hitung PSI untuk variabel numerik
+    numeric_columns = ["label"]
+    psi_results = {}
+    for column in numeric_columns:
+        psi_value = calculate_psi(base_data, indexed_data, column)
+        psi_results[column] = psi_value
+        print(f"PSI for {column}: {psi_value}")
+
+    # Log PSI ke MLflow
+    for col_name, psi_val in psi_results.items():
+        mlflow.log_metric(f"psi_{col_name}", psi_val)
+
+    # Konversi kolom string menjadi indeks numerik menggunakan StringIndexer
     indexers = [  
         StringIndexer(inputCol=col, outputCol=col + "_indexed").setHandleInvalid("skip")  
         for col in categorical_columns  
     ]  
-  
-    # Assemble kolom fitur (numerik dan hasil StringIndexer) menjadi satu kolom `features`  
+
+    # Assemble kolom fitur menjadi satu kolom `features`
     feature_columns = [col + "_indexed" for col in categorical_columns] + numeric_columns  
     assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")  
-  
-    # Model Logistic Regression  
-    # lr = LogisticRegression(labelCol="label", featuresCol="features")  
-  
-    # Pipeline untuk melakukan transformasi data dan training model  
-    # pipeline = Pipeline(stages=indexers + [assembler, lr])  
-  
-    # Training model  
-    model = mlflow.spark.load_model("runs:/ee8d9e5282da46558784493539696b9d/logistic_regression_model")
-
-    # Simpan model menggunakan MLflow  
-    # Mencatat parameter dan metrik  
-    # Jika Anda ingin mencatat metrik, misalnya akurasi, Anda perlu menghitungnya  
+
+    # Load model dari MLflow
+    model = mlflow.spark.load_model("runs:/f0acd65c02534b2197a220485e179aba/logistic_regression_model")
+
+    # Prediksi
     predictions = model.transform(indexed_data)  
+
+    # Hitung akurasi model
     accuracy = predictions.filter(predictions.label == predictions.prediction).count() / float(predictions.count())  
     mlflow.log_metric("accuracy", accuracy)  
-    print(f"Accuracy: {accuracy}")
-  
-    # Akhiri pelacakan MLflow  
-    mlflow.end_run()  
+    print(f"Accuracy: {accuracy}")  
+
+    # Akhiri pelacakan MLflow
+    mlflow.end_run()
 
 def main():
     try:
         # Initialize Spark
-        # Clean data
         print("Cleaning data...")
         cleaned_path = clean_data("telco_customer_churn_drift.csv")
         predict(cleaned_path)
-        # Prepare features
-        
-        # Save predictions
-        # output_path = "path/to/output/churn_predictions.csv"
-        # print(f"\nSaving predictions to {output_path}")
-        # final_predictions.write.csv(output_path,
-        #                           header=True,
-        #                           mode="overwrite")
-        
     except Exception as e:
         print(f"Error in pipeline: {str(e)}")
         raise e
 
 if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()