DATA LOADING and trasformation to TENSOR (Auto Encoder with dictionary sample in output) TUTORIAL

Prerequisites

To run this tutorial, please make sure the following packages are installed:

- PyTorch 0.4.1
- TorchVision 0.2.1
- PIL: For image io and transforms
- Matplotlib: To generate plots, histograms and etc

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from os import listdir
from os.path import join
from PIL import Image
import numpy as np
import random

DATASET CLASS

torch.utils.data.Dataset is an abstract class representing a dataset. Your custom dataset should inherit Dataset and override the following methods:

- __len__ so that len(dataset) returns the size of the dataset.
- __getitem__ to support the indexing such that dataset[i] can be used to get ith sample.

Let’s create a dataset class for our Auto Encoder dataset. We will read the 'Input' image directory and 'Ground Truth' image directory in __init__ but leave the reading of images to __getitem__. This is memory efficient because all the images are not stored in the memory at once but read as required. Sample of our dataset will be a dictionary sample. Our datset will take an optional argument transform so that any required processing can be applied on the sample.
class AutoEncoderDataSet(Dataset):
    def __init__(self, dir_in, dir_gt, transform=None):
        self.dir_in = self.load_dir_single(dir_in)
        self.dir_gt = self.load_dir_single(dir_gt)
        self.transform = transform

    def is_image_file(self, filename):
        return any(filename.endswith(extension) for extension in [".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG"])

    def load_img(self, filename):
        img = Image.open(filename)

        return img

    def load_dir_single(self, directory):
        return [join(directory, x) for x in listdir(directory) if self.is_image_file(x)]

    def __len__(self):
        return len(self.dir_in)

    def __getitem__(self, index):
        img_in = self.load_img(self.dir_in[index])
        img_gt = self.load_img(self.dir_gt[index])
        sample = {'img_in': img_in, 'img_gt': img_gt}

        if self.transform:
            sample = self.transform(sample)

        return sample
Let’s create few transforms:

- RandomCrop to crop from image randomly. This is data augmentation.
- RandomHorizontalFlip to flip image randomly. This is data augmentation.
- RandomVerticalFlip to flip image randomly. This is data augmentation.
- RandomRotate to rotate image randomly. This is data augmentation.
- ToTensor to convert the numpy images to torch images (we need to swap axes).

We will write them as callable classes instead of simple functions so that parameters of the transform need not be passed everytime it’s called. For this, we just need to implement __call__ method and if required, __init__ method.

Crop the given PIL Images randomly:
class RandomCrop(object):
    """Crop the given PIL Images randomly."""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        """
        Args:
            sample {img_in PIL Image, img_gt PIL Image}: Images to be cropped.
        Returns:
            {img_in PIL Image, img_gt PIL Image}: Randomly cropped images.
        """
        img_in, img_gt = sample['img_in'], sample['img_gt']

        w, h = img_in.size
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        img_in = img_in.crop((left, top, left + new_w, top + new_h))
        img_gt = img_gt.crop((left, top, left + new_w, top + new_h))

        return {'img_in': img_in, 'img_gt': img_gt}

Horizontally flip the given PIL Images randomly with a given probability:
class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL Images randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        """
        Args:
            sample {img_in PIL Image, img_gt PIL Image}: Images to be flipped.
        Returns:
            {img_in PIL Image, img_gt PIL Image}: Randomly flipped images.
        """
        img_in, img_gt = sample['img_in'], sample['img_gt']

        if random.random() < self.p:
            img_in = img_in.transpose(Image.FLIP_LEFT_RIGHT)
            img_gt = img_gt.transpose(Image.FLIP_LEFT_RIGHT)

        return {'img_in': img_in, 'img_gt': img_gt}

Vertically flip the given PIL Images randomly with a given probability:
class RandomVerticalFlip(object):
    """Vertically flip the given PIL Images randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        """
        Args:
            sample {img_in PIL Image, img_gt PIL Image}: Images to be flipped.
        Returns:
            {img_in PIL Image, img_gt PIL Image}: Randomly flipped image.
        """
        img_in, img_gt = sample['img_in'], sample['img_gt']

        if random.random() < self.p:
            img_in = img_in.transpose(Image.FLIP_TOP_BOTTOM)
            img_gt = img_gt.transpose(Image.FLIP_TOP_BOTTOM)

        return {'img_in': img_in, 'img_gt': img_gt}

Rotate the given PIL Image randomly with a given probability:
class RandomRotate(object):
    """Rotate the given PIL Image randomly with a given probability.
    Args:
        p (float): probability of the image being rotated. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        """
        Args:
            sample {img_in PIL Image, img_gt PIL Image}: Images to be rotated.
        Returns:
            {img_in PIL Image, img_gt PIL Image}: Randomly rotated images.
        """
        img_in, img_gt = sample['img_in'], sample['img_gt']

        if random.random() < self.p:
            angle = random.choice([Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270])
            img_in = img_in.transpose(angle)
            img_gt = img_gt.transpose(angle)

        return {'img_in': img_in, 'img_gt': img_gt}

Convert ndarrays to Tensors:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        img_in, img_gt = sample['img_in'], sample['img_gt']
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        img_in = np.asarray(img_in)
        img_in = img_in.transpose((2, 0, 1))
        img_gt = np.asarray(img_gt)
        img_gt = img_gt.transpose((2, 0, 1))

        return {'img_in': torch.from_numpy(img_in), 'img_gt': torch.from_numpy(img_gt)}
torchvision.transforms.Compose is a simple callable class which allows us to composes several transforms together. We will use a RandomCrop(128), RandomHorizontalFlip(), RandomVerticalFlip(), RandomRotate(), and ToTensor() classes. We will push transform function to our AutoEncoderDataSet class as:
    composed = transforms.Compose([RandomCrop(128), RandomHorizontalFlip(), RandomVerticalFlip(), RandomRotate(), ToTensor()])
    auto_encoder_dataset = AutoEncoderDataSet(ps['DIR_IMG_IN'], ps['DIR_IMG_GT'], composed)
Let’s instantiate this class and iterate through the data samples.
def main(ps):
    composed = transforms.Compose([RandomCrop(128), RandomHorizontalFlip(), RandomVerticalFlip(), RandomRotate(), ToTensor()])
    auto_encoder_dataset = AutoEncoderDataSet('img/tr/in/', 'img/tr/gt/', composed)
    for i in range(len(auto_encoder_dataset)):
        sample = auto_encoder_dataset[i]
        img_in, img_gt = sample['img_in'], sample['img_gt']
        print(i, 'Input image:', img_in.size(), 'Ground truth image:', img_gt.size())
Out:
      
0 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
1 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
2 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
3 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
4 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
5 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
6 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
7 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
8 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])
9 Input image: torch.Size([3, 128, 128]) Ground truth image: torch.Size([3, 128, 128])

The full example code:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from os import listdir
from os.path import join
from PIL import Image
import numpy as np
import random


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        img_in, img_gt = sample['img_in'], sample['img_gt']
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        img_in = np.asarray(img_in)
        img_in = img_in.transpose((2, 0, 1))
        img_gt = np.asarray(img_gt)
        img_gt = img_gt.transpose((2, 0, 1))

        return {'img_in': torch.from_numpy(img_in), 'img_gt': torch.from_numpy(img_gt)}

class RandomRotate(object):
    """Rotate the given PIL Image randomly with a given probability.
    Args:
        p (float): probability of the image being rotated. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        """
        Args:
            sample {img_in PIL Image, img_gt PIL Image}: Images to be rotated.
        Returns:
            {img_in PIL Image, img_gt PIL Image}: Randomly rotated images.
        """
        img_in, img_gt = sample['img_in'], sample['img_gt']

        if random.random() < self.p:
            angle = random.choice([Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270])
            img_in = img_in.transpose(angle)
            img_gt = img_gt.transpose(angle)

        return {'img_in': img_in, 'img_gt': img_gt}


class RandomVerticalFlip(object):
    """Vertically flip the given PIL Images randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        """
        Args:
            sample {img_in PIL Image, img_gt PIL Image}: Images to be flipped.
        Returns:
            {img_in PIL Image, img_gt PIL Image}: Randomly flipped image.
        """
        img_in, img_gt = sample['img_in'], sample['img_gt']

        if random.random() < self.p:
            img_in = img_in.transpose(Image.FLIP_TOP_BOTTOM)
            img_gt = img_gt.transpose(Image.FLIP_TOP_BOTTOM)

        return {'img_in': img_in, 'img_gt': img_gt}


class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL Images randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        """
        Args:
            sample {img_in PIL Image, img_gt PIL Image}: Images to be flipped.
        Returns:
            {img_in PIL Image, img_gt PIL Image}: Randomly flipped image.
        """
        img_in, img_gt = sample['img_in'], sample['img_gt']

        if random.random() < self.p:
            img_in = img_in.transpose(Image.FLIP_LEFT_RIGHT)
            img_gt = img_gt.transpose(Image.FLIP_LEFT_RIGHT)

        return {'img_in': img_in, 'img_gt': img_gt}


class RandomCrop(object):
    """Crop the given PIL Images randomly."""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        """
        Args:
            sample {img_in PIL Image, img_gt PIL Image}: Images to be cropped.
        Returns:
            {img_in PIL Image, img_gt PIL Image}: Randomly cropped image.
        """
        img_in, img_gt = sample['img_in'], sample['img_gt']

        w, h = img_in.size
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        img_in = img_in.crop((left, top, left + new_w, top + new_h))
        img_gt = img_gt.crop((left, top, left + new_w, top + new_h))

        return {'img_in': img_in, 'img_gt': img_gt}


class AutoEncoderDataSet(Dataset):
    def __init__(self, dir_in, dir_gt, transform=None):
        self.dir_in = self.load_dir_single(dir_in)
        self.dir_gt = self.load_dir_single(dir_gt)
        self.transform = transform

    def is_image_file(self, filename):
        return any(filename.endswith(extension) for extension in [".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG"])

    def load_img(self, filename):
        img = Image.open(filename)

        return img

    def load_dir_single(self, directory):
        return [join(directory, x) for x in listdir(directory) if self.is_image_file(x)]

    def __len__(self):
        return len(self.dir_in)

    def __getitem__(self, index):
        img_in = self.load_img(self.dir_in[index])
        img_gt = self.load_img(self.dir_gt[index])
        sample = {'img_in': img_in, 'img_gt': img_gt}

        if self.transform:
            sample = self.transform(sample)

        return sample


def main(ps):
    composed = transforms.Compose([RandomCrop(128), RandomHorizontalFlip(), RandomVerticalFlip(), RandomRotate(), ToTensor()])
    auto_encoder_dataset = AutoEncoderDataSet(ps['DIR_IMG_IN'], ps['DIR_IMG_GT'], composed)
    for i in range(len(auto_encoder_dataset)):
        sample = auto_encoder_dataset[i]
        img_in, img_gt = sample['img_in'], sample['img_gt']
        print(i, 'Input image:', img_in.size(), 'Ground truth image:', img_gt.size())


if __name__ == "__main__":
    ps = {
        'DIR_IMG_IN': 'img/tr/in/',
        'DIR_IMG_GT': 'img/tr/gt/'
    }
    main(ps)

REFERENCES:

0. class torch.utils.data.Dataset