diff --git a/dags/data_cleanup.py b/dags/data_cleanup.py
index a7f098b54492a5914238166da642ce8fbd4fd1d7..003101e981d29b9c230cb25eed48ae065462af72 100644
--- a/dags/data_cleanup.py
+++ b/dags/data_cleanup.py
@@ -72,13 +72,19 @@ def clean_data(file_path):
     df_transformed = df_transformed.drop("features")
     df_transformed.show()
 
+    output_path = ""
     # Hapus folder keluaran jika ada
-    output_path = "telco_customer_churn_transformed.csv"
+    if (file_path == "telco.customer_churn.csv"):
+        output_path = "telco_customer_churn_transformed.csv"
+    else:
+        output_path = "telco_customer_churn_drift_transformed.csv"
     if os.path.exists(output_path):
         shutil.rmtree(output_path)
 
     # Simpan data yang telah dibersihkan
     df_transformed.coalesce(1).write.csv(output_path, header=True, mode='overwrite')
 
+    return output_path
+
 if __name__ == "__main__":
     clean_data("telco_customer_churn.csv")
\ No newline at end of file
diff --git a/dags/prediction.py b/dags/prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..54a91e2d6399455d02643c3d68ba656bb316d16b
--- /dev/null
+++ b/dags/prediction.py
@@ -0,0 +1,93 @@
+from pyspark.sql import SparkSession
+from pyspark.ml import Pipeline
+from pyspark.sql.functions import col, when, isnan, trim, regexp_replace, lower
+from pyspark.sql.types import DoubleType, IntegerType, StringType
+from pyspark.sql.functions import col, udf
+from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder
+from pyspark.ml.linalg import SparseVector, DenseVector
+import mlflow.spark
+import shutil
+import os
+from data_cleanup import clean_data
+
+def init_spark():
+    return SparkSession.builder.appName("Churn Prediction").getOrCreate()
+
+def predict(path_cleaned_data):  
+    # Set tracking URI to the MLflow server  
+    mlflow.set_tracking_uri('http://localhost:5000')  
+  
+    # Mulai pelacakan MLflow  
+    mlflow.start_run()  
+  
+    spark = SparkSession.builder.appName("CustomerChurn").getOrCreate()  
+      
+    # Baca data  
+    cleaned_data = spark.read.csv(path_cleaned_data, header=True, inferSchema=True)  
+  
+    # Menampilkan data untuk memastikan format  
+    cleaned_data.show()  
+  
+    # Tentukan kolom string dan kolom numerik  
+    categorical_columns = ["gender", "Partner", "Dependents", "PhoneService", "MultipleLines",  
+                           "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection",  
+                           "TechSupport", "StreamingTV", "StreamingMovies", "Contract",  
+                           "PaperlessBilling", "PaymentMethod"]  
+    numeric_columns = ["SeniorCitizen", "tenure", "MonthlyCharges", "TotalCharges"]  
+  
+    # Konversi target label 'Churn' menjadi numerik  
+    label_indexer = StringIndexer(inputCol="Churn", outputCol="label")  
+    indexed_data = label_indexer.fit(cleaned_data).transform(cleaned_data)  
+  
+    # Konversi kolom string menjadi indeks numerik menggunakan StringIndexer  
+    indexers = [  
+        StringIndexer(inputCol=col, outputCol=col + "_indexed").setHandleInvalid("skip")  
+        for col in categorical_columns  
+    ]  
+  
+    # Assemble kolom fitur (numerik dan hasil StringIndexer) menjadi satu kolom `features`  
+    feature_columns = [col + "_indexed" for col in categorical_columns] + numeric_columns  
+    assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")  
+  
+    # Model Logistic Regression  
+    # lr = LogisticRegression(labelCol="label", featuresCol="features")  
+  
+    # Pipeline untuk melakukan transformasi data dan training model  
+    # pipeline = Pipeline(stages=indexers + [assembler, lr])  
+  
+    # Training model  
+    model = mlflow.spark.load_model("runs:/ee8d9e5282da46558784493539696b9d/logistic_regression_model")
+
+    # Simpan model menggunakan MLflow  
+    # Mencatat parameter dan metrik  
+    # Jika Anda ingin mencatat metrik, misalnya akurasi, Anda perlu menghitungnya  
+    predictions = model.transform(indexed_data)  
+    accuracy = predictions.filter(predictions.label == predictions.prediction).count() / float(predictions.count())  
+    mlflow.log_metric("accuracy", accuracy)  
+    print(f"Accuracy: {accuracy}")
+  
+    # Akhiri pelacakan MLflow  
+    mlflow.end_run()  
+
+def main():
+    try:
+        # Initialize Spark
+        # Clean data
+        print("Cleaning data...")
+        cleaned_path = clean_data("telco_customer_churn_drift.csv")
+        predict(cleaned_path)
+        # Prepare features
+        
+        # Save predictions
+        # output_path = "path/to/output/churn_predictions.csv"
+        # print(f"\nSaving predictions to {output_path}")
+        # final_predictions.write.csv(output_path,
+        #                           header=True,
+        #                           mode="overwrite")
+        
+    except Exception as e:
+        print(f"Error in pipeline: {str(e)}")
+        raise e
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file