Goal:

The goals of this starter notebook are to:

  • use DICOMs as the image input
  • high level overview of what considerations need to be taken and what the results mean when creating a model that predicts medical conditions

The dataset used is conveniently provided by fastai - SIIM-ACR Pneumothorax Segmentation dataset and contains 250 Dicom images (175 No Pneumothorax and 75 Pneumothorax)

Considerations:

  • patient overlap
  • sampling
  • evaluting AI models for medical use

This notebook is based on this fastai notebook. For more information about DICOMs and fastai medical imaging you can click here

#load dependancies
from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *
from fastai2.medical.imaging import *

import pydicom
import seaborn as sns
matplotlib.rcParams['image.cmap'] = 'bone'

Load the Data

pneumothorax_source = untar_data(URLs.SIIM_SMALL)
items = get_dicom_files(pneumothorax_source, recurse=True, folders='sm')
df = pd.read_csv(pneumothorax_source/f"labels_sm.csv")
items
(#26) [Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000000 - Copy.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000000.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000002.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000005.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000006 - Copy.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000006.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000007.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000008.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000009.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000011.dcm')...]

Side Note: The SIIM_SMALL dataset has no duplicate patient IDs, has an equal number of males and females so I used a custom even smaller dataset to show the functionality of DicomSplit and DataSplit below

Viewing the Data

View Dicom

The show function is specifically tailored to display .dcm formats. By customizing the show function we have now view patient information with each image

@patch
@delegates(show_image)
def show_dinfo(self:DcmDataset, scale=True, cmap=plt.cm.bone, min_px=-1100, max_px=None, **kwargs):
    """show function that prints patient attributes from DICOM head"""
    px = (self.windowed(*scale) if isinstance(scale,tuple)
          else self.hist_scaled(min_px=min_px,max_px=max_px,brks=scale) if isinstance(scale,(ndarray,Tensor))
          else self.hist_scaled(min_px=min_px,max_px=max_px) if scale
          else self.scaled_px)
    print(f'Patient Age: {self.PatientAge}')
    print(f'Patient Sex: {self.PatientSex}')
    print(f'Body Part Examined: {self.BodyPartExamined}')
    print(f'Rows: {self.Rows} Columns: {self.Columns}')
    show_image(px, cmap=cmap, **kwargs)
patient = 7
sample = dcmread(items[patient])
sample.show_dinfo()
Patient Age: 31
Patient Sex: M
Body Part Examined: CHEST
Rows: 1024 Columns: 1024

Create a Dataframe

DICOM formats contain alot of useful information but difficult to see image by image so we need to capture this information and create a dataframe for better viewing and data manipulation.

Create a dataframe Customize the functions so that we include what we want in our dataframe

#updating to accomodate 
def _dcm2dict2(fn, **kwargs): 
        t = fn.dcmread()
        return fn, t.PatientID, t.PatientAge, t.PatientSex, t.BodyPartExamined, t.Modality, t.Rows, t.Columns, t.BitsStored, t.PixelRepresentation

@delegates(parallel)
def _from_dicoms2(cls, fns, n_workers=0, **kwargs):
    return pd.DataFrame(parallel(_dcm2dict2, fns, n_workers=n_workers, **kwargs))
pd.DataFrame.from_dicoms2 = classmethod(_from_dicoms2)
test_df = pd.DataFrame.from_dicoms2(items)
test_df.columns=['file', 'PatientID', 'Age', 'Sex', 'Bodypart', 'Modality', 'Rows', 'Cols', 'BitsStored', 'PixelRep' ]
test_df.to_csv('test_df.csv')
test_df.head()
file PatientID Age Sex Bodypart Modality Rows Cols BitsStored PixelRep
0 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000000 - Copy.dcm 16d7f894-55d7-4d95-8957-d18987f0e981 62 M CHEST CR 1024 1024 8 0
1 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000000.dcm 16d7f894-55d7-4d95-8957-d18987f0e981 62 M CHEST CR 1024 1024 8 0
2 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000002.dcm 850ddeb3-73ac-45e0-96bf-7d275bc83782 52 F CHEST CR 1024 1024 8 0
3 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000005.dcm e0fd6161-2b8d-4757-96bc-6cf620a993d5 65 F CHEST CR 1024 1024 8 0
4 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000006 - Copy.dcm 99171908-3665-48e8-82c8-66d0098ce209 52 F CHEST CR 1024 1024 8 0

We can now view the information (note this for my custom dataset)

#Plot 3 comparisons
def plot_comparison(df, feature, feature1, feature2):
    "Plot 3 comparisons from a dataframe"
    fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize = (16, 4))
    s1 = sns.countplot(df[feature], ax=ax1)
    s1.set_title(feature)
    s2 = sns.countplot(df[feature1], ax=ax2)
    s2.set_title(feature1)
    s3 = sns.countplot(df[feature2], ax=ax3)
    s3.set_title(feature2)
    plt.show()
plot_comparison(test_df, 'PatientID', 'Sex', 'Bodypart')

Age comparison

def age_comparison(df, feature):
    "Plot hisogram of age range in dataset"
    fig, (ax1) = plt.subplots(1,1, figsize = (16, 4))
    s1 = sns.countplot(df[feature], ax=ax1)
    s1.set_title(feature)
    plt.show()
    
age_comparison(test_df, 'Age')

Modelling

Considerations:

  • patient overlap between train and val set
  • sampling - how many negative and postive cases are in the train/val split (class imbalance)
  • augmentations - consideration of what augmentations are used and why in some cases may not be useful

Patient Overlap

It is important to know if there is going to be any patient overlap when creating the train and validation sets as this may lead to an overly optimistic test set. The great thing about DICOMs is that we can check to see if there are any duplicate patientIDs in the test and valid sets when we split our data

def DicomSplit(valid_pct=0.2, seed=None, **kwargs):
    "Splits `items` between train/val with `valid_pct`"
    "and checks if identical patient IDs exist in both the train and valid sets"
    def _inner(o, **kwargs):
        train_list=[]; valid_list=[]
        if seed is not None: torch.manual_seed(seed)
        rand_idx = L(int(i) for i in torch.randperm(len(o)))
        cut = int(valid_pct * len(o))
        trn = rand_idx[cut:]; trn_p = o[rand_idx[cut:]]
        val = rand_idx[:cut]; val_p = o[rand_idx[:cut]]
        
        for i, im in enumerate(trn_p):
            trn = im.dcmread()
            patient_ID = trn.PatientID
            train_list.append(patient_ID)
        for j, jm in enumerate(val_p):
            val = jm.dcmread()
            vpatient_ID = val.PatientID
            valid_list.append(vpatient_ID) 
        print(set(train_list) & set(valid_list))
        return rand_idx[cut:], rand_idx[:cut]
    return _inner
set_seed(7)

trn,val = DicomSplit(valid_pct=0.2)(items)
trn, val
{'6224213b-a185-4821-8490-c9cba260a959'}
((#21) [2,13,9,12,11,24,8,14,16,6...], (#5) [19,18,3,23,17])

The custom test dataset only has 26 images which is split in a test set of 24 and a valid set of 5 using valid_pct of 0.2. By customizing RandomSplitter into DicomSplit you can view to see if there are any duplicate PatientIDs. In this case there is a duplicate ID: 6224213b-a185-4821-8490-c9cba260a959

Using set_seed allows for reproducible results and ensures we use the same seed when training

Sampling

This dataset has 2 classes Pneumothorax and No Pneumothorax, DataSplit looks at how many Pneumothorax and No Pneumothorax images are in the train and valid sets. This is to view how fair the train/val split is to ensure good model sampling

def DataSplit(valid_pct=0.2, seed=None, **kwargs):
    "Check the number of each class in train and valid sets"
    def _inner(o, **kwargs):
        train_list=[]; valid_list=[]
        if seed is not None: torch.manual_seed(seed)
        rand_idx = L(int(i) for i in torch.randperm(len(o)))
        cut = int(valid_pct * len(o)) 
        trn_p = o[rand_idx[cut:]]
        val_p = o[rand_idx[:cut]]
        
        for p in enumerate(trn_p):
            b = str(p).split('/')[7]
            train_list.append(b)
        for q in enumerate(val_p):
            e = str(q).split('/')[7]
            valid_list.append(e)
        print(f'train: {train_list}\n valid: {valid_list}')
        return rand_idx[cut:], rand_idx[:cut]
    return _inner

using the same set_seed we can get reproducible results

set_seed(7)

trn,val = DataSplit(valid_pct=0.2)(items)
trn, val
train: ['No Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'Pneumothorax', 'No Pneumothorax', 'Pneumothorax', 'Pneumothorax', 'No Pneumothorax', 'Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'No Pneumothorax', 'No Pneumothorax']
 valid: ['Pneumothorax', 'Pneumothorax', 'No Pneumothorax', 'Pneumothorax', 'No Pneumothorax']
((#21) [2,13,9,12,11,24,8,14,16,6...], (#5) [19,18,3,23,17])

With this test dataset (the train set has 14 No Pneumothorax and 5 Pneumothorax images and the valid set has 4 No Pneumothorax and 3 Pneumothorax images

>>Work in progress work on various techniques to help with unbalanced datasets especially true for medical image datasets where there are typically alot more images of 'normal' compared to 'diseased' (oversampling, k-fold)

Augmentations

img1 = (pneumothorax_source/'chest1.png'); img2 = (pneumothorax_source/'chest2.png')

>>Work in progress Choosing the right augmentations is important in determing how it affects the sampling process. For example in some cases it may not be a good idea to flip images.

Here is an image of a 'normal' patient in its correct orientation (heart showing in the middle right)

Image.open(img1)

If we flip the image

Image.open(img2)

We can now see the heart is middle left. If the classifier was looking to detect defects of the heart then this type of augmentation would not be suitable.

#clear out some memory
import gc
gc.collect()
14734
#switch back to the full dataset
items = get_dicom_files(pneumothorax_source, recurse=True, folders='train')
df = pd.read_csv(pneumothorax_source/f"labels.csv")
xtra_tfms = [RandomResizedCrop(194)]
batch_tfms = [*aug_transforms(do_flip=False, flip_vert=False, xtra_tfms=xtra_tfms), Normalize.from_stats(*imagenet_stats)]
set_seed(7)
pneumothorax = DataBlock(blocks=(ImageBlock(cls=PILDicom), CategoryBlock),
                   get_x=lambda x:pneumothorax_source/f"{x[0]}",
                   get_y=lambda x:x[1],
                   splitter=RandomSplitter(valid_pct=0.2),
                   item_tfms=Resize(256),
                   batch_tfms=batch_tfms)

dls = pneumothorax.dataloaders(df.values, bs=12, num_workers=0)

Check train and valid sizes

len(dls.train_ds), len(dls.valid_ds)
(200, 50)
net = xresnext50(pretrained=False, sa=True, act_cls=Mish, n_out=dls.c)
learn = Learner(dls, 
                net,
                loss_func=LabelSmoothingCrossEntropy(),
                opt_func=Adam,
                metrics=[accuracy], 
                cbs=[ShowGraphCallback()])

If you do not specifiy a loss function or optimization function fastai automatically allocates one. You can view the loss_func and opt_func as follows:

learn.loss_func
LabelSmoothingCrossEntropy()
learn.opt_func
<function fastai2.optimizer.Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-05, wd=0.01, decouple_wd=True)>
learn.lr_find()
SuggestedLRs(lr_min=0.00010000000474974513, lr_steep=0.0005754399462603033)
learn.unfreeze()
learn.fit_one_cycle(3, slice(1e-3))
epoch train_loss valid_loss accuracy time
0 0.697180 0.693759 0.620000 00:24
1 0.718116 0.737178 0.720000 00:25
2 0.672629 0.627202 0.740000 00:24
def show_results2(self, ds_idx=1, dl=None, max_n=9, shuffle=False, **kwargs):
    if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)
    b = dl.one_batch()
    t,a,preds = self.get_preds(dl=[b], with_decoded=True)
    print(f'Acutal: {a}\n Preds: {preds}\n')
    self.dls.show_results(b, preds, max_n=max_n, **kwargs)
show_results2(learn, max_n=12, nrows=2, ncols=6)
Acutal: TensorCategory([0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1], dtype=torch.int32)
 Preds: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])

Model Evaluation

Because medical models are high impact it is important to know how good a model is at detecting a certain condition.

Accuracy

The above model has an accuracy of 74%. One needs to look deeper into how the accuracy of 74% was calculated and whether it is acceptable.

  • Accuracy is the probablity that the model is correct or
  • Accuracy is the probability that the model is correct and the patient has the condition PLUS the probability that the model is correct and the patient does not have the condition

False Positive & False Negative

There are some other key terms that need to be used when evaluating medical models:

  • False Negative is an error in which a test result improperly indicates no presence of a condition (the result is negative), when in reality it is present.
  • False Positive is an error in which a test result improperly indicates presence of a condition, such as a disease (the result is positive), when in reality it is not present

Sensitivity & Specificity

  • Sensitivity or True Positive Rate is where the model classifies a patient has the disease given the patient actually does have the disease. Sensitivity quantifies the avoidance of false negatives

Example: A new test was tested on 10,000 patients, if the new test has a sensitivity of 90% the test will correctly detect 9,000 (True Positive) patients but will miss 1000 (False Negative) patients that have the condition but were tested as not having the condition

  • Specificity or True Negative Rate is where the model classifies a patient as not having the disease given the patient actually does not have the disease. Specificity quantifies the avoidance of false positives

Understanding and using sensitivity, specificity and predictive values is a great paper if you are interested in learning more

PPV and NPV

Most medical testing is evaluated via PPV (Postive Predictive Value) or NPV (Negative Predictive Value).

  • PPV - if the model predicts a patient has a condition what is probabilty that the patient actually has the condition

  • NPV - if the model predicts a patient does not have a condition what is the probability that the patient actually does not have the condition

The ideal value of the PPV, with a perfect test, is 1 (100%), and the worst possible value would be zero

The ideal value of the NPV, with a perfect test, is 1 (100%), and the worst possible value would be zero

Confusion Matrix

Plot a confusion matrix - note that this is plotted against the valid dataset size which is 75 in this case

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
len(dls.valid_ds)==len(losses)==len(idxs)
interp.plot_confusion_matrix(figsize=(7,7))
upp, low = interp.confusion_matrix()
tn, fp = upp[0], upp[1]
fn, tp = low[0], low[1]
print(tn, fp, fn, tp)
33 3 10 4

Sensitivity = True Positive/(True Positive + False Negative)

sensitivity = tp/(tp + fn)
sensitivity
0.2857142857142857

In this case the model only has a sensitivity of 28% and hence is only capable of correctly detecting 28% True Positives(ie who have Pneumothorax) but will miss 72% of False Negatives (patients that actually have Pneumothorax but were told they did not! Not a good situation to be in).

This is also know as a Type II error

Specificity = True Negative/(False Positive + True Negative)

specificity = tn/(fp + tn)
specificity
0.9166666666666666

In this case the model has a specificity of 91% and hence can correctly detect 91% of the time that a patient does not have Pneumothorax but will incorrectly classify that 9% of the patients have Pneumothorax (False Postive) but actually do not.

This is also known as a Type I error

Positive Predictive Value (PPV)

ppv = tp/(tp+fp)
ppv
0.5714285714285714

In this case the model performs poorly in correctly predicting patients with Pneumothorax

Negative Predictive Value (NPV)

npv = tn/(tn+fn)
npv
0.7674418604651163

This model is better at predicting patients with No Pneumothorax

Some of these metrics can be calculated using sklearn's classification report

interp.print_classification_report()
                 precision    recall  f1-score   support

No Pneumothorax       0.77      0.92      0.84        36
   Pneumothorax       0.57      0.29      0.38        14

       accuracy                           0.74        50
      macro avg       0.67      0.60      0.61        50
   weighted avg       0.71      0.74      0.71        50

Calculating Accuracy

The accuracy of this model as mentioned before is 74% - lets now calculate this!

We can also look at Accuracy as:

accuracy = sensitivity x prevalence + specificity * (1 - prevalence)

Prevalence is a statistical concept referring to the number of cases of a disease that are present in a particular population at a given time.

The prevalence in this case is how many patients in the valid dataset have the condition compared to the total number. To view the number of Pneuomothorax patients in the valid set

t= dls.valid_ds.cat
#t[0]

There are 20 Pneumothorax images in the valid set hence the prevalance here is 20/75 = 0.27

accuracy = (0.28 * 0.27) + (0.91 * (1 - 0.27))
accuracy
0.7399

By reviewing the metrics above, you can evaluate how good or bad your model is performing

fin