Data Mining

# Versions of the libraries used are noted beside the import lines
# Python version: 3.11.7
import numpy as np                              # 1.26.4
import pandas as pd                             # 2.1.4
import jupyter_bokeh                            # 4.0.5                              
from itertools import combinations
from bokeh.plotting import figure               # bokeh 3.3.4
from bokeh.plotting import show
from bokeh.layouts import column
from bokeh.layouts import row
from bokeh.layouts import gridplot
from bokeh.models import ColumnDataSource
from bokeh.models import FactorRange
from bokeh.models import RadioButtonGroup
from bokeh.models import Div
from bokeh.models import CustomJS
from bokeh.models import Legend
from bokeh.models import MultiChoice
from bokeh.models import CheckboxGroup
from bokeh.models import ColorBar
from bokeh.models import BasicTicker
from bokeh.models import PrintfTickFormatter
from bokeh.models import Label, Select
from bokeh.models import DataRange1d
from bokeh.models import FixedTicker
from bokeh.models import LabelSet
from bokeh.models import LinearColorMapper
from bokeh.models import NumberFormatter
from bokeh.models import DataTable
from bokeh.models import TableColumn
from bokeh.models import HoverTool
from bokeh.models import Whisker
from bokeh.transform import linear_cmap
from bokeh.palettes import Blues256
from bokeh.palettes import Category10
from bokeh.io import output_notebook
import sklearn.preprocessing                    # scikit-learn 1.2.2
import sklearn.model_selection
import sklearn.neighbors 
import sklearn.metrics

# Required to generate static images for pdf export
from bokeh.io import export_png
from IPython.display import Image
from IPython.display import display

# Enable notebook output
output_notebook()

create_pdf = False
Loading BokehJS ...

Introduction

This notebook demonstrates data analysis of medical data relating to cardiovascular disease using Python and bokeh. Cardiovascular disease is a leading cause of death worldwide, and its early detection and timely intervention can improve the outcomes of the patient. I have recent experience with the early prediction of cardiovascular disease and have developed an interest in risk prediction using health data. Most predictions of cardiovascular risks are based on mathematical models developed by analysing various risk factors from medical studies. The most widely used dataset for cardiovascular risk prediction is the Framingham Heart Study dataset https://www.framinghamheartstudy.org/fhs-risk-functions/cardiovascular-disease-10-year-risk/. The Framingham heart study found eight primary predictors of cardiovascular disease: sex, age, systolic blood pressure, total cholesterol, HDL cholesterol, BMI, smoking, and diabetes.

In this analysis, I use data from the US Centers for Disease Control and Prevention National Health and Nutrition Examination Survey (NHANES, https://wwwn.cdc.gov/nchs/nhanes/) to investigate whether a person’s diagnosis of cardiovascular disease can be predicted using machine learning techniques. The bokeh package was used for all visualisations as requested in the task requirement. I could not get Python callbacks in bokeh to work in the Jupyter Notebook, so each visualisation uses custom JavaScript callbacks.

Privacy and Ethics

Cardiovascular disease prediction raises ethical concerns, particularly regarding data privacy and the potential for bias. Individuals in the NHANES data have been de-identified and replaced with a sequence (ID) number. However, I will be aggregating data from different parts of the dataset for each individual, increasing their identification risk. I have been mindful of this and limited the data I have aggregated. Investigating the fairness of models and, therefore, the impact of bias was beyond the scope of this study but would need to be considered if the modelling was to be used for prediction.

Data

The data used in the analysis was sourced from the NHANES 2017-March 2020 pre-pandemic datasets (https://wwwn.cdc.gov/nchs/nhanes/continuousnhanes/default.aspx?Cycle=2017-2020) and consisted of demographic, examination, laboratory and questionnaire data. The specific data used were:

Each data file was downloaded in an XPT format and converted to a CSV file using the SAS Universal Viewer. The data is in the public domain and can be used as specified in the license (https://www.cdc.gov/other/agencymaterials.htmlmaterials.html).

Load data and clean

Each CSV file is loaded into pandas data frames. For each data frame, only the required variables were kept, with others being dropped. The remaining variable names were adjusted to be more meaningful. The data type (dtype) was checked (not shown in the code), and the existence of null values was checked. For some datasets, zero values were also removed.

Demographics variables

The variables retained were:

  • SEQN: Respondent sequence number
  • RIAGENDR: Gender (male=1, female=2)
  • RIDAGEYR: Age in years at screening

I adjusted the gender values so that male = 0 and female = 2.

# Load demographic data file in data frame
demo = pd.read_csv("Demographics.csv", comment="#")

# Dropping the unnecessary columns
columns_to_keep = ["SEQN", "RIAGENDR", "RIDAGEYR"]
demo = demo[columns_to_keep]

# Adjust gender so male = 0 and female = 1
demo["RIAGENDR"] -= 1

# Rename columns
demo.rename(columns={"RIAGENDR": "SEX", "RIDAGEYR": "AGE"}, inplace=True)

print(f"{demo.shape[0]} data points")

# Check how many null values in the data frame
print("Number of null values per variable")
print(demo.isnull().sum())
15600 data points
Number of null values per variable
SEQN    0
SEX     0
AGE     0
dtype: int64

Body measurement variables

The variables retained were:

  • SEQN Respondent sequence number
  • BMXBMI Body Mass Index (kg/m^2)
# Load body data file in a data frame
body = pd.read_csv("Body_Measurements.csv", comment="#")

# Dropping the unnecessary columns
columns_to_keep = ["SEQN", "BMXBMI"]
body = body[columns_to_keep]

# Drop rows where the measurement is zero
body = body.loc[(body["BMXBMI"] != 0)]

# Rename columns
body.rename(columns={"BMXBMI": "BMI"}, inplace=True)

print(f"{body.shape[0]} data points")

# Check how many null values in the data frame
print("Number of null values per variable")
print(body.isnull().sum())
13137 data points
Number of null values per variable
SEQN    0
BMI     0
dtype: int64

Blood pressure measurement variables

The variables retained were:

  • SEQN Respondent sequence number
  • BPXOSY1 Systolic - 1st oscillometric reading
  • BPXODI1 Diastolic - 1st oscillometric reading
  • BPXOSY2 Systolic - 2nd oscillometric reading
  • BPXODI2 Diastolic - 2nd oscillometric reading
  • BPXOSY3 Systolic - 3rd oscillometric reading
  • BPXODI3 Diastolic - 3rd oscillometric reading

The three systolic and diastolic pressure readings were averaged, and only the average was retained. Blood pressure is measured in mmHg.

# Load the blood pressure data file in data frame
blood = pd.read_csv("Blood_Pressure_Measurement.csv", comment="#")

# Dropping the unnecessary columns
blood.drop(columns=["BPAOARM", "BPAOCSZ", "BPXOPLS1", "BPXOPLS2", "BPXOPLS3"],
            inplace=True)

# Calcualte the averaage of the three measurements, retuning as an integer
blood["BPXOSYX"] = blood[["BPXOSY1", "BPXOSY2", "BPXOSY3"]].mean(axis=1).astype(int)
blood["BPXODIX"] = blood[["BPXODI1", "BPXODI2", "BPXODI3"]].mean(axis=1).astype(int)

# Now we have the average of blood pressure, we don't need the original measurements
blood.drop(columns=["BPXOSY1", "BPXOSY2", "BPXOSY3","BPXODI1", "BPXODI2", "BPXODI3"],
            inplace=True)

# Drop rows where the measurement is zero
blood = blood.loc[(blood["BPXOSYX"] != 0) & (blood["BPXODIX"] != 0)]

# Rename columns
blood.rename(columns={"BPXOSYX": "BPSY", "BPXODIX": "BPDI"}, inplace=True)

print(f"{blood.shape[0]} data points")

# Check how many null values in the data frame
print("Number of null values per variable")
print(blood.isnull().sum())
10353 data points
Number of null values per variable
SEQN    0
BPSY    0
BPDI    0
dtype: int64

Cholesterol HDL variables

The variables retained were:

  • SEQN Respondent sequence number
  • LBDHDD Direct HDL-Cholesterol (mg/dL)
# Load  Cholesterol - HDL data file in data frame
col_hdl = pd.read_csv("Cholesterol_HDL.csv", comment="#")

# Dropping the unnecessary columns
col_hdl.drop(columns=["LBDHDDSI"], inplace=True)

# Drop rows where the measurement is zero
col_hdl = col_hdl.loc[(col_hdl["LBDHDD"] != 0)]

# Rename columns
col_hdl.rename(columns={"LBDHDD": "CHDL"}, inplace=True)

print(f"{col_hdl.shape[0]} data points")

# Check how many null values in the data frame
print("Number of null values per variable")
print(col_hdl.isnull().sum())
10828 data points
Number of null values per variable
SEQN    0
CHDL    0
dtype: int64

Cholesterol Total

The variables retained were:

  • SEQN Respondent sequence number
  • LBXTC Total Cholesterol (mg/dL)
# Load  Cholesterol - Total data file in data frame
col_total = pd.read_csv("Cholesterol_Total.csv", comment="#")

# Dropping the unnecessary columns
col_total.drop(columns=["LBDTCSI"], inplace=True)

# Rename columns
col_total.rename(columns={"LBXTC": "CTOT"}, inplace=True)

print(f"{col_total.shape[0]} data points")

# Check how many null values in the data frame
print("Number of null values per variable")
print(col_total.isnull().sum())
12198 data points
Number of null values per variable
SEQN    0
CTOT    0
dtype: int64

Diabetes

The variables retained were:

  • SEQN Respondent sequence number
  • DIQ010 Doctors has told you have diabetes (1 = Yes, 2 = No, 3 = Borderline, 7 = Refused, 9 = Don’t know)

I am only interested in Yes/No, so I dropped any row with a value above 2 and adjusted the values so that No = 0 and Yes = 1.

# Load  Diabetes questions file in data frame
diab = pd.read_csv("Diabetes.csv", comment="#")

# Dropping the unnecessary columns
columns_to_keep = ["SEQN", "DIQ010"]
diab = diab[columns_to_keep]

# Keep only those who answer 1 or 2.
diab = diab[diab["DIQ010"] <= 2]

# Adjust No value to be equal to 0
diab["DIQ010"] = np.where(diab["DIQ010"] == 2, diab["DIQ010"] - 2, diab["DIQ010"])

# Rename columns
diab.rename(columns={"DIQ010": "DIABETES"}, inplace=True)

print(f"{diab.shape[0]} data points")

# Check how many null values in the data frame
print("Number of null values per variable")
print(diab.isnull().sum())
14694 data points
Number of null values per variable
SEQN        0
DIABETES    0
dtype: int64

Smoking

The variables retained were:

  • SEQN Respondent sequence number
  • SMQ020 Smoked at least 100 cigarettes in life (1 = Yes, 2 = No, 7 = Refused, 9 = Don’t know)

I am only interested in Yes/No, so I dropped any row with a value greater than 2 and adjusted the values so that No = 0 and Yes = 1.

# Load  smoking questions file in data frame
smoke = pd.read_csv("Smoking.csv", comment="#")

# Dropping the unnecessary columns
columns_to_keep = ["SEQN", "SMQ020"]
smoke = smoke[columns_to_keep]

# Keep only those who answer 1 or 2.
smoke = smoke[smoke["SMQ020"] <= 2]

# Adjust No value to be equal to 0
smoke["SMQ020"] = np.where(smoke["SMQ020"] == 2, smoke["SMQ020"] - 2,
                            smoke["SMQ020"])

# Rename columns
smoke.rename(columns={"SMQ020": "SMOKES"}, inplace=True)
            
print(f"{smoke.shape[0]} data points")

# Check how many null values in the data frame
print("Number of null values per variable")
print(smoke.isnull().sum())
11132 data points
Number of null values per variable
SEQN      0
SMOKES    0
dtype: int64

Coronary heart disease

This data is for Coronary heart disease, a specific type of Cardiovascular disease. Cardiovascular disease is a broad term encompassing all diseases of the heart and blood vessels. The NHANES dataset didn’t have data specifically for cardiovascular disease, so I’m using coronary heart disease data as a substitute.

The variables retained were:

  • SEQN Respondent sequence number
  • MCQ160c Ever told you had coronary heart disease (1 = Yes, 2 = No, 7 = Refused, 9 = Don’t know)

I am only interested in Yes/No, so I dropped any row with a value greater than 2 and adjusted the values so that No = 0 and Yes = 1.

# Load  smoking questions file in data frame
medical = pd.read_csv("Medical_Conditions.csv", comment="#")

# Dropping the unnecessary columns
columns_to_keep = ["SEQN", "MCQ160C"]
medical = medical[columns_to_keep]

# Keep only those who answer 1 or 2.
medical = medical[medical["MCQ160C"] <= 2]

# Adjust No value to be equal to 0
medical["MCQ160C"] = np.where(medical["MCQ160C"] == 2, medical["MCQ160C"] - 2,
                                medical["MCQ160C"])

# Rename columns
medical.rename(columns={"MCQ160C": "CD"}, inplace=True)
            
print(f"{medical.shape[0]} data points")

# Check how many null values in the data frame
print("Number of null values per variable")
print(medical.isnull().sum())
14958 data points
Number of null values per variable
SEQN    0
CD      0
dtype: int64

Join data frames

Each participant is uniquely identified in the datasets by their sequence number (SQEN). The demographics data frame contains an entry for each participant in the survey. In contrast, the other data frames contain a subset of participants based on different criteria for the data being collected. If I do a left merge between the other data frames and the demo data frames, I will keep entries for each participant and fill in missing data from the other data frames with NaNs.

# Data frames to merge
dfs = [body, blood, col_hdl, col_total, diab, smoke, medical]

# Primary date frame
nhanes_df = demo 

# Loop through all data frames, doing a left merge
for df in dfs:
    nhanes_df = pd.merge(nhanes_df, df, on="SEQN", how="left")

# Drop the sequence number column as it is no longer needed
nhanes_df = nhanes_df.drop("SEQN", axis=1)

# Check how many null values in the data frame
print(nhanes_df.isnull().sum())

print(f"{nhanes_df.shape[0]} number of points before dropping null values.")

# Drop all rows containing null values
nhanes_df = nhanes_df.dropna()

# Only interested in those over 20 and under 75
nhanes_df = nhanes_df[(nhanes_df["AGE"] >= 20) & (nhanes_df["AGE"] <= 75)]


print(f"{nhanes_df.shape[0]} data points in the final data frame.")

# Final check for null values
print(nhanes_df.isnull().sum())
SEX            0
AGE            0
BMI         2423
BPSY        5207
BPDI        5207
CHDL        4732
CTOT        3362
DIABETES     866
SMOKES      4428
CD           602
dtype: int64
15600 number of points before dropping null values.
6298 data points in the final data frame.
SEX         0
AGE         0
BMI         0
BPSY        0
BPDI        0
CHDL        0
CTOT        0
DIABETES    0
SMOKES      0
CD          0
dtype: int64

Exploratory Data Analysis

In this section, I explore the data used in the analysis to better understand it and identify patterns, relationships between variables, and potential issues, such as outliers.

Data Distributions

To visualise the distributions of each variable using bokeh, a function was created that groups the data by sex and allows the user to visualise each sex separately or together on the same histogram plot. The distributions of the categorical variables, smoking, diabetes, and coronary heart disease, were not plotted using histograms as they were uninteresting. For these variables, a stacked bar graph was prepared instead.

# Function to create an interactive histogram by sex for a given data frame
#  and variable
def plot_histogram_by_sex(df, variable, type, nbins=24, units="", add_categories=None):
    """
    Create a histogram plot that allows the user to select to show 
    either, male, female or both male and female data.

    :param df: Data frame to use
    :param variable: Name of the column to plot
    :param type: The label for the name for the plots 
    """
    sex_group = df.groupby("SEX")
    # 0: 'Male', 1: 'Female'
    male_df = sex_group.get_group(0)[variable]
    female_df = sex_group.get_group(1)[variable]
    
    # Set up histogram bins
    max_value = df[variable].max()
    min_value = df[variable].min()
    bins = np.linspace(min_value, max_value, nbins+1)
    hist_male, edges_male = np.histogram(male_df, bins=bins)
    hist_female, edges_female = np.histogram(female_df, bins=bins)

    # Create data sources
    source_male = ColumnDataSource(data=dict(
        top=hist_male,
        bottom=np.zeros_like(hist_male),
        left=edges_male[:-1],
        right=edges_male[1:]))

    source_female = ColumnDataSource(data=dict(
        top=hist_female,
        bottom=np.zeros_like(hist_female),
        left=edges_female[:-1],
        right=edges_female[1:]))

    # Create the main figure for showing individual or overlaid histograms
    p = figure(
        title=type+" Distribution",
        x_axis_label=type+units,
        y_axis_label="Count",
        height=500,
        width=900,
        tools="pan,wheel_zoom,box_zoom,reset,save",
        active_drag="box_zoom")

    # Add the histograms as quad glyphs with initially male visible,
    # female invisible
    male_quad = p.quad(
        top="top", bottom="bottom", left="left", right="right",
        source=source_male,
        fill_color="#1E90FF", #Doger blue for males 
        line_color="white",
        alpha=0.7,
        hover_fill_color="#4169E1",
        hover_line_color="white",
        legend_label="Male",
        name="male_hist")

    female_quad = p.quad(
        top="top", bottom="bottom", left="left", right="right",
        source=source_female,
        fill_color="#FF1493",  # Deeppink for females
        line_color="white",
        alpha=0.5,  # Lower alpha for better overlay visibility
        hover_fill_color="#FF1493",
        hover_line_color="white",
        legend_label="Female",
        visible=False,  # Initially hidden
        name="female_hist")

    if add_categories:
        # Add vertical lines to the plot
        max_height = max(max(hist_male), max(hist_female)) * 1.1

        # Add annotations for BMI categories
        y_pos = max_height * 0.97

        i = 0
        for val, label, colour in add_categories:
            line = p.line([val, val], [0, max_height], line_color=colour,
                            line_width=2, line_dash="dashed")
            text = p.text(val + 1, y_pos - 16*i, [label], text_color=colour,
                            text_font_size="8pt")
            i += 1
            
    # Create a header
    header = Div(
        text=f"""
        <h2 style='text-align: center; color: #444444;'>{type} Distribution by Gender</h2>
        <p style='text-align: center;'>Select an option to view male, female, or overlaid {type} histograms</p>
        """,
        width=700)

    # Configure the legend
    p.legend.location = "top_right"
    p.legend.click_policy = "hide"  # Allow toggling by clicking the legend

    # Set up the radio button group for selection with the third "Both" option
    radio_button_group = RadioButtonGroup(labels=["Male", "Female", "Both"],
                                            active=0)

    # Add JavaScript callback for interactivity
    callback = CustomJS(args=dict(
            male_quad=male_quad,
            female_quad=female_quad,
            plot=p, type=type
        ), code="""
        if (cb_obj.active == 0) {
            // Male only
            male_quad.visible = true;
            female_quad.visible = false;
            let sex_text = "Male "
            let title = sex_text.concat(type, " Distribution");
            plot.title.text = title;
        } else if (cb_obj.active == 1) {
            // Female only
            male_quad.visible = false;
            female_quad.visible = true;
            let sex_text = "Female "
            let title = sex_text.concat(type, " Distribution");
            plot.title.text = title;
        } else {
            // Both overlaid
            male_quad.visible = true;
            female_quad.visible = true;
            let sex_text = "Male & Female "
            let title = sex_text.concat(type, " Distribution Comparison");
            plot.title.text = title;
        }
    """)

    radio_button_group.js_on_change("active", callback)

    # Update the title based on initial selection
    p.title.text = f"Male {type} Distribution"

    # Create the layout
    layout = column(header, radio_button_group, p)

    if create_pdf:
        # Create static images for pdf
        export_png(layout, filename=f"{variable}_plot.png")
        display(Image(filename=f"{variable}_plot.png"))
    else:
        # Display the plot in the notebook
        show(layout, notebook_handle=False)

Distribution of Age

The distribution of participants’ ages is shown in the figure below. The distribution for males and females is uniform, with approximately similar counts between the sexes. There are a couple of age groups with slightly more data, with the 60-year-old group having almost twice that of other age groups. The NHANES survey tries to balance participation across age groups and sexes, and this plot shows that this has been achieved.

plot_histogram_by_sex(nhanes_df, "AGE", "AGE".title(), units=" (years)")

Distribution of BMI

The BMI distribution, shown below, is unimodal and skewed to the right for both sexes, although the skewness appears to be higher for females than males. Also superimposed on the plot are the health classifications based on BMI. To the left of the green line is underweight, and above the red line is severely obese. More females are severely obese than males.

bmi_categories = [
    (18.5, "Normals", "#32CD32"),  # Green
    (25, "Overweight", "#FFD700"), # Gold
    (30, "Obese", "#FFA500"),      # Orange
    (35, "Severely Obese", "#FF4500")   # OrangeRed
]
plot_histogram_by_sex(nhanes_df, "BMI", "BMI", add_categories=bmi_categories,
                        units=" (kg/m^2)")

Distribution of Systolic Blood Pressure

The distribution of systolic blood pressure, shown below, is unimodal and skewed to the right for both sexes. More females have a systolic blood pressure that is either in Hypertensive crisis or below normal. Both sexes contain very low values below 50 mmHg. These are potential measurement issues and should be investigated further.

bpsy_categories = [
    (120, "Normal", "#32CD32"),  # Green
    (130, "St 1 Hypertension", "#FFD700"),     # Gold
    (140, "St 2 Hypertension", "#FFA500"),      # Orange
    (180, "Hypertensive crisis", "#FF4500")   # OrangeRed
]
plot_histogram_by_sex(nhanes_df, "BPSY", "Systolic Blood Pressure".title(),
                        add_categories=bpsy_categories, units=" (mmHg)")

Distribution of Diastolic Blood Pressure

The distribution of diastolic blood pressure, shown below, is unimodal and symmetric for both sexes. Both males and females have potential outliers on either side of the distribution. Both sexes have a median diastolic blood pressure below what is considered normal, with some very low values below 40 mmHg. These are potential measurement issues and should be investigated further.

bpdi_categories = [
    (80, "Normal", "#32CD32"),  # Green
    (90, "St 1 Hypertension", "#FFD700"),     # Gold
    (120, "Hypertensive crisis", "#FF4500")   # OrangeRed
]
plot_histogram_by_sex(nhanes_df, "BPDI", "Distolic Blood Pressure".title(),
                        add_categories=bpdi_categories, units=" (mmHg)")

Distribution of HDL Cholesterol

The distribution of HDL cholesterol, shown below, is unimodal and skewed to the right for both sexes. Both males and females have potential outliers on the right side of the distribution, with a few values larger than 150 mg/dL. Males tend to have lower HDL levels than females, and lower HDL levels are undesirable.

plot_histogram_by_sex(nhanes_df, "CHDL", "HDL Cholesterol")

Distribution of Total Cholesterol

The distribution of total cholesterol, shown below, is unimodal and nearly symmetric for both sexes. The median value is similar for both males and females and is less than the desirable value of 200 mg/dL. Both sexes have potential outliers on the right side of the distribution, with a few values larger than 350 mg/dL.

plot_histogram_by_sex(nhanes_df, "CTOT", "Total Cholesterol")

Smoking, Diabetes and Coronary heart disease

These variables are categorical, and a stack bar chart was prepared to visualise their distribution. Each variable had a possible value of either No or Yes. The stacked bar chart below provides a visual breakdown of the Yes and No relative proportions for each variable and sex.

# Function for producing the stack bar chart of the categorical variables
def plot_categorical():
    """
    Creates a stack bar chart for categorical variables

    This function uses ideas found at
    https://docs.bokeh.org/en/2.4.1/docs/user_guide/categorical.html
    """
    # Group by SEX and get the counts of the other categorical variables
    smoke_counts = nhanes_df.groupby("SEX")["SMOKES"].value_counts().unstack(
                                        fill_value=0)
    diabetes_counts = nhanes_df.groupby("SEX")["DIABETES"].value_counts().unstack(
                                        fill_value=0)
    cd_counts = nhanes_df.groupby("SEX")["CD"].value_counts().unstack(
                                        fill_value=0)

    factors = [("Smoke", "No"), ("Smoke", "Yes"),
        ("Diabetes", "No"), ("Diabetes", "Yes"),
        ("Coronary Heart Disease", "No"), ("Coronary Heart Disease", "Yes"),
    ]

    regions = ['Male', 'Female']

    source = ColumnDataSource(data=dict(
        x = factors,
        Male = [smoke_counts.iloc[0, 0], smoke_counts.iloc[0, 1],
                diabetes_counts.iloc[0, 0], diabetes_counts.iloc[0, 1],
                cd_counts.iloc[0, 0], cd_counts.iloc[0, 1]
               ],
        Female = [smoke_counts.iloc[1, 0], smoke_counts.iloc[1, 1],
                  diabetes_counts.iloc[1, 0], diabetes_counts.iloc[1, 1],
                  cd_counts.iloc[1, 0], diabetes_counts.iloc[1,1]
                 ]))

    p = figure(x_range=FactorRange(*factors), width=900, height=500,
                y_axis_label="Count", toolbar_location=None, tools="")

    p.vbar_stack(regions, x='x', width=0.9, alpha=0.5,
                    color=["#1E90FF", "#FF1493"], source=source,
                    legend_label=regions)

    p.y_range.start = 0
    p.x_range.range_padding = 0.1
    p.xaxis.major_label_orientation = 1
    p.xgrid.grid_line_color = None
    p.legend.location = "top_center"
    p.legend.orientation = "horizontal"

    if create_pdf:
        # Create static images for pdf
        export_png(p, filename="categorical_plot.png")
        display(Image(filename="categorical_plot.png"))        
    else:
        show(p)
# Produce the plot
plot_categorical()

Some observations regarding smoking:

  • There are more non-smokers than smokers.
  • Approximately the same number of males are smokers and non-smokers.
  • There are more non-smoking females than smoking females.
  • There are more non-smoking females than non-smoking males.

Some observations regarding diabetes:

  • There are many more people without diabetes than with diabetes.
  • Of the people with diabetes, approximately 50% are male and 50% are female.
  • There are more men with diabetes than females.

Some observations regarding Coronary heart disease

  • Many more people haven’t been diagnosed than those who have.
  • Of the people without coronary heart disease, approximately 50% are male and 50% are female.
  • There are more females diagnosed with coronary heart disease than males.

Correlation

A correlogram plot was used to compare the linear association between all pairs of quantitative variables. A function was created to create this plot and provide interactivity for the user. The user can select the sex and other factors such as smoking or diabetes. The plot will then be updated to show the correlation between the chosen options, e.g., female, non-smoker with diabetes.

# Function to create correlation matrix
def create_correlation_plot():
    """
    Creates a correlogram to be used as a correlation matrix.

    The ideas for this function came from the correlogram example in the
    bokeh documentation.
    https://docs.bokeh.org/en/3.3.4/docs/examples/topics/categorical/correlogram.html
    """
    # Plot correlation matrix
    groups = nhanes_df.groupby(["SEX", "SMOKES", "DIABETES"])
    df = groups.get_group((0,0,0)) # Male, non-smoker, non-diabetic
    df = df[df.columns[~df.columns.isin(["SEX", "SMOKES", "DIABETES"])]]

    all_vars = list(df.columns)
    pairs = list(combinations(all_vars, 2))
     
    x, y = list(zip(*pairs))

    # Create a dictionary to store our data with all combinations
    data = {}
    
    # Generate data for all combinations
    for gender in ["male", "female"]:
        # male=0, female =1
        g_factor = 0 if gender == "male" else 1
    
        for smoker in [False, True]:
            # no = 0, yes =1
            s_factor = 0 if not smoker else 1
        
            for diabetic in [False, True]:
                # no = 0, yes = 1
                d_factor = 0 if not diabetic else 1
            
                # Create a key for this combination
                key = f'{gender}_{"smoker" if smoker else "nonsmoker"}_{"diabetic" if diabetic else "nondiabetic"}'

                df = groups.get_group((g_factor,s_factor,d_factor))
                df = df[df.columns[~df.columns.isin(["SEX", "SMOKES", "DIABETES"])]]

                correlations = []
                for a, b in pairs:
                    matrix = np.corrcoef(df[a], df[b])
                    correlations.append(matrix[0, 1])

                data[key] = {"var_1": x, "var_2": y,
                    "correlation": correlations,
                    "dot_size": [(1+ 10 * abs(corr)) * 10 for corr in correlations],
                    "texts": [f"{x:.2f}" for x in correlations]}

    new_df= pd.DataFrame(data["male_nonsmoker_nondiabetic"])
    x_range = new_df["var_1"].unique()
    y_range = list(new_df["var_2"].unique())
    
    data_sources= {}    
    # Add all combinations to each data source
    for key in data:
        data_sources[key] = ColumnDataSource(pd.DataFrame(data[key]))
        
    p = figure(width=800, height=800,
        x_range= x_range, y_range= y_range,
        x_axis_location="above", 
        toolbar_location=None,
        tools="hover",
        tooltips=[("correlation", "@texts")],
        background_fill_color="#fafafa")

    # Create a renderers dictionary to store all line renderers
    renderers = {}

    # Create a line for each combination of factors for each metric
    for gender in ["male", "female"]:
        for smoker in [False, True]:
            for diabetic in [False, True]:
                renderer_key = f'{gender}_{"smoker" if smoker else "nonsmoker"}_{"diabetic" if diabetic else "nondiabetic"}'
                renderers[renderer_key] = p.scatter(x="var_1", y="var_2",
                    size="dot_size",
                    source=data_sources[renderer_key],
                    fill_color=linear_cmap("correlation", "RdYlGn9", -1, 1),
                    line_color="#202020",
                    visible=False)

    # Default is "male_nonsmoker_nondiabetic"
    c = renderers["male_nonsmoker_nondiabetic"]
    c.visible= True
    color_bar = c.construct_color_bar(
        location=(300, 0),
        ticker=FixedTicker(ticks=[-1, 0.0, 1]),
        title="correlation",
        major_tick_line_color=None,
        width=400,
        height=30)

    p.add_layout(color_bar, "below")
    p.axis.major_tick_line_color = None
    p.axis.major_tick_out = 0
    p.axis.axis_line_color = None
    p.grid.grid_line_color = None
    p.outline_line_color = None

    # Create MultiChoice widget with only demographic options
    options = [
        ("female", 'Female'),  # Default is male
        ("smoker", "Smoker"),  # Default is non-smoker
        ("diabetic", "Diabetic")  # Default is non-diabetic
    ]

    # Initially select no modifiers (default to male, non-smoker, non-diabetic)
    multi_choice = MultiChoice(value=[], options=options, 
                    title="Select sex and other factors:", width=300)

    # Create info text
    info_div = Div(text="""<p><b>Current Selection:</b> Male, Non-smoker, Non-diabetic</p>""", 
               width=400, height=30)

    # CustomJS callback to control visibility
    callback = CustomJS(args=dict(
            renderers=renderers, info_div=info_div
        ), code="""
        // Get the selected values from the MultiChoice widget
        const selected = cb_obj.value;
    
        // Check selections with defaults
        const isFemale = selected.includes('female');
        const gender = isFemale ? 'female' : 'male';
    
        const isSmoker = selected.includes('smoker');
        const smokerStatus = isSmoker ? 'smoker' : 'nonsmoker';
    
        const isDiabetic = selected.includes('diabetic');
        const diabeticStatus = isDiabetic ? 'diabetic' : 'nondiabetic';
    
        // Build the key
        const key = gender + '_' + smokerStatus + '_' + diabeticStatus;
    
        // Update info display
        let infoText = "<p><b>Current Selection:</b> " + 
                  (gender === 'male' ? 'Male' : 'Female') + ", " +
                  (smokerStatus === 'smoker' ? 'Smoker' : 'Non-smoker') + ", " +
                  (diabeticStatus === 'diabetic' ? 'Diabetic' : 'Non-diabetic') + "</p>";
        info_div.text = infoText;
    
        // Default all renderers to invisible first
        for (let name in renderers) {
            renderers[name].visible = false;
        }
    
        // Show the selected profile
        if (key in renderers) {
            renderers[key].visible = true;
        }
    """)

    # Attach the callback to the widget
    multi_choice.js_on_change("value", callback)

    # Create the layout and show the result
    layout = column(row(multi_choice, info_div), p)

    if create_pdf:
        # Create static images for pdf
        export_png(layout, filename="correlation_plot.png")
        display(Image(filename="correlation_plot.png"))
    else:
        show(layout)
    
    return data
correlation_data = create_correlation_plot()
# Create data table for the correlstion information
def create_correlation_table(corr_data):
    """
    Creates a bokeh data table for the correlation information. The
    table has the same interactivty as the correlogram plot.

    :param corr_data: Dictionary containg the correlation data
    """
    # Store all datasets in a JS-accessible args dict
    all_data = corr_data

    # The source attached to the DataTable
    source =  ColumnDataSource(all_data["male_nonsmoker_nondiabetic"])

    columns = [TableColumn(field="var_1", title="Variable 1"),
                TableColumn(field="var_2", title="Variable 2"),
                TableColumn(field="correlation", title="Correlation",
                            formatter=NumberFormatter(format="0.000",
                            text_align="center"))
            ]
    data_table = DataTable(source=source, columns=columns, width=500, height=600)

    # Create MultiChoice widget with only demographic options
    options = [
        ("female", 'Female'),  # Default is male
        ("smoker", "Smoker"),  # Default is non-smoker
        ("diabetic", "Diabetic")  # Default is non-diabetic
        ]

    # Initially select no modifiers (default to male, non-smoker, non-diabetic)
    multi_choice = MultiChoice(value=[], options=options, 
                    title="Select sex and behaviour factors:", width=300)

    # Create info text
    info_div = Div(text="""<p><b>Current Selection:</b> Male, Non-smoker, Non-diabetic</p>""", 
               width=400, height=30)

    # Pass all datasets as an argument to CustomJS
    callback = CustomJS(args=dict(
            source=source, all_data=all_data, info_div=info_div
        ), code="""
        // Get the selected values from the MultiChoice widget
        const selected = cb_obj.value;
    
        // Check selections with defaults
        const isFemale = selected.includes('female');
        const gender = isFemale ? 'female' : 'male';
    
        const isSmoker = selected.includes('smoker');
        const smokerStatus = isSmoker ? 'smoker' : 'nonsmoker';
    
        const isDiabetic = selected.includes('diabetic');
        const diabeticStatus = isDiabetic ? 'diabetic' : 'nondiabetic';
    
        // Build the key
        const key = gender + '_' + smokerStatus + '_' + diabeticStatus;
    
        // Update info display
        let infoText = "<p><b>Current Selection:</b> " + 
                  (gender === 'male' ? 'Male' : 'Female') + ", " +
                  (smokerStatus === 'smoker' ? 'Smoker' : 'Non-smoker') + ", " +
                  (diabeticStatus === 'diabetic' ? 'Diabetic' : 'Non-diabetic') + "</p>";
        info_div.text = infoText;

        // Overwrite the source's data
        source.data = {...all_data[key]};

        source.change.emit();
    """)

    # Attach the callback to the widget
    multi_choice.js_on_change("value", callback)

    layout = column(row(multi_choice, info_div), data_table)

    if create_pdf:
        # Create static images for pdf
        export_png(layout, filename="correlation_table.png")
        display(Image(filename="correlation_table.png"))
    else:
        show(layout)
create_correlation_table(correlation_data)

The correlogram plot and data table shown above contain lots of information. There is a strong positive linear association between the systolic and diastolic blood pressures for both sexes. Other specific observations by sex are:

Male

  • Apart from blood pressure, all other linear associations are weak (0.15 < |correlation| < 0.5) or non-existent.
  • For a non-smoking non-diabetic, there are weak linear associations between BMI and HDL cholesterol (negative), BMI and diastolic blood pressure (positive), age and systolic blood pressure (positive), diastolic blood pressure and total cholesterol (positive), age and total cholesterol (positive) and age and coronary heart disease (positive).
  • For a smoking non-diabetic, the following weak linear associations get slightly stronger: age and coronary heart disease, age and systolic blood pressure and BMI and diastolic blood pressure.
  • Some of the weak linear associations for diabetic non-smokers are similar to those for non-diabetics, such as BMI and HDL cholesterol (negative). Still, new associations appear for age and diastolic blood pressure (negative), age and total cholesterol (negative), and age and coronary heart disease (positive).
  • The weak linear association for a smoking diabetic is similar to that for a non-smoking diabetic, with the majority of changes being to the strength of the association, which mainly increases.

Female

  • Like the males, the only strong linear association is between the systolic and diastolic blood pressures. All other linear associations are weak or non-existent.
  • Compared to males, females have a stronger linear association between age and total cholesterol (positive), age and systolic blood pressure (positive), total cholesterol and systolic blood pressure (positive), and total cholesterol and HDL cholesterol (positive).
  • For a non-diabetic, smoking reduces the strength of the weak linear associations.
  • For a diabetic non-smoker, the strength of the weak linear association is lower than for a non-diabetic non-smoker.
  • For a diabetic smoker, the weak associations involving age change from positive to negative.

Scatter plot matrix

The scatter plot matrix is a helpful visualisation for understanding the pairwise relationships between the quantitative variables in the dataset. It can help identify any linear or non-linear correlations between the variables and potential outliers or unusual patterns in the data. A function was created that generates a scatter plot matrix for a given data frame and group by variable. The group by feature allows the data to be grouped by sex, smoking status, and diabetes status.

def create_scatter_plot_matrix(df, sel_var):
    """
    Creates a scatter plot matrix for the provided data frame

    This function was created with ideas from:
     * https://docs.bokeh.org/en/3.1.1/docs/examples/topics/stats/splom.html
     * https://datawranglingpy.gagolewski.com/chapter/310-matrix.html#
     
    :param df: The data frame to use for plotting
    :param sel_var: String with the name of the selector variable
    """
    # Categorical variables
    cat_vars = ["SEX", "SMOKES", "DIABETES", "CD"]
    cat_vars.remove(sel_var)
    
    # Drop the categorical columns
    df = df[df.columns[~df.columns.isin(cat_vars)]]

    all_vars = list(df.columns)

    # Remove the selector variable from all_vars
    all_vars.remove(sel_var)
    n_var = len(all_vars)

    colours = { "SEX": ("#1E90FF", "#FF1493"),
        "SMOKES": ("#008080","#ffa500"),
        "DIABETES": ("#000000","#ff0000"),
        "CD": ("#5e4fa2","#9e0142")}
    units = ["(years)","(kg/m^2)","(mmHg)", "(mmHg)", "(mg/dL)", "(mg/dL)"]
    tools ="pan,wheel_zoom,box_zoom,reset"

    # Compute min/max for each variable and create DataRange1d with padding
    x_ranges = []
    y_ranges = []
    for var in all_vars:
        vmin = df[var].min()
        vmax = df[var].max()
        padding = 0.1 * (vmax - vmin)
        x_ranges.append(DataRange1d(start=vmin - padding, end=vmax + padding))
        y_ranges.append(DataRange1d(start=vmin - padding, end=vmax + padding))
    
    sources_1 = []
    sources_2 = []
    plots = []
    pwidth = 150
    for row_idx, y in enumerate(all_vars):
        row = []
        for col_idx, x in enumerate(all_vars):
            if x == y:
                # Dummy plot for diagonal: only label, no glyphs
                p = figure(width=pwidth, height=pwidth, tools=tools)
                # Hide axes and grid
                p.xaxis.visible = False
                p.yaxis.visible = False
                p.grid.visible = False
                # Add a centred text label
                label = Label(x=50, y=60, x_units="screen", y_units="screen",
                          text=y+" "+units[col_idx], text_align="center",
                          text_baseline="middle",
                          text_font_size="8pt")
                p.add_layout(label)
                row.append(p)
                sources_1.append(None)
                sources_2.append(None)
            else:
                first_mask = (df[sel_var] == 0)
                second_mask = (df[sel_var] == 1)
                sm = ColumnDataSource(data=dict(
                    x=df.loc[first_mask, x],
                    y=df.loc[first_mask, y]))
                sf = ColumnDataSource(data=dict(
                    x=df.loc[second_mask, x],
                    y=df.loc[second_mask, y]))
                p = figure(
                    width=pwidth, height=pwidth,
                    x_axis_label=x, y_axis_label=y,
                    tools=tools,
                    x_range=x_ranges[col_idx], y_range=y_ranges[row_idx])
                # Only show x-axis label on the bottom row
                if row_idx == n_var - 1:
                    p.xaxis.axis_label = ""#x
                else:
                    p.xaxis.axis_label = ""
                # Only show y-axis label on the leftmost column
                if col_idx == 0:
                    p.yaxis.axis_label = ""#y
                else:
                    p.yaxis.axis_label = "" 
                
                p.circle("x", "y", source=sm, size=6, alpha=0.2,
                         color=colours[sel_var][0])
                p.circle("x", "y", source=sf, size=6, alpha=0.2,
                        color=colours[sel_var][1])
                p.xaxis.major_label_text_font_size = "6pt"
                p.yaxis.major_label_text_font_size = "6pt"
                row.append(p)
                sources_1.append(sm)
                sources_2.append(sf)
        plots.append(row)

    # Prepare data for JS
    data_dict = {v: df[v].values for v in all_vars}
    cat_arr = df[sel_var].values

    if sel_var == "SMOKES":
        sel_opts = ["Both", "Non-Smoker", "Smoker"]
    elif sel_var == "DIABETES":
        sel_opts = ["Both", "Non-Diabetic", "Diabetic"]
    elif sel_var == "CD":
        sel_opts = ["Both", "Not-Diagnosed", "Diagnosed"]        
    else:
        sel_opts = ["Both", "Male", "Female"]

    cat_select = Select(title="Category", value="Both", options=sel_opts)

    callback = CustomJS(
        args=dict(
            sources_f=sources_1,
            sources_s=sources_2,
            variables=all_vars,
            data_dict=data_dict,
            cat_arr=cat_arr,
            cat_select=cat_select,
            sel_opts=sel_opts
        ),
        code="""
        const cat = cat_select.value;
        const n = variables.length;
        let source_idx = 0;
        for (let row = 0; row < n; row++) {
            for (let col = 0; col < n; col++) {
                if (row === col) {
                    source_idx += 1;
                    continue;
                }
                let sm = sources_f[source_idx];
                let sf = sources_s[source_idx];
                if (cat === sel_opts[1]) {
                    sm.data = {
                        x: data_dict[variables[col]].filter((v, i) => cat_arr[i] === 0),
                        y: data_dict[variables[row]].filter((v, i) => cat_arr[i] === 0)
                    };
                    sf.data = {x: [], y: []};
                } else if (cat === sel_opts[2]) {
                    sm.data = {x: [], y: []};
                    sf.data = {
                        x: data_dict[variables[col]].filter((v, i) => cat_arr[i] === 1),
                        y: data_dict[variables[row]].filter((v, i) => cat_arr[i] === 1)
                    };
                } else { // Both
                    sm.data = {
                        x: data_dict[variables[col]].filter((v, i) => cat_arr[i] === 0),
                        y: data_dict[variables[row]].filter((v, i) => cat_arr[i] === 0)
                    };
                    sf.data = {
                        x: data_dict[variables[col]].filter((v, i) => cat_arr[i] === 1),
                        y: data_dict[variables[row]].filter((v, i) => cat_arr[i] === 1)
                    };
                }
                sm.change.emit();
                sf.change.emit();
                source_idx += 1;
            }
        }
        """
    )
    cat_select.js_on_change("value", callback)

    layout = column(cat_select, gridplot(plots)) 

    if create_pdf:
        # Create static images for pdf
        export_png(layout, filename=f"{sel_var}_spm_plot.png")
        display(Image(filename=f"{sel_var}_spm_plot.png"))
    else:
        show(layout)

Scatter plot matrix grouped by sex

create_scatter_plot_matrix(nhanes_df, "SEX")

Scatter plot matrix grouped by smoking

create_scatter_plot_matrix(nhanes_df, "SMOKES")

Scatter plot matrix grouped by smoking

create_scatter_plot_matrix(nhanes_df, "DIABETES")

Scatter plot matrix grouped by smoking

create_scatter_plot_matrix(nhanes_df, "CD")

The scatter plot matrices show positive correlations between certain variables, such as age and systolic blood pressure. Systolic blood pressure and diastolic blood pressure are linearly correlated. Some variables show non-linear correlation, such as age and diastolic blood pressure (concave), age and HDL cholesterol (convex) and BMI and HDL cholesterol (L-shaped). Some of these patterns change when grouped with smoking and or diabetes. However, there does not appear to be a clear linear relationship between any of the independent variables and the target variable, coronary heart disease.

Scatter plots

Whilst the scatter plot matrix allows for the comparison of all quantitative variables, the plots tend to be small, and individual plots can be used to inspect specific pairs of variables. A function was created that creates scatter plots for the request pairs of variables. The plots are interactive and allow the users to filter the data based on sex and other factors (smoking or diabetes). Choosing a sex removes all data from the other sex. The plots are usually coloured by sex; however, if the user selects the colour by coronary heart disease, then the points in the plot will be coloured red for diagnosed and black for not diagnosed.

Below are some example plots; these were chosen to match some of the types of shapes discussed in the previous section.

def plot_scatter_by_cat(x, y):
    """
    Creates an individual scatter plot using bokeh. The plot is interactive
    and allow users to filter the data based on sex and other factors. If a
    particular sex is selected, only data for that sex is shown.

    :param x: Name of first variable as a string
    :param y: Name of second variable as a string
    """
    groups = nhanes_df.groupby(["SEX", "SMOKES", "DIABETES"])

    good_names = {"BPSY": "Systolic blood pressure",
                  "BPDI": "Diastolic blood pressure",
                  "BMI": "Body Mass Index",
                  "CTOT": "Total Cholesterol", "CHDL": "HDL Cholesterol",
                  "AGE": "Age"}
    
    units = {"BPSY": " (mmHg)", "BPDI": " (mmHg)", "BMI": " (kg/m^2)",
             "CTOT": " (mg/dL)", "CHDL": " (mg/dL)", "AGE": " (years)"}

    # Dictionary to hold all data sources
    data_sources = {}
    
    # Generate data for all combinations
    for gender in ["male", "female"]:
        # male=0, female =1
        g_factor = 0 if gender == "male" else 1
    
        for smoker in [False, True]:
            # no = 0, yes =1
            s_factor = 0 if not smoker else 1
        
            for diabetic in [False, True]:
                # no = 0, yes = 1
                d_factor = 0 if not diabetic else 1
            
                # Create a key for this combination
                key =f'{gender}_{"smoker" if smoker else "nonsmoker"}_{"diabetic" if diabetic else "nondiabetic"}'

                df = groups.get_group((g_factor, s_factor, d_factor))
                df = df[df.columns[~df.columns.isin(["SMOKES", "DIABETES"])]]

                data_sources[key] = ColumnDataSource(
                    pd.DataFrame({
                        "var_1": df[x],
                        "var_2": df[y],
                        "group": df["SEX"],
                        "colour": ["#FF1493" if gender else "#1E90FF" for gender in df["SEX"]],
                        "cd_colour": ["red" if x else "black" for x in df["CD"]]
                        }))
    # Create the figure
    p = figure(title=f"Scatter plot of {good_names[y]} versus {good_names[x]}",
            x_axis_label=good_names[x] + units[x],
            y_axis_label=good_names[y] + units[y],
            tools="box_zoom,wheel_zoom,pan,reset", width=800, height=800)

    # Create a renderers dictionary to store all line renderers
    renderers = []

    # Create a plot each category
    for gender in ["male", "female"]:
        colour = "#1E90FF" if gender=="male" else "#FF1493"
        for smoker in [False, True]:
            for diabetic in [False, True]:
                if smoker:
                    if diabetic:
                        marker="triangle"
                    else:
                        marker="square"
                else:
                    if diabetic:
                        marker="inverted_triangle"
                    else:
                        marker="circle"
                    
                renderer_key = f'{gender}_{"smoker" if smoker else "nonsmoker"}_{"diabetic" if diabetic else "nondiabetic"}'
                
                renderers.append(p.scatter(
                    x="var_1", y="var_2",
                    source=data_sources[renderer_key],
                    color="colour",
                    marker=marker,
                    visible=True))

    # Define tooltips
    tooltips = [("(x, y)", "(@var_1, @var_2)")]

    # Add HoverTool
    hover = HoverTool(tooltips=tooltips)
    p.add_tools(hover)

    # Checkbox for data sources (applies to both sexes)
    checkbox_sources = CheckboxGroup(
        labels=["Non-Smoker/Non-Diabetic", "Non-Smoker/Diabetic", "Smoker/Non-Diabetic", "Smoker/Diabetic"],
        active=[0,1,2,3])
    
    # Checkbox for sexes
    checkbox_sexes = CheckboxGroup(labels=["Male", "Female"], active=[0,1])

    # Check box for colour by coronary heart disease
    checkbox_colourby = CheckboxGroup(labels=["Color by Coronary heart disease"],
                                      active=[])

    # Add label under check box to tell user what the colour red means
    label = Div(text='Diagnosed <span style="color:red;">red</span> symbols.',
                styles={'font-size': '12px', 'color': 'black'})
    
    # CustomJS callback
    callback = CustomJS(args=dict(
        renderers=renderers,
        sources_cb=checkbox_sources,
        sexes_cb=checkbox_sexes,
        colourby_cb=checkbox_colourby,
        data_sources=data_sources
    ), code="""
        // 0-3: male sources, 4-7: female sources
        for (let i = 0; i < 4; i++) {
            // Male
            renderers[i].visible = sources_cb.active.includes(i) && sexes_cb.active.includes(0);
            // Female
            renderers[i+4].visible = sources_cb.active.includes(i) && sexes_cb.active.includes(1);
        }

        // Update colours in all data sources
        for (const key in data_sources) {
            let source = data_sources[key];
            let data = source.data;
            let group = data['group'];
            let colour = data['colour'];
            let cd_colour = data['cd_colour'];
            if (colourby_cb.active.includes(0)) {
                for (let i = 0; i < group.length; i++) {
                    colour[i] = cd_colour[i];
                }
            } else {
            for (let i = 0; i < group.length; i++) {
                    colour[i] = (group[i] === 0) ? "#1E90FF" : "#FF1493";                    
                }
            }
            source.change.emit();
        }
    """)
    checkbox_sources.js_on_change("active", callback)
    checkbox_sexes.js_on_change("active", callback)
    checkbox_colourby.js_on_change("active", callback)

    layout = row(p, column(
                    checkbox_sexes,
                    checkbox_sources,
                    checkbox_colourby,
                    label
                )
            )

    if create_pdf:
        # Create static images for pdf
        export_png(layout, filename=f"{x}_v_{y}_s_plot.png")
        display(Image(filename=f"{x}_v_{y}_s_plot.png"))
    else:
        show(layout)

The figure below shows the scatter plot of diastolic versus systolic blood pressure. While there is some scatter, the plot does exhibit a positive relationship between the variables.

plot_scatter_by_cat(x = "BPSY", y = "BPDI")

The plot below shows the scatter plot of systolic blood pressure versus age, an example of a weak linear association, especially for female data. The scatter in systolic blood pressure measurement increases with age.

plot_scatter_by_cat(x = "AGE", y = "BPSY")

The plot below shows the scatter plot of HDL cholesterol versus age. If you isolate the male data, you will see the data has a convex trend, i.e. HDL levels are higher at low and high ages, with a minimum in the middle age.

plot_scatter_by_cat(x = "AGE", y = "CHDL")

The plot below shows the scatter plot of diastolic blood pressure versus age. Again, isolating the male data shows a concave trend, i.e., diastolic blood pressure is lower in young and old people and higher in the middle age.

plot_scatter_by_cat(x = "AGE", y = "BPDI")

The last plot below shows the scatter plot of BMI versus HDL cholesterol, which is an example of what I describe as an L-shaped trend. If you colour this plot to coronary heart disease, you will notice that many of the points for diagnosed coronary heart disease occur at the lower values of HDL cholesterol.

plot_scatter_by_cat(x = "CHDL", y = "BMI")

Outliers

Outliers have already been mentioned when discussing the distributions of the quantitative variables. Here, I produce a boxplot using bokeh to allow visualisation of the outliers.

def create_boxplot(df):
    """
    Creates a boxplot using bokeh for the quantitative variables in the
    provided data frame.

    Ideas for this function came from the boxplot example in the bokeh
    documentation
    https://docs.bokeh.org/en/3.0.2/docs/examples/topics/stats/boxplot.html

    :param df: Input data frame
    """
    # Categorical variables
    cat_vars = ["SEX", "SMOKES", "DIABETES", "CD"]
    
    # Drop the categorical columns
    df = df[df.columns[~df.columns.isin(cat_vars)]]

    # Get list of remaining columns
    cols = list(df.columns)

    colours = Category10[len(cols)]

    # Calcualte the statistics
    q1 = df[cols].quantile(0.25)
    q2 = df[cols].quantile(0.50)
    q3 = df[cols].quantile(0.75)
    iqr = q3 - q1
    upper = q3 + 1.5 * iqr
    lower = q1 - 1.5 * iqr

    # Cap lower whisker at zero
    lower = lower.clip(lower=0)

    # Create source
    source = ColumnDataSource(data=dict(
        variables=cols,
        q1=q1.values,
        q2=q2.values,
        q3=q3.values,
        upper=upper.values,
        lower=lower.values,
        colour=colours
    ))

    # Prepare outliers
    outx = []
    outy = []
    for i, col in enumerate(cols):
        col_outliers = df[(df[col] > upper[col]) | (df[col] < lower[col])][col]
        outx.extend([col] * col_outliers.count())
        outy.extend(col_outliers.values)

    outlier_source = ColumnDataSource(data={
        'variable': outx,
        'value': outy
    })
    
    p = figure(x_range=cols, background_fill_color="#efefef",
           width=800, height=800, 
           title="Boxplot of the quantitative variables in the data frame",
           tools="box_zoom,wheel_zoom,pan,reset"
          )

    # Add boxes
    box_1 = p.vbar(x='variables', width=0.7, bottom='q2', top='q3',
                    source=source, fill_color='colour', line_color="black")
    box_2 = p.vbar(x='variables', width=0.7, bottom='q1', top='q2',
                    source=source, fill_color='colour', line_color="black")

    # Add whiskers
    p.add_layout(Whisker(source=source, base="variables", upper="upper",
                        lower="lower"))
    # Add outliers
    outlier_renderer = p.circle(x='variable', y='value', source=outlier_source,
                                size=6, color="black", fill_alpha=0.6)

    hover_box = HoverTool(
        tooltips=[
            ("Variable", "@variables"),
            ("Q1", "@q1"),
            ("Median", "@q2"),
            ("Q3", "@q3"),
            ("Lower Whisker", "@lower"),
            ("Upper Whisker", "@upper"),
        ],
        renderers=[box_1, box_2]
    )
    p.add_tools(hover_box)

    hover_outlier = HoverTool(
        tooltips=[
            ("Variable", "@variable"),
            ("Outlier Value", "@value"),
        ],
        renderers=[outlier_renderer]
    )
    p.add_tools(hover_outlier)

    
    if create_pdf:
        # Create static images for pdf
        export_png(p, filename="box_plot.png")
        display(Image(filename="box_plot.png"))
    else:
        show(p)

The boxplot below shows age has no outliers, BMI has outliers above the top whisker, and the systolic and diastolic blood pressures have outliers above and below the whiskers. It has already been noted that the lower outliers had values that seemed too low for a person and, therefore, need further investigation. Total and HDL cholesterol have lower and higher outliers. A significant number of higher outliers is observed for these variables. Usually, the next step would be to consider treatments of the outliers, either trimming the data or transforming the data to remove outliers. In this analysis, we acknowledge the outlier’s existence but apply no treatment to remove or minimise them.

create_boxplot(nhanes_df)

Machine learning model for classification

In this section, I look at answering three questions:

  • Given an individual’s age, BMI, systolic and diastolic blood pressure, HDL and total cholesterol, whether they smoke, if they have diabetes and if they have been diagnosed with coronary heart disease, can I predict their sex?
  • Given an individual’s sex, age, BMI, systolic and diastolic blood pressure, HDL and total cholesterol, whether they smoke and if they have been diagnosed with coronary heart disease, can I predict if they have diabetes?
  • Given an individual’s sex, age, BMI, systolic and diastolic blood pressure, HDL and total cholesterol, whether they smoke, and if they have diabetes, can I predict if they have been diagnosed with coronary heart disease?

All the questions above involve classifying individuals. Therefore, I have chosen to use a K-Nearest Neighbours model to answer these questions. The steps required to generate and assess a classifier model have been incorporated into two functions, KNClassifier and assess_model. The KNClassifier function performs the following steps:

  • Drops any columns not needed for the models,
  • Copy the variable to predict into a new data frame,
  • Drops the variable to predict from the data frame and assign the resultant data frame to a new variable,
  • Standardises the independent variables to have a zero mean and unit standard deviation,
  • Split the data into training and testing sets. A test size of 20% of the overall data is used.
  • Fit the KNN model to the training data using k = 5,
  • Predicts the variable using the testing dataset,
  • Calls assess_model.

The assess_model function performs the following steps:

  • Calculates the model’s accuracy, precision, recall and F1-measure.
  • Prints out the metrics to the screen
  • Calculate the confusion matrix
  • Creates a bokeh plot of the confusion matrix
def assess_model(y, yHat, var, labels):
    """
    Assess a model for accuracy, precision, recall and F1 measure.
    Prints a summary table and plots a confusion matrix.

    Metrics are calculated using sklearn.metrics.

    :param y: Pandas dataframe with true vales 
    :param yHat: Pandas dataframe with predicted values
    :param labels: List of labels for the confusion matrix plot
    """
    # Calculate model accuracy
    accuracy = sklearn.metrics.accuracy_score(y, yHat)

    # Calculate model precision
    precision = sklearn.metrics.precision_score(y, yHat)

    # Calculate model recall
    recall = sklearn.metrics.recall_score(y, yHat)

    # Calcualte model f1 score
    f1 = sklearn.metrics.f1_score(y, yHat)

    # Print summary
    print(f"{'Accuracy':<20}{accuracy*100:.2f} %")
    print(f"{'Precision':<20}{precision*100:.2f} %")
    print(f"{'Recall':<20}{recall*100:.2f} %")
    print(f"{'F1 measure':<20}{f1*100:.2f} %")

    # Calculate the confusion matrix
    model_confusion_matrix = sklearn.metrics.confusion_matrix(y, yHat)
    
    # Prepare data for Bokeh
    xname = []
    yname = []
    value = []
    text_colour = []
    vmin, vmax = model_confusion_matrix.min(), model_confusion_matrix.max()
    threshold = vmin + (vmax - vmin) * 0.5
    
    for i in range(2):
        for j in range(2):
            val = model_confusion_matrix[i, j]
            xname.append(labels[j])
            yname.append(labels[i])
            value.append(val)
            # Use white text for dark cells, black for light cells
            text_colour.append("white" if val > threshold else "black")

    source = ColumnDataSource(data=dict(xname=xname, yname=yname, value=value,
            text_colour=text_colour, val_str=[str(v) for v in value]))

    # Set up colour mapper
    colors = Blues256[::-1]  # Reverse so higher = darker
    mapper = LinearColorMapper(palette=colors, low=min(value), high=max(value))

    # Create figure
    p = figure(title=f"Confusion Matrix for prediction of {var}" ,
           x_range=labels, y_range=list(reversed(labels)),
           x_axis_location="above", width=400, height=400,
           tools="", toolbar_location=None)

    # Add cells to the matrix
    p.rect(x="xname", y="yname", width=1, height=1, source=source,
       line_color="black", fill_color={'field': 'value', 'transform': mapper})

    # Add text annotations
    labels_ = LabelSet(x="xname", y="yname", text="val_str",  level="glyph",
               x_offset=-10, y_offset=0, source=source,text_color="text_colour")
    p.add_layout(labels_)

    # Colour bar
    color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size="10pt",
                ticker=BasicTicker(desired_num_ticks=5),
                formatter=PrintfTickFormatter(format="%d"),
                label_standoff=6, border_line_color=None, location=(0,0))
    p.add_layout(color_bar, "right")

    # Axis labels
    p.xaxis.axis_label = "Predicted"
    p.yaxis.axis_label = "True"

    if create_pdf:
        # Create static images for pdf
        export_png(p, filename=f"{var}_confusion_matrix.png")
        display(Image(filename=f"{var}_confusion_matrix.png"))
    else:        
        show(p)
def KNClassifier(df, pred_var=None, labels=["No", "Yes"] , drop_cols=None):
    """
    Creates a K-Nearest Neighbours classifier for the requested prediction
    variable.

    The function also calls assess_model to print a summary of the perofrmance
    metrics and plots a confusion matrix.

    :param df: Data frame containg data
    :param pred_var: Name of the variable to predict as a string
    :param labels: List of labels for the confusion matrix plot
    :param drop_cols: (Optional) List of column names to exclude from the model
     fitting 
    :return KNN model
    """
    print(f"Prediction  variable: {pred_var}")

    if drop_cols:
        # Dropping requested columns
        df = df.drop(columns=drop_cols)
        
    # Get a copy of the variable we are trying to predict
    y = df[[pred_var]]
    
    # Drop the columns we are trying to predict
    x = df.drop(pred_var,axis=1)
        
    std_scaler = sklearn.preprocessing.StandardScaler()

    # Standardise independent variables by removing the mean and scaling to unit variance.
    x = std_scaler.fit_transform(x)

    # Split the data into train and test.
    x_train, x_test, y_train, y_test = sklearn.model_selection.train_test_split(
        x, y, test_size=0.2, random_state=10)

    print(f"Number of training points: {x_train.shape[0]}")
    print(f"Number of testing points: {x_test.shape[0]}")

    # Create a KNN classifier
    knn = sklearn.neighbors.KNeighborsClassifier(n_neighbors=5)

    # Fit the KNN, making sure the y data is flattened
    knn.fit(x_train, np.ravel(y_train,order="C"))

    # Predict on the model
    y_test_knn_pred = knn.predict(x_test)

    # Assess the model
    assess_model(y_test, y_test_knn_pred, pred_var, labels)
    
    # Return the model
    return knn

Predicting the sex of individuals based on health data

This section uses the health data to build a classifying model to see if the person’s sex can be predicted from that data.

sex_mapping = {0: 'Male', 1: 'Female'}
sex_value_counts = nhanes_df["SEX"].value_counts(normalize=True)
sex_value_counts.index = sex_value_counts.index.map(sex_mapping)
percentage_counts_sex = sex_value_counts.mul(100).round(2).astype(str) + "%"
print("\nValue counts as percentages for the column:")
print(percentage_counts_sex)

Value counts as percentages for the column:
SEX
Female    51.08%
Male      48.92%
Name: proportion, dtype: object

Before building the model, I checked the proportion of females in the sample. From the above information, you can see females make up 51% of the sample, so this is a balanced classification problem.

knn_1 = KNClassifier(nhanes_df, pred_var="SEX", labels=["Male", "Female"])
Prediction  variable: SEX
Number of training points: 5038
Number of testing points: 1260
Accuracy            66.27 %
Precision           67.50 %
Recall              63.56 %
F1 measure          65.48 %

The model performance on the test data is shown above; 66% of the time, it accurately predicts the sex of the individual is female. This is higher than 51%, which is the outcome if it labelled everyone as female. The precision shows that 67.5% of people labelled as female are indeed female. The recall indicates that 63.5% of females will be correctly labelled as female. The harmonic average of precision and recall is 65.5%. My conclusion is that the classifier model is average.

Predicting if someone is diabetic

db_mapping = {0: 'No', 1: 'Yes'}
db_value_counts = nhanes_df["DIABETES"].value_counts(normalize=True)
db_value_counts.index = db_value_counts.index.map(db_mapping)
percentage_counts_db = db_value_counts.mul(100).round(2).astype(str) + "%"
print("\nValue counts as percentages for the column:")
print(percentage_counts_db)

Value counts as percentages for the column:
DIABETES
No     85.5%
Yes    14.5%
Name: proportion, dtype: object

Only 14.5% of the sample is diabetic, which is a very imbalanced classification problem.

knn_2 = KNClassifier(nhanes_df, pred_var="DIABETES")
Prediction  variable: DIABETES
Number of training points: 5038
Number of testing points: 1260
Accuracy            85.24 %
Precision           47.47 %
Recall              25.97 %
F1 measure          33.57 %

The model performance on the test data is shown above; 85% of the time, it accurately predicts the individual’s diabetic status (non-diabetic or diabetic). Given that 85.5% of the people in the sample are non-diabetic, the model performs no better than assuming everyone is non-diabetic. The precision shows that 47.5% of people labelled as diabetic indeed have diabetes. Whilst the recall indicates that 26% of diabetics will be correctly labelled as diabetic. My conclusion is that the classifier model for diabetes prediction is terrible. The underrepresentation of diabetic people in the data set is probably the cause of the poor model performance.

Predicting if someone has coronary heart disease

cd_mapping = {0: 'No', 1: 'Yes'}
cd_value_counts = nhanes_df["CD"].value_counts(normalize=True)
cd_value_counts.index = cd_value_counts.index.map(db_mapping)
percentage_counts_cd = cd_value_counts.mul(100).round(2).astype(str) + "%"
print("\nValue counts as percentages for the column:")
print(percentage_counts_cd)

Value counts as percentages for the column:
CD
No     96.71%
Yes     3.29%
Name: proportion, dtype: object

Only 3.3% of the sample has been diagnosed with coronary heart disease. Based on experience with diabetes prediction, the model is unlikely to be useful.

knn_3 = KNClassifier(nhanes_df, pred_var="CD")
Prediction  variable: CD
Number of training points: 5038
Number of testing points: 1260
Accuracy            96.03 %
Precision           14.29 %
Recall              2.22 %
F1 measure          3.85 %

As expected, the model’s performance is very poor, with a precision of 14.3% and a recall of only 2.2%. From the confusion matrix, you can see that 44 people who had coronary heart disease were predicted not to have the disease and, therefore, would go untreated. It is hypothesised that the very large underrepresentation of people with coronary heart disease in the sample is the cause for such a poor-performing model.

To test this hypothesis, I filtered the data to only include people over the age of 45 (the scatter lots showed that very few young people had coronary heart disease) and those who were smokers. This increased the proportion in the sample with a diagnosis of coronary heart disease to 7.4%.

# Only people 45 and over
nhanes_df_filt = nhanes_df[(nhanes_df["AGE"] >= 45)]

# Group by smoking
groups = nhanes_df_filt.groupby(["SMOKES"])

# Choose only the smokers
df_2 = groups.get_group((1,))
df_2 = df_2[df_2.columns[~df_2.columns.isin(["SMOKES"])]]

cd_mapping = {0: 'No', 1: 'Yes'}
cd_value_counts = df_2["CD"].value_counts(normalize=True)
cd_value_counts.index = cd_value_counts.index.map(db_mapping)
percentage_counts_cd = cd_value_counts.mul(100).round(2).astype(str) + "%"

print("\nValue counts as percentages for the column:")
print(percentage_counts_cd)

knn_4 = KNClassifier(df_2, pred_var="CD")

Value counts as percentages for the column:
CD
No     92.85%
Yes     7.15%
Name: proportion, dtype: object
Prediction  variable: CD
Number of training points: 1297
Number of testing points: 325
Accuracy            92.31 %
Precision           16.67 %
Recall              4.76 %
F1 measure          7.41 %

Filtering the data to only include people who are more likely to have been diagnosed with coronary heart disease improved the model’s precision (16.7%) and recall (4.8%). This confirms that having a more even representation of people with coronary heart disease in the sample would help improve the performance, and therefore, model’s usefullness.

Future work

This analysis was not meant to be a fully comprehensive data science study into predicting coronary heart disease from health data. Thus, there are several areas where this analysis could be improved:

  • I used Coronary heart disease data as a substitute for Cardiovascular disease. Consider using a different dataset that contains data for Cardiovascular disease.
  • Investigate and implement treatments for the outliers in the data.
  • Improve the balance of the classification problem by adding further samples with coronary heart disease or implementing a statistical sampling technique to generate synthetic samples.
  • Explore other factors that might have a stronger relationship to the development of coronary heart disease, such as physical activity level.
  • Investigation of other machine learning classification models.

Summary

This notebook analyses medical data from the NHANES survey (2017-2020) to predict a diagnosis of coronary heart disease (cardiovascular disease) using machine learning techniques with Python and Bokeh visualizations.

Key Findings from Exploratory Data Analysis:

  • Age distribution is uniform across sexes.
  • BMI shows a right-skewed distribution with more severely obese females than males.
  • Blood pressure measurements reveal some potential outliers requiring investigation.
  • A strong positive correlation exists between systolic and diastolic blood pressure.
  • Most other variable correlations are weak.
  • Clear patterns emerge in smoking (more male smokers), diabetes (balanced by sex), and coronary heart disease (more female diagnoses).

Machine Learning Results:

Three K-Nearest Neighbors classification models were tested:

  1. Predicting Sex: 66.27% accuracy - considered average performance.
  2. Predicting Diabetes: 85.24% accuracy but poor precision (47.5%) and recall (26%) due to class imbalance (only 14.5% diabetic).
  3. Predicting Coronary Heart Disease: Abysmal performance (96% accuracy but only 2.2% recall) due to severe class imbalance (only 3.3% with disease).

Main Conclusions:

The study demonstrates that predicting cardiovascular disease from basic health metrics is challenging, primarily due to the dataset’s class imbalance. The severe underrepresentation of people with coronary heart disease makes reliable prediction difficult with standard machine-learning approaches.

Future Improvements:

The analysis suggests using more balanced datasets, investigating outlier treatments, exploring additional risk factors like physical activity, and testing other machine learning models to improve prediction accuracy.

Back to top