Starting with Medical Imaging
Starter notebook that looks at some high level considerations when modeling and evaluating models for medical diagnosis
Goal:
The goals of this starter notebook are to:
- use DICOMs as the image input
- high level overview of what considerations need to be taken and what the results mean when creating a model that predicts medical conditions
The dataset used is conveniently provided by fastai
- SIIM-ACR Pneumothorax Segmentation dataset and contains 250 Dicom images (175 No Pneumothorax and 75 Pneumothorax)
Considerations:
- patient overlap
- sampling
- evaluting AI models for medical use
#load dependancies
from fastai.basics import *
from fastai.callback.all import *
from fastai.vision.all import *
from fastai.medical.imaging import *
import pydicom
import seaborn as sns
matplotlib.rcParams['image.cmap'] = 'bone'
pneumothorax_source = untar_data(URLs.SIIM_SMALL)
items = get_dicom_files(pneumothorax_source, recurse=True, folders='sm')
df = pd.read_csv(pneumothorax_source/f"labels_sm.csv")
items
DicomSplit
and DataSplit
below
The show
function is specifically tailored to display .dcm
formats. By customizing the show
function we have now view patient information with each image
@patch
@delegates(show_image)
def show_dinfo(self:DcmDataset, scale=True, cmap=plt.cm.bone, min_px=-1100, max_px=None, **kwargs):
"""show function that prints patient attributes from DICOM head"""
px = (self.windowed(*scale) if isinstance(scale,tuple)
else self.hist_scaled(min_px=min_px,max_px=max_px,brks=scale) if isinstance(scale,(ndarray,Tensor))
else self.hist_scaled(min_px=min_px,max_px=max_px) if scale
else self.scaled_px)
print(f'Patient Age: {self.PatientAge}')
print(f'Patient Sex: {self.PatientSex}')
print(f'Body Part Examined: {self.BodyPartExamined}')
print(f'Rows: {self.Rows} Columns: {self.Columns}')
show_image(px, cmap=cmap, **kwargs)
patient = 7
sample = dcmread(items[patient])
sample.show_dinfo()
DICOM
images contains alot of useful information (known as tags
) and fastai
provides a conveninent way by using the from_dicoms
function of getting that information and placing it into a DataFrame
. However we do not need to have all the information ported to a Dataframe
For example we can create a DataFrame
of all the tags
in the header section. In this case there are 42 columns and in most cases we do not need all this information
full_dataframe = pd.DataFrame.from_dicoms(items)
full_dataframe[:1]
We can create a custom dataframe that takes into consideration the information we want, for example:
- filename
- age
- sex
- row size
- column size
#updating to accomodate
def _dcm2dict2(fn, **kwargs):
t = fn.dcmread()
return fn, t.PatientID, t.PatientAge, t.PatientSex, t.Rows, t.Columns
@delegates(parallel)
def _from_dicoms2(cls, fns, n_workers=0, **kwargs):
return pd.DataFrame(parallel(_dcm2dict2, fns, n_workers=n_workers, **kwargs))
pd.DataFrame.from_dicoms2 = classmethod(_from_dicoms2)
test_df = pd.DataFrame.from_dicoms2(items)
test_df.columns=['file', 'PatientID', 'Age', 'Sex', 'Rows', 'Cols']
test_df.to_csv('test_df.csv')
test_df.head()
We can do some exploratory data analysis and see that this custom dataset has duplicate patient IDs
#Plot 2 comparisons
def plot_comparison(df, feature, feature1):
"Plot 3 comparisons from a dataframe"
fig, (ax1, ax2) = plt.subplots(1,2, figsize = (16, 4))
s1 = sns.countplot(df[feature], ax=ax1)
s1.set_title(feature)
s2 = sns.countplot(df[feature1], ax=ax2)
s2.set_title(feature1)
plt.show()
plot_comparison(test_df, 'PatientID', 'Sex')
Age comparison
def age_comparison(df, feature):
"Plot hisogram of age range in dataset"
fig, (ax1) = plt.subplots(1,1, figsize = (16, 4))
s1 = sns.countplot(df[feature], ax=ax1)
s1.set_title(feature)
plt.show()
age_comparison(test_df, 'Age')
Some considerations when modelling the data:
- is there patient overlap between train and validation sets
- sampling - how many negative and postive cases are represented in the train/val split
- augmentations - consideration of what augmentations are used and why in some cases may not be useful
It is important to know if there is going to be any patient overlap when creating the train and validation sets as this may lead to an overly optimistic result when evaluating against a test set. The great thing about DICOMs
is that we can check to see if there are any duplicate patientID
s in the test and valid sets when we split our data
DicomSplit
is a custom function that uses the default fastai
splitting function that splits the data into train and validation sets based on valid_pct
value but now also checks to see if identical patient IDs exist in both the train and validation sets.
def DicomSplit(valid_pct=0.2, seed=None, **kwargs):
"Splits `items` between train/val with `valid_pct`"
"and checks if identical patient IDs exist in both the train and valid sets"
def _inner(o, **kwargs):
train_list=[]; valid_list=[]
if seed is not None: torch.manual_seed(seed)
rand_idx = L(int(i) for i in torch.randperm(len(o)))
cut = int(valid_pct * len(o))
trn = rand_idx[cut:]; trn_p = o[rand_idx[cut:]]
val = rand_idx[:cut]; val_p = o[rand_idx[:cut]]
for i, im in enumerate(trn_p):
trn = im.dcmread()
patient_ID = trn.PatientID
train_list.append(patient_ID)
for j, jm in enumerate(val_p):
val = jm.dcmread()
vpatient_ID = val.PatientID
valid_list.append(vpatient_ID)
print(set(train_list) & set(valid_list))
return rand_idx[cut:], rand_idx[:cut]
return _inner
fastai
is to use set_seed
or you could incorporate it within a function
set_seed(7)
trn,val = DicomSplit(valid_pct=0.2)(items)
trn, val
The custom test dataset only has 26 images (small number of images to show how DicomSplit
works) which is split into a test
set of 21 and a valid
set of 5 using valid_pct
of 0.2. By customizing RandomSplitter
in DicomSplit
you can check to see if there are any duplicate PatientID
s betweeen the 2 sets.
In this case there is a duplicate ID: 6224213b-a185-4821-8490-c9cba260a959
, this patient is present in both the train and validation sets.
This dataset consists of 2 classes Pneumothorax
and No Pneumothorax
. It is important to consider how this data is represented within the train and validation sets.
DataSplit
looks at how many Pneumothorax
and No Pneumothorax
images there are in the training and validation sets. This is to view how fair the train/val split is to ensure good model sampling.
def DataSplit(valid_pct=0.2, seed=None, **kwargs):
"Check the number of each class in train and valid sets"
def _inner(o, **kwargs):
train_list=[]; valid_list=[]
if seed is not None: torch.manual_seed(seed)
rand_idx = L(int(i) for i in torch.randperm(len(o)))
cut = int(valid_pct * len(o))
trn_p = o[rand_idx[cut:]]
val_p = o[rand_idx[:cut]]
for p in enumerate(trn_p):
b = str(p).split('/')[7]
train_list.append(b)
for q in enumerate(val_p):
e = str(q).split('/')[7]
valid_list.append(e)
train_totals = {x:train_list.count(x) for x in train_list}
valid_totals = {x:valid_list.count(x) for x in valid_list}
print(f'Train totals: {train_totals}\nValid totals: {valid_totals}')
return rand_idx[cut:], rand_idx[:cut]
return _inner
We can now see the how the data is split using set_seed
. For example if set it to 7 like the prior example:
set_seed(7)
trn,val = DataSplit(valid_pct=0.2)(items)
trn, val
In this case the train
set has 16 No Pneumothorax
and 5 Pneumothorax
images and the valid
set has 2 No Pneumothorax
and 3 Pneumothorax
images
How about using a seed of 77
set_seed(77)
trn,val = DataSplit(valid_pct=0.2)(items)
trn, val
In this case the train
set has 14 No Pneumothorax
and 7 Pneumothorax
images and the valid
set has 4 No Pneumothorax
and 1 Pneumothorax
image
img1 = (pneumothorax_source/'chest1.png'); img2 = (pneumothorax_source/'chest2.png')
Image.open(img1)
If we flip the image
Image.open(img2)
We can now see the heart is middle left. If the classifier was looking to detect defects of the heart then this type of augmentation would not be suitable.
Garbage Collector
is a great way of freeing up memory if needed
#clear out some memory
import gc
gc.collect()
Switching back to the full dataset which contains 250 images.
items_full = get_dicom_files(pneumothorax_source, recurse=True, folders='train')
df_full = pd.read_csv(pneumothorax_source/f"labels.csv")
Using DataSplit
we can check the number of each class represented in the training and validation sets
set_seed(7)
trn,val = DataSplit(valid_pct=0.2)(items_full)
trn, val
Create the DataBlock
and specify some transforms
xtra_tfms = [RandomResizedCrop(194)]
batch_tfms = [*aug_transforms(do_flip=False, flip_vert=False, xtra_tfms=xtra_tfms), Normalize.from_stats(*imagenet_stats)]
For the Splitter we have already split the data into training and validation sets using DicomSplit
and we now need to incorporate this split into the DataBlock
. Fastai provides a convenient method using IndexSplitter
which splits items
so that val_idx
are in the validation set and the others in the training set. DicomSplit
splits the indexs of the dataset so we can feed this into IndexSplitter
with val_idx
set to the val
index that was created with DicomSplit
pneumothorax = DataBlock(blocks=(ImageBlock(cls=PILDicom), CategoryBlock),
get_x=lambda x:f'{pneumothorax_source}/{x[0]}',
get_y=lambda x:x[1],
splitter=IndexSplitter(val),
item_tfms=Resize(256),
batch_tfms=batch_tfms)
dls = pneumothorax.dataloaders(df_full.values, bs=16, num_workers=0)
Check train
and valid
sizes
len(dls.train_ds), len(dls.valid_ds)
#net = xresnext50(pretrained=False, sa=True, act_cls=Mish, n_out=dls.c)
net = resnet50()
learn = Learner(dls,
net,
metrics=[accuracy],
cbs=[ShowGraphCallback()])
fastai
automatically allocates one. You can view the loss_func
and opt_func
as follows:
learn.loss_func
learn.opt_func
learn.lr_find()
learn.unfreeze()
learn.fit_one_cycle(3, slice(1e-2))
To see how show_results
works, I slightly tweaked it so that we can see the ground truth, predictions and the probabilites.
get_preds
returns a tuple: Probabilites, Ground Truth, Prediction
def show_results2(self, ds_idx=1, dl=None, max_n=9, shuffle=False, **kwargs):
if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)
b = dl.one_batch()
p,g,preds = self.get_preds(dl=[b], with_decoded=True)
print(f'Ground Truth: {g}\n Probabilites: {p}\n Prediction: {preds}\n')
self.dls.show_results(b, preds, max_n=max_n, **kwargs)
show_results2(learn, max_n=12, nrows=2, ncols=6)
The probabilies determine the predicted outcome. In this dataset there are only 2 classes No Pneumothorax
and Pneumothorax
hence the reason why each probability has 2 values, the first value is the probability whether the image belongs to class 0 or No Pneumothorax
and the second value is the probability whether the image belongs to class 1 or Pneumothorax
certain
they are. For example: [0.4903, 0.5097] and [0.1903, 0.8097] both produce the same results that the image belongs to class 1 but in the second case the model is alot more certain
that it belongs to class 1.
Because medical models are high impact it is important to know how good a model is at detecting a certain condition.
What is the accuracy of the model above? You can simply use accuracy
to get that information
accuracy
The above model has an accuracy
of 66.5%.
-
Accuracy
is the probablity that the model is correct or to be more specific:
-
Accuracy
is the probability that the model is correct and the patient has the condition PLUS the probability that the model is correct and the patient does not have the condition
There are some other key terms that need to be used when evaluating medical models:
-
False Negative
is an error in which a test result improperly indicates no presence of a condition (the result is negative), when in reality it is present.
-
False Positive
is an error in which a test result improperly indicates presence of a condition, such as a disease (the result is positive), when in reality it is not present
-
Sensitivity
orTrue Positive Rate
is where the model classifies a patient has the disease given the patient actually does have the disease. Sensitivity quantifies the avoidance of false negatives
Example: A new test was tested on 10,000 patients, if the new test has a sensitivity of 90% the test will correctly detect 9,000 (True Positive) patients but will miss 1000 (False Negative) patients that have the condition but were tested as not having the condition
-
Specificity
orTrue Negative Rate
is where the model classifies a patient as not having the disease given the patient actually does not have the disease. Specificity quantifies the avoidance of false positives
Most medical testing is evaluated via PPV
(Postive Predictive Value) or NPV
(Negative Predictive Value).
-
PPV
- if the model predicts a patient has a condition what is probabilty that the patient actually has the condition -
NPV
- if the model predicts a patient does not have a condition what is the probability that the patient actually does not have the condition
The ideal value of the PPV, with a perfect test, is 1 (100%), and the worst possible value would be zero
The ideal value of the NPV, with a perfect test, is 1 (100%), and the worst possible value would be zero
Plot a confusion matrix - note that this is plotted against the valid
dataset
interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
len(dls.valid_ds)==len(losses)==len(idxs)
interp.plot_confusion_matrix(figsize=(7,7))
We can manually reproduce the results interpreted from plot_confusion_matrix
upp, low = interp.confusion_matrix()
tn, fp = upp[0], upp[1]
fn, tp = low[0], low[1]
print(tn, fp, fn, tp)
sensitivity = tp/(tp + fn)
sensitivity
In this case the model only has a sensitivity of 28% and hence is only capable of correctly detecting 28% True Positives(ie who have Pneumothorax) but will miss 72% of False Negatives (patients that actually have Pneumothorax but were told they did not! Not a good situation to be in).
specificity = tn/(fp + tn)
specificity
In this case the model has a specificity of 80% and hence can correctly detect 80% of the time that a patient does NOT have Pneumothorax but will incorrectly classify that 20% of the patients have Pneumothorax (False Postive) but actually do not.
Positive Predictive Value (PPV)
ppv = tp/(tp+fp)
ppv
In this case the model performs poorly in correctly predicting patients with Pneumothorax
Negative Predictive Value (NPV)
npv = tn/(tn+fn)
npv
This model is better at predicting patients with No Pneumothorax
Some of these metrics can be calculated using sklearn
's classification report
interp.print_classification_report()
The accuracy of this model as mentioned before is 66% - lets now calculate this!
We can also look at Accuracy
as:
Prevalence
is a statistical concept referring to the number of cases of a disease that are present in a particular population at a given time.
The prevalence in this case is how many patients in the valid
dataset have the condition compared to the total number. To view the number of Pneuomothorax
patients in the valid
set
t= dls.valid_ds.cat
#t[0]
There are 20 Pneumothorax
images in the valid set hence the prevalance here is 20/75 = 0.27
accuracy = (sensitivity * 0.27) + (specificity * (1 - 0.27))
accuracy