Note: Functions listed with a ✔ are custom functions


The goals of this starter notebook are to:

  • use DICOMs as the image input
  • high level overview of what considerations need to be taken and what the results mean when creating a model that predicts medical conditions

The dataset used is conveniently provided by fastai - SIIM-ACR Pneumothorax Segmentation dataset and contains 250 Dicom images (175 No Pneumothorax and 75 Pneumothorax)


  • patient overlap
  • sampling
  • evaluting AI models for medical use
#load dependancies
from fastai.basics import *
from fastai.callback.all import *
from import *
from fastai.medical.imaging import *

import pydicom
import seaborn as sns
matplotlib.rcParams['image.cmap'] = 'bone'

Load the Data

pneumothorax_source = untar_data(URLs.SIIM_SMALL)
items = get_dicom_files(pneumothorax_source, recurse=True, folders='sm')
df = pd.read_csv(pneumothorax_source/f"labels_sm.csv")
(#26) [Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000000 - Copy.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000000.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000002.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000005.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000006 - Copy.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000006.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000007.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000008.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000009.dcm'),Path('C:/Users/avird/.fastai/data/siim_small/sm/No Pneumothorax/000011.dcm')...]

Note: The SIIM_SMALL dataset has no duplicate patient IDs, has an equal number of males and females so I am using a custom even smaller dataset to show the functionality of DicomSplit and DataSplit below

Viewing the Data

View Dicom

The show function is specifically tailored to display .dcm formats. By customizing the show function we have now view patient information with each image

def show_dinfo(self:DcmDataset, scale=True,, min_px=-1100, max_px=None, **kwargs):
    """show function that prints patient attributes from DICOM head"""
    px = (self.windowed(*scale) if isinstance(scale,tuple)
          else self.hist_scaled(min_px=min_px,max_px=max_px,brks=scale) if isinstance(scale,(ndarray,Tensor))
          else self.hist_scaled(min_px=min_px,max_px=max_px) if scale
          else self.scaled_px)
    print(f'Patient Age: {self.PatientAge}')
    print(f'Patient Sex: {self.PatientSex}')
    print(f'Body Part Examined: {self.BodyPartExamined}')
    print(f'Rows: {self.Rows} Columns: {self.Columns}')
    show_image(px, cmap=cmap, **kwargs)
patient = 7
sample = dcmread(items[patient])
Patient Age: 31
Patient Sex: M
Body Part Examined: CHEST
Rows: 1024 Columns: 1024

Create a Dataframe of all the tags in the header section

Tip: The head section of a DICOM images contains alot of useful information (known as tags) and fastai provides a conveninent way by using the from_dicoms function of getting that information and placing it into a DataFrame. However we do not need to have all the information ported to a Dataframe

For example we can create a DataFrame of all the tags in the header section. In this case there are 42 columns and in most cases we do not need all this information

full_dataframe = pd.DataFrame.from_dicoms(items)
SpecificCharacterSet SOPClassUID SOPInstanceUID StudyDate StudyTime AccessionNumber Modality ConversionType ReferringPhysicianName SeriesDescription ... LossyImageCompression LossyImageCompressionMethod fname MultiPixelSpacing PixelSpacing1 img_min img_max img_mean img_std img_pct_window
0 ISO_IR 100 1.2.840.10008. 19010101 000000.00 CR WSD view: PA ... 01 ISO_10918_1 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000000 - Copy.dcm 1 0.168 0 254 160.398039 53.854885 0.087029

1 rows × 42 columns

We can create a custom dataframe that takes into consideration the information we want, for example:

  • filename
  • age
  • sex
  • row size
  • column size
#updating to accomodate 
def _dcm2dict2(fn, **kwargs): 
        t = fn.dcmread()
        return fn, t.PatientID, t.PatientAge, t.PatientSex, t.Rows, t.Columns

def _from_dicoms2(cls, fns, n_workers=0, **kwargs):
    return pd.DataFrame(parallel(_dcm2dict2, fns, n_workers=n_workers, **kwargs))
pd.DataFrame.from_dicoms2 = classmethod(_from_dicoms2)
test_df = pd.DataFrame.from_dicoms2(items)
test_df.columns=['file', 'PatientID', 'Age', 'Sex', 'Rows', 'Cols']
file PatientID Age Sex Rows Cols
0 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000000 - Copy.dcm 16d7f894-55d7-4d95-8957-d18987f0e981 62 M 1024 1024
1 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000000.dcm 16d7f894-55d7-4d95-8957-d18987f0e981 62 M 1024 1024
2 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000002.dcm 850ddeb3-73ac-45e0-96bf-7d275bc83782 52 F 1024 1024
3 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000005.dcm e0fd6161-2b8d-4757-96bc-6cf620a993d5 65 F 1024 1024
4 C:\Users\avird\.fastai\data\siim_small\sm\No Pneumothorax\000006 - Copy.dcm 99171908-3665-48e8-82c8-66d0098ce209 52 F 1024 1024


We can do some exploratory data analysis and see that this custom dataset has duplicate patient IDs

#Plot 2 comparisons
def plot_comparison(df, feature, feature1):
    "Plot 3 comparisons from a dataframe"
    fig, (ax1, ax2) = plt.subplots(1,2, figsize = (16, 4))
    s1 = sns.countplot(df[feature], ax=ax1)
    s2 = sns.countplot(df[feature1], ax=ax2)
plot_comparison(test_df, 'PatientID', 'Sex')

Age comparison

def age_comparison(df, feature):
    "Plot hisogram of age range in dataset"
    fig, (ax1) = plt.subplots(1,1, figsize = (16, 4))
    s1 = sns.countplot(df[feature], ax=ax1)
age_comparison(test_df, 'Age')


Some considerations when modelling the data:

  • is there patient overlap between train and validation sets
  • sampling - how many negative and postive cases are represented in the train/val split
  • augmentations - consideration of what augmentations are used and why in some cases may not be useful

Patient Overlap

It is important to know if there is going to be any patient overlap when creating the train and validation sets as this may lead to an overly optimistic result when evaluating against a test set. The great thing about DICOMs is that we can check to see if there are any duplicate patientIDs in the test and valid sets when we split our data

DicomSplit ✔

DicomSplit is a custom function that uses the default fastai splitting function that splits the data into train and validation sets based on valid_pct value but now also checks to see if identical patient IDs exist in both the train and validation sets.

def DicomSplit(valid_pct=0.2, seed=None, **kwargs):
    "Splits `items` between train/val with `valid_pct`"
    "and checks if identical patient IDs exist in both the train and valid sets"
    def _inner(o, **kwargs):
        train_list=[]; valid_list=[]
        if seed is not None: torch.manual_seed(seed)
        rand_idx = L(int(i) for i in torch.randperm(len(o)))
        cut = int(valid_pct * len(o))
        trn = rand_idx[cut:]; trn_p = o[rand_idx[cut:]]
        val = rand_idx[:cut]; val_p = o[rand_idx[:cut]]
        for i, im in enumerate(trn_p):
            trn = im.dcmread()
            patient_ID = trn.PatientID
        for j, jm in enumerate(val_p):
            val = jm.dcmread()
            vpatient_ID = val.PatientID
        print(set(train_list) & set(valid_list))
        return rand_idx[cut:], rand_idx[:cut]
    return _inner

Tip: There are a number of ways of setting the seed to ensure reproducible results. The easiest way in fastai is to use set_seed or you could incorporate it within a function

trn,val = DicomSplit(valid_pct=0.2)(items)
trn, val
((#21) [2,13,9,12,11,24,8,14,16,6...], (#5) [19,18,3,23,17])

The custom test dataset only has 26 images (small number of images to show how DicomSplit works) which is split into a test set of 21 and a valid set of 5 using valid_pct of 0.2. By customizing RandomSplitter in DicomSplit you can check to see if there are any duplicate PatientIDs betweeen the 2 sets.

In this case there is a duplicate ID: 6224213b-a185-4821-8490-c9cba260a959, this patient is present in both the train and validation sets.

Important: When working with a medical data set, is it important to consider that the splits should be based on patient identifiers, and not on the individual examples.


This dataset consists of 2 classes Pneumothorax and No Pneumothorax. It is important to consider how this data is represented within the train and validation sets.

DataSplit ✔

DataSplit looks at how many Pneumothorax and No Pneumothorax images there are in the training and validation sets. This is to view how fair the train/val split is to ensure good model sampling.

def DataSplit(valid_pct=0.2, seed=None, **kwargs):
    "Check the number of each class in train and valid sets"
    def _inner(o, **kwargs):
        train_list=[]; valid_list=[]
        if seed is not None: torch.manual_seed(seed)
        rand_idx = L(int(i) for i in torch.randperm(len(o)))
        cut = int(valid_pct * len(o)) 
        trn_p = o[rand_idx[cut:]]
        val_p = o[rand_idx[:cut]]
        for p in enumerate(trn_p):
            b = str(p).split('/')[7]
        for q in enumerate(val_p):
            e = str(q).split('/')[7]
        train_totals = {x:train_list.count(x) for x in train_list}
        valid_totals = {x:valid_list.count(x) for x in valid_list}
        print(f'Train totals: {train_totals}\nValid totals: {valid_totals}')
        return rand_idx[cut:], rand_idx[:cut]
    return _inner

We can now see the how the data is split using set_seed. For example if set it to 7 like the prior example:


trn,val = DataSplit(valid_pct=0.2)(items)
trn, val
Train totals: {'No Pneumothorax': 16, 'Pneumothorax': 5}
Valid totals: {'Pneumothorax': 3, 'No Pneumothorax': 2}
((#21) [2,13,9,12,11,24,8,14,16,6...], (#5) [19,18,3,23,17])

In this case the train set has 16 No Pneumothorax and 5 Pneumothorax images and the valid set has 2 No Pneumothorax and 3 Pneumothorax images

How about using a seed of 77


trn,val = DataSplit(valid_pct=0.2)(items)
trn, val
Train totals: {'No Pneumothorax': 14, 'Pneumothorax': 7}
Valid totals: {'Pneumothorax': 1, 'No Pneumothorax': 4}
((#21) [9,18,8,10,12,17,3,1,20,22...], (#5) [19,6,14,7,5])

In this case the train set has 14 No Pneumothorax and 7 Pneumothorax images and the valid set has 4 No Pneumothorax and 1 Pneumothorax image

Note: In these cases we can see what differences setting the seed has on the distribution of images within the training and validation sets

Note: >>Work in progress work on various techniques to help with unbalanced datasets especially true for medical image datasets where there are typically alot more images of ’normal’ compared to ’diseased’ (oversampling, undersampling, stratified k-fold cross validation etc)


img1 = (pneumothorax_source/'chest1.png'); img2 = (pneumothorax_source/'chest2.png')

Note: >>Work in progress Choosing the right augmentations is important in determing how it affects the sampling process. For example in some cases it may not be a good idea to flip images.
Here is an image of a 'normal' patient in its correct orientation (heart showing in the middle right)

If we flip the image

We can now see the heart is middle left. If the classifier was looking to detect defects of the heart then this type of augmentation would not be suitable.


Tip: Garbage Collector is a great way of freeing up memory if needed
#clear out some memory
import gc

Switching back to the full dataset which contains 250 images.

items_full = get_dicom_files(pneumothorax_source, recurse=True, folders='train')
df_full = pd.read_csv(pneumothorax_source/f"labels.csv")

Using DataSplit we can check the number of each class represented in the training and validation sets


trn,val = DataSplit(valid_pct=0.2)(items_full)
trn, val
Train totals: {'No Pneumothorax': 141, 'Pneumothorax': 59}
Valid totals: {'No Pneumothorax': 34, 'Pneumothorax': 16}
((#200) [33,65,231,167,74,127,184,89,122,79...],
 (#50) [115,233,139,163,161,177,57,21,34,99...])

Create the DataBlock and specify some transforms

xtra_tfms = [RandomResizedCrop(194)]
batch_tfms = [*aug_transforms(do_flip=False, flip_vert=False, xtra_tfms=xtra_tfms), Normalize.from_stats(*imagenet_stats)]

For the Splitter we have already split the data into training and validation sets using DicomSplit and we now need to incorporate this split into the DataBlock. Fastai provides a convenient method using IndexSplitter which splits items so that val_idx are in the validation set and the others in the training set. DicomSplit splits the indexs of the dataset so we can feed this into IndexSplitter with val_idx set to the val index that was created with DicomSplit

pneumothorax = DataBlock(blocks=(ImageBlock(cls=PILDicom), CategoryBlock),
                   get_x=lambda x:f'{pneumothorax_source}/{x[0]}',
                   get_y=lambda x:x[1],

dls = pneumothorax.dataloaders(df_full.values, bs=16, num_workers=0)

Check train and valid sizes

len(dls.train_ds), len(dls.valid_ds)
(200, 50)
#net = xresnext50(pretrained=False, sa=True, act_cls=Mish, n_out=dls.c)
net = resnet50()
learn = Learner(dls, 

Tip: If you do not specifiy a loss function or optimization function fastai automatically allocates one. You can view the loss_func and opt_func as follows:
FlattenedLoss of CrossEntropyLoss()
<function fastai.optimizer.Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-05, wd=0.01, decouple_wd=True)>
SuggestedLRs(lr_min=0.012022644281387329, lr_steep=0.0003311311302240938)
learn.fit_one_cycle(3, slice(1e-2))
epoch train_loss valid_loss accuracy time
0 0.678739 3.618579 0.360000 00:10
1 0.680920 0.631048 0.740000 00:10
2 0.661906 0.582651 0.720000 00:10


To see how show_results works, I slightly tweaked it so that we can see the ground truth, predictions and the probabilites.

Tip: get_preds returns a tuple: Probabilites, Ground Truth, Prediction
def show_results2(self, ds_idx=1, dl=None, max_n=9, shuffle=False, **kwargs):
    if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)
    b = dl.one_batch()
    p,g,preds = self.get_preds(dl=[b], with_decoded=True)
    print(f'Ground Truth: {g}\n Probabilites: {p}\n Prediction: {preds}\n')
    self.dls.show_results(b, preds, max_n=max_n, **kwargs)
show_results2(learn, max_n=12, nrows=2, ncols=6)
Ground Truth: TensorCategory([0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0], dtype=torch.int32)
 Probabilites: tensor([[6.5734e-01, 3.4266e-01, 2.0420e-10,  ..., 2.2660e-10, 2.3010e-10,
        [5.9493e-01, 4.0506e-01, 1.7759e-08,  ..., 1.7563e-08, 1.9316e-08,
        [8.7132e-01, 1.2868e-01, 3.4132e-11,  ..., 3.7495e-11, 3.8186e-11,
        [7.7597e-01, 2.2403e-01, 1.9229e-09,  ..., 2.0393e-09, 1.9173e-09,
        [6.5044e-01, 3.4956e-01, 5.0422e-12,  ..., 6.0412e-12, 5.9057e-12,
        [7.3728e-01, 2.6272e-01, 2.0523e-14,  ..., 2.5750e-14, 2.8775e-14,
 Prediction: tensor([0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0])

The probabilies determine the predicted outcome. In this dataset there are only 2 classes No Pneumothorax and Pneumothorax hence the reason why each probability has 2 values, the first value is the probability whether the image belongs to class 0 or No Pneumothorax and the second value is the probability whether the image belongs to class 1 or Pneumothorax

Tip: It is a good idea looking at the probabilities to see how certain they are. For example: [0.4903, 0.5097] and [0.1903, 0.8097] both produce the same results that the image belongs to class 1 but in the second case the model is alot more certain that it belongs to class 1.

Model Evaluation

Because medical models are high impact it is important to know how good a model is at detecting a certain condition.


What is the accuracy of the model above? You can simply use accuracy to get that information


The above model has an accuracy of 66.5%.

  • Accuracy is the probablity that the model is correct or to be more specific:
  • Accuracy is the probability that the model is correct and the patient has the condition PLUS the probability that the model is correct and the patient does not have the condition

False Positive & False Negative

There are some other key terms that need to be used when evaluating medical models:

  • False Negative is an error in which a test result improperly indicates no presence of a condition (the result is negative), when in reality it is present.
  • False Positive is an error in which a test result improperly indicates presence of a condition, such as a disease (the result is positive), when in reality it is not present

Sensitivity & Specificity

  • Sensitivity or True Positive Rate is where the model classifies a patient has the disease given the patient actually does have the disease. Sensitivity quantifies the avoidance of false negatives

Example: A new test was tested on 10,000 patients, if the new test has a sensitivity of 90% the test will correctly detect 9,000 (True Positive) patients but will miss 1000 (False Negative) patients that have the condition but were tested as not having the condition

  • Specificity or True Negative Rate is where the model classifies a patient as not having the disease given the patient actually does not have the disease. Specificity quantifies the avoidance of false positives

Tip: Understanding and using sensitivity, specificity and predictive values is a great paper if you are interested in learning more


Most medical testing is evaluated via PPV (Postive Predictive Value) or NPV (Negative Predictive Value).

  • PPV - if the model predicts a patient has a condition what is probabilty that the patient actually has the condition

  • NPV - if the model predicts a patient does not have a condition what is the probability that the patient actually does not have the condition

The ideal value of the PPV, with a perfect test, is 1 (100%), and the worst possible value would be zero

The ideal value of the NPV, with a perfect test, is 1 (100%), and the worst possible value would be zero

Confusion Matrix

Plot a confusion matrix - note that this is plotted against the valid dataset

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()

We can manually reproduce the results interpreted from plot_confusion_matrix

upp, low = interp.confusion_matrix()
tn, fp = upp[0], upp[1]
fn, tp = low[0], low[1]
print(tn, fp, fn, tp)
29 7 10 4

Note: Sensitivity = True Positive/(True Positive + False Negative)
sensitivity = tp/(tp + fn)

In this case the model only has a sensitivity of 28% and hence is only capable of correctly detecting 28% True Positives(ie who have Pneumothorax) but will miss 72% of False Negatives (patients that actually have Pneumothorax but were told they did not! Not a good situation to be in).

Note: This is also know as a Type II error

Note: Specificity = True Negative/(False Positive + True Negative)
specificity = tn/(fp + tn)

In this case the model has a specificity of 80% and hence can correctly detect 80% of the time that a patient does NOT have Pneumothorax but will incorrectly classify that 20% of the patients have Pneumothorax (False Postive) but actually do not.

Note: This is also known as a Type I error

Positive Predictive Value (PPV)

ppv = tp/(tp+fp)

In this case the model performs poorly in correctly predicting patients with Pneumothorax

Negative Predictive Value (NPV)

npv = tn/(tn+fn)

This model is better at predicting patients with No Pneumothorax

Some of these metrics can be calculated using sklearn's classification report

                 precision    recall  f1-score   support

No Pneumothorax       0.74      0.81      0.77        36
   Pneumothorax       0.36      0.29      0.32        14

       accuracy                           0.66        50
      macro avg       0.55      0.55      0.55        50
   weighted avg       0.64      0.66      0.65        50

Calculating Accuracy

The accuracy of this model as mentioned before is 66% - lets now calculate this!

We can also look at Accuracy as:

Tip: accuracy = sensitivity x prevalence + specificity * (1 - prevalence)
Prevalence is a statistical concept referring to the number of cases of a disease that are present in a particular population at a given time.

The prevalence in this case is how many patients in the valid dataset have the condition compared to the total number. To view the number of Pneuomothorax patients in the valid set


There are 20 Pneumothorax images in the valid set hence the prevalance here is 20/75 = 0.27

accuracy = (sensitivity * 0.27) + (specificity * (1 - 0.27))