Intro

Fastai2 provides a wide range of data augmentation techniques and this blog is particularly focused on image augmentation techniques (This is a update to the article 'Data Augmentation Techniques' I wrote in 2018 using fastai v1[1])

Working with limited data has its own challenges, using data augmentation can have positive results only if the augmentation techniques enhance the current data set for example is there any worth is training a network to ‘learn’ about a landmark in a flipped upside down orientation?

Invariance is the ability of convolutional neural networks to classify objects even when they are placed in different orientations. Data augmentation is a way of creating new ‘data’ with different orientations. The benefits of this are two fold, the first being the ability to generate ‘more data’ from limited data and secondly it prevents over fitting.

Most deep learning libraries use a step by step method of augmentation whilst *fastai2 utilizes methods that combine various augmentation parameters to reduce the number of computations and reduce the number of lossy operations*[2].

Fastai uses Pipelines to compose several transforms together. A Pipeline is defined by passing a list of Transforms and it will then compose the transforms inside it. In this blog I will look at what order these transforms are conducted and what effect they have on image quality and efficiency. Pipelines are sorted by the internal order atribute (more discussed below) with a default order of 0.

Using this as a high-level API example

from fastai2.vision.all import*
source = untar_data(URLs.PETS)
#High-level API example
testblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
                 get_items=get_image_files, 
                 splitter=RandomSplitter(seed=42),
                 get_y=using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name'),
                 item_tfms=[Resize(256)],
                 batch_tfms=[*aug_transforms(xtra_tfms=None)])
test_dls = testblock.dataloaders(source/'images')
test_dls.show_batch(max_n=6, nrows=1, ncols=6)

To check the order of how augementations are conducted we can call *after_item* and *after_batch*

after_item

test_dls.after_item
Pipeline: Resize -> ToTensor

In this case images are:

resized to sizes of equal length, in this case 256 and then

convert the image into a *channel* X *height* X *weigth* tensor

But what does Resize do?

Click the button to view the Resize class

#collapse
#https://github.com/fastai/fastai2/blob/master/nbs/09_vision.augment.ipynb
@delegates()
class Resize(RandTransform):
    split_idx = None
    mode,mode_mask,order,final_size = Image.BILINEAR,Image.NEAREST,1,None
    "Resize image to `size` using `method`"
    def __init__(self, size, method=ResizeMethod.Crop, pad_mode=PadMode.Reflection,
                 resamples=(Image.BILINEAR, Image.NEAREST), **kwargs):
        super().__init__(**kwargs)
        self.size,self.pad_mode,self.method = _process_sz(size),pad_mode,method
        self.mode,self.mode_mask = resamples

    def before_call(self, b, split_idx):
        if self.method==ResizeMethod.Squish: return
        self.pcts = (0.5,0.5) if split_idx else (random.random(),random.random())

    def encodes(self, x:(Image.Image,TensorBBox,TensorPoint)):
        orig_sz = _get_sz(x)
        self.final_size = self.size
        if self.method==ResizeMethod.Squish:
            return x.crop_pad(orig_sz, Tuple(0,0), orig_sz=orig_sz, pad_mode=self.pad_mode,
                   resize_mode=self.mode_mask if isinstance(x,PILMask) else self.mode, resize_to=self.size)

        w,h = orig_sz
        op = (operator.lt,operator.gt)[self.method==ResizeMethod.Pad]
        m = w/self.size[0] if op(w/self.size[0],h/self.size[1]) else h/self.size[1]
        cp_sz = (int(m*self.size[0]),int(m*self.size[1]))
        tl = Tuple(int(self.pcts[0]*(w-cp_sz[0])), int(self.pcts[1]*(h-cp_sz[1])))
        return x.crop_pad(cp_sz, tl, orig_sz=orig_sz, pad_mode=self.pad_mode,
                   resize_mode=self.mode_mask if isinstance(x,PILMask) else self.mode, resize_to=self.size)

By default resize 'squishes' the image to the size specified. The image is resized so that the shorter dimension matches the size specifed and the rest padded with what is specified in pad_mode.

The method parameter can be be 1 of 3 values: Crop, Pad or Squish(default) eg: *method=ResizeMethod.Squish* The padding parameter also takes 1 of 3 values: Border, Zeros and Reflection(default) eg: *pad_mode=PadMode.Reflection*.

The images are resized/resamples using bilinear and nearest neighbour interprolations[3].

We can check to see how initial image sizes are affected by Resize. I choose an image with numbers so that you can see different areas of the image easier and I colored each of the corners a different color to better see what effects Resize has on the image.

#Load a test image
image_path = 'C:/Users/avird/.fastai/data/0100-number_12.jpg'
img = Image.open(image_path)
img.shape, type(img)
((380, 500), PIL.JpegImagePlugin.JpegImageFile)
#Convert image into a fastai.PILImage
img = PILImage(PILImage.create(image_path).resize((500,380)))
img.shape, type(img)
((380, 500), fastai2.vision.core.PILImage)
#View the image
img

Fastai uses 3 types of resize methods (using *ResizeMethod*: Squish, Pad and Crop) and they can be plotted against each other to view the differences between them. Squish is the fastai default. To better view the differences I used a padding of zeros.(the default for padding is Reflection)

> Image size 5

#collapse
#Use image size of 5
_,axs = plt.subplots(1,3,figsize=(20,20))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
    rsz = Resize(5, method=method, pad_mode=PadMode.Zeros)
    show_image(rsz(img, split_idx=1), ctx=ax, title=f'{method}, size=5');

Using an image size of 5 we can see how the image is affected by Resize. At this size we can see all the 4 different colors in each corner and there is not much difference between squish abd crop. With pad however the image is being resized so the shorter dimension (in this case the height (as the original image size is 380 height and 500 width) is matched to the image size of 256 and then padded with zeros.

> Image size 15

#collapse
#Use image size of 15
_,axs = plt.subplots(1,3,figsize=(20,20))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
    rsz = Resize(15, method=method, pad_mode=PadMode.Zeros)
    show_image(rsz(img, split_idx=1), ctx=ax, title=f'{method}, size=15');

At image size 15, both 'squish' and 'pad' are still showing all the colors in the corners but with 'crop' you start to notice that the colors in each corner are begining to fade as the image is being cropped from the center

> Image size 256

#collapse
#Use image size of 256
_,axs = plt.subplots(1,3,figsize=(20,20))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
    rsz = Resize(256, method=method, pad_mode=PadMode.Zeros)
    show_image(rsz(img, split_idx=0), ctx=ax, title=f'{method}, size=256');

At 256 both 'squish' and 'pad' still display the full image and 'crop' displays the cropped image

What impacts could this have on real datasets.

Using an image from a Covid19 dataset [5]

#collapse
test_path = 'C:/Users/avird/.fastai/data/0002.jpeg'
testimg = Image.open(test_path)
img2 = PILImage(PILImage.create(test_path).resize((944, 656)))
img2

#collapse
#Use image size of 256
_,axs = plt.subplots(1,3,figsize=(20,20))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
    rsz = Resize(256, method=method, pad_mode=PadMode.Zeros)
    show_image(rsz(img2, split_idx=1), ctx=ax, title=f'{method}, size=256');

In this case:

the default 'squish' resize method squishes the image on the horizontal axis. You can view the whole image however you can see that ribcage has been constricted towards the center. The implications of this could mean that important features you see in the original image could either be erased or diluted.

for the 'pad' resize the image is still viewable fully but again the image has been squished on the vertical axis.

With 'crop', the image is cropped from the centre hence we lose image details from the edges

The implications of these choices is really dependant on the dataset but they could have an detrimental effect if the wrong choice is choosen leading to vital features being erased or diluted

after_batch

Back to the pets example if we run *after_batch*, this shows us the after batch augmentation pipeline. Previously item_tfms is used to resize the images and to collate them into tensors ready for GPU processing.

test_dls.after_batch
Pipeline: IntToFloatTensor -> AffineCoordTfm -> LightingTfm

This reveals the pipeline process for the batch transformations:

convert ints to float tensors

apply all the affine transformations

followed by the lighting transformations

The order is important in order to maintain a number of key aspects:> Maintain image quality

Reduce computations

Improve efficiency

As mentioned in Fastbook [4], most machine libraries use a step by step process of augmentation which can lead to a reduced quality of images. The datablock example above is an example of a high-level API which is pretty flexible but not as much as a mid-level API.

The mid-level datablock below is an exact example of the high-level datablock above and allows for more customizations and we will use this datablock for the rest of the blog

#collapse
#Helper for viewing single images
def repeat_one(source, n=128):
    """Single image helper for displaying batch"""
    return [get_image_files(source)[1]]*n

Mid-Level API and viewing a batch of a single image

#mid-level API example 
#num_workers = 0 because I use windows :) and windows does not support multiprocessing on CUDA [6]
tfms = [[PILImage.create], [using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name'), Categorize]]
item_tfms = [ToTensor(), Resize(256)]
splitter=RandomSplitter(seed=42)
after_b = [IntToFloatTensor(), *aug_transforms(xtra_tfms=RandomResizedCrop(256), min_scale=0.9)]

dsets = Datasets(repeat_one(source/'images'), tfms=tfms)
dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=32, num_workers=0, splits=splitter)
dls.after_batch
Pipeline: RandomResizedCrop -> IntToFloatTensor -> AffineCoordTfm -> RandomResizedCropGPU -> LightingTfm

#collapse
dls.show_batch(max_n=6, nrows=1, ncols=6)

Image comparisons (Fastai v The Rest)

#create 1 batch
x,y = dls.one_batch()

Checking image quality and speed using step by step transformations.

#collapse
time
x1 = TensorImage(x.clone())
x1 = x1.affine_coord(sz=256)
x1 = x1.brightness(max_lighting=0.2, p=1.)
x1 = x1.zoom(max_zoom=1.1, p=0.5)
x1 = x1.warp(magnitude=0.2, p=0.5)

_,axs = subplots(1, 1, figsize=(5,5))
TensorImage(x1[0]).show(ctx=axs[0])
Wall time: 103 ms
<matplotlib.axes._subplots.AxesSubplot at 0x29421ecdfc8>

Checking image quality and speed using fastai2

#collapse
%%time
tfms = setup_aug_tfms([Brightness(max_lighting=0.2, p=1.,),
                       CropPad(size=256),
                       Zoom(max_zoom=1.1, p=0.5),
                       Warp(magnitude=0.2, p=0.5)
                      ])
x = Pipeline(tfms)(x)
_,axs = subplots(1, 1, figsize=(5,5))
TensorImage(x[0]).show(ctx=axs[0])
Wall time: 45.9 ms
<matplotlib.axes._subplots.AxesSubplot at 0x294220787c8>

Comparing the times above using a pipeline where a list of transforms are passed in is nearly twice as fast as using augmentations step by step. In this case the step by step method completed the task in 103ms compard to 46s using fastai

Looking at side by side look at image quality

#collapse
image_comp():
    x,y = dls.one_batch()
    tfms = setup_aug_tfms([Brightness(max_lighting=0.3, p=1.,),
                       Resize(size=256),
                       Zoom(max_zoom=1.1, p=1.),
                       Warp(magnitude=0.2, p=1.)
                      ])
    x = Pipeline(tfms)(x)

    x1 = TensorImage(x.clone())
    x1 = x1.affine_coord(sz=256)
    x1 = x1.brightness(max_lighting=0.3, p=1.)
    x1 = x1.zoom(max_zoom=1.1, p=1.)
    x1 = x1.warp(magnitude=0.2, p=1.)

    _,axs = subplots(1, 2, figsize=(20,20))
    TensorImage(x[0]).show(ctx=axs[0], title='fastai')
    TensorImage(x1[0]).show(ctx=axs[1], title='other')
image_comp()

You can definately see differences between the two pictures, the 'fastai' image is more clearer compared to the 'other' image. How about some other examples

image_comp()
image_comp()

List of Transforms

There a number of transforms and here is a list of the most common ones

RandomResizedCrop = "Picks a random scaled crop of an image and resize it to size - order 1"
IntToFloatTensor = "Transform image to float tensor, optionally dividing by 255 (e.g. for images) - order 10
Rotate = "Apply a random rotation of at most max_deg with probability p to a batch of images"
Brightness = "Apply change in brightness of max_lighting to batch of images with probability p."
RandomErasing = "Randomly selects a rectangle region in an image and randomizes its pixels." - order 100
CropPad = "Center crop or pad an image to size" - order 0
Zoom = "Apply a random zoom of at most max_zoom with probability p to a batch of images"
Warp = "Apply perspective warping with magnitude and p on a batch of matrices"
Contrast = "Apply change in contrast of max_lighting to batch of images with probability p."

Pipeline for multiple augmentations

In the example above the after_batch pipeline consisted of IntToFloatTensor > Affine tranformations > Lighting transformations.

However what we uses additional augmentations, what does the pipeline look like then?

#collapse
source = untar_data(URLs.PETS)
#Helper for viewing single images
def repeat_one(source, n=128):
    """Single image helper for displaying batch"""
    return [get_image_files(source)[2]]*n

#collapse
#Include multiple transforms
tfms = [[PILImage.create], [using_attr(RegexLabeller(r'(.+)_\d+.jpg$'), 'name'), Categorize]]
item_tfms = [ToTensor(), Resize(296)]
splitter=RandomSplitter(seed=42)
xtra_tfms = [Rotate(max_deg=45, p=1.),
            RandomErasing(p=1., max_count=10, min_aspect=0.5, sl=0.2, sh=0.2),
            RandomResizedCrop(p=1., size=256),
            Brightness(max_lighting=0.2, p=1.),
            CropPad(size=256),
            Zoom(max_zoom=2.1, p=0.5),
            Warp(magnitude=0.2, p=1.0)
            ]
after_b = [IntToFloatTensor(), *aug_transforms(mult=1.0, do_flip=False, flip_vert=False, max_rotate=0., 
           max_zoom=1.1, max_lighting=0.,max_warp=0., p_affine=0.75, p_lighting=0.75, xtra_tfms=xtra_tfms, size=256,
           mode='bilinear', pad_mode=PadMode.Reflection, align_corners=True, batch=False, min_scale=0.9)]

mdsets = Datasets(repeat_one(source/'images'), tfms=tfms)
mdls = mdsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=32, num_workers=0, splits=splitter)

Looking at after_item - it is the same as before

mdls.after_item
Pipeline: Resize -> ToTensor
mdls.after_batch
Pipeline: CropPad -> IntToFloatTensor -> AffineCoordTfm -> LightingTfm -> RandomErasing

after_batch is now a different story and we can see the list of how fastai computes its augmentations. These are all done in sequence (depending on their order) starting with

CropPad

followed by affine

lighting

and random erasing transforms.

Here is what the batch looks like

mdls.show_batch(max_n=6, nrows=1, ncols=6)

The order number determines the sequence of the transforms for example CropPad is order 0, Resize and RandomCrop are order 1 hence the reason they appear first on the list. IntToFloatTensor is order 10 and runs after PIL transforms on the GPU. Affine transforms are order 30 and so is RandomResizedCropGPU and lighting transforms are order 40. RandomErasing is order 100.

Viewing the order of transforms

#for example
mdls.after_batch
Pipeline: RandomResizedCrop -> CropPad -> IntToFloatTensor -> AffineCoordTfm -> RandomResizedCropGPU -> LightingTfm -> RandomErasing
RandomResizedCrop.order, CropPad.order, IntToFloatTensor.order, AffineCoordTfm.order, RandomResizedCropGPU.order, RandomErasing.order
(0, 0, 10, 30, 30, 100)

You can force the order by implicity specifying the order of a transform by stating the order within a transform class.

Interesting Observations

There were some interesting observations during this experimention. Adding a *min_scale* value in aug_transforms adds RandomResizedCropGPU to the pipeline

mdls.after_batch
Pipeline: CropPad -> IntToFloatTensor -> AffineCoordTfm -> RandomResizedCropGPU -> LightingTfm -> RandomErasing

However if you add RandomResizedCrop as well as a min_scale value the pipeline now looks like this

mdls.after_batch
Pipeline: RandomResizedCrop -> CropPad -> IntToFloatTensor -> AffineCoordTfm -> RandomResizedCropGPU -> LightingTfm -> RandomErasing

And if you use RandomResizedCrop with no min_scale value the pipeline is now:

mdls.after_batch
Pipeline: RandomResizedCrop -> CropPad -> IntToFloatTensor -> AffineCoordTfm -> LightingTfm -> RandomErasing

Still to do

There is clearly a plethora of options and additonal experimentation is needed to see what the impact of the various pipelines are on image quality, efficiency and end results -*work in progress*

Manually going through the pipeline

Attempt to manually go throught the pipeline.

#collapse
image_path = 'C:/Users/avird/.fastai/data/oxford-iiit-pet/images/Abyssinian_1.jpg'
TEST_IMAGE = Image.open(image_path)
img = PILImage(PILImage.create(image_path))
img.shape, type(img)                
((400, 600), fastai2.vision.core.PILImage)

This is the original image of size 400 height and 600 width

#collapse
#Original Image
img

Resize to 256 using default crop and reflection padding

#collapse
#Resize to 256 using default crop and reflection padding
r = Resize(256, method=ResizeMethod.Crop, pad_mode=PadMode.Reflection)
w = r(img)
w.shape, type(w)
((256, 256), fastai2.vision.core.PILImage)

#collapse
w

Crop the image using size 256

#collapse
Crop
crp = CropPad(256)
c = r(crp(img))
c.shape, type(c)
((256, 256), fastai2.vision.core.PILImage)

#collapse
c

Convert PILImage into a TensorImage

timg = TensorImage(array(c)).permute(2,0,1).float()/255.
timg.shape, type(timg)
(torch.Size([3, 256, 256]), fastai2.torch_core.TensorImage)
h = TensorImage(timg[None].expand(3, *timg.shape).clone())
h.shape, type(h)
(torch.Size([3, 3, 256, 256]), fastai2.torch_core.TensorImage)
#if do_flip=true and flip-vert=false = Flip
fli = Flip(p=0.5)
_,axs = plt.subplots(1,2, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y = fli(h)
    show_image(y[0], ctx=ax, cmap='Greys')
#if do_flp=true and flip_vert=true = dihyderal
dih = Dihedral(p=0.5)
_,axs = plt.subplots(1,2, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y1 = dih(h)
    show_image(y1[0], ctx=ax)
#Rotate 
rot = Rotate(max_deg=45, p=1.)
_,axs = plt.subplots(1,2, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y2 = rot(h)
    show_image(y2[0], ctx=ax)
# Zoom
zoo = Zoom(max_zoom=4.1, p=0.5)
_,axs = plt.subplots(1,2, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y3 = zoo(h)
    show_image(y3[0], ctx=ax)
# Warp
war = Warp(magnitude=0.7, p=0.5)
_,axs = plt.subplots(1,2, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y4 = war(h)
    show_image(y4[0], ctx=ax)
#Brightness
bri = h.brightness(draw=0.9, p=1.)
_,axs = plt.subplots(1,2, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y5 = bri
    show_image(y5[0], ctx=ax)
#Contrast
con = h.contrast(draw=1.9, p=0.5)
_,axs = plt.subplots(1,2, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
    y6 = con
    show_image(y6[0], ctx=ax)

View the images side by side

#collapse
_,axs = plt.subplots(1,8, figsize=(20,9))
for i,ax in enumerate(axs.flatten()):
    y7 = y1 + y4
    show_image(img, ctx=axs[0], title='original')
    show_image(w, ctx=axs[1], title='resize 256')
    show_image(y[0], ctx=axs[2], title='flip')
    show_image(y2[0], ctx=axs[3], title='rotate')
    show_image(y3[0], ctx=axs[4], title='zoom')
    show_image(y4[0], ctx=axs[5], title='warp')
    show_image(y5[0], ctx=axs[6], title='brighness')
    show_image(y6[0], ctx=axs[7], title='contrast')