In this post, I’m going to develop machine learning models to predict death in heart failure patients.
Dataset
The description of the dataset can be here.
import pandas as pd
df = pd.read_csv('heart_failure_clinical_records_dataset.csv')
df.head()
age | anaemia | creatinine_phosphokinase | diabetes | ejection_fraction | high_blood_pressure | platelets | serum_creatinine | serum_sodium | sex | smoking | time | DEATH_EVENT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 75.0 | 0 | 582 | 0 | 20 | 1 | 265000.00 | 1.9 | 130 | 1 | 0 | 4 | 1 |
1 | 55.0 | 0 | 7861 | 0 | 38 | 0 | 263358.03 | 1.1 | 136 | 1 | 0 | 6 | 1 |
2 | 65.0 | 0 | 146 | 0 | 20 | 0 | 162000.00 | 1.3 | 129 | 1 | 1 | 7 | 1 |
3 | 50.0 | 1 | 111 | 0 | 20 | 0 | 210000.00 | 1.9 | 137 | 1 | 0 | 7 | 1 |
4 | 65.0 | 1 | 160 | 1 | 20 | 0 | 327000.00 | 2.7 | 116 | 0 | 0 | 8 | 1 |
df.shape
(299, 13)
Here I’m going to build classifiers using the machine learning library PyCaret.
from pycaret.classification import *
cls = setup(df, target='DEATH_EVENT',
categorical_features=['anaemia', 'diabetes', 'high_blood_pressure', 'sex', 'smoking'],
numeric_features=['age', 'creatinine_phosphokinase', 'ejection_fraction', 'platelets', 'serum_creatinine', 'serum_sodium'],
ignore_features=['time'])
compare_models()
Model | Accuracy | AUC | Recall | Prec. | F1 | Kappa | MCC | TT (Sec) | |
---|---|---|---|---|---|---|---|---|---|
lr | Logistic Regression | 0.7560 | 0.7323 | 0.4571 | 0.7388 | 0.5476 | 0.3967 | 0.4265 | 0.0050 |
ridge | Ridge Classifier | 0.7371 | 0.0000 | 0.4429 | 0.6872 | 0.5214 | 0.3556 | 0.3816 | 0.0040 |
lda | Linear Discriminant Analysis | 0.7371 | 0.7643 | 0.4429 | 0.6872 | 0.5214 | 0.3556 | 0.3816 | 0.0050 |
rf | Random Forest Classifier | 0.7274 | 0.7681 | 0.5000 | 0.6433 | 0.5475 | 0.3598 | 0.3749 | 0.0330 |
catboost | CatBoost Classifier | 0.7226 | 0.7690 | 0.4429 | 0.6788 | 0.5151 | 0.3332 | 0.3607 | 0.4300 |
nb | Naive Bayes | 0.7033 | 0.6956 | 0.2857 | 0.6483 | 0.3909 | 0.2312 | 0.2686 | 0.0040 |
et | Extra Trees Classifier | 0.6986 | 0.7374 | 0.3714 | 0.6421 | 0.4317 | 0.2536 | 0.2849 | 0.0310 |
ada | Ada Boost Classifier | 0.6943 | 0.6952 | 0.4857 | 0.6104 | 0.5160 | 0.2997 | 0.3207 | 0.0090 |
xgboost | Extreme Gradient Boosting | 0.6933 | 0.7191 | 0.4571 | 0.5673 | 0.4902 | 0.2798 | 0.2918 | 0.1980 |
lightgbm | Light Gradient Boosting Machine | 0.6888 | 0.7279 | 0.4143 | 0.5753 | 0.4604 | 0.2570 | 0.2736 | 0.1200 |
gbc | Gradient Boosting Classifier | 0.6798 | 0.6988 | 0.4857 | 0.5406 | 0.4998 | 0.2684 | 0.2774 | 0.0090 |
qda | Quadratic Discriminant Analysis | 0.6795 | 0.6440 | 0.3286 | 0.5483 | 0.4027 | 0.2044 | 0.2206 | 0.0070 |
dt | Decision Tree Classifier | 0.6557 | 0.6173 | 0.5000 | 0.5048 | 0.4900 | 0.2338 | 0.2423 | 0.0040 |
svm | SVM – Linear Kernel | 0.6317 | 0.0000 | 0.1000 | 0.0333 | 0.0500 | 0.0000 | 0.0000 | 0.0040 |
knn | K Neighbors Classifier | 0.6126 | 0.4562 | 0.1286 | 0.3083 | 0.1756 | -0.0139 | -0.0085 | 0.0060 |
Comments