Skip to content
Snippets Groups Projects
Commit 9303d866 authored by haikalardzi's avatar haikalardzi
Browse files

t

parent 7791b738
Branches
No related merge requests found
# scripts/model_training.py
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
import mlflow
import mlflow.spark
from mlflow.tracking import MlflowClient
def create_spark_session():
return SparkSession.builder \
.appName("ModelTraining") \
.config("spark.mlflow.tracking.uri", "http://mlflow:5000") \
.getOrCreate()
def prepare_features(df, feature_columns, label_column):
"""Prepare feature vector for training"""
assembler = VectorAssembler(
inputCols=feature_columns,
outputCol="features"
)
return assembler
def train_model(spark, data_path, feature_columns, label_column):
"""Train the model and log metrics to MLflow"""
# Read cleaned data
df = spark.read.parquet(data_path)
# Split data into training and validation sets
train_df, val_df = df.randomSplit([0.8, 0.2], seed=42)
# Create feature pipeline
feature_assembler = prepare_features(df, feature_columns, label_column)
# Initialize the model
rf = RandomForestClassifier(
labelCol=label_column,
featuresCol="features",
numTrees=100,
maxDepth=10,
seed=42
)
# Create pipeline
pipeline = Pipeline(stages=[feature_assembler, rf])
# Start MLflow run
with mlflow.start_run() as run:
# Log parameters
mlflow.log_params({
"num_trees": 100,
"max_depth": 10,
"feature_columns": feature_columns,
"label_column": label_column
})
# Train model
model = pipeline.fit(train_df)
# Make predictions on validation set
predictions = model.transform(val_df)
# Evaluate model
evaluator = MulticlassClassificationEvaluator(
labelCol=label_column,
predictionCol="prediction"
)
# Calculate metrics
accuracy = evaluator.setMetricName("accuracy").evaluate(predictions)
f1 = evaluator.setMetricName("f1").evaluate(predictions)
precision = evaluator.setMetricName("precision").evaluate(predictions)
recall = evaluator.setMetricName("recall").evaluate(predictions)
# Log metrics
mlflow.log_metrics({
"accuracy": accuracy,
"f1": f1,
"precision": precision,
"recall": recall
})
# Log model
mlflow.spark.log_model(
model,
"model",
registered_model_name="my_model"
)
# Compare with current production model and promote if better
promote_model_if_better(run.info.run_id, accuracy)
return model
def promote_model_if_better(run_id, new_accuracy):
"""Promote model to production if it performs better"""
client = MlflowClient()
# Get current production model if it exists
production_versions = client.get_latest_versions("my_model", stages=["Production"])
should_promote = True
if production_versions:
# Get the current production model's metrics
prod_run = client.get_run(production_versions[0].run_id)
prod_accuracy = prod_run.data.metrics["accuracy"]
# Only promote if new model is better
should_promote = new_accuracy > prod_accuracy
if should_promote:
# Get the latest version number
latest_version = client.get_latest_versions("my_model", stages=None)[-1].version
# Transition current Production model to Archived if it exists
if production_versions:
client.transition_model_version_stage(
name="my_model",
version=production_versions[0].version,
stage="Archived"
)
# Promote new model to Production
client.transition_model_version_stage(
name="my_model",
version=latest_version,
stage="Production"
)
if __name__ == "__main__":
# Initialize Spark session
spark = create_spark_session()
# Define feature columns and label column
feature_columns = ["feature1", "feature2", "feature3"] # Replace with your actual feature columns
label_column = "label" # Replace with your actual label column
# Train model
model = train_model(spark, "cleaned_data_path", feature_columns, label_column)
spark.stop()
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment