From 6fb89f522c1c6ec70ad2c9d2013e1ece205a9a60 Mon Sep 17 00:00:00 2001 From: Arleen <13521059@mahasiswa.itb.ac.id> Date: Fri, 10 Jan 2025 17:18:17 +0700 Subject: [PATCH] [Feat] train, predict, eval, track model functions --- dags/model.py | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 dags/model.py diff --git a/dags/model.py b/dags/model.py new file mode 100644 index 0000000..829536b --- /dev/null +++ b/dags/model.py @@ -0,0 +1,107 @@ +import pickle +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import confusion_matrix +from sklearn.metrics import roc_auc_score +from sklearn.metrics import RocCurveDisplay +from sklearn.model_selection import cross_val_score +from sklearn.model_selection import RepeatedStratifiedKFold +from sklearn.metrics import classification_report +from xgboost import XGBClassifier +import mlflow +import mlflow.sklearn +import mlflow.xgboost + + +def model_train(x_train, y_train): + classifier = XGBClassifier(learning_rate=0.01, max_depth=3, n_estimators=1000) + classifier.fit(x_train, y_train) + + # Save the trained model to a pickle file + with open("models/XGBClassifier.pkl", "wb") as file: + pickle.dump(classifier, file) + + return classifier + + +def model_predict(x_test): + # Load the model from its pickle file + try: + with open("models/XGBClassifier.pkl", "rb") as file: + model = pickle.load(file) + except FileNotFoundError: + print("Model file not found. Training a new model...") + model = model_train(x_test, y_test) # Retrain with initial data + + # Make prediction + y_pred = model.predict(x_test) + print(y_pred) + + return y_pred + + +def model_evaluation(x_test,y_test): + y_pred = model_predict(x_test) + + # Classification Report + report = classification_report(y_test, y_pred, output_dict=True) + print(report) + return report + + +def model_tracking(model_name, params, report): + # Set experiment + mlflow.set_experiment('MLFlow Simulation 2') + mlflow.set_tracking_uri('http://127.0.0.1:5000/') + + with mlflow.start_run(run_name=model_name): + mlflow.log_params(params) + + # Log Metrics + metrics = { + 'accuracy': report['accuracy'], + 'f1_score_macro': report['macro avg']['f1-score'] + } + + if '0' in report: + metrics.update({ + 'recall_class_0': report['0']['recall'], + }) + if '1' in report: + metrics.update({ + 'recall_class_1': report['1']['recall'] + }) + + mlflow.log_metrics(metrics) + + # Log Model + try: + with open("models/XGBClassifier.pkl", "rb") as file: + model = pickle.load(file) + mlflow.xgboost.log_model(model, "model") + except Exception as e: + print(f"Error logging model: {e}") + + + + + +# testing only + +columns = [ + "SeniorCitizen", "Partner", "Dependents", "tenure", + "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", + "Contract", "PaperlessBilling", + "PaymentMethod", "MonthlyCharges", "TotalCharges" +] +data = [[0, "Yes", "No", 1, "No", "Yes", "No", "No", "Month-to-month", "Yes", "Electronic check", 29.85, 29.85]] + +x_test = pd.DataFrame(data, columns=columns) +x_test = x_test.apply(lambda col: col.astype('category').cat.codes if col.dtypes == 'object' else col) + +y_test = pd.Series([0]) + +report = model_evaluation(x_test, y_test) +model_tracking("XGBoost", {"learning_rate": 0.01, "max_depth": 3, "n_estimators": 1000}, report) \ No newline at end of file -- GitLab