Classifying Heart Disease with Decision Trees

We will be using a database with 297 datapoints and the following attributes:

  1. age
  2. sex
  3. chest pain type (4 values)
  4. resting blood pressure
  5. serum cholestoral in mg/dl
  6. fasting blood sugar > 120 mg/dl
  7. resting electrocardiographic results (values 0,1,2)
  8. maximum heart rate achieved
  9. exercise induced angina
  10. oldpeak = ST depression induced by exercise relative to rest
  11. the slope of the peak exercise ST segment
  12. number of major vessels (0-3) colored by flourosopy
  13. thal: 3 = normal; 6 = fixed defect; 7 = reversable defect
  14. output parameter, which describes whether a heart disease is diagnosed (value 1) or not (value 0)
In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn import tree
from sklearn.utils import shuffle
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import seaborn as sns
In [33]:
# load the data

df =  pd.read_csv('heart.csv',names = ['age','sex','cp','trestbps','chol','fbs','restecg','thalach','exang','oldpeak','slope','ca','thal','target'])
df.head()
Out[33]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 63.0 1.0 1.0 145.0 233.0 1.0 2.0 150.0 0.0 2.3 3.0 0.0 6.0 0
1 67.0 1.0 4.0 160.0 286.0 0.0 2.0 108.0 1.0 1.5 2.0 3.0 3.0 1
2 67.0 1.0 4.0 120.0 229.0 0.0 2.0 129.0 1.0 2.6 2.0 2.0 7.0 1
3 37.0 1.0 3.0 130.0 250.0 0.0 0.0 187.0 0.0 3.5 3.0 0.0 3.0 0
4 41.0 0.0 2.0 130.0 204.0 0.0 2.0 172.0 0.0 1.4 1.0 0.0 3.0 0
In [34]:
df.describe()
Out[34]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
count 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000 297.000000
mean 54.542088 0.676768 3.158249 131.693603 247.350168 0.144781 0.996633 149.599327 0.326599 1.055556 1.602694 0.676768 4.730640 0.461279
std 9.049736 0.468500 0.964859 17.762806 51.997583 0.352474 0.994914 22.941562 0.469761 1.166123 0.618187 0.938965 1.938629 0.499340
min 29.000000 0.000000 1.000000 94.000000 126.000000 0.000000 0.000000 71.000000 0.000000 0.000000 1.000000 0.000000 3.000000 0.000000
25% 48.000000 0.000000 3.000000 120.000000 211.000000 0.000000 0.000000 133.000000 0.000000 0.000000 1.000000 0.000000 3.000000 0.000000
50% 56.000000 1.000000 3.000000 130.000000 243.000000 0.000000 1.000000 153.000000 0.000000 0.800000 2.000000 0.000000 3.000000 0.000000
75% 61.000000 1.000000 4.000000 140.000000 276.000000 0.000000 2.000000 166.000000 1.000000 1.600000 2.000000 1.000000 7.000000 1.000000
max 77.000000 1.000000 4.000000 200.000000 564.000000 1.000000 2.000000 202.000000 1.000000 6.200000 3.000000 3.000000 7.000000 1.000000

Decision Tree is not a similarity-based algorithm, but an information-based algorithm, hence there is no need to scale the data.

In [35]:
df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 297 entries, 0 to 301
Data columns (total 14 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   age       297 non-null    float64
 1   sex       297 non-null    float64
 2   cp        297 non-null    float64
 3   trestbps  297 non-null    float64
 4   chol      297 non-null    float64
 5   fbs       297 non-null    float64
 6   restecg   297 non-null    float64
 7   thalach   297 non-null    float64
 8   exang     297 non-null    float64
 9   oldpeak   297 non-null    float64
 10  slope     297 non-null    float64
 11  ca        297 non-null    float64
 12  thal      297 non-null    float64
 13  target    297 non-null    int64  
dtypes: float64(13), int64(1)
memory usage: 34.8 KB
In [36]:
df.groupby('target').mean()
Out[36]:
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal
target
0 52.643750 0.556250 2.793750 129.175000 243.493750 0.143750 0.843750 158.581250 0.143750 0.598750 1.412500 0.275000 3.787500
1 56.759124 0.817518 3.583942 134.635036 251.854015 0.145985 1.175182 139.109489 0.540146 1.589051 1.824818 1.145985 5.832117

By grouping by our target variable and taking the average of each attribute, we can see there are some noticeable differences between people with a diagnosed heart disease and without.

We can try to visualize these distribution differences with pairplots and using as color hue our target variable

In [37]:
sns.pairplot(df[['age','sex','cp','trestbps','chol','fbs','target']], hue="target")
Out[37]:
<seaborn.axisgrid.PairGrid at 0x156aa5b2be0>
In [38]:
sns.pairplot(df[['restecg','thalach','exang','oldpeak','slope','ca','thal','target']], hue="target")
Out[38]:
<seaborn.axisgrid.PairGrid at 0x156b0ce8040>

Shuffle data and split between train (70) and test (30) sets

In [54]:
# shuffle
df = shuffle(df).reset_index(drop=True)

# divide attributes between output variable (y) and explanatory variables (x)

df_x = df.drop(['target'], axis=1)
df_y = df['target']
In [55]:
# split between traing and test sets

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df_x, df_y, test_size=0.3, random_state=1)

Train a decission tree and check performance to get a benchmark

In [56]:
# fit decission tree on train set

dtc = DecisionTreeClassifier()
dtc.fit(X_train, y_train)
Out[56]:
DecisionTreeClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In [57]:
# get performance of our decission tree with the test set

prediction = dtc.predict(X_test)

#Score

score_tree = accuracy_score(y_test, prediction)

print("Accuracy score: " + str( score_tree))
Accuracy score: 0.7444444444444445
In [58]:
# visualize decission tree

fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (10,10), dpi=300)
tree.plot_tree(dtc, filled=True,max_depth=3)
Out[58]:
[Text(0.5431034482758621, 0.9, 'X[2] <= 3.5\ngini = 0.495\nsamples = 207\nvalue = [114, 93]'),
 Text(0.27586206896551724, 0.7, 'X[11] <= 0.5\ngini = 0.311\nsamples = 109\nvalue = [88, 21]'),
 Text(0.13793103448275862, 0.5, 'X[0] <= 56.5\ngini = 0.198\nsamples = 81\nvalue = [72, 9]'),
 Text(0.06896551724137931, 0.3, 'X[3] <= 111.0\ngini = 0.07\nsamples = 55\nvalue = [53, 2]'),
 Text(0.034482758620689655, 0.1, '\n  (...)  \n'),
 Text(0.10344827586206896, 0.1, '\n  (...)  \n'),
 Text(0.20689655172413793, 0.3, 'X[1] <= 0.5\ngini = 0.393\nsamples = 26\nvalue = [19, 7]'),
 Text(0.1724137931034483, 0.1, '\n  (...)  \n'),
 Text(0.2413793103448276, 0.1, '\n  (...)  \n'),
 Text(0.41379310344827586, 0.5, 'X[10] <= 1.5\ngini = 0.49\nsamples = 28\nvalue = [16, 12]'),
 Text(0.3448275862068966, 0.3, 'X[7] <= 186.5\ngini = 0.32\nsamples = 15\nvalue = [12, 3]'),
 Text(0.3103448275862069, 0.1, '\n  (...)  \n'),
 Text(0.3793103448275862, 0.1, '\n  (...)  \n'),
 Text(0.4827586206896552, 0.3, 'X[1] <= 0.5\ngini = 0.426\nsamples = 13\nvalue = [4, 9]'),
 Text(0.4482758620689655, 0.1, '\n  (...)  \n'),
 Text(0.5172413793103449, 0.1, '\n  (...)  \n'),
 Text(0.8103448275862069, 0.7, 'X[9] <= 0.75\ngini = 0.39\nsamples = 98\nvalue = [26, 72]'),
 Text(0.6896551724137931, 0.5, 'X[11] <= 0.5\ngini = 0.49\nsamples = 35\nvalue = [20, 15]'),
 Text(0.6206896551724138, 0.3, 'X[4] <= 272.5\ngini = 0.386\nsamples = 23\nvalue = [17, 6]'),
 Text(0.5862068965517241, 0.1, '\n  (...)  \n'),
 Text(0.6551724137931034, 0.1, '\n  (...)  \n'),
 Text(0.7586206896551724, 0.3, 'X[9] <= 0.05\ngini = 0.375\nsamples = 12\nvalue = [3, 9]'),
 Text(0.7241379310344828, 0.1, '\n  (...)  \n'),
 Text(0.7931034482758621, 0.1, '\n  (...)  \n'),
 Text(0.9310344827586207, 0.5, 'X[12] <= 6.5\ngini = 0.172\nsamples = 63\nvalue = [6, 57]'),
 Text(0.896551724137931, 0.3, 'X[8] <= 0.5\ngini = 0.42\nsamples = 20\nvalue = [6, 14]'),
 Text(0.8620689655172413, 0.1, '\n  (...)  \n'),
 Text(0.9310344827586207, 0.1, '\n  (...)  \n'),
 Text(0.9655172413793104, 0.3, 'gini = 0.0\nsamples = 43\nvalue = [0, 43]')]

The tree seems to be too long, this could create a problem of overfitting. We will try to change some parameters and to "prune" it to make it shorter

Improving decision tree

We will be trying different criterions and we will prune the tree to improve accuracy

In [59]:
### try differents criterions to measure the quality of the split

# gini criterion
dtc = DecisionTreeClassifier(criterion='gini')
dtc.fit(X_train, y_train)
pred = dtc.predict(X_test)
print('Criterion = gini', accuracy_score(y_test, pred))

# entropy criterion
dtc = DecisionTreeClassifier(criterion='entropy')
dtc.fit(X_train, y_train)
pred = dtc.predict(X_test)
print('Criterion = entropy', accuracy_score(y_test, pred))
Criterion = gini 0.7444444444444445
Criterion = entropy 0.7222222222222222

Gini criterion gives us a higher accuracy

To choose the best hyperparameters ('max_depth' and 'min_samples_leaf') we will be using GridSearch.

In [60]:
# We will vary hyperparameters from 2 to 10
best_parameters = {'max_depth': np.arange(2,11), 'min_samples_leaf': np.arange(2,11)}

# choose gini criterion
dtc = DecisionTreeClassifier(criterion='entropy') 

model = GridSearchCV(estimator=dtc, param_grid=best_parameters, n_jobs=-1, verbose=1, cv=StratifiedKFold(n_splits=3, shuffle=True, random_state=17))
model.fit(X_train, y_train)
Fitting 3 folds for each of 81 candidates, totalling 243 fits
Out[60]:
GridSearchCV(cv=StratifiedKFold(n_splits=3, random_state=17, shuffle=True),
             estimator=DecisionTreeClassifier(criterion='entropy'), n_jobs=-1,
             param_grid={'max_depth': array([ 2,  3,  4,  5,  6,  7,  8,  9, 10]),
                         'min_samples_leaf': array([ 2,  3,  4,  5,  6,  7,  8,  9, 10])},
             verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In [61]:
# get best parameters
model.best_params_
Out[61]:
{'max_depth': 3, 'min_samples_leaf': 7}
In [62]:
# check improved accuracy with new hyperparameters

pred = model.predict(X_test)
accuracy_score(y_test, pred)
Out[62]:
0.7555555555555555
In [64]:
# we visualize the improved tree

fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (10,10), dpi=300)
tree.plot_tree(model.best_estimator_, filled=True,max_depth=3)
Out[64]:
[Text(0.4230769230769231, 0.875, 'X[2] <= 3.5\nentropy = 0.993\nsamples = 207\nvalue = [114, 93]'),
 Text(0.15384615384615385, 0.625, 'X[0] <= 46.5\nentropy = 0.707\nsamples = 109\nvalue = [88, 21]'),
 Text(0.07692307692307693, 0.375, 'entropy = 0.0\nsamples = 32\nvalue = [32, 0]'),
 Text(0.23076923076923078, 0.375, 'X[1] <= 0.5\nentropy = 0.845\nsamples = 77\nvalue = [56, 21]'),
 Text(0.15384615384615385, 0.125, 'entropy = 0.235\nsamples = 26\nvalue = [25, 1]'),
 Text(0.3076923076923077, 0.125, 'entropy = 0.966\nsamples = 51\nvalue = [31, 20]'),
 Text(0.6923076923076923, 0.625, 'X[9] <= 0.75\nentropy = 0.835\nsamples = 98\nvalue = [26, 72]'),
 Text(0.5384615384615384, 0.375, 'X[11] <= 0.5\nentropy = 0.985\nsamples = 35\nvalue = [20, 15]'),
 Text(0.46153846153846156, 0.125, 'entropy = 0.828\nsamples = 23\nvalue = [17, 6]'),
 Text(0.6153846153846154, 0.125, 'entropy = 0.811\nsamples = 12\nvalue = [3, 9]'),
 Text(0.8461538461538461, 0.375, 'X[12] <= 6.5\nentropy = 0.454\nsamples = 63\nvalue = [6, 57]'),
 Text(0.7692307692307693, 0.125, 'entropy = 0.881\nsamples = 20\nvalue = [6, 14]'),
 Text(0.9230769230769231, 0.125, 'entropy = 0.0\nsamples = 43\nvalue = [0, 43]')]

We ended up with a tree with a higher accuracy and limited the size of the tree to avoid overfitting

Note: we could keep improving this algorithm with more sophisticated methods, such as Random Forests, but the objective of this exercise was to only work with Decision Trees

In [ ]: