Skip to content
Snippets Groups Projects
Commit 2d491569 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[configs.datasets] Allow data augmentation to be turned-off contextually

parent 8dd85019
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -12,19 +12,15 @@ from ...data.transforms import ( ...@@ -12,19 +12,15 @@ from ...data.transforms import (
) )
AUGMENTATION_ROTATION = [_rotation()] RANDOM_ROTATION = [_rotation()]
"""Shared data augmentation based on random rotation only""" """Shared data augmentation based on random rotation only"""
AUGMENTATION_WITHOUT_ROTATION = [_hflip(), _vflip(), _jitter()] RANDOM_FLIP_JITTER = [_hflip(), _vflip(), _jitter()]
"""Shared data augmentation transforms without random rotation""" """Shared data augmentation transforms without random rotation"""
AUGMENTATION = AUGMENTATION_ROTATION + AUGMENTATION_WITHOUT_ROTATION def make_subset(l, transforms, prefixes, suffixes):
"""Shared data augmentation transforms"""
def make_subset(l, transforms):
"""Creates a new data set, applying transforms """Creates a new data set, applying transforms
Parameters Parameters
...@@ -36,6 +32,14 @@ def make_subset(l, transforms): ...@@ -36,6 +32,14 @@ def make_subset(l, transforms):
transforms : list transforms : list
A list of transforms that needs to be applied to all samples in the set A list of transforms that needs to be applied to all samples in the set
prefixes : list
A list of data augmentation operations that needs to be applied
**before** the transforms above
suffixes : list
A list of data augmentation operations that needs to be applied
**after** the transforms above
Returns Returns
------- -------
...@@ -46,7 +50,8 @@ def make_subset(l, transforms): ...@@ -46,7 +50,8 @@ def make_subset(l, transforms):
""" """
from ...data.utils import SampleList2TorchDataset as wrapper from ...data.utils import SampleList2TorchDataset as wrapper
return wrapper(l, transforms)
return wrapper(l, transforms, prefixes, suffixes)
def make_trainset(l, transforms, rotation_before=False): def make_trainset(l, transforms, rotation_before=False):
...@@ -77,10 +82,19 @@ def make_trainset(l, transforms, rotation_before=False): ...@@ -77,10 +82,19 @@ def make_trainset(l, transforms, rotation_before=False):
""" """
if rotation_before: if rotation_before:
return make_subset(l, AUGMENTATION_ROTATION + transforms + \ return make_subset(
AUGMENTATION_WITHOUT_ROTATION) l,
transforms=transforms,
prefixes=RANDOM_ROTATION,
suffixes=RANDOM_FLIP_JITTER,
)
return make_subset(l, transforms + AUGMENTATION) return make_subset(
l,
transforms,
prefixes=[],
suffixes=(RANDOM_ROTATION + RANDOM_FLIP_JITTER),
)
def make_dataset(subsets, transforms): def make_dataset(subsets, transforms):
...@@ -120,8 +134,12 @@ def make_dataset(subsets, transforms): ...@@ -120,8 +134,12 @@ def make_dataset(subsets, transforms):
for key in subsets.keys(): for key in subsets.keys():
if key == "train": if key == "train":
retval[key] = make_trainset(subsets[key], transforms) retval[key] = make_trainset(
subsets[key], transforms=transforms, rotation_before=False
)
else: else:
retval[key] = make_subset(subsets[key], transforms) retval[key] = make_subset(
subsets[key], transforms=transforms, prefixes=[], suffixes=[]
)
return retval return retval
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
"""Common utilities""" """Common utilities"""
import contextlib
import PIL.Image import PIL.Image
import PIL.ImageOps import PIL.ImageOps
import PIL.ImageChops import PIL.ImageChops
...@@ -17,8 +19,9 @@ from .transforms import Compose, ToTensor ...@@ -17,8 +19,9 @@ from .transforms import Compose, ToTensor
def invert_mode1_image(img): def invert_mode1_image(img):
"""Inverts a binary PIL image (mode == ``"1"``)""" """Inverts a binary PIL image (mode == ``"1"``)"""
return PIL.ImageOps.invert(img.convert("RGB")).convert(mode="1", return PIL.ImageOps.invert(img.convert("RGB")).convert(
dither=None) mode="1", dither=None
)
def subtract_mode1_images(img1, img2): def subtract_mode1_images(img1, img2):
...@@ -27,8 +30,14 @@ def subtract_mode1_images(img1, img2): ...@@ -27,8 +30,14 @@ def subtract_mode1_images(img1, img2):
return PIL.ImageChops.subtract(img1, img2) return PIL.ImageChops.subtract(img1, img2)
def overlayed_image(img, label, mask=None, label_color=(0, 255, 0), def overlayed_image(
mask_color=(0, 0, 255), alpha=0.4): img,
label,
mask=None,
label_color=(0, 255, 0),
mask_color=(0, 0, 255),
alpha=0.4,
):
"""Creates an image showing existing labels and masko """Creates an image showing existing labels and masko
This function creates a new representation of the input image ``img`` This function creates a new representation of the input image ``img``
...@@ -78,20 +87,21 @@ def overlayed_image(img, label, mask=None, label_color=(0, 255, 0), ...@@ -78,20 +87,21 @@ def overlayed_image(img, label, mask=None, label_color=(0, 255, 0),
""" """
# creates a representation of labels with the right color # creates a representation of labels with the right color
label_colored = PIL.ImageOps.colorize(label.convert("L"), (0, 0, 0), label_colored = PIL.ImageOps.colorize(
label_color) label.convert("L"), (0, 0, 0), label_color
)
# blend image and label together - first blend to get vessels drawn with a # blend image and label together - first blend to get vessels drawn with a
# slight "label_color" tone on top, then composite with original image, not # slight "label_color" tone on top, then composite with original image, not
# to loose brightness. # to loose brightness.
retval = PIL.Image.blend(img, label_colored, alpha) retval = PIL.Image.blend(img, label_colored, alpha)
retval = PIL.Image.composite(img, retval, retval = PIL.Image.composite(img, retval, invert_mode1_image(label))
invert_mode1_image(label))
# creates a representation of the mask negative with the right color # creates a representation of the mask negative with the right color
if mask is not None: if mask is not None:
antimask_colored = PIL.ImageOps.colorize(mask.convert("L"), mask_color, antimask_colored = PIL.ImageOps.colorize(
(0, 0, 0)) mask.convert("L"), mask_color, (0, 0, 0)
)
tmp = PIL.Image.blend(retval, antimask_colored, alpha) tmp = PIL.Image.blend(retval, antimask_colored, alpha)
retval = PIL.Image.composite(retval, tmp, mask) retval = PIL.Image.composite(retval, tmp, mask)
...@@ -106,6 +116,14 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): ...@@ -106,6 +116,14 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
It supports indexing such that dataset[i] can be used to get ith sample. It supports indexing such that dataset[i] can be used to get ith sample.
Attributes
----------
augmented : bool
Tells if this set has data augmentation prefixes or suffixes installed.
Parameters Parameters
---------- ----------
...@@ -114,17 +132,31 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): ...@@ -114,17 +132,31 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
transforms : :py:class:`list`, Optional transforms : :py:class:`list`, Optional
a list of transformations to be applied to **both** image and a list of transformations to be applied to **both** image and
ground-truth data. Notice that image changing transformations such as ground-truth data. Notice a last transform
:py:class:`.transforms.ColorJitter` are only applied to the image and
**not** to ground-truth. Also notice a last transform
(:py:class:`bob.ip.binseg.data.transforms.ToTensor`) is always applied. (:py:class:`bob.ip.binseg.data.transforms.ToTensor`) is always applied.
prefixes : :py:class:`list`, Optional
a list of data augmentation transformations to be applied to **both**
image and ground-truth data and **before** ``transforms`` above.
Notice that transforms like
:py:class:`bob.ip.binseg.data.transforms.ColorJitter` are only applied
to the input image.
suffixes : :py:class:`list`, Optional
a list of data augmentation transformations to be applied to **both**
image and ground-truth data and **after** ``transforms`` above.
Notice that transforms like
:py:class:`bob.ip.binseg.data.transforms.ColorJitter` are only applied
to the input image.
""" """
def __init__(self, samples, transforms=[]): def __init__(self, samples, transforms=[], prefixes=[], suffixes=[]):
self._samples = samples self._samples = samples
self._transform = Compose(transforms + [ToTensor()]) self._middle = transforms
self._transforms = Compose(prefixes + transforms + suffixes + [ToTensor()])
self.augmented = bool(prefixes or suffixes)
def __len__(self): def __len__(self):
""" """
...@@ -138,6 +170,18 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): ...@@ -138,6 +170,18 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
""" """
return len(self._samples) return len(self._samples)
@contextlib.contextmanager
def not_augmented(self):
"""Context to avoid data augmentation to be applied to self"""
backup = (self.augmented, self._transforms)
self.augmented = False
self._transforms = Compose(self._middle + [ToTensor()])
try:
yield self
finally:
self.augmented, self._transforms = backup
def __getitem__(self, key): def __getitem__(self, key):
""" """
...@@ -161,11 +205,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset): ...@@ -161,11 +205,13 @@ class SampleList2TorchDataset(torch.utils.data.Dataset):
data = item.data # triggers data loading data = item.data # triggers data loading
retval = [data["data"]] retval = [data["data"]]
if "label" in data: retval.append(data["label"]) if "label" in data:
if "mask" in data: retval.append(data["mask"]) retval.append(data["label"])
if "mask" in data:
retval.append(data["mask"])
if self._transform: if self._transforms:
retval = self._transform(*retval) retval = self._transforms(*retval)
return [item.key] + retval return [item.key] + retval
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment