diff --git a/dags/model_training.py b/dags/model_training.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc30bc1c350557c30f668b440f1f9f85893def3 --- /dev/null +++ b/dags/model_training.py @@ -0,0 +1,143 @@ +# scripts/model_training.py +from pyspark.sql import SparkSession +from pyspark.ml.feature import VectorAssembler +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.ml import Pipeline +import mlflow +import mlflow.spark +from mlflow.tracking import MlflowClient + +def create_spark_session(): + return SparkSession.builder \ + .appName("ModelTraining") \ + .config("spark.mlflow.tracking.uri", "http://mlflow:5000") \ + .getOrCreate() + +def prepare_features(df, feature_columns, label_column): + """Prepare feature vector for training""" + assembler = VectorAssembler( + inputCols=feature_columns, + outputCol="features" + ) + return assembler + +def train_model(spark, data_path, feature_columns, label_column): + """Train the model and log metrics to MLflow""" + # Read cleaned data + df = spark.read.parquet(data_path) + + # Split data into training and validation sets + train_df, val_df = df.randomSplit([0.8, 0.2], seed=42) + + # Create feature pipeline + feature_assembler = prepare_features(df, feature_columns, label_column) + + # Initialize the model + rf = RandomForestClassifier( + labelCol=label_column, + featuresCol="features", + numTrees=100, + maxDepth=10, + seed=42 + ) + + # Create pipeline + pipeline = Pipeline(stages=[feature_assembler, rf]) + + # Start MLflow run + with mlflow.start_run() as run: + # Log parameters + mlflow.log_params({ + "num_trees": 100, + "max_depth": 10, + "feature_columns": feature_columns, + "label_column": label_column + }) + + # Train model + model = pipeline.fit(train_df) + + # Make predictions on validation set + predictions = model.transform(val_df) + + # Evaluate model + evaluator = MulticlassClassificationEvaluator( + labelCol=label_column, + predictionCol="prediction" + ) + + # Calculate metrics + accuracy = evaluator.setMetricName("accuracy").evaluate(predictions) + f1 = evaluator.setMetricName("f1").evaluate(predictions) + precision = evaluator.setMetricName("precision").evaluate(predictions) + recall = evaluator.setMetricName("recall").evaluate(predictions) + + # Log metrics + mlflow.log_metrics({ + "accuracy": accuracy, + "f1": f1, + "precision": precision, + "recall": recall + }) + + # Log model + mlflow.spark.log_model( + model, + "model", + registered_model_name="my_model" + ) + + # Compare with current production model and promote if better + promote_model_if_better(run.info.run_id, accuracy) + + return model + +def promote_model_if_better(run_id, new_accuracy): + """Promote model to production if it performs better""" + client = MlflowClient() + + # Get current production model if it exists + production_versions = client.get_latest_versions("my_model", stages=["Production"]) + + should_promote = True + + if production_versions: + # Get the current production model's metrics + prod_run = client.get_run(production_versions[0].run_id) + prod_accuracy = prod_run.data.metrics["accuracy"] + + # Only promote if new model is better + should_promote = new_accuracy > prod_accuracy + + if should_promote: + # Get the latest version number + latest_version = client.get_latest_versions("my_model", stages=None)[-1].version + + # Transition current Production model to Archived if it exists + if production_versions: + client.transition_model_version_stage( + name="my_model", + version=production_versions[0].version, + stage="Archived" + ) + + # Promote new model to Production + client.transition_model_version_stage( + name="my_model", + version=latest_version, + stage="Production" + ) + +if __name__ == "__main__": + # Initialize Spark session + spark = create_spark_session() + + # Define feature columns and label column + feature_columns = ["feature1", "feature2", "feature3"] # Replace with your actual feature columns + label_column = "label" # Replace with your actual label column + + # Train model + model = train_model(spark, "cleaned_data_path", feature_columns, label_column) + + spark.stop() \ No newline at end of file