Skip to content
Snippets Groups Projects
Commit a752cd0e authored by Hidayatullah Wildan Ghaly Buchary's avatar Hidayatullah Wildan Ghaly Buchary
Browse files

feat: add data drift simulation test (basic)

parent 7e2ac728
Branches
2 merge requests!10chore: add model version control and finally can trigger the sparking flow...,!5Complete the assignment
import airflow
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
dag = DAG(
dag_id="data_drift_simulation",
default_args={
"owner": "admin",
"start_date": airflow.utils.dates.days_ago(1)
},
schedule_interval="@daily"
)
start = PythonOperator(
task_id="start",
python_callable=lambda: print("Data Drift Simulation started"),
dag=dag
)
data_drift_job = SparkSubmitOperator(
task_id="data_drift",
conn_id="spark_conn",
application="jobs/python/data_drift.py",
dag=dag,
)
trigger_sparking_flow = PythonOperator(
task_id="trigger_sparking_flow",
python_callable=lambda: print("Triggering sparking_flow DAG"),
dag=dag,
)
end = PythonOperator(
task_id="end",
python_callable=lambda: print("Data Drift Simulation completed successfully"),
dag=dag
)
start >> data_drift_job >> trigger_sparking_flow >> end
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
from pyspark.sql.functions import when, col, isnull
from pyspark.ml.classification import LogisticRegressionModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, StringIndexer
spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate()
schema = StructType([
StructField("customerID", StringType(), True),
StructField("gender", StringType(), True),
StructField("SeniorCitizen", StringType(), True),
StructField("Partner", StringType(), True),
StructField("Dependents", StringType(), True),
StructField("tenure", DoubleType(), True),
StructField("PhoneService", StringType(), True),
StructField("MultipleLines", StringType(), True),
StructField("InternetService", StringType(), True),
StructField("OnlineSecurity", StringType(), True),
StructField("OnlineBackup", StringType(), True),
StructField("DeviceProtection", StringType(), True),
StructField("TechSupport", StringType(), True),
StructField("StreamingTV", StringType(), True),
StructField("StreamingMovies", StringType(), True),
StructField("Contract", StringType(), True),
StructField("PaperlessBilling", StringType(), True),
StructField("PaymentMethod", StringType(), True),
StructField("MonthlyCharges", DoubleType(), True),
StructField("TotalCharges", DoubleType(), True),
StructField("Churn", StringType(), True)
])
original_data_path = "dataset/cleaned_data.csv"
original_data_df = spark.read.csv(original_data_path, header=True, schema=schema)
for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']:
if original_data_df.filter(isnull(col(column))).count() > 0:
print(f"Null values found in {column}")
original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))
data_drift_df = original_data_df.sample(fraction=0.2, seed=42)
if data_drift_df.count() == 0:
print("data_drift_df is empty after sampling.")
data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType()))
model_path = 'logistic_regression_model'
loaded_model = LogisticRegressionModel.load(model_path)
categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']
for column in categorical_columns:
indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep")
data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df)
for column in categorical_columns:
if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0:
print(f"Null values found in {column + '_index'}")
feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges']
assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')
data = assembler.transform(data_drift_df)
predictions = loaded_model.transform(data)
evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')
accuracy = evaluator.evaluate(predictions)
print(f"Accuracy after data drift: {accuracy * 100:.2f}%")
if accuracy < 0.70:
print("Accuracy is below 70%. Triggering retraining of the model.")
from airflow.models import DagBag
dag_bag = DagBag()
sparking_flow_dag = dag_bag.get_dag('sparking_flow')
if sparking_flow_dag:
sparking_flow_dag.clear()
sparking_flow_dag.run()
print("Successfully triggered sparking_flow DAG.")
else:
print("sparking_flow DAG not found.")
else:
print("Accuracy is acceptable.")
spark.stop()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment