diff --git a/dags/data_drift_simulation.py b/dags/data_drift_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..2491c37354351710ca9031c4a51f36be6b599392 --- /dev/null +++ b/dags/data_drift_simulation.py @@ -0,0 +1,40 @@ +import airflow +from airflow import DAG +from airflow.operators.python import PythonOperator +from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator + +dag = DAG( + dag_id="data_drift_simulation", + default_args={ + "owner": "admin", + "start_date": airflow.utils.dates.days_ago(1) + }, + schedule_interval="@daily" +) + +start = PythonOperator( + task_id="start", + python_callable=lambda: print("Data Drift Simulation started"), + dag=dag +) + +data_drift_job = SparkSubmitOperator( + task_id="data_drift", + conn_id="spark_conn", + application="jobs/python/data_drift.py", + dag=dag, +) + +trigger_sparking_flow = PythonOperator( + task_id="trigger_sparking_flow", + python_callable=lambda: print("Triggering sparking_flow DAG"), + dag=dag, +) + +end = PythonOperator( + task_id="end", + python_callable=lambda: print("Data Drift Simulation completed successfully"), + dag=dag +) + +start >> data_drift_job >> trigger_sparking_flow >> end diff --git a/jobs/python/data_drift.py b/jobs/python/data_drift.py new file mode 100644 index 0000000000000000000000000000000000000000..8adcdbe0cea7146f53bbbda85481e3e20bb2e78a --- /dev/null +++ b/jobs/python/data_drift.py @@ -0,0 +1,90 @@ +from pyspark.sql import SparkSession +from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType +from pyspark.sql.functions import when, col, isnull +from pyspark.ml.classification import LogisticRegressionModel +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.ml.feature import VectorAssembler, StringIndexer + +spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate() + +schema = StructType([ + StructField("customerID", StringType(), True), + StructField("gender", StringType(), True), + StructField("SeniorCitizen", StringType(), True), + StructField("Partner", StringType(), True), + StructField("Dependents", StringType(), True), + StructField("tenure", DoubleType(), True), + StructField("PhoneService", StringType(), True), + StructField("MultipleLines", StringType(), True), + StructField("InternetService", StringType(), True), + StructField("OnlineSecurity", StringType(), True), + StructField("OnlineBackup", StringType(), True), + StructField("DeviceProtection", StringType(), True), + StructField("TechSupport", StringType(), True), + StructField("StreamingTV", StringType(), True), + StructField("StreamingMovies", StringType(), True), + StructField("Contract", StringType(), True), + StructField("PaperlessBilling", StringType(), True), + StructField("PaymentMethod", StringType(), True), + StructField("MonthlyCharges", DoubleType(), True), + StructField("TotalCharges", DoubleType(), True), + StructField("Churn", StringType(), True) +]) + +original_data_path = "dataset/cleaned_data.csv" +original_data_df = spark.read.csv(original_data_path, header=True, schema=schema) + +for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']: + if original_data_df.filter(isnull(col(column))).count() > 0: + print(f"Null values found in {column}") + +original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType())) + +data_drift_df = original_data_df.sample(fraction=0.2, seed=42) + +if data_drift_df.count() == 0: + print("data_drift_df is empty after sampling.") + +data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType())) + +model_path = 'logistic_regression_model' +loaded_model = LogisticRegressionModel.load(model_path) + +categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod'] + +for column in categorical_columns: + indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep") + data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df) + +for column in categorical_columns: + if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0: + print(f"Null values found in {column + '_index'}") + +feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges'] +assembler = VectorAssembler(inputCols=feature_columns, outputCol='features') + +data = assembler.transform(data_drift_df) + +predictions = loaded_model.transform(data) + +evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy') +accuracy = evaluator.evaluate(predictions) + +print(f"Accuracy after data drift: {accuracy * 100:.2f}%") + +if accuracy < 0.70: + print("Accuracy is below 70%. Triggering retraining of the model.") + from airflow.models import DagBag + dag_bag = DagBag() + sparking_flow_dag = dag_bag.get_dag('sparking_flow') + + if sparking_flow_dag: + sparking_flow_dag.clear() + sparking_flow_dag.run() + print("Successfully triggered sparking_flow DAG.") + else: + print("sparking_flow DAG not found.") +else: + print("Accuracy is acceptable.") + +spark.stop()