diff --git a/bob/ip/binseg/data/transforms.py b/bob/ip/binseg/data/transforms.py
index 6dfd0f560f55c468a7fc8541bf7c507363b3d8eb..3c8a09dc10735050f7326a924ec82967b1c51741 100644
--- a/bob/ip/binseg/data/transforms.py
+++ b/bob/ip/binseg/data/transforms.py
@@ -1,86 +1,70 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+"""Image transformations for our pipelines
+
+Differences between methods here and those from
+:py:mod:`torchvision.transforms` is that these support multiple simultaneous
+image inputs, which are required to feed segmentation networks (e.g. image and
+labels or masks).  We also take care of data augmentations, in which random
+flipping and rotation needs to be applied across all input images, but color
+jittering, for example, only on the input image.
 """
-Image transformations for our pipelines.
 
-All transforms work with :py:class:`PIL.Image.Image` objects. We make heavy use
-of `torchvision <https://github.com/pytorch/vision>`_.
-"""
+import random
 
+import numpy
+import PIL.Image
+import torchvision.transforms
+import torchvision.transforms.functional
 
-import torchvision.transforms.functional as VF
-import random
-from PIL import Image
-from torchvision.transforms.transforms import Lambda
-from torchvision.transforms.transforms import Compose as TorchVisionCompose
-import math
-import warnings
-import collections
 import bob.core
 
-_pil_interpolation_to_str = {
-    Image.NEAREST: "PIL.Image.NEAREST",
-    Image.BILINEAR: "PIL.Image.BILINEAR",
-    Image.BICUBIC: "PIL.Image.BICUBIC",
-    Image.LANCZOS: "PIL.Image.LANCZOS",
-    Image.HAMMING: "PIL.Image.HAMMING",
-    Image.BOX: "PIL.Image.BOX",
-}
-Iterable = collections.abc.Iterable
 
-# Compose
+class TupleMixin:
+    """Adds support to work with tuples of objects to torchvision transforms"""
 
+    def __call__(self, *args):
+        return [super(TupleMixin, self).__call__(k) for k in args]
 
-class Compose:
-    """Composes several transforms.
 
-    Attributes
-    ----------
-    transforms : list
-        list of transforms to compose.
-    """
+class CenterCrop(TupleMixin, torchvision.transforms.CenterCrop):
+    pass
 
-    def __init__(self, transforms):
-        self.transforms = transforms
 
-    def __call__(self, *args):
-        for t in self.transforms:
-            args = t(*args)
-        return args
+class Pad(TupleMixin, torchvision.transforms.Pad):
+    pass
 
-    def __repr__(self):
-        format_string = self.__class__.__name__ + "("
-        for t in self.transforms:
-            format_string += "\n"
-            format_string += "    {0}".format(t)
-        format_string += "\n)"
-        return format_string
 
+class Resize(TupleMixin, torchvision.transforms.Resize):
+    pass
 
-# Preprocessing
 
+class ToTensor(TupleMixin, torchvision.transforms.ToTensor):
+    pass
 
-class CenterCrop:
-    """
-    Crop at the center.
 
-    Attributes
-    ----------
-    size : int
-        target size
-    """
+class Compose(torchvision.transforms.Compose):
+    def __call__(self, *args):
+        for t in self.transforms:
+            args = t(*args)
+        return args
 
-    def __init__(self, size):
-        self.size = size
 
-    def __call__(self, *args):
-        return [VF.center_crop(img, self.size) for img in args]
+class _Crop:
+    def __init__(self, i, j, h, w):
+        self.i = i
+        self.j = j
+        self.h = h
+        self.w = w
+
+    def __call__(self, img):
+        return img.crop((self.j, self.i, self.j + self.w, self.i + self.h))
 
 
-class Crop:
+class Crop(TupleMixin, _Crop):
     """
-    Crop at the given coordinates.
+    Crops one image at the given coordinates.
 
     Attributes
     ----------
@@ -94,46 +78,17 @@ class Crop:
         width of the cropped image.
     """
 
-    def __init__(self, i, j, h, w):
-        self.i = i
-        self.j = j
-        self.h = h
-        self.w = w
-
-    def __call__(self, *args):
-        return [
-            img.crop((self.j, self.i, self.j + self.w, self.i + self.h)) for img in args
-        ]
-
-
-class Pad:
-    """
-    Constant padding
-
-    Attributes
-    ----------
-    padding : int or tuple
-        padding on each border. If a single int is provided this is used to pad all borders.
-        If tuple of length 2 is provided this is the padding on left/right and top/bottom respectively.
-        If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively.
-
-    fill : int
-        pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
-        This value is only used when the padding_mode is constant
-    """
+    pass
 
-    def __init__(self, padding, fill=0):
-        self.padding = padding
-        self.fill = fill
 
-    def __call__(self, *args):
-        return [
-            VF.pad(img, self.padding, self.fill, padding_mode="constant")
-            for img in args
-        ]
+class _AutoLevel16to8:
+    def __call__(self, img):
+        return PIL.Image.fromarray(
+            bob.core.convert(img, "uint8", (0, 255), img.getextrema())
+        )
 
 
-class AutoLevel16to8:
+class AutoLevel16to8(TupleMixin, _AutoLevel16to8):
     """Converts a 16-bit image to 8-bit representation using "auto-level"
 
     This transform assumes that the input images are gray-scaled.
@@ -143,293 +98,109 @@ class AutoLevel16to8:
     destination image.
     """
 
-    def _process_one(self, img):
-        return Image.fromarray(
-            bob.core.convert(img, "uint8", (0, 255), img.getextrema())
-        )
+    pass
 
-    def __call__(self, *args):
-        return [self._process_one(img) for img in args]
+
+class _ToRGB:
+    def __call__(self, img):
+        return img.convert(mode="RGB")
 
 
-class ToRGB:
+class ToRGB(TupleMixin, _ToRGB):
     """Converts from any input format to RGB, using an ADAPTIVE conversion.
 
     This transform takes the input image and converts it to RGB using
-    py:method:`Image.Image.convert`, with `mode='RGB'` and using all other
+    py:method:`PIL.Image.Image.convert`, with `mode='RGB'` and using all other
     defaults.  This may be aggressive if applied to 16-bit images without
     further considerations.
     """
 
-    def __call__(self, *args):
-        return [img.convert(mode="RGB") for img in args]
+    pass
 
 
-class ToTensor:
-    """Converts :py:class:`PIL.Image.Image` to :py:class:`torch.Tensor` """
+class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip):
+    """Randomly flips all input images horizontally"""
 
     def __call__(self, *args):
-        return [VF.to_tensor(img) for img in args]
-
-
-# Augmentations
-
-
-class RandomHFlip:
-    """
-    Flips horizontally
-
-    Attributes
-    ----------
-    prob : float
-        probability at which imgage is flipped. Defaults to ``0.5``
-    """
-
-    def __init__(self, prob=0.5):
-        self.prob = prob
-
-    def __call__(self, *args):
-        if random.random() < self.prob:
-            return [VF.hflip(img) for img in args]
-
+        if random.random() < self.p:
+            return [
+                torchvision.transforms.functional.hflip(img) for img in args
+            ]
         else:
             return args
 
 
-class RandomVFlip:
-    """
-    Flips vertically
-
-    Attributes
-    ----------
-    prob : float
-        probability at which imgage is flipped. Defaults to ``0.5``
-    """
-
-    def __init__(self, prob=0.5):
-        self.prob = prob
+class RandomVerticalFlip(torchvision.transforms.RandomVerticalFlip):
+    """Randomly flips all input images vertically"""
 
     def __call__(self, *args):
-        if random.random() < self.prob:
-            return [VF.vflip(img) for img in args]
-
+        if random.random() < self.p:
+            return [
+                torchvision.transforms.functional.vflip(img) for img in args
+            ]
         else:
             return args
 
 
-class RandomRotation:
-    """
-    Rotates by degree
+class RandomRotation(torchvision.transforms.RandomRotation):
+    """Randomly rotates all input images by the same amount
 
-    Attributes
-    ----------
-    degree_range : tuple
-        range of degrees in which image and ground truth are rotated. Defaults to ``(-15, +15)``
-    prob : float
-        probability at which imgage is rotated. Defaults to ``0.5``
+    Unlike the current torchvision implementation, we also accept a probability
+    for applying the rotation.
     """
 
-    def __init__(self, degree_range=(-15, +15), prob=0.5):
-        self.prob = prob
-        self.degree_range = degree_range
+    def __init__(self, p=0.5, **kwargs):
+        kwargs.setdefault('degrees', 15)
+        kwargs.setdefault('resample', PIL.Image.BILINEAR)
+        super(RandomRotation, self).__init__(**kwargs)
+        self.p = p
 
     def __call__(self, *args):
-        if random.random() < self.prob:
-            degree = random.randint(*self.degree_range)
-            return [VF.rotate(img, degree, resample=Image.BILINEAR) for img in args]
+        if random.random() < self.p:
+            angle = self.get_params(self.degrees)
+            return [
+                torchvision.transforms.functional.rotate(img, angle,
+                    self.resample, self.expand, self.center)
+                for img in args
+                ]
         else:
             return args
 
 
-class ColorJitter(object):
-    """
-    Randomly change the brightness, contrast, saturation and hue
+class ColorJitter(torchvision.transforms.ColorJitter):
+    """Randomly applies a color jitter transformation on the **first** image
 
-    Attributes
+    Notice this transform extension, unlike others in this module, only affects
+    the first image passed as input argument.  Unlike the current torchvision
+    implementation, we also accept a probability for applying the jitter.
+
+    Parameters
     ----------
-    brightness : float
-        how much to jitter brightness. brightness_factor
-        is chosen uniformly from ``[max(0, 1 - brightness), 1 + brightness]``.
-    contrast : float
-        how much to jitter contrast. contrast_factor
-        is chosen uniformly from ``[max(0, 1 - contrast), 1 + contrast]``.
-    saturation : float
-        how much to jitter saturation. saturation_factor
-        is chosen uniformly from ``[max(0, 1 - saturation), 1 + saturation]``.
-    hue : float
-        how much to jitter hue. hue_factor is chosen uniformly from
-        ``[-hue, hue]``. Should be >=0 and <= 0.5
-    prob : float
-        probability at which the operation is applied
-    """
 
-    def __init__(
-        self, brightness=0.3, contrast=0.3, saturation=0.02, hue=0.02, prob=0.5
-    ):
-        self.brightness = brightness
-        self.contrast = contrast
-        self.saturation = saturation
-        self.hue = hue
-        self.prob = prob
-
-    @staticmethod
-    def get_params(brightness, contrast, saturation, hue):
-        transforms = []
-        if brightness > 0:
-            brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
-            transforms.append(
-                Lambda(lambda img: VF.adjust_brightness(img, brightness_factor))
-            )
+    p : float
+        probability at which the operation is applied
 
-        if contrast > 0:
-            contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
-            transforms.append(
-                Lambda(lambda img: VF.adjust_contrast(img, contrast_factor))
-            )
+    *args : tuple
+        passed to parent
 
-        if saturation > 0:
-            saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
-            transforms.append(
-                Lambda(lambda img: VF.adjust_saturation(img, saturation_factor))
-            )
+    **kwargs : dict
+        passed to parent
 
-        if hue > 0:
-            hue_factor = random.uniform(-hue, hue)
-            transforms.append(Lambda(lambda img: VF.adjust_hue(img, hue_factor)))
-
-        random.shuffle(transforms)
-        transform = TorchVisionCompose(transforms)
+    """
 
-        return transform
+    def __init__(self, p=0.5, **kwargs):
+        kwargs.setdefault('brightness', 0.3)
+        kwargs.setdefault('contrast', 0.3)
+        kwargs.setdefault('saturation', 0.02)
+        kwargs.setdefault('hue', 0.02)
+        super(ColorJitter, self).__init__(**kwargs)
+        self.p = p
 
     def __call__(self, *args):
-        if random.random() < self.prob:
+        if random.random() < self.p:
             transform = self.get_params(
                 self.brightness, self.contrast, self.saturation, self.hue
             )
-            trans_img = transform(args[0])
-            return [trans_img, *args[1:]]
-        else:
-            return args
-
-
-class RandomResizedCrop:
-    """Crop to random size and aspect ratio.
-    A crop of random size of the original size and a random aspect ratio of
-    the original aspect ratio is made. This crop is finally resized to
-    given size. This is popularly used to train the Inception networks.
-
-    Attributes
-    ----------
-    size : int
-        expected output size of each edge
-    scale : tuple
-        range of size of the origin size cropped. Defaults to ``(0.08, 1.0)``
-    ratio : tuple
-        range of aspect ratio of the origin aspect ratio cropped. Defaults to ``(3. / 4., 4. / 3.)``
-    interpolation :
-        Defaults to ``PIL.Image.BILINEAR``
-    prob : float
-        probability at which the operation is applied. Defaults to ``0.5``
-    """
-
-    def __init__(
-        self,
-        size,
-        scale=(0.08, 1.0),
-        ratio=(3.0 / 4.0, 4.0 / 3.0),
-        interpolation=Image.BILINEAR,
-        prob=0.5,
-    ):
-        if isinstance(size, tuple):
-            self.size = size
-        else:
-            self.size = (size, size)
-        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
-            warnings.warn("range should be of kind (min, max)")
-
-        self.interpolation = interpolation
-        self.scale = scale
-        self.ratio = ratio
-        self.prob = prob
-
-    @staticmethod
-    def get_params(img, scale, ratio):
-        area = img.size[0] * img.size[1]
-
-        for attempt in range(10):
-            target_area = random.uniform(*scale) * area
-            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
-            aspect_ratio = math.exp(random.uniform(*log_ratio))
-
-            w = int(round(math.sqrt(target_area * aspect_ratio)))
-            h = int(round(math.sqrt(target_area / aspect_ratio)))
-
-            if w <= img.size[0] and h <= img.size[1]:
-                i = random.randint(0, img.size[1] - h)
-                j = random.randint(0, img.size[0] - w)
-                return i, j, h, w
-
-        # Fallback to central crop
-        in_ratio = img.size[0] / img.size[1]
-        if in_ratio < min(ratio):
-            w = img.size[0]
-            h = w / min(ratio)
-        elif in_ratio > max(ratio):
-            h = img.size[1]
-            w = h * max(ratio)
-        else:  # whole image
-            w = img.size[0]
-            h = img.size[1]
-        i = (img.size[1] - h) // 2
-        j = (img.size[0] - w) // 2
-        return i, j, h, w
-
-    def __call__(self, *args):
-        if random.random() < self.prob:
-            imgs = []
-            for img in args:
-                i, j, h, w = self.get_params(img, self.scale, self.ratio)
-                img = VF.resized_crop(img, i, j, h, w, self.size, self.interpolation)
-                imgs.append(img)
-            return imgs
+            return [transform(args[0]), *args[1:]]
         else:
             return args
-
-    def __repr__(self):
-        interpolate_str = _pil_interpolation_to_str[self.interpolation]
-        format_string = self.__class__.__name__ + "(size={0}".format(self.size)
-        format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale))
-        format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio))
-        format_string += ", interpolation={0})".format(interpolate_str)
-        return format_string
-
-
-class Resize:
-    """Resize to given size.
-
-    Attributes
-    ----------
-    size : tuple or int
-        Desired output size. If size is a sequence like
-        (h, w), output size will be matched to this. If size is an int,
-        smaller edge of the image will be matched to this number.
-        i.e, if height > width, then image will be rescaled to
-        (size * height / width, size)
-    interpolation : int
-        Desired interpolation. Default is``PIL.Image.BILINEAR``
-    """
-
-    def __init__(self, size, interpolation=Image.BILINEAR):
-        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
-        self.size = size
-        self.interpolation = interpolation
-
-    def __call__(self, *args):
-        return [VF.resize(img, self.size, self.interpolation) for img in args]
-
-    def __repr__(self):
-        interpolate_str = _pil_interpolation_to_str[self.interpolation]
-        return self.__class__.__name__ + "(size={0}, interpolation={1})".format(
-            self.size, interpolate_str
-        )
diff --git a/bob/ip/binseg/test/test_transforms.py b/bob/ip/binseg/test/test_transforms.py
index e71716a0bc927186ad2121340e7e0b230b5adfae..826343f632d9c4e7b62fbecea68df45bcf71728b 100644
--- a/bob/ip/binseg/test/test_transforms.py
+++ b/bob/ip/binseg/test/test_transforms.py
@@ -1,50 +1,341 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import random
+
+import nose.tools
+import numpy
 import torch
-import unittest
-import numpy as np
-from bob.ip.binseg.data.transforms import *
+import torchvision.transforms.functional
 
-transforms = Compose(
-    [RandomHFlip(prob=1), RandomHFlip(prob=1), RandomVFlip(prob=1), RandomVFlip(prob=1)]
-)
+from ..data.transforms import *
 
 
-def create_img():
-    t = torch.randn((3, 42, 24))
-    pil = VF.to_pil_image(t)
+def _create_img(size):
+    t = torch.randn(size)
+    pil = torchvision.transforms.functional.to_pil_image(t)
     return pil
 
 
-class Tester(unittest.TestCase):
-    """
-    Unit test for random flips
-    """
-
-    def test_flips(self):
-        transforms = Compose(
-            [
-                RandomHFlip(prob=1),
-                RandomHFlip(prob=1),
-                RandomVFlip(prob=1),
-                RandomVFlip(prob=1),
-            ]
-        )
-        img, gt, mask = [create_img() for i in range(3)]
-        img_t, gt_t, mask_t = transforms(img, gt, mask)
-        self.assertTrue(np.all(np.array(img_t) == np.array(img)))
-        self.assertTrue(np.all(np.array(gt_t) == np.array(gt)))
-        self.assertTrue(np.all(np.array(mask_t) == np.array(mask)))
-
-    def test_to_tensor(self):
-        transforms = ToTensor()
-        img, gt, mask = [create_img() for i in range(3)]
-        img_t, gt_t, mask_t = transforms(img, gt, mask)
-        self.assertEqual(str(img_t.dtype), "torch.float32")
-        self.assertEqual(str(gt_t.dtype), "torch.float32")
-        self.assertEqual(str(mask_t.dtype), "torch.float32")
-
-
-if __name__ == "__main__":
-    unittest.main()
+def test_center_crop():
+
+    # parameters
+    im_size = (3, 22, 20)  # (planes, height, width)
+    crop_size = (10, 12)  # (height, width)
+
+    # test
+    bh = (im_size[1] - crop_size[0]) // 2
+    bw = (im_size[2] - crop_size[1]) // 2
+    idx = (slice(bh, -bh), slice(bw, -bw), slice(0, im_size[0]))
+    transforms = CenterCrop(crop_size)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    nose.tools.eq_(img.size, (im_size[2], im_size[1]))  # confirms the above
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    nose.tools.eq_(
+        img_t.size, (crop_size[1], crop_size[0])
+    )  # confirms the above
+    # notice that PIL->array does array.transpose(1, 2, 0)
+    # so it creates an array that is (height, width, planes)
+    assert numpy.all(numpy.array(img_t) == numpy.array(img)[idx])
+    assert numpy.all(numpy.array(gt_t) == numpy.array(gt)[idx])
+    assert numpy.all(numpy.array(mask_t) == numpy.array(mask)[idx])
+
+
+def test_center_crop_uneven():
+
+    # parameters
+    im_size = (3, 23, 20)  # (planes, height, width)
+    crop_size = (10, 13)  # (height, width)
+
+    # test
+    bh = (im_size[1] - crop_size[0]) // 2
+    bw = (im_size[2] - crop_size[1]) // 2
+    # when the crop size is uneven, this is what happens - notice here that the
+    # image height is uneven, and the crop width as well - the attributions of
+    # extra pixels will depend on what is uneven (original image or crop)
+    idx = (slice(bh, -(bh + 1)), slice((bw + 1), -bw), slice(0, im_size[0]))
+    transforms = CenterCrop(crop_size)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    nose.tools.eq_(img.size, (im_size[2], im_size[1]))  # confirms the above
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    nose.tools.eq_(
+        img_t.size, (crop_size[1], crop_size[0])
+    )  # confirms the above
+    # notice that PIL->array does array.transpose(1, 2, 0)
+    # so it creates an array that is (height, width, planes)
+    assert numpy.all(numpy.array(img_t) == numpy.array(img)[idx])
+    assert numpy.all(numpy.array(gt_t) == numpy.array(gt)[idx])
+    assert numpy.all(numpy.array(mask_t) == numpy.array(mask)[idx])
+
+
+def test_pad_default():
+
+    # parameters
+    im_size = (3, 22, 20)  # (planes, height, width)
+    pad_size = 2
+
+    # test
+    idx = (
+        slice(pad_size, -pad_size),
+        slice(pad_size, -pad_size),
+        slice(0, im_size[0]),
+    )
+    transforms = Pad(pad_size)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    nose.tools.eq_(img.size, (im_size[2], im_size[1]))  # confirms the above
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    # notice that PIL->array does array.transpose(1, 2, 0)
+    # so it creates an array that is (height, width, planes)
+    assert numpy.all(numpy.array(img_t)[idx] == numpy.array(img))
+    assert numpy.all(numpy.array(gt_t)[idx] == numpy.array(gt))
+    assert numpy.all(numpy.array(mask_t)[idx] == numpy.array(mask))
+
+    # checks that the border introduced with padding is all about "fill"
+    img_t = numpy.array(img_t)
+    img_t[idx] = 0
+    border_size_plane = (img_t[:,:,0].size - numpy.array(img)[:,:,0].size)
+    nose.tools.eq_(img_t.sum(), 0)
+
+    gt_t = numpy.array(gt_t)
+    gt_t[idx] = 0
+    nose.tools.eq_(gt_t.sum(), 0)
+
+    mask_t = numpy.array(mask_t)
+    mask_t[idx] = 0
+    nose.tools.eq_(mask_t.sum(), 0)
+
+
+def test_pad_2tuple():
+
+    # parameters
+    im_size = (3, 22, 20)  # (planes, height, width)
+    pad_size = (1, 2)  # left/right, top/bottom
+    fill = (3, 4, 5)
+
+    # test
+    idx = (
+        slice(pad_size[1], -pad_size[1]),
+        slice(pad_size[0], -pad_size[0]),
+        slice(0, im_size[0]),
+    )
+    transforms = Pad(pad_size, fill)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    nose.tools.eq_(img.size, (im_size[2], im_size[1]))  # confirms the above
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    # notice that PIL->array does array.transpose(1, 2, 0)
+    # so it creates an array that is (height, width, planes)
+    assert numpy.all(numpy.array(img_t)[idx] == numpy.array(img))
+    assert numpy.all(numpy.array(gt_t)[idx] == numpy.array(gt))
+    assert numpy.all(numpy.array(mask_t)[idx] == numpy.array(mask))
+
+    # checks that the border introduced with padding is all about "fill"
+    img_t = numpy.array(img_t)
+    img_t[idx] = 0
+    border_size_plane = (img_t[:,:,0].size - numpy.array(img)[:,:,0].size)
+    expected_sum = sum((fill[k]*border_size_plane) for k in range(3))
+    nose.tools.eq_(img_t.sum(), expected_sum)
+
+    gt_t = numpy.array(gt_t)
+    gt_t[idx] = 0
+    nose.tools.eq_(gt_t.sum(), expected_sum)
+
+    mask_t = numpy.array(mask_t)
+    mask_t[idx] = 0
+    nose.tools.eq_(mask_t.sum(), expected_sum)
+
+
+def test_pad_4tuple():
+
+    # parameters
+    im_size = (3, 22, 20)  # (planes, height, width)
+    pad_size = (1, 2, 3, 4)  # left, top, right, bottom
+    fill = (3, 4, 5)
+
+    # test
+    idx = (
+        slice(pad_size[1], -pad_size[3]),
+        slice(pad_size[0], -pad_size[2]),
+        slice(0, im_size[0]),
+    )
+    transforms = Pad(pad_size, fill)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    nose.tools.eq_(img.size, (im_size[2], im_size[1]))  # confirms the above
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    # notice that PIL->array does array.transpose(1, 2, 0)
+    # so it creates an array that is (height, width, planes)
+    assert numpy.all(numpy.array(img_t)[idx] == numpy.array(img))
+    assert numpy.all(numpy.array(gt_t)[idx] == numpy.array(gt))
+    assert numpy.all(numpy.array(mask_t)[idx] == numpy.array(mask))
+
+    # checks that the border introduced with padding is all about "fill"
+    img_t = numpy.array(img_t)
+    img_t[idx] = 0
+    border_size_plane = (img_t[:,:,0].size - numpy.array(img)[:,:,0].size)
+    expected_sum = sum((fill[k]*border_size_plane) for k in range(3))
+    nose.tools.eq_(img_t.sum(), expected_sum)
+
+    gt_t = numpy.array(gt_t)
+    gt_t[idx] = 0
+    nose.tools.eq_(gt_t.sum(), expected_sum)
+
+    mask_t = numpy.array(mask_t)
+    mask_t[idx] = 0
+    nose.tools.eq_(mask_t.sum(), expected_sum)
+
+
+def test_resize_downscale_w():
+
+    # parameters
+    im_size = (3, 22, 20)  # (planes, height, width)
+    new_size = 10  # (smallest edge)
+
+    # test
+    transforms = Resize(new_size)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    nose.tools.eq_(img.size, (im_size[2], im_size[1]))  # confirms the above
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    new_size = (new_size, (new_size*im_size[1])/im_size[2])
+    nose.tools.eq_(img_t.size, new_size)
+    nose.tools.eq_(gt_t.size, new_size)
+    nose.tools.eq_(mask_t.size, new_size)
+
+
+def test_resize_downscale_hw():
+
+    # parameters
+    im_size = (3, 22, 20)  # (planes, height, width)
+    new_size = (10, 12)  # (height, width)
+
+    # test
+    transforms = Resize(new_size)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    nose.tools.eq_(img.size, (im_size[2], im_size[1]))  # confirms the above
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    nose.tools.eq_(img_t.size, (new_size[1], new_size[0]))
+    nose.tools.eq_(gt_t.size, (new_size[1], new_size[0]))
+    nose.tools.eq_(mask_t.size, (new_size[1], new_size[0]))
+
+
+def test_crop():
+
+    # parameters
+    im_size = (3, 22, 20)  # (planes, height, width)
+    crop_size = (3, 2, 10, 12)  # (upper, left, height, width)
+
+    # test
+    idx = (
+        slice(crop_size[0], crop_size[0]+crop_size[2]),
+        slice(crop_size[1], crop_size[1]+crop_size[3]),
+        slice(0, im_size[0]),
+    )
+    transforms = Crop(*crop_size)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    nose.tools.eq_(img.size, (im_size[2], im_size[1]))  # confirms the above
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    # notice that PIL->array does array.transpose(1, 2, 0)
+    # so it creates an array that is (height, width, planes)
+    assert numpy.all(numpy.array(img_t) == numpy.array(img)[idx])
+    assert numpy.all(numpy.array(gt_t) == numpy.array(gt)[idx])
+    assert numpy.all(numpy.array(mask_t) == numpy.array(mask)[idx])
+
+
+def test_to_tensor():
+
+    transforms = ToTensor()
+    img, gt, mask = [_create_img((3, 5, 5)) for i in range(3)]
+    gt = gt.convert("1", dither=None)
+    mask = mask.convert("1", dither=None)
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    nose.tools.eq_(img_t.dtype, torch.float32)
+    nose.tools.eq_(gt_t.dtype, torch.float32)
+    nose.tools.eq_(mask_t.dtype, torch.float32)
+
+
+def test_horizontal_flip():
+
+    transforms = RandomHorizontalFlip(p=1)
+
+    im_size = (3, 24, 42)  # (planes, height, width)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+
+    # notice that PIL->array does array.transpose(1, 2, 0)
+    # so it creates an array that is (height, width, planes)
+    assert numpy.all(numpy.flip(img_t, axis=1) == numpy.array(img))
+    assert numpy.all(numpy.flip(gt_t, axis=1) == numpy.array(gt))
+    assert numpy.all(numpy.flip(mask_t, axis=1) == numpy.array(mask))
+
+
+def test_vertical_flip():
+
+    transforms = RandomVerticalFlip(p=1)
+
+    im_size = (3, 24, 42)  # (planes, height, width)
+    img, gt, mask = [_create_img(im_size) for i in range(3)]
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+
+    # notice that PIL->array does array.transpose(1, 2, 0)
+    # so it creates an array that is (height, width, planes)
+    assert numpy.all(numpy.flip(img_t, axis=0) == numpy.array(img))
+    assert numpy.all(numpy.flip(gt_t, axis=0) == numpy.array(gt))
+    assert numpy.all(numpy.flip(mask_t, axis=0) == numpy.array(mask))
+
+
+def test_rotation():
+
+    im_size = (3, 24, 42)  # (planes, height, width)
+    transforms = RandomRotation(degrees=90, p=1)
+    img = _create_img(im_size)
+
+    # asserts all images are rotated the same
+    # and they are different from the original
+    random.seed(42)
+    img1_t, img2_t, img3_t = transforms(img, img, img)
+    nose.tools.eq_(img1_t.size, (im_size[2], im_size[1]))
+    assert numpy.all(numpy.array(img1_t) == numpy.array(img2_t))
+    assert numpy.all(numpy.array(img1_t) == numpy.array(img3_t))
+    assert numpy.any(numpy.array(img1_t) != numpy.array(img))
+
+    # asserts two random transforms are not the same
+    img_t2, = transforms(img)
+    assert numpy.any(numpy.array(img_t2) != numpy.array(img1_t))
+
+
+def test_color_jitter():
+
+    im_size = (3, 24, 42)  # (planes, height, width)
+    transforms = ColorJitter(p=1)
+    img = _create_img(im_size)
+
+    # asserts only the first image is jittered
+    # and it is different from the original
+    # all others match the input data
+    random.seed(42)
+    img1_t, img2_t, img3_t = transforms(img, img, img)
+    nose.tools.eq_(img1_t.size, (im_size[2], im_size[1]))
+    assert numpy.any(numpy.array(img1_t) != numpy.array(img))
+    assert numpy.any(numpy.array(img1_t) != numpy.array(img2_t))
+    assert numpy.all(numpy.array(img2_t) == numpy.array(img3_t))
+    assert numpy.all(numpy.array(img2_t) == numpy.array(img))
+
+    # asserts two random transforms are not the same
+    img1_t2, img2_t2, img3_t2 = transforms(img, img, img)
+    assert numpy.any(numpy.array(img1_t2) != numpy.array(img1_t))
+    assert numpy.all(numpy.array(img2_t2) == numpy.array(img))
+    assert numpy.all(numpy.array(img3_t2) == numpy.array(img))
+
+
+def test_compose():
+
+    transforms = Compose([
+                RandomVerticalFlip(p=1),
+                RandomHorizontalFlip(p=1),
+                RandomVerticalFlip(p=1),
+                RandomHorizontalFlip(p=1),
+                ])
+
+    img, gt, mask = [_create_img((3, 24, 42)) for i in range(3)]
+    img_t, gt_t, mask_t = transforms(img, gt, mask)
+    assert numpy.all(numpy.array(img_t) == numpy.array(img))
+    assert numpy.all(numpy.array(gt_t) == numpy.array(gt))
+    assert numpy.all(numpy.array(mask_t) == numpy.array(mask))