As part of the final project for my Data Science For Clinical Research class in graduate school I needed to create a machine learning model capable of detecting asthmatic subjects. The purpose of this project was not to create a production ready model using state-of-the art machine learning algorithms but to understand the fundamentals of data science. These fundamentals are broken down into three main parts:
The class instructor gave us the option to complete this project with the programming language of our choice. The only requirements were to do a class presentation describing the steps taken, submit a document containing all code and output and to use at least two machine learning algorithms. This post represents both the presentation and document submission requirements of the project. I decided to use R for the data extraction and transformation steps due to the numerous R packages that make connecting to databases or transforming data extremely easy. For the machine learning portion I was a bit conflicted between tidymodels
in R and scikit-learn
in Python. Ultimately, I decided to go with Python as one of the main reasons for taking this class was to improve my Python skills. Enjoy!
Let’s begin by creating the ohdsi
database from the MySQL data dump. The class instructor provided us with the 900 MB MySQL data dump and I’ve upload it to google drive for anyone that would like to experiment with the data or recreate this analysis. I have MySQL set up locally and have executed the following commands from the command line to create the database.
mysql -u root
create database ohdsi;
use ohdsi;
# import the database
source ./data/ohdsi_sample.sql;
# confirm that tables were created
show tables;
library(tidyverse)
library(RMariaDB)
library(DBI)
library(lubridate)
library(kableExtra)
library(janitor)
library(gtsummary)
library(recipes)
library(rsample)
library(reticulate)
# connect to database
<- dbConnect(MariaDB(),
con dbname = 'ohdsi',
host = 'localhost',
user = Sys.getenv('DB_USER'),
password = Sys.getenv('DB_PASSWORD'))
dbListTables(con, "ohdsi")
## [1] "concept" "condition_era" "condition_occurrence"
## [4] "drug_era" "drug_exposure" "location"
## [7] "measurement" "observation" "observation_period"
## [10] "payer_plan_period" "person" "procedure_occurrence"
## [13] "provider" "relationship" "source_to_concept_map"
## [16] "specimen" "visit_occurrence" "vocabulary"
There are multiple clinical data tables available but for this project only the following will be used:
The data dictionary for each table can be found here.
As the goal is to create a model that can predict whether or not a subject has asthma the first step will be to identify subjects as cases (asthma present) or controls (asthma not present). The condition_occurrence
table will be queried to return all the conditions for each subject. The concept
table will also be joined to the condition_occurrence
table to get a descriptive name for each condition. If a subject has at least one condition that matches the string asthma
the subject will be tagged as a case otherwise the subject will be a control.
<- "
sql select
person_id,
lower(concept_name) as concept_name,
condition_start_date,
case
when lower(concept_name) like '%asthma%' then 'case'
else 'control'
end as status
from
condition_occurrence
-- join to replace id with names
left join concept on condition_concept_id = concept_id
order by
person_id, status, condition_start_date
"
<- dbGetQuery(con, sql)
condition_occurence
kable(head(condition_occurence, 20), caption = "Table 1") %>%
kable_styling(full_width = FALSE)
person_id | concept_name | condition_start_date | status |
---|---|---|---|
1 | backache | 2009-07-25 | control |
1 | low back pain | 2009-07-25 | control |
1 | menopausal syndrome | 2009-07-25 | control |
1 | thoracic radiculitis | 2009-07-25 | control |
1 | hypocalcemia | 2009-10-14 | control |
1 | postoperative pain | 2009-10-14 | control |
1 | osteoporosis | 2010-03-12 | control |
1 | congestive heart failure | 2010-03-12 | control |
1 | antiallergenic drug adverse reaction | 2010-03-12 | control |
1 | pure hypercholesterolemia | 2010-03-12 | control |
1 | retention of urine | 2010-03-12 | control |
1 | bipolar disorder | 2010-04-01 | control |
1 | bipolar i disorder, single manic episode, in full remission | 2010-04-01 | control |
1 | neck sprain | 2010-08-17 | control |
1 | subchronic catatonic schizophrenia | 2010-11-05 | control |
1 | schizophrenia | 2010-11-05 | control |
2 | asthma | 2008-12-09 | case |
2 | closed fracture of phalanx of foot | 2008-10-04 | control |
2 | closed fracture of phalanx of foot | 2008-10-04 | control |
2 | mononeuropathy of lower limb | 2008-10-04 | control |
Notice the order by
clause in the SQL query. This was necessary as a subject can have multiple diagnoses which can lead to the subject being tagged as both a case and a control. Subject #2 is evidence of this as shown in Table 1. Thus, if a subject has at least one asthma diagnosis we’d like to keep the initial occurrence of the asthma diagnosis. If a subject does not have an asthma diagnosis we’d like to keep the first diagnosis for that subject.
<- condition_occurence %>%
person_status distinct(person_id, .keep_all = TRUE)
kable(head(person_status), caption = "Table 2") %>%
kable_styling(full_width = FALSE)
person_id | concept_name | condition_start_date | status |
---|---|---|---|
1 | backache | 2009-07-25 | control |
2 | asthma | 2008-12-09 | case |
3 | constipation | 2009-10-11 | control |
4 | spinal stenosis | 2009-09-20 | control |
5 | osteochondropathy | 2008-06-01 | control |
6 | after-cataract with vision obscured | 2009-09-01 | control |
Query the person
table to get gender
, race
and year_of_birth
. An additional variable named age_at_diagnosis
will also be created based on the difference in condition_start_date
and year_of_birth
.
<- "
sql select
person_id,
c1.concept_name as gender,
c2.concept_name as race,
year_of_birth
from
person
left join concept c1 on gender_concept_id = c1.concept_id
left join concept c2 on race_concept_id = c2.concept_id
"
<- dbGetQuery(con, sql)
person
<- person %>%
demographics left_join(person_status %>%
select(person_id, condition_start_date),
by = "person_id") %>%
mutate(age_at_diagnosis = year(condition_start_date) - year_of_birth) %>%
select(-c(year_of_birth, condition_start_date))
kable(head(demographics), caption = "Table 3") %>%
kable_styling(full_width = FALSE)
person_id | gender | race | age_at_diagnosis |
---|---|---|---|
1 | MALE | White | 86 |
2 | MALE | White | 65 |
3 | FEMALE | White | 73 |
4 | MALE | No matching concept | 68 |
5 | MALE | White | 72 |
6 | MALE | Black or African American | 66 |
Query the drug_era
table to return all drugs that a subject has taken.
<- "
sql select
person_id,
lower(concept_name) as drug,
drug_era_start_date
from
drug_era
left join concept on drug_concept_id = concept_id
"
<- dbGetQuery(con, sql)
drug_era
kable(head(drug_era), caption = "Table 4") %>%
kable_styling(full_width = FALSE)
person_id | drug | drug_era_start_date |
---|---|---|
1 | cytomegalovirus immune globulin | 2010-04-17 |
1 | thioridazine | 2008-08-21 |
1 | thiothixene | 2009-03-19 |
1 | trazodone | 2009-04-09 |
1 | methocarbamol | 2010-05-20 |
1 | methylphenidate | 2009-02-16 |
Based on Table 4 let’s filter to all drugs taken within two years after the first diagnosis and then keep only the top 10 drugs for cases and the top 10 drugs for controls
<- person_status %>%
top_drugs_by_person_status select(person_id, status, condition_start_date) %>%
left_join(drug_era, by = "person_id") %>%
# get time between the date a drug was started and the date a diagnosis was provided
mutate(days_passed = difftime(
units = "days")) %>%
drug_era_start_date, condition_start_date, # filter to all drugs taken within two years after condition_start_date
filter(between(days_passed, 0, 365*2)) %>%
count(status, drug) %>%
group_by(status) %>%
# get the top 10 drugs by person status
top_n(10, wt = n) %>%
ungroup() %>%
distinct(drug) %>%
arrange(drug)
kable(top_drugs_by_person_status, caption = "Table 5") %>%
kable_styling(full_width = FALSE)
drug |
---|
acetaminophen |
dipyridamole |
hydrochlorothiazide |
hydrocodone |
levothyroxine |
lisinopril |
lovastatin |
metformin |
oxygen |
propranolol |
simvastatin |
Finally, let’s create a data frame where each row represents a subject and the columns are the drugs shown in Table 5. The columns are binary where 1
specifies that the subject has taken that drug within two years of the diagnosis and 0
specifies the opposite.
<- drug_era %>%
drugs_taken inner_join(top_drugs_by_person_status, by = "drug") %>%
# tag the subject as being on drug
mutate(on_drug = 1) %>%
# ensure that person drug combination is unique
distinct(person_id, drug, .keep_all = TRUE) %>%
# convert from long to wide data frame
pivot_wider(
-drug_era_start_date,
names_from = drug,
values_from = on_drug,
# tag subject as not on drug
values_fill = 0,
names_sort = TRUE
)
kable(head(drugs_taken), caption = "Table 6") %>%
kable_styling() %>%
scroll_box(width = "100%")
person_id | acetaminophen | dipyridamole | hydrochlorothiazide | hydrocodone | levothyroxine | lisinopril | lovastatin | metformin | oxygen | propranolol | simvastatin |
---|---|---|---|---|---|---|---|---|---|---|---|
1 | 1 | 0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
2 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 |
3 | 0 | 0 | 1 | 0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
4 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 |
5 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
6 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
Query the condition_era
table to return all diagnoses for a subject. Additionally, all instances of asthma will be removed from condition_era
as there should not be any reference to an asthma diagnosis within the data used in the machine learning process.
<- "
sql select
person_id,
lower(concept_name) as diagnosis,
condition_era_start_date
from
condition_era
left join concept on condition_concept_id = concept_id
where
concept_name not like '%asthma%'
"
<- dbGetQuery(con, sql)
condition_era
kable(head(condition_era), caption = "Table 7") %>%
kable_styling(full_width = FALSE)
person_id | diagnosis | condition_era_start_date |
---|---|---|
1 | neck sprain | 2010-08-17 |
1 | osteoporosis | 2010-03-12 |
1 | backache | 2009-07-25 |
1 | retention of urine | 2010-03-12 |
1 | low back pain | 2009-07-25 |
1 | congestive heart failure | 2010-03-12 |
To identify other conditions people with asthma may experience I’ve chosen a list of conditions based on online research. This list is not exhaustive and may not be 100% correct. I have absolutely zero domain knowledge in clinical research but, with Google as my sidekick, I was able to determine clinically relevant conditions that may be useful for predicting asthma status:
These four diagnoses seem to fall under the category of conduction disorder
:
malaise and fatigue
- source 1
<- c(
dxs "atrial fibrillation",
"chest pain",
"conduction disorder",
"congestive heart failure",
"coronary arteriosclerosis",
"gastroesophageal reflux disease",
"malaise and fatigue"
)
<- person_status %>%
common_dxs_with_asthma select(person_id, condition_start_date, status) %>%
left_join(condition_era, by = "person_id") %>%
mutate(days_passed = difftime(condition_era_start_date, condition_start_date, units = "days")) %>%
filter(
> 0 &
days_passed str_detect(diagnosis, paste0("^(", paste(dxs, collapse="|"), ")"))
%>%
) mutate(
clean_diagnosis_name = case_when(
str_detect(diagnosis, "coronary arteriosclerosis") ~ "coronary arteriosclerosis",
TRUE ~ diagnosis
)%>%
) distinct(diagnosis, clean_diagnosis_name)
<- condition_era %>%
diagnoses inner_join(common_dxs_with_asthma, by = "diagnosis") %>%
distinct(person_id, clean_diagnosis_name) %>%
mutate(dx_present = 1) %>%
pivot_wider(
names_from = clean_diagnosis_name,
values_from = dx_present,
values_fill = 0,
names_sort = TRUE
%>%
) clean_names()
kable(head(diagnoses), caption = "Table 8") %>%
kable_styling() %>%
scroll_box(width = "100%")
person_id | atrial_fibrillation | chest_pain | conduction_disorder_of_the_heart | congestive_heart_failure | coronary_arteriosclerosis | gastroesophageal_reflux_disease | malaise_and_fatigue |
---|---|---|---|---|---|---|---|
1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
2 | 1 | 0 | 1 | 1 | 1 | 1 | 0 |
4 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
5 | 0 | 0 | 1 | 0 | 1 | 0 | 1 |
6 | 0 | 0 | 0 | 0 | 1 | 1 | 0 |
7 | 0 | 0 | 1 | 1 | 1 | 1 | 1 |
The lab tests data frame was constructed in the same way as the diagnoses data frame. The main difference is that the measurement
table was queried instead of the condition_era
table. One downside of the measurement
table is that it did not contain the actual values for the lab tests. It only contained the name of the lab tests taken by a subject. The three specific tests that were filtered for are electrocardiogram
, iron binding capacity
and thyroid stimulating hormone
.
<- "
sql SELECT
person_id,
lower(concept_name) as measurement_name,
measurement_date
FROM
measurement
left join concept on measurement_concept_id = concept_id
"
<- dbGetQuery(con, sql)
measurement
# disconnet from DB
dbDisconnect(con)
<- person_status %>%
common_tests_for_asthma select(person_id, condition_start_date, status) %>%
left_join(measurement, by = "person_id") %>%
mutate(days_passed = difftime(measurement_date, condition_start_date, units = "days")) %>%
filter(
> 0 &
days_passed str_detect(
measurement_name,"^(iron binding capacity|electrocardiogram|thyroid stimulating hormone)"
)%>%
) mutate(
clean_measurement_name = case_when(
str_detect(measurement_name, "electrocardiogram") ~ "electrocardiogram",
str_detect(measurement_name, "iron binding capacity") ~ "iron binding capacity",
TRUE ~ "thyroid stimulating hormone"
)%>%
) distinct(measurement_name, clean_measurement_name)
<- measurement %>%
lab_tests inner_join(common_tests_for_asthma, by = "measurement_name") %>%
distinct(person_id, clean_measurement_name) %>%
mutate(measurement_taken = 1) %>%
pivot_wider(
names_from = clean_measurement_name,
values_from = measurement_taken,
values_fill = 0,
names_sort = TRUE
%>%
) clean_names()
kable(head(lab_tests), caption = "Table 9") %>%
kable_styling(full_width = FALSE)
person_id | electrocardiogram | iron_binding_capacity | thyroid_stimulating_hormone |
---|---|---|---|
2 | 1 | 0 | 1 |
5 | 1 | 0 | 0 |
6 | 1 | 0 | 0 |
7 | 1 | 1 | 1 |
8 | 1 | 1 | 1 |
11 | 1 | 0 | 1 |
The person_status
data frame created in in Identify Cases and Controls section is the master data frame. The demographics
, drugs_taken
, diagnosis
and lab_tests
data frames will be left joined to the master data frame using person_id
as the primary key. If a subject did not appear in the drugs_taken
, diagnosis
, or lab_tests
data frames it means that the subject did not have any of these selected values. Thus, these NA values will be replaced with 0.
<- list(person_status, demographics,
asthma_prediction_data %>%
drugs_taken, diagnoses, lab_tests) reduce(left_join, by = "person_id") %>%
# remove unknown race
filter(race != "No matching concept") %>%
# if subject has no drugs, diagnosis or labs data replace NA with 0
mutate_if(is.numeric, replace_na, 0) %>%
select(-c(person_id, concept_name, condition_start_date))
kable(head(asthma_prediction_data), caption = "Table 10") %>%
kable_styling() %>%
scroll_box(width = "100%")
status | gender | race | age_at_diagnosis | acetaminophen | dipyridamole | hydrochlorothiazide | hydrocodone | levothyroxine | lisinopril | lovastatin | metformin | oxygen | propranolol | simvastatin | atrial_fibrillation | chest_pain | conduction_disorder_of_the_heart | congestive_heart_failure | coronary_arteriosclerosis | gastroesophageal_reflux_disease | malaise_and_fatigue | electrocardiogram | iron_binding_capacity | thyroid_stimulating_hormone |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
control | MALE | White | 86 | 1 | 0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
case | MALE | White | 65 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 1 | 1 | 1 | 0 | 1 | 0 | 1 |
control | FEMALE | White | 73 | 0 | 0 | 1 | 0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
control | MALE | White | 72 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 1 | 0 | 0 |
control | MALE | Black or African American | 66 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 1 | 0 | 0 |
control | MALE | White | 86 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
The summary statistics table provides an overall summary of the data. It shows that there are 877 subjects of which 301 are cases and 576 are controls. Thus, the data is class imbalanced.
%>%
asthma_prediction_data tbl_summary(by = status) %>%
add_overall(last = TRUE)
Characteristic | case, N = 3011 | control, N = 5761 | Overall, N = 8771 |
---|---|---|---|
gender | |||
FEMALE | 166 (55%) | 311 (54%) | 477 (54%) |
MALE | 135 (45%) | 265 (46%) | 400 (46%) |
race | |||
Black or African American | 34 (11%) | 55 (9.5%) | 89 (10%) |
White | 267 (89%) | 521 (90%) | 788 (90%) |
age_at_diagnosis | 74 (68, 82) | 73 (67, 79) | 73 (67, 80) |
acetaminophen | 165 (55%) | 312 (54%) | 477 (54%) |
dipyridamole | 100 (33%) | 142 (25%) | 242 (28%) |
hydrochlorothiazide | 162 (54%) | 287 (50%) | 449 (51%) |
hydrocodone | 105 (35%) | 218 (38%) | 323 (37%) |
levothyroxine | 149 (50%) | 267 (46%) | 416 (47%) |
lisinopril | 132 (44%) | 240 (42%) | 372 (42%) |
lovastatin | 121 (40%) | 235 (41%) | 356 (41%) |
metformin | 114 (38%) | 171 (30%) | 285 (32%) |
oxygen | 130 (43%) | 224 (39%) | 354 (40%) |
propranolol | 123 (41%) | 203 (35%) | 326 (37%) |
simvastatin | 139 (46%) | 253 (44%) | 392 (45%) |
atrial_fibrillation | 238 (79%) | 307 (53%) | 545 (62%) |
chest_pain | 229 (76%) | 292 (51%) | 521 (59%) |
conduction_disorder_of_the_heart | 167 (55%) | 175 (30%) | 342 (39%) |
congestive_heart_failure | 214 (71%) | 253 (44%) | 467 (53%) |
coronary_arteriosclerosis | 259 (86%) | 327 (57%) | 586 (67%) |
gastroesophageal_reflux_disease | 210 (70%) | 225 (39%) | 435 (50%) |
malaise_and_fatigue | 240 (80%) | 285 (49%) | 525 (60%) |
electrocardiogram | 275 (91%) | 365 (63%) | 640 (73%) |
iron_binding_capacity | 63 (21%) | 72 (12%) | 135 (15%) |
thyroid_stimulating_hormone | 230 (76%) | 285 (49%) | 515 (59%) |
1
n (%); Median (IQR)
|
An 80-20 train-test split was performed. The split was stratified by status to achieve an equal proportion of classes in each split. Additionally, gender and race were one hot encoded while age was normalized to be between the range of 0 and 1. This normalization ensures that all variables are in the same range.
set.seed(13)
# do stratified split on status as data is imbalanced
<- initial_split(asthma_prediction_data,
data_split strata = status, prop = 0.8)
<- training(data_split)
train_set_split <- testing(data_split)
test_set_split
<- recipe(status ~ ., train_set_split) %>%
model_recipe step_dummy(c(gender,race)) %>%
step_range(age_at_diagnosis) %>%
prep()
<- bake(model_recipe, train_set_split)
train_set
<- train_set %>% select(-status)
train_features
<- train_set %>%
train_target select(status) %>%
mutate(status = if_else(status == "case", 1, 0))
<- bake(model_recipe, test_set_split)
test_set
<- test_set %>% select(-status)
test_features
<- test_set %>%
test_target select(status) %>%
mutate(status = if_else(status == "case", 1, 0))
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC, LinearSVC
from sklearn import model_selection, metrics
from imblearn.over_sampling import RandomOverSampler
from sklearn.metrics import plot_confusion_matrix
from imblearn.pipeline import Pipeline, make_pipeline
from sklearn.model_selection import GridSearchCV
# access train_features and train_targets created by r
= r.train_features
train_features = r.train_target
train_target
train_features.info()
## <class 'pandas.core.frame.DataFrame'>
## RangeIndex: 702 entries, 0 to 701
## Data columns (total 24 columns):
## # Column Non-Null Count Dtype
## --- ------ -------------- -----
## 0 age_at_diagnosis 702 non-null float64
## 1 acetaminophen 702 non-null float64
## 2 dipyridamole 702 non-null float64
## 3 hydrochlorothiazide 702 non-null float64
## 4 hydrocodone 702 non-null float64
## 5 levothyroxine 702 non-null float64
## 6 lisinopril 702 non-null float64
## 7 lovastatin 702 non-null float64
## 8 metformin 702 non-null float64
## 9 oxygen 702 non-null float64
## 10 propranolol 702 non-null float64
## 11 simvastatin 702 non-null float64
## 12 atrial_fibrillation 702 non-null float64
## 13 chest_pain 702 non-null float64
## 14 conduction_disorder_of_the_heart 702 non-null float64
## 15 congestive_heart_failure 702 non-null float64
## 16 coronary_arteriosclerosis 702 non-null float64
## 17 gastroesophageal_reflux_disease 702 non-null float64
## 18 malaise_and_fatigue 702 non-null float64
## 19 electrocardiogram 702 non-null float64
## 20 iron_binding_capacity 702 non-null float64
## 21 thyroid_stimulating_hormone 702 non-null float64
## 22 gender_MALE 702 non-null float64
## 23 race_White 702 non-null float64
## dtypes: float64(24)
## memory usage: 131.8 KB
Stratified five-fold cross validation was applied to the training set to simulate model performance on unseen data. The training data was split into five subsets of approximately equal sizes. Four subsets were combined together and used to train the model. The fifth subset, sometimes called the hold-out or assessment set, was then used to evaluate the model’s performance on unseen data. It’s important to note that each data point will appear in exactly one fold as the sampling was performed without replacement. This process was then repeated four more times with different subsets reserved for training and evaluation each time. This resulted in five sets of performance metrics which were created from the five different assessment sets used for model evaluation. The performance metrics used were F1, precision and recall. The average for each of these performance metrics was then calculated to estimate the model’s ability to generalize to unseen data
= model_selection.StratifiedKFold(n_splits=5)
kfold
= ['f1', 'precision', 'recall']
performance_metrics
= ['test_' + x for x in performance_metrics] assessment_performance_metrics
Based on the summary statistics we see that the ratio of cases to controls is almost 1:2. This imbalance can make it difficult for models to learn to distinguish between the majority and minority classes. It will especially be problematic for our purposes as the event of interest is the minority class (cases). Randomly oversampling the data is one way to overcome the challenges of class imbalance. The minority class was oversampled to contain 80% of the number of observations that the majority class contained.
= RandomOverSampler(sampling_strategy=0.80, random_state=42) oversample_strategy
It is vital that oversampling only occur on the training data as the test data should reflect what one would expect in reality. Additionally, oversampling should occur inside cross validation to prevent overly optimistic performance metrics as a result of oversampling being performed on the assessment set. To facilitate this the function get_model_metrics
was created. This function makes it convenient to test multiple models and ensures that oversampling is not performed on the assessment set within cross validation.
def get_model_metrics(model_to_use, model_name, oversample):
if oversample == 'Yes':
= Pipeline([('oversample', oversample_strategy),
model 'model', model_to_use)])
(else:
= model_to_use
model
= model_selection.cross_validate(
cv_results
model,
train_features,
np.ravel(train_target),=kfold,
cv=performance_metrics)
scoring
= pd.DataFrame(columns=['Algorithm', 'Over Sampling', 'Metric', 'Value'])
df
for metric in assessment_performance_metrics:
= cv_results[metric].mean().round(2)
score_value = df.append({'Algorithm': model_name,
df 'Over Sampling': oversample,
'Metric': metric,
'Value': score_value},
=True)
ignore_index
return(df)
Table 11 compares the average of the evaluation metrics on the assessment folds from cross-validation for multiple models. The three performance metrics used are precision, recall and F1. Precision represents the proportion of positive identifications that were actually correct. Recall measures the proportion of true positives that were correctly identified and F1 is the weighted average of precision and recall.
Based on Table 11 the linear support vector classifier trained on the oversampled data had the highest recall. This model also had the lowest precision as improving recall tends to decrease precision and vice versa. For our purposes this is not an issue as a high recall is more important than a high precision since the goal is to correctly identify actual cases.
The precision score of 0.47 means that when the model classifies a subject as a case it is correct 47% of the time. Alternately, it is incorrectly classifying a subject as a case 53% of the time.
The recall score of 0.83 means that of all the subjects who were actual cases the model correctly identified 83% of these subjects as cases.
In real world clinical settings the metric to optimize for is often determined by subject matter experts. They can determine if precision and recall are equally important or if more emphasis should be placed on either one. In instances where both precision and recall are equally important the F1 score can be used as it’s indicative of both a good precision and good recall.
= get_model_metrics(
logistic_reg =42),
LogisticRegression(random_state"Logistic Regression",
='No')
oversample
= get_model_metrics(
oversample_logistic_reg =42),
LogisticRegression(random_state"Logistic Regression",
='Yes')
oversample
= get_model_metrics(
svm_model =42),
SVC(random_state"SVC (RBF)",
='No')
oversample
= get_model_metrics(
oversample_svm_model =42),
SVC(random_state"SVC (RBF)",
='Yes')
oversample
= get_model_metrics(
linear_svc =42),
LinearSVC(random_state"SVC (Linear)",
='No')
oversample
= get_model_metrics(
oversample_linear_svc =42),
LinearSVC(random_state"SVC (Linear)",
='Yes')
oversample
= (
compare_models
pd.concat([logistic_reg, oversample_logistic_reg,
svm_model, oversample_svm_model,
linear_svc, oversample_linear_svc
])=['Algorithm','Over Sampling'], columns='Metric', values='Value')
.pivot(index
.reset_index() )
<- py$compare_models %>%
model_performance rename_all(., ~ str_replace_all(., "test_", "")) %>%
arrange(desc(recall))
kable(model_performance, caption = "Table 11") %>%
kable_styling(full_width = FALSE)
Algorithm | Over Sampling | f1 | precision | recall |
---|---|---|---|---|
SVC (Linear) | Yes | 0.61 | 0.55 | 0.68 |
Logistic Regression | Yes | 0.61 | 0.55 | 0.67 |
SVC (RBF) | Yes | 0.59 | 0.56 | 0.63 |
SVC (RBF) | No | 0.56 | 0.61 | 0.52 |
Logistic Regression | No | 0.53 | 0.58 | 0.49 |
SVC (Linear) | No | 0.53 | 0.57 | 0.49 |
Table 11 shows that the linear SVC model on the oversampled data had the highest recall. Let’s try hyperparameter tuning to determine if performance can be improved
= {'clf__loss': ['hinge', 'squared_hinge'],
grid 'clf__C': [0.5, 0.1, 0.005, 0.001]}
= Pipeline([('sampling', oversample_strategy),
pipeline 'clf', LinearSVC(random_state=42, max_iter = 20000))])
(
= GridSearchCV(pipeline, grid, scoring='recall', cv=kfold)
grid_cv
grid_cv.fit(train_features, np.ravel(train_target))
## GridSearchCV(cv=StratifiedKFold(n_splits=5, random_state=None, shuffle=False),
## estimator=Pipeline(steps=[('sampling',
## RandomOverSampler(random_state=42,
## sampling_strategy=0.8)),
## ('clf',
## LinearSVC(max_iter=20000,
## random_state=42))]),
## param_grid={'clf__C': [0.5, 0.1, 0.005, 0.001],
## 'clf__loss': ['hinge', 'squared_hinge']},
## scoring='recall')
Print the best hyperparamaters
grid_cv.best_params_
## {'clf__C': 0.005, 'clf__loss': 'hinge'}
Fit the tuned model on the oversampled data
= (
tuned_oversample_svc_model
get_model_metrics(= 0.005, loss='hinge', random_state=42),
LinearSVC(C "Tuned SVC (Linear)",
='Yes')
oversample=['Algorithm','Over Sampling'], columns='Metric', values='Value')
.pivot(index
.reset_index() )
<- model_performance %>%
model_performance bind_rows(py$tuned_oversample_svc_model %>%
rename_all(., ~ str_replace_all(., "test_", ""))) %>%
arrange(desc(recall))
kable(model_performance, caption = "Table 12") %>%
kable_styling(full_width = FALSE)
Algorithm | Over Sampling | f1 | precision | recall |
---|---|---|---|---|
Tuned SVC (Linear) | Yes | 0.60 | 0.47 | 0.83 |
SVC (Linear) | Yes | 0.61 | 0.55 | 0.68 |
Logistic Regression | Yes | 0.61 | 0.55 | 0.67 |
SVC (RBF) | Yes | 0.59 | 0.56 | 0.63 |
SVC (RBF) | No | 0.56 | 0.61 | 0.52 |
Logistic Regression | No | 0.53 | 0.58 | 0.49 |
SVC (Linear) | No | 0.53 | 0.57 | 0.49 |
The tuned linear SVC model on the oversampled data has the highest recall. This model should produce similar performance metrics when applied to the test set. If the performance is significantly worse it’s a sign that the model was overfitted during training.
The tuned Linear SVC will be trained on the full oversampled train data and then be used to generate predictions on the test data.
# load the test data
= r.test_features
test_features = r.test_target
test_target
# specify the best model
= LinearSVC(C=0.005, loss='hinge', penalty='l2', random_state=42)
best_model
# oversample the train data
= oversample_strategy.fit_resample(
train_oversample_features, train_oversample_target
train_features, train_target)
# Train the model on the oversampled train data
= best_model.fit(train_oversample_features, np.ravel(train_oversample_target))
model_fit_on_train
# Make predictions on the test data
= model_fit_on_train.predict(test_features)
y_pred_class
= pd.DataFrame({"Algorithm": ["Linear SVC"],
test_metrics "F1": [metrics.f1_score(test_target, y_pred_class).round(2)],
"Precision": [metrics.precision_score(test_target, y_pred_class).round(2)],
"Recall": [metrics.recall_score(test_target, y_pred_class).round(2)]})
The model performance on the test set is very similar to the performance shown during cross validation. Thus, the model was not overfitted or underfitted during training. 87% of cases were correctly predicted while 44% of subjects were correctly classified.
kable(py$test_metrics, caption = "Table 13") %>%
kable_styling(full_width = FALSE)
Algorithm | F1 | Precision | Recall |
---|---|---|---|
Linear SVC | 0.58 | 0.44 | 0.87 |
plt.figure()
plot_confusion_matrix(model_fit_on_train, test_features,test_target,=['control','case']) display_labels
## <sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay object at 0x7fd6a39b8b38>
plt.show()