Starting with Medical Imaging
Starter notebook that looks at some high level considerations when modeling and evaluating models for medical diagnosis
Goal:
The goals of this starter notebook are to:
- use DICOMs as the image input
- high level overview of what considerations need to be taken and what the results mean when creating a model that predicts medical conditions
The dataset used is conveniently provided by fastai - SIIM-ACR Pneumothorax Segmentation dataset and contains 250 Dicom images (175 No Pneumothorax and 75 Pneumothorax)
Considerations:
- patient overlap
- sampling
- evaluting AI models for medical use
This notebook is based on this fastai notebook. For more information about DICOMs and fastai medical imaging you can click here
#load dependancies
from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *
from fastai2.medical.imaging import *
import pydicom
import seaborn as sns
matplotlib.rcParams['image.cmap'] = 'bone'
pneumothorax_source = untar_data(URLs.SIIM_SMALL)
items = get_dicom_files(pneumothorax_source, recurse=True, folders='sm')
df = pd.read_csv(pneumothorax_source/f"labels_sm.csv")
items
Side Note: The SIIM_SMALL dataset has no duplicate patient IDs, has an equal number of males and females so I used a custom even smaller dataset to show the functionality of DicomSplit
and DataSplit
below
The show
function is specifically tailored to display .dcm
formats. By customizing the show
function we have now view patient information with each image
@patch
@delegates(show_image)
def show_dinfo(self:DcmDataset, scale=True, cmap=plt.cm.bone, min_px=-1100, max_px=None, **kwargs):
"""show function that prints patient attributes from DICOM head"""
px = (self.windowed(*scale) if isinstance(scale,tuple)
else self.hist_scaled(min_px=min_px,max_px=max_px,brks=scale) if isinstance(scale,(ndarray,Tensor))
else self.hist_scaled(min_px=min_px,max_px=max_px) if scale
else self.scaled_px)
print(f'Patient Age: {self.PatientAge}')
print(f'Patient Sex: {self.PatientSex}')
print(f'Body Part Examined: {self.BodyPartExamined}')
print(f'Rows: {self.Rows} Columns: {self.Columns}')
show_image(px, cmap=cmap, **kwargs)
patient = 7
sample = dcmread(items[patient])
sample.show_dinfo()
DICOM
formats contain alot of useful information but difficult to see image by image so we need to capture this information and create a dataframe for better viewing and data manipulation.
Create a dataframe Customize the functions so that we include what we want in our dataframe
#updating to accomodate
def _dcm2dict2(fn, **kwargs):
t = fn.dcmread()
return fn, t.PatientID, t.PatientAge, t.PatientSex, t.BodyPartExamined, t.Modality, t.Rows, t.Columns, t.BitsStored, t.PixelRepresentation
@delegates(parallel)
def _from_dicoms2(cls, fns, n_workers=0, **kwargs):
return pd.DataFrame(parallel(_dcm2dict2, fns, n_workers=n_workers, **kwargs))
pd.DataFrame.from_dicoms2 = classmethod(_from_dicoms2)
test_df = pd.DataFrame.from_dicoms2(items)
test_df.columns=['file', 'PatientID', 'Age', 'Sex', 'Bodypart', 'Modality', 'Rows', 'Cols', 'BitsStored', 'PixelRep' ]
test_df.to_csv('test_df.csv')
test_df.head()
We can now view the information (note this for my custom dataset)
#Plot 3 comparisons
def plot_comparison(df, feature, feature1, feature2):
"Plot 3 comparisons from a dataframe"
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize = (16, 4))
s1 = sns.countplot(df[feature], ax=ax1)
s1.set_title(feature)
s2 = sns.countplot(df[feature1], ax=ax2)
s2.set_title(feature1)
s3 = sns.countplot(df[feature2], ax=ax3)
s3.set_title(feature2)
plt.show()
plot_comparison(test_df, 'PatientID', 'Sex', 'Bodypart')
Age comparison
def age_comparison(df, feature):
"Plot hisogram of age range in dataset"
fig, (ax1) = plt.subplots(1,1, figsize = (16, 4))
s1 = sns.countplot(df[feature], ax=ax1)
s1.set_title(feature)
plt.show()
age_comparison(test_df, 'Age')
Considerations:
- patient overlap between train and val set
- sampling - how many negative and postive cases are in the train/val split (class imbalance)
- augmentations - consideration of what augmentations are used and why in some cases may not be useful
It is important to know if there is going to be any patient overlap when creating the train and validation sets as this may lead to an overly optimistic test set. The great thing about DICOMs
is that we can check to see if there are any duplicate patientID
s in the test and valid sets when we split our data
def DicomSplit(valid_pct=0.2, seed=None, **kwargs):
"Splits `items` between train/val with `valid_pct`"
"and checks if identical patient IDs exist in both the train and valid sets"
def _inner(o, **kwargs):
train_list=[]; valid_list=[]
if seed is not None: torch.manual_seed(seed)
rand_idx = L(int(i) for i in torch.randperm(len(o)))
cut = int(valid_pct * len(o))
trn = rand_idx[cut:]; trn_p = o[rand_idx[cut:]]
val = rand_idx[:cut]; val_p = o[rand_idx[:cut]]
for i, im in enumerate(trn_p):
trn = im.dcmread()
patient_ID = trn.PatientID
train_list.append(patient_ID)
for j, jm in enumerate(val_p):
val = jm.dcmread()
vpatient_ID = val.PatientID
valid_list.append(vpatient_ID)
print(set(train_list) & set(valid_list))
return rand_idx[cut:], rand_idx[:cut]
return _inner
set_seed(7)
trn,val = DicomSplit(valid_pct=0.2)(items)
trn, val
The custom test dataset only has 26 images which is split in a test
set of 24 and a valid
set of 5 using valid_pct
of 0.2. By customizing RandomSplitter
into DicomSplit
you can view to see if there are any duplicate PatientID
s. In this case there is a duplicate ID: 6224213b-a185-4821-8490-c9cba260a959
Using set_seed
allows for reproducible results and ensures we use the same seed
when training
This dataset has 2 classes Pneumothorax
and No Pneumothorax
, DataSplit
looks at how many Pneumothorax
and No Pneumothorax
images are in the train
and valid
sets. This is to view how fair the train/val split is to ensure good model sampling
def DataSplit(valid_pct=0.2, seed=None, **kwargs):
"Check the number of each class in train and valid sets"
def _inner(o, **kwargs):
train_list=[]; valid_list=[]
if seed is not None: torch.manual_seed(seed)
rand_idx = L(int(i) for i in torch.randperm(len(o)))
cut = int(valid_pct * len(o))
trn_p = o[rand_idx[cut:]]
val_p = o[rand_idx[:cut]]
for p in enumerate(trn_p):
b = str(p).split('/')[7]
train_list.append(b)
for q in enumerate(val_p):
e = str(q).split('/')[7]
valid_list.append(e)
print(f'train: {train_list}\n valid: {valid_list}')
return rand_idx[cut:], rand_idx[:cut]
return _inner
using the same set_seed
we can get reproducible results
set_seed(7)
trn,val = DataSplit(valid_pct=0.2)(items)
trn, val
With this test dataset (the train
set has 14 No Pneumothorax
and 5 Pneumothorax
images and the valid
set has 4 No Pneumothorax
and 3 Pneumothorax
images
>>Work in progress work on various techniques to help with unbalanced datasets especially true for medical image datasets where there are typically alot more images of 'normal' compared to 'diseased' (oversampling, k-fold)
img1 = (pneumothorax_source/'chest1.png'); img2 = (pneumothorax_source/'chest2.png')
>>Work in progress Choosing the right augmentations is important in determing how it affects the sampling process. For example in some cases it may not be a good idea to flip images.
Here is an image of a 'normal' patient in its correct orientation (heart showing in the middle right)
Image.open(img1)
If we flip the image
Image.open(img2)
We can now see the heart is middle left. If the classifier was looking to detect defects of the heart then this type of augmentation would not be suitable.
#clear out some memory
import gc
gc.collect()
#switch back to the full dataset
items = get_dicom_files(pneumothorax_source, recurse=True, folders='train')
df = pd.read_csv(pneumothorax_source/f"labels.csv")
xtra_tfms = [RandomResizedCrop(194)]
batch_tfms = [*aug_transforms(do_flip=False, flip_vert=False, xtra_tfms=xtra_tfms), Normalize.from_stats(*imagenet_stats)]
set_seed(7)
pneumothorax = DataBlock(blocks=(ImageBlock(cls=PILDicom), CategoryBlock),
get_x=lambda x:pneumothorax_source/f"{x[0]}",
get_y=lambda x:x[1],
splitter=RandomSplitter(valid_pct=0.2),
item_tfms=Resize(256),
batch_tfms=batch_tfms)
dls = pneumothorax.dataloaders(df.values, bs=12, num_workers=0)
Check train
and valid
sizes
len(dls.train_ds), len(dls.valid_ds)
net = xresnext50(pretrained=False, sa=True, act_cls=Mish, n_out=dls.c)
learn = Learner(dls,
net,
loss_func=LabelSmoothingCrossEntropy(),
opt_func=Adam,
metrics=[accuracy],
cbs=[ShowGraphCallback()])
If you do not specifiy a loss function or optimization function fastai
automatically allocates one. You can view the loss_func
and opt_func
as follows:
learn.loss_func
learn.opt_func
learn.lr_find()
learn.unfreeze()
learn.fit_one_cycle(3, slice(1e-3))
def show_results2(self, ds_idx=1, dl=None, max_n=9, shuffle=False, **kwargs):
if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)
b = dl.one_batch()
t,a,preds = self.get_preds(dl=[b], with_decoded=True)
print(f'Acutal: {a}\n Preds: {preds}\n')
self.dls.show_results(b, preds, max_n=max_n, **kwargs)
show_results2(learn, max_n=12, nrows=2, ncols=6)
Because medical models are high impact it is important to know how good a model is at detecting a certain condition.
The above model has an accuracy
of 74%. One needs to look deeper into how the accuracy of 74% was calculated and whether it is acceptable.
-
Accuracy
is the probablity that the model is correct or
-
Accuracy
is the probability that the model is correct and the patient has the condition PLUS the probability that the model is correct and the patient does not have the condition
There are some other key terms that need to be used when evaluating medical models:
-
False Negative
is an error in which a test result improperly indicates no presence of a condition (the result is negative), when in reality it is present.
-
False Positive
is an error in which a test result improperly indicates presence of a condition, such as a disease (the result is positive), when in reality it is not present
-
Sensitivity
orTrue Positive Rate
is where the model classifies a patient has the disease given the patient actually does have the disease. Sensitivity quantifies the avoidance of false negatives
Example: A new test was tested on 10,000 patients, if the new test has a sensitivity of 90% the test will correctly detect 9,000 (True Positive) patients but will miss 1000 (False Negative) patients that have the condition but were tested as not having the condition
-
Specificity
orTrue Negative Rate
is where the model classifies a patient as not having the disease given the patient actually does not have the disease. Specificity quantifies the avoidance of false positives
Understanding and using sensitivity, specificity and predictive values is a great paper if you are interested in learning more
Most medical testing is evaluated via PPV
(Postive Predictive Value) or NPV
(Negative Predictive Value).
-
PPV
- if the model predicts a patient has a condition what is probabilty that the patient actually has the condition -
NPV
- if the model predicts a patient does not have a condition what is the probability that the patient actually does not have the condition
The ideal value of the PPV, with a perfect test, is 1 (100%), and the worst possible value would be zero
The ideal value of the NPV, with a perfect test, is 1 (100%), and the worst possible value would be zero
Plot a confusion matrix - note that this is plotted against the valid
dataset size which is 75 in this case
interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
len(dls.valid_ds)==len(losses)==len(idxs)
interp.plot_confusion_matrix(figsize=(7,7))
upp, low = interp.confusion_matrix()
tn, fp = upp[0], upp[1]
fn, tp = low[0], low[1]
print(tn, fp, fn, tp)
Sensitivity = True Positive/(True Positive + False Negative)
sensitivity = tp/(tp + fn)
sensitivity
In this case the model only has a sensitivity of 28% and hence is only capable of correctly detecting 28% True Positives(ie who have Pneumothorax) but will miss 72% of False Negatives (patients that actually have Pneumothorax but were told they did not! Not a good situation to be in).
This is also know as a Type II error
Specificity = True Negative/(False Positive + True Negative)
specificity = tn/(fp + tn)
specificity
In this case the model has a specificity of 91% and hence can correctly detect 91% of the time that a patient does not have Pneumothorax but will incorrectly classify that 9% of the patients have Pneumothorax (False Postive) but actually do not.
This is also known as a Type I error
Positive Predictive Value (PPV)
ppv = tp/(tp+fp)
ppv
In this case the model performs poorly in correctly predicting patients with Pneumothorax
Negative Predictive Value (NPV)
npv = tn/(tn+fn)
npv
This model is better at predicting patients with No Pneumothorax
Some of these metrics can be calculated using sklearn
's classification report
interp.print_classification_report()
The accuracy of this model as mentioned before is 74% - lets now calculate this!
We can also look at Accuracy
as:
accuracy = sensitivity x prevalence + specificity * (1 - prevalence)
Prevalence
is a statistical concept referring to the number of cases of a disease that are present in a particular population at a given time.
The prevalence in this case is how many patients in the valid
dataset have the condition compared to the total number. To view the number of Pneuomothorax
patients in the valid
set
t= dls.valid_ds.cat
#t[0]
There are 20 Pneumothorax
images in the valid set hence the prevalance here is 20/75 = 0.27
accuracy = (0.28 * 0.27) + (0.91 * (1 - 0.27))
accuracy
By reviewing the metrics above, you can evaluate how good or bad your model is performing