Skip to content
Snippets Groups Projects
Commit a9b43b06 authored by Arleen's avatar Arleen
Browse files

[Feat] trying train with mlflow tracking

parent aefe3d3f
No related merge requests found
Pipeline #66336 failed with stages
from train import *
from pyspark.sql import SparkSession
if __name__=="__main__":
# Initialize Spark session
spark = SparkSession.builder.appName("ChurnPrediction").getOrCreate()
df = spark.read.csv("telco_customer_churn.csv", header=True, inferSchema=True)
train(df)
\ No newline at end of file
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
train.py 0 → 100644
import mlflow
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from imblearn.over_sampling import SMOTE
import pickle
def train(df):
# prepare data
X = df.drop('Churn')
y = df['Churn']
print(X.head)
print(y.head)
mn = MinMaxScaler()
X = mn.fit_transform(X)
# split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size= 0.2, stratify= y, random_state= 42)
smote = SMOTE(sampling_strategy='minority')
X_sm, y_sm = smote.fit_resample(X_train, y_train)
# MLFlow tracking
mlflow.autolog()
# train with ANN
model = keras.models.Sequential([
keras.layers.Dense(19,input_shape=(19,),activation='relu'),
keras.layers.Dense(128,activation='relu'),
keras.layers.Dense(1,activation='sigmoid')
])
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
model.fit(X_sm, y_sm,
batch_size=8,
epochs=10,
verbose=1)
pickle.dump(model, open('model.pkl', 'wb')) #Saving the model
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment