diff --git a/spark/drift_check.py b/spark/drift_check.py
index c8ccc7acc1784185f384e0cfc284632a0be41cec..e1208c0df557ef83814ab5db4aa6f31505b75bfb 100644
--- a/spark/drift_check.py
+++ b/spark/drift_check.py
@@ -1,4 +1,5 @@
 import pandas as pd
+import numpy as np
 import glob
 from pyspark.sql import SparkSession
 
@@ -20,6 +21,25 @@ def get_latest_versions(base_path, base_name):
         return None, None
     return version_numbers[-1], version_numbers[-2]
 
+# Function to calculate PSI (Population Stability Index)
+def calculate_psi(base_distribution, new_distribution, bins=10):
+    psi = 0
+    for base_bin, new_bin in zip(base_distribution, new_distribution):
+        if base_bin > 0 and new_bin > 0:  # Avoid division by zero and log(0)
+            psi += (new_bin - base_bin) * np.log(new_bin / base_bin)
+    return psi
+
+def calculate_column_psi(latest_col, previous_col, bins=10):
+    # Create bins for both distributions
+    bin_edges = pd.cut(pd.concat([latest_col, previous_col]), bins=bins, retbins=True)[1]
+    
+    # Calculate distributions for each dataset
+    latest_distribution = pd.cut(latest_col, bins=bin_edges).value_counts(normalize=True, sort=False)
+    previous_distribution = pd.cut(previous_col, bins=bin_edges).value_counts(normalize=True, sort=False)
+    
+    # Calculate PSI
+    return calculate_psi(previous_distribution, latest_distribution)
+
 base_path = "/opt/data/versioning/"
 base_name = "clean"
 drift_result_file = "/opt/data/drift_result.txt"
@@ -42,27 +62,21 @@ else:
     latest_df = pd.read_csv(latest_file)
     previous_df = pd.read_csv(previous_file)
 
-    # Example of calculating statistics (mean and std for numeric columns)
-    latest_stats = latest_df.describe().to_dict()
-    previous_stats = previous_df.describe().to_dict()
-
-    # Validate mean and std values before calculating drift metric
     drift_detected = False
-    for column in latest_stats:
-        latest_mean = latest_stats[column].get("mean")
-        previous_mean = previous_stats[column].get("mean")
-        previous_std = previous_stats[column].get("std")
+    psi_threshold = 0.1  # Example threshold for PSI (adjust as necessary)
 
-        if latest_mean is not None and previous_mean is not None:
-            # Avoid division by zero (set a minimum value for std)
-            previous_std = previous_std or 1
+    # Iterate over numeric columns to calculate PSI
+    for column in latest_df.select_dtypes(include=['number']).columns:
+        latest_col = latest_df[column].dropna()
+        previous_col = previous_df[column].dropna()
 
-            # Calculate drift metric using Python's abs() function
-            drift_metric = abs(latest_mean - previous_mean) / previous_std
-            print(f"Drift metric for column '{column}': {drift_metric}")
+        # Ensure both columns are non-empty
+        if not latest_col.empty and not previous_col.empty:
+            column_psi = calculate_column_psi(latest_col, previous_col)
+            print(f"PSI for column '{column}': {column_psi}")
 
-            # Check for significant drift (threshold can be adjusted)
-            if drift_metric > 0.5:  # Example threshold for drift detection
+            # Check if PSI exceeds the threshold
+            if column_psi > psi_threshold:
                 drift_detected = True
                 break