diff --git a/dags/model_training.py b/dags/model_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bc30bc1c350557c30f668b440f1f9f85893def3
--- /dev/null
+++ b/dags/model_training.py
@@ -0,0 +1,143 @@
+# scripts/model_training.py
+from pyspark.sql import SparkSession
+from pyspark.ml.feature import VectorAssembler
+from pyspark.ml.classification import RandomForestClassifier
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator
+from pyspark.ml import Pipeline
+import mlflow
+import mlflow.spark
+from mlflow.tracking import MlflowClient
+
+def create_spark_session():
+    return SparkSession.builder \
+        .appName("ModelTraining") \
+        .config("spark.mlflow.tracking.uri", "http://mlflow:5000") \
+        .getOrCreate()
+
+def prepare_features(df, feature_columns, label_column):
+    """Prepare feature vector for training"""
+    assembler = VectorAssembler(
+        inputCols=feature_columns,
+        outputCol="features"
+    )
+    return assembler
+
+def train_model(spark, data_path, feature_columns, label_column):
+    """Train the model and log metrics to MLflow"""
+    # Read cleaned data
+    df = spark.read.parquet(data_path)
+    
+    # Split data into training and validation sets
+    train_df, val_df = df.randomSplit([0.8, 0.2], seed=42)
+    
+    # Create feature pipeline
+    feature_assembler = prepare_features(df, feature_columns, label_column)
+    
+    # Initialize the model
+    rf = RandomForestClassifier(
+        labelCol=label_column,
+        featuresCol="features",
+        numTrees=100,
+        maxDepth=10,
+        seed=42
+    )
+    
+    # Create pipeline
+    pipeline = Pipeline(stages=[feature_assembler, rf])
+    
+    # Start MLflow run
+    with mlflow.start_run() as run:
+        # Log parameters
+        mlflow.log_params({
+            "num_trees": 100,
+            "max_depth": 10,
+            "feature_columns": feature_columns,
+            "label_column": label_column
+        })
+        
+        # Train model
+        model = pipeline.fit(train_df)
+        
+        # Make predictions on validation set
+        predictions = model.transform(val_df)
+        
+        # Evaluate model
+        evaluator = MulticlassClassificationEvaluator(
+            labelCol=label_column,
+            predictionCol="prediction"
+        )
+        
+        # Calculate metrics
+        accuracy = evaluator.setMetricName("accuracy").evaluate(predictions)
+        f1 = evaluator.setMetricName("f1").evaluate(predictions)
+        precision = evaluator.setMetricName("precision").evaluate(predictions)
+        recall = evaluator.setMetricName("recall").evaluate(predictions)
+        
+        # Log metrics
+        mlflow.log_metrics({
+            "accuracy": accuracy,
+            "f1": f1,
+            "precision": precision,
+            "recall": recall
+        })
+        
+        # Log model
+        mlflow.spark.log_model(
+            model,
+            "model",
+            registered_model_name="my_model"
+        )
+        
+        # Compare with current production model and promote if better
+        promote_model_if_better(run.info.run_id, accuracy)
+        
+        return model
+
+def promote_model_if_better(run_id, new_accuracy):
+    """Promote model to production if it performs better"""
+    client = MlflowClient()
+    
+    # Get current production model if it exists
+    production_versions = client.get_latest_versions("my_model", stages=["Production"])
+    
+    should_promote = True
+    
+    if production_versions:
+        # Get the current production model's metrics
+        prod_run = client.get_run(production_versions[0].run_id)
+        prod_accuracy = prod_run.data.metrics["accuracy"]
+        
+        # Only promote if new model is better
+        should_promote = new_accuracy > prod_accuracy
+    
+    if should_promote:
+        # Get the latest version number
+        latest_version = client.get_latest_versions("my_model", stages=None)[-1].version
+        
+        # Transition current Production model to Archived if it exists
+        if production_versions:
+            client.transition_model_version_stage(
+                name="my_model",
+                version=production_versions[0].version,
+                stage="Archived"
+            )
+        
+        # Promote new model to Production
+        client.transition_model_version_stage(
+            name="my_model",
+            version=latest_version,
+            stage="Production"
+        )
+
+if __name__ == "__main__":
+    # Initialize Spark session
+    spark = create_spark_session()
+    
+    # Define feature columns and label column
+    feature_columns = ["feature1", "feature2", "feature3"]  # Replace with your actual feature columns
+    label_column = "label"  # Replace with your actual label column
+    
+    # Train model
+    model = train_model(spark, "cleaned_data_path", feature_columns, label_column)
+    
+    spark.stop()
\ No newline at end of file