Skip to content
Snippets Groups Projects
Commit b0f95b17 authored by Hidayatullah Wildan Ghaly Buchary's avatar Hidayatullah Wildan Ghaly Buchary
Browse files

chore: data drift can finally trigger the sparking flow BUT MULTIPLE TIMES AT ONCE

parent a752cd0e
Branches
2 merge requests!10chore: add model version control and finally can trigger the sparking flow...,!5Complete the assignment
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
from pyspark.sql.functions import when, col, isnull
from pyspark.ml.classification import LogisticRegressionModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, StringIndexer
spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate()
schema = StructType([
StructField("customerID", StringType(), True),
StructField("gender", StringType(), True),
StructField("SeniorCitizen", StringType(), True),
StructField("Partner", StringType(), True),
StructField("Dependents", StringType(), True),
StructField("tenure", DoubleType(), True),
StructField("PhoneService", StringType(), True),
StructField("MultipleLines", StringType(), True),
StructField("InternetService", StringType(), True),
StructField("OnlineSecurity", StringType(), True),
StructField("OnlineBackup", StringType(), True),
StructField("DeviceProtection", StringType(), True),
StructField("TechSupport", StringType(), True),
StructField("StreamingTV", StringType(), True),
StructField("StreamingMovies", StringType(), True),
StructField("Contract", StringType(), True),
StructField("PaperlessBilling", StringType(), True),
StructField("PaymentMethod", StringType(), True),
StructField("MonthlyCharges", DoubleType(), True),
StructField("TotalCharges", DoubleType(), True),
StructField("Churn", StringType(), True)
])
original_data_path = "dataset/cleaned_data.csv"
original_data_df = spark.read.csv(original_data_path, header=True, schema=schema)
for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']:
if original_data_df.filter(isnull(col(column))).count() > 0:
print(f"Null values found in {column}")
original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))
data_drift_df = original_data_df.sample(fraction=0.2, seed=42)
if data_drift_df.count() == 0:
print("data_drift_df is empty after sampling.")
data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType()))
model_path = 'logistic_regression_model'
loaded_model = LogisticRegressionModel.load(model_path)
categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
from pyspark.sql.functions import when, col, isnull, monotonically_increasing_id, rand
from pyspark.ml.classification import LogisticRegressionModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, StringIndexer
spark = SparkSession.builder.master("local").appName("DataDriftSimulation").getOrCreate()
schema = StructType([
StructField("customerID", StringType(), True),
StructField("gender", StringType(), True),
StructField("SeniorCitizen", StringType(), True),
StructField("Partner", StringType(), True),
StructField("Dependents", StringType(), True),
StructField("tenure", DoubleType(), True),
StructField("PhoneService", StringType(), True),
StructField("MultipleLines", StringType(), True),
StructField("InternetService", StringType(), True),
StructField("OnlineSecurity", StringType(), True),
StructField("OnlineBackup", StringType(), True),
StructField("DeviceProtection", StringType(), True),
StructField("TechSupport", StringType(), True),
StructField("StreamingTV", StringType(), True),
StructField("StreamingMovies", StringType(), True),
StructField("Contract", StringType(), True),
StructField("PaperlessBilling", StringType(), True),
StructField("PaymentMethod", StringType(), True),
StructField("MonthlyCharges", DoubleType(), True),
StructField("TotalCharges", DoubleType(), True),
StructField("Churn", StringType(), True)
])
original_data_path = "dataset/cleaned_data.csv"
original_data_df = spark.read.csv(original_data_path, header=True, schema=schema)
for column in ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']:
if original_data_df.filter(isnull(col(column))).count() > 0:
print(f"Null values found in {column}")
original_data_df = original_data_df.withColumn("SeniorCitizen", col("SeniorCitizen").cast(IntegerType()))
data_drift_df = original_data_df.sample(fraction=0.2, seed=42)
if data_drift_df.count() == 0:
print("data_drift_df is empty after sampling.")
data_drift_df = data_drift_df.withColumn('Churn', when(data_drift_df['Churn'] == 'Yes', 1).otherwise(0).cast(IntegerType()))
data_drift_df = data_drift_df.withColumn("customerID", monotonically_increasing_id().cast(StringType()))
# Introduce significant changes to the new data to drop accuracy
data_drift_df = data_drift_df.withColumn("MonthlyCharges", col("MonthlyCharges") * (1 + rand() * 0.5))
data_drift_df = data_drift_df.withColumn("TotalCharges", col("TotalCharges") * (1 + rand() * 0.5))
data_drift_df = data_drift_df.withColumn("Churn", when(rand() > 0.5, 1).otherwise(0))
model_path = 'logistic_regression_model'
loaded_model = LogisticRegressionModel.load(model_path)
categorical_columns = ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']
for column in categorical_columns:
indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep")
data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df)
for column in categorical_columns:
indexer = StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid="keep")
data_drift_df = indexer.fit(data_drift_df).transform(data_drift_df)
if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0:
print(f"Null values found in {column + '_index'}")
for column in categorical_columns:
if data_drift_df.filter(isnull(col(column + "_index"))).count() > 0:
print(f"Null values found in {column + '_index'}")
feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges']
assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')
feature_columns = [col + "_index" for col in categorical_columns] + ['tenure', 'MonthlyCharges', 'TotalCharges']
assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')
data = assembler.transform(data_drift_df)
data = assembler.transform(data_drift_df)
predictions = loaded_model.transform(data)
predictions = loaded_model.transform(data)
evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')
accuracy = evaluator.evaluate(predictions)
evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')
accuracy = evaluator.evaluate(predictions)
print(f"Accuracy after data drift: {accuracy * 100:.2f}%")
print(f"Accuracy after data drift: {accuracy * 100:.2f}%")
# Select only the original columns for merging
data_drift_df = data_drift_df.select("customerID", "gender", "SeniorCitizen", "Partner", "Dependents", "tenure", "PhoneService", "MultipleLines", "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies", "Contract", "PaperlessBilling", "PaymentMethod", "MonthlyCharges", "TotalCharges", "Churn")
if accuracy < 0.70:
print("Accuracy is below 70%. Triggering retraining of the model.")
from airflow.models import DagBag
dag_bag = DagBag()
sparking_flow_dag = dag_bag.get_dag('sparking_flow')
if sparking_flow_dag:
sparking_flow_dag.clear()
sparking_flow_dag.run()
print("Successfully triggered sparking_flow DAG.")
else:
print("sparking_flow DAG not found.")
else:
print("Accuracy is acceptable.")
# Merge old and new data
merged_data_df = original_data_df.union(data_drift_df)
spark.stop()
output_path = "dataset/data_drifted_customers.csv"
merged_data_df.write.csv(output_path, header=True, mode='overwrite')
if accuracy < 0.70:
print("Accuracy is below 70%. Triggering retraining of the model.")
from airflow.models import DagBag
dag_bag = DagBag()
sparking_flow_dag = dag_bag.get_dag('sparking_flow')
if sparking_flow_dag:
sparking_flow_dag.clear()
sparking_flow_dag.run()
print("Successfully triggered sparking_flow DAG.")
else:
print("sparking_flow DAG not found.")
else:
print("Accuracy is acceptable.")
spark.stop()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment