diff --git a/spark/drift_check.py b/spark/drift_check.py index c8ccc7acc1784185f384e0cfc284632a0be41cec..e1208c0df557ef83814ab5db4aa6f31505b75bfb 100644 --- a/spark/drift_check.py +++ b/spark/drift_check.py @@ -1,4 +1,5 @@ import pandas as pd +import numpy as np import glob from pyspark.sql import SparkSession @@ -20,6 +21,25 @@ def get_latest_versions(base_path, base_name): return None, None return version_numbers[-1], version_numbers[-2] +# Function to calculate PSI (Population Stability Index) +def calculate_psi(base_distribution, new_distribution, bins=10): + psi = 0 + for base_bin, new_bin in zip(base_distribution, new_distribution): + if base_bin > 0 and new_bin > 0: # Avoid division by zero and log(0) + psi += (new_bin - base_bin) * np.log(new_bin / base_bin) + return psi + +def calculate_column_psi(latest_col, previous_col, bins=10): + # Create bins for both distributions + bin_edges = pd.cut(pd.concat([latest_col, previous_col]), bins=bins, retbins=True)[1] + + # Calculate distributions for each dataset + latest_distribution = pd.cut(latest_col, bins=bin_edges).value_counts(normalize=True, sort=False) + previous_distribution = pd.cut(previous_col, bins=bin_edges).value_counts(normalize=True, sort=False) + + # Calculate PSI + return calculate_psi(previous_distribution, latest_distribution) + base_path = "/opt/data/versioning/" base_name = "clean" drift_result_file = "/opt/data/drift_result.txt" @@ -42,27 +62,21 @@ else: latest_df = pd.read_csv(latest_file) previous_df = pd.read_csv(previous_file) - # Example of calculating statistics (mean and std for numeric columns) - latest_stats = latest_df.describe().to_dict() - previous_stats = previous_df.describe().to_dict() - - # Validate mean and std values before calculating drift metric drift_detected = False - for column in latest_stats: - latest_mean = latest_stats[column].get("mean") - previous_mean = previous_stats[column].get("mean") - previous_std = previous_stats[column].get("std") + psi_threshold = 0.1 # Example threshold for PSI (adjust as necessary) - if latest_mean is not None and previous_mean is not None: - # Avoid division by zero (set a minimum value for std) - previous_std = previous_std or 1 + # Iterate over numeric columns to calculate PSI + for column in latest_df.select_dtypes(include=['number']).columns: + latest_col = latest_df[column].dropna() + previous_col = previous_df[column].dropna() - # Calculate drift metric using Python's abs() function - drift_metric = abs(latest_mean - previous_mean) / previous_std - print(f"Drift metric for column '{column}': {drift_metric}") + # Ensure both columns are non-empty + if not latest_col.empty and not previous_col.empty: + column_psi = calculate_column_psi(latest_col, previous_col) + print(f"PSI for column '{column}': {column_psi}") - # Check for significant drift (threshold can be adjusted) - if drift_metric > 0.5: # Example threshold for drift detection + # Check if PSI exceeds the threshold + if column_psi > psi_threshold: drift_detected = True break