From bec562fc92eec358b387590de0048fde13335496 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Thu, 9 Apr 2020 17:22:57 +0200 Subject: [PATCH] [bob.db.drive] Remove requirement --- bob/ip/binseg/configs/datasets/drive.py | 4 +- bob/ip/binseg/configs/datasets/drive1024.py | 28 +++--- .../binseg/configs/datasets/drive1024test.py | 26 +++--- bob/ip/binseg/configs/datasets/drive1168.py | 28 +++--- bob/ip/binseg/configs/datasets/drive608.py | 26 +++--- bob/ip/binseg/configs/datasets/drive960.py | 28 +++--- bob/ip/binseg/configs/datasets/drivetest.py | 4 +- .../starechasedb1iostarhrf544ssldrive.py | 57 +++++------- bob/ip/binseg/data/binsegdataset.py | 36 ++++---- bob/ip/binseg/data/utils.py | 86 +++++++++++++++++-- conda/meta.yaml | 1 - doc/setup.rst | 15 +--- 12 files changed, 208 insertions(+), 131 deletions(-) diff --git a/bob/ip/binseg/configs/datasets/drive.py b/bob/ip/binseg/configs/datasets/drive.py index 3412e5be..5e5b4986 100644 --- a/bob/ip/binseg/configs/datasets/drive.py +++ b/bob/ip/binseg/configs/datasets/drive.py @@ -14,7 +14,7 @@ segmentation of blood vessels in retinal images. """ from bob.ip.binseg.data.transforms import * -transforms = Compose( +_transforms = Compose( [ CenterCrop((544, 544)), RandomHFlip(), @@ -28,4 +28,4 @@ transforms = Compose( from bob.ip.binseg.data.utils import DelayedSample2TorchDataset from bob.ip.binseg.data.drive import dataset as drive dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"], - transform=transforms) + transform=_transforms) diff --git a/bob/ip/binseg/configs/datasets/drive1024.py b/bob/ip/binseg/configs/datasets/drive1024.py index ea99feb0..f27ac3b7 100644 --- a/bob/ip/binseg/configs/datasets/drive1024.py +++ b/bob/ip/binseg/configs/datasets/drive1024.py @@ -1,13 +1,20 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- +# coding=utf-8 -from bob.db.drive import Database as DRIVE -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset +"""DRIVE (training set) for Vessel Segmentation + +The DRIVE database has been established to enable comparative studies on +segmentation of blood vessels in retinal images. -#### Config #### +* Reference: [DRIVE-2004]_ +* Original resolution (height x width): 584 x 565 +* This configuration resolution: 1024 x 1024 (center-crop) +* Training samples: 20 +* Split reference: [DRIVE-2004]_ +""" -transforms = Compose( +from bob.ip.binseg.data.transforms import * +_transforms = Compose( [ RandomRotation(), CenterCrop((540, 540)), @@ -19,8 +26,7 @@ transforms = Compose( ] ) -# bob.db.dataset init -bobdb = DRIVE(protocol="default") - -# PyTorch dataset -dataset = BinSegDataset(bobdb, split="train", transform=transforms) +from bob.ip.binseg.data.utils import DelayedSample2TorchDataset +from bob.ip.binseg.data.drive import dataset as drive +dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"], + transform=_transforms) diff --git a/bob/ip/binseg/configs/datasets/drive1024test.py b/bob/ip/binseg/configs/datasets/drive1024test.py index c409dae5..54d1cf5c 100644 --- a/bob/ip/binseg/configs/datasets/drive1024test.py +++ b/bob/ip/binseg/configs/datasets/drive1024test.py @@ -1,16 +1,22 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- +# coding=utf-8 -from bob.db.drive import Database as DRIVE -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset +"""DRIVE (training set) for Vessel Segmentation -#### Config #### +The DRIVE database has been established to enable comparative studies on +segmentation of blood vessels in retinal images. -transforms = Compose([CenterCrop((540, 540)), Resize(1024), ToTensor()]) +* Reference: [DRIVE-2004]_ +* Original resolution (height x width): 584 x 565 +* This configuration resolution: 1024 x 1024 (center-crop) +* Test samples: 20 +* Split reference: [DRIVE-2004]_ +""" -# bob.db.dataset init -bobdb = DRIVE(protocol="default") +from bob.ip.binseg.data.transforms import * +_transforms = Compose([CenterCrop((540, 540)), Resize(1024), ToTensor()]) -# PyTorch dataset -dataset = BinSegDataset(bobdb, split="test", transform=transforms) +from bob.ip.binseg.data.utils import DelayedSample2TorchDataset +from bob.ip.binseg.data.drive import dataset as drive +dataset = DelayedSample2TorchDataset(drive.subsets("default")["test"], + transform=_transforms) diff --git a/bob/ip/binseg/configs/datasets/drive1168.py b/bob/ip/binseg/configs/datasets/drive1168.py index f4a51d95..ce84af98 100644 --- a/bob/ip/binseg/configs/datasets/drive1168.py +++ b/bob/ip/binseg/configs/datasets/drive1168.py @@ -1,13 +1,20 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- +# coding=utf-8 -from bob.db.drive import Database as DRIVE -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset +"""DRIVE (training set) for Vessel Segmentation + +The DRIVE database has been established to enable comparative studies on +segmentation of blood vessels in retinal images. -#### Config #### +* Reference: [DRIVE-2004]_ +* Original resolution (height x width): 584 x 565 +* This configuration resolution: 1168 x 1168 (center-crop) +* Training samples: 20 +* Split reference: [DRIVE-2004]_ +""" -transforms = Compose( +from bob.ip.binseg.data.transforms import * +_transforms = Compose( [ RandomRotation(), Crop(75, 10, 416, 544), @@ -20,8 +27,7 @@ transforms = Compose( ] ) -# bob.db.dataset init -bobdb = DRIVE(protocol="default") - -# PyTorch dataset -dataset = BinSegDataset(bobdb, split="train", transform=transforms) +from bob.ip.binseg.data.utils import DelayedSample2TorchDataset +from bob.ip.binseg.data.drive import dataset as drive +dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"], + transform=_transforms) diff --git a/bob/ip/binseg/configs/datasets/drive608.py b/bob/ip/binseg/configs/datasets/drive608.py index e251930b..d5794205 100644 --- a/bob/ip/binseg/configs/datasets/drive608.py +++ b/bob/ip/binseg/configs/datasets/drive608.py @@ -1,13 +1,20 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from bob.db.drive import Database as DRIVE -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset +"""DRIVE (training set) for Vessel Segmentation + +The DRIVE database has been established to enable comparative studies on +segmentation of blood vessels in retinal images. -#### Config #### +* Reference: [DRIVE-2004]_ +* Original resolution (height x width): 584 x 565 +* This configuration resolution: 608 x 608 (center-crop) +* Training samples: 20 +* Split reference: [DRIVE-2004]_ +""" -transforms = Compose( +from bob.ip.binseg.data.transforms import * +_transforms = Compose( [ RandomRotation(), CenterCrop((470, 544)), @@ -20,8 +27,7 @@ transforms = Compose( ] ) -# bob.db.dataset init -bobdb = DRIVE(protocol="default") - -# PyTorch dataset -dataset = BinSegDataset(bobdb, split="train", transform=transforms) +from bob.ip.binseg.data.utils import DelayedSample2TorchDataset +from bob.ip.binseg.data.drive import dataset as drive +dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"], + transform=_transforms) diff --git a/bob/ip/binseg/configs/datasets/drive960.py b/bob/ip/binseg/configs/datasets/drive960.py index 1a29eeee..a37f54b0 100644 --- a/bob/ip/binseg/configs/datasets/drive960.py +++ b/bob/ip/binseg/configs/datasets/drive960.py @@ -1,13 +1,20 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- +# coding=utf-8 -from bob.db.drive import Database as DRIVE -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset +"""DRIVE (training set) for Vessel Segmentation + +The DRIVE database has been established to enable comparative studies on +segmentation of blood vessels in retinal images. -#### Config #### +* Reference: [DRIVE-2004]_ +* Original resolution (height x width): 584 x 565 +* This configuration resolution: 960 x 960 (center-crop) +* Training samples: 20 +* Split reference: [DRIVE-2004]_ +""" -transforms = Compose( +from bob.ip.binseg.data.transforms import * +_transforms = Compose( [ RandomRotation(), CenterCrop((544, 544)), @@ -19,8 +26,7 @@ transforms = Compose( ] ) -# bob.db.dataset init -bobdb = DRIVE(protocol="default") - -# PyTorch dataset -dataset = BinSegDataset(bobdb, split="train", transform=transforms) +from bob.ip.binseg.data.utils import DelayedSample2TorchDataset +from bob.ip.binseg.data.drive import dataset as drive +dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"], + transform=_transforms) diff --git a/bob/ip/binseg/configs/datasets/drivetest.py b/bob/ip/binseg/configs/datasets/drivetest.py index a92cd812..9315e59a 100644 --- a/bob/ip/binseg/configs/datasets/drivetest.py +++ b/bob/ip/binseg/configs/datasets/drivetest.py @@ -14,9 +14,9 @@ segmentation of blood vessels in retinal images. """ from bob.ip.binseg.data.transforms import * -transforms = Compose([CenterCrop((544, 544)), ToTensor()]) +_transforms = Compose([CenterCrop((544, 544)), ToTensor()]) from bob.ip.binseg.data.utils import DelayedSample2TorchDataset from bob.ip.binseg.data.drive import dataset as drive dataset = DelayedSample2TorchDataset(drive.subsets("default")["test"], - transform=transforms) + transform=_transforms) diff --git a/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py b/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py index f126871f..7e4eda99 100644 --- a/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py +++ b/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py @@ -1,42 +1,29 @@ -from bob.ip.binseg.configs.datasets.stare544 import dataset as stare -from bob.ip.binseg.configs.datasets.chasedb1544 import dataset as chase -from bob.ip.binseg.configs.datasets.iostarvessel544 import dataset as iostar -from bob.ip.binseg.configs.datasets.hrf544 import dataset as hrf -from bob.db.drive import Database as DRIVE -from bob.ip.binseg.data.transforms import * -import torch -from bob.ip.binseg.data.binsegdataset import ( - BinSegDataset, - SSLBinSegDataset, - UnLabeledBinSegDataset, -) +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""DRIVE (SSL training set) for Vessel Segmentation -#### Config #### +The DRIVE database has been established to enable comparative studies on +segmentation of blood vessels in retinal images. -# PyTorch dataset -labeled_dataset = torch.utils.data.ConcatDataset([stare, chase, iostar, hrf]) +* Reference: [DRIVE-2004]_ +* This configuration resolution: 544 x 544 (center-crop) +* Split reference: [DRIVE-2004]_ -#### Unlabeled STARE TRAIN #### -unlabeled_transforms = Compose( - [ - CenterCrop((544, 544)), - RandomHFlip(), - RandomVFlip(), - RandomRotation(), - ColorJitter(), - ToTensor(), - ] -) +The dataset available in this file is composed of STARE, CHASE-DB1, IOSTAR +vessel and HRF (with annotated samples) and DRIVE without labels. +""" -# bob.db.dataset init -drivebobdb = DRIVE(protocol="default") +# Labelled bits +import torch.utils.data +from bob.ip.binseg.configs.datasets.stare544 import dataset as _stare +from bob.ip.binseg.configs.datasets.chasedb1544 import dataset as _chase +from bob.ip.binseg.configs.datasets.iostarvessel544 import dataset as _iostar +from bob.ip.binseg.configs.datasets.hrf544 import dataset as _hrf +_labelled = torch.utils.data.ConcatDataset([_stare, _chase, _iostar, _hrf]) -# PyTorch dataset -unlabeled_dataset = UnLabeledBinSegDataset( - drivebobdb, split="train", transform=unlabeled_transforms -) +# Use DRIVE without labels in this setup +from .drive import dataset as _unlabelled -# SSL Dataset - -dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset) +from bob.ip.binseg.data.utils import SSLDataset +dataset = SSLDataset(_labelled, _unlabelled) diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py index e2977d37..9c487da3 100644 --- a/bob/ip/binseg/data/binsegdataset.py +++ b/bob/ip/binseg/data/binsegdataset.py @@ -5,15 +5,15 @@ import random class BinSegDataset(Dataset): - """PyTorch dataset wrapper around bob.db binary segmentation datasets. - A transform object can be passed that will be applied to the image, ground truth and mask (if present). + """PyTorch dataset wrapper around bob.db binary segmentation datasets. + A transform object can be passed that will be applied to the image, ground truth and mask (if present). It supports indexing such that dataset[i] can be used to get ith sample. - + Parameters - ---------- + ---------- bobdb : :py:mod:`bob.db.base` - Binary segmentation bob database (e.g. bob.db.drive) - split : str + Binary segmentation bob database (e.g. bob.db.drive) + split : str ``'train'`` or ``'test'``. Defaults to ``'train'`` transform : :py:mod:`bob.ip.binseg.data.transforms`, optional A transform or composition of transfroms. Defaults to ``None``. @@ -48,7 +48,7 @@ class BinSegDataset(Dataset): Parameters ---------- index : int - + Returns ------- list @@ -68,12 +68,12 @@ class BinSegDataset(Dataset): class SSLBinSegDataset(Dataset): - """PyTorch dataset wrapper around bob.db binary segmentation datasets. - A transform object can be passed that will be applied to the image, ground truth and mask (if present). + """PyTorch dataset wrapper around bob.db binary segmentation datasets. + A transform object can be passed that will be applied to the image, ground truth and mask (if present). It supports indexing such that dataset[i] can be used to get ith sample. - + Parameters - ---------- + ---------- labeled_dataset : :py:class:`torch.utils.data.Dataset` BinSegDataset with labeled samples unlabeled_dataset : :py:class:`torch.utils.data.Dataset` @@ -98,7 +98,7 @@ class SSLBinSegDataset(Dataset): Parameters ---------- index : int - + Returns ------- list @@ -112,15 +112,15 @@ class SSLBinSegDataset(Dataset): class UnLabeledBinSegDataset(Dataset): # TODO: if switch to handle case were not a bob.db object but a path to a directory is used - """PyTorch dataset wrapper around bob.db binary segmentation datasets. - A transform object can be passed that will be applied to the image, ground truth and mask (if present). + """PyTorch dataset wrapper around bob.db binary segmentation datasets. + A transform object can be passed that will be applied to the image, ground truth and mask (if present). It supports indexing such that dataset[i] can be used to get ith sample. - + Parameters - ---------- + ---------- dv : :py:mod:`bob.db.base` or str Binary segmentation bob database (e.g. bob.db.drive) or path to folder containing unlabeled images - split : str + split : str ``'train'`` or ``'test'``. Defaults to ``'train'`` transform : :py:mod:`bob.ip.binseg.data.transforms`, optional A transform or composition of transfroms. Defaults to ``None``. @@ -148,7 +148,7 @@ class UnLabeledBinSegDataset(Dataset): Parameters ---------- index : int - + Returns ------- list diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py index ecb90bca..6eca077a 100644 --- a/bob/ip/binseg/data/utils.py +++ b/bob/ip/binseg/data/utils.py @@ -6,8 +6,12 @@ import functools + import nose.plugins.skip + +import torch import torch.utils.data + import bob.extension @@ -43,12 +47,13 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset): transform : :py:mod:`bob.ip.binseg.data.transforms`, optional A transform or composition of transfroms. Defaults to ``None``. + """ def __init__(self, samples, transform=None): - self.samples = samples - self.transform = transform + self._samples = samples + self._transform = transform def __len__(self): """ @@ -60,7 +65,7 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset): size of the dataset """ - return len(self.samples) + return len(self._samples) def __getitem__(self, index): """ @@ -73,19 +78,86 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset): Returns ------- - sample : tuple + sample : list The sample data: ``[key, image[, gt[, mask]]]`` """ - item = self.samples[index] + item = self._samples[index] data = item.data # triggers data loading retval = [data["data"]] if "label" in data: retval.append(data["label"]) if "mask" in data: retval.append(data["mask"]) - if self.transform: - retval = self.transform(*retval) + if self._transform: + retval = self._transform(*retval) return [item.key] + retval + + +class SSLDataset(torch.utils.data.Dataset): + """PyTorch dataset wrapper around labelled and unlabelled sample lists + + Yields elements of the form: + + .. code-block:: text + + [key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image] + + The size of the dataset is the same as the labelled dataset. + + Indexing works by selecting the right element on the labelled dataset, and + randomly picking another one from the unlabelled dataset + + Parameters + ---------- + + labelled : :py:class:`torch.utils.data.Dataset` + Labelled dataset (**must** have "mask" and "label" entries for every + sample) + + unlabelled : :py:class:`torch.utils.data.Dataset` + Unlabelled dataset (**may** have "mask" and "label" entries for every + sample, but are ignored) + + """ + + def __init__(self, labelled, unlabelled): + self.labelled = labelled + self.unlabelled = unlabelled + + def __len__(self): + """ + + Returns + ------- + + size : int + size of the dataset + + """ + + return len(self.labelled) + + def __getitem__(self, index): + """ + + Parameters + ---------- + index : int + The index for the element to pick + + Returns + ------- + + sample : list + The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]`` + + """ + + retval = self.labelled[index] + # gets one an unlabelled sample randomly to follow the labelled sample + unlab = self.unlabelled[torch.randint(len(self.unlabelled))] + # only interested in key and data + return retval + unlab[:2] diff --git a/conda/meta.yaml b/conda/meta.yaml index a8906290..4f30610b 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -74,7 +74,6 @@ test: - sphinx - sphinx_rtd_theme - sphinxcontrib-programoutput - - bob.db.drive - bob.db.stare - bob.db.chasedb1 - bob.db.hrf diff --git a/doc/setup.rst b/doc/setup.rst index 7a1d3c7b..5eaa6653 100644 --- a/doc/setup.rst +++ b/doc/setup.rst @@ -22,9 +22,8 @@ this: Datasets -------- -The package supports a range of retina fundus datasets, but does not install the -`bob.db` iterator APIs by default, or includes the raw data itself, which you -must procure. +The package supports a range of retina fundus datasets, but does not include +the raw data itself, which you must procure. To setup a dataset, do the following: @@ -38,16 +37,6 @@ To setup a dataset, do the following: you unpack them in their **pristine** state. Changing the location of files within a dataset distribution will likely cause execution errors. -2. Install the corresponding ``bob.db`` package (package names are marked on - :ref:`bob.ip.binseg.datasets`) with the following command: - - .. code-block:: sh - - # replace "<package>" by the corresponding package name - (<myenv>) $ conda install <package> - # example: - (<myenv>) $ conda install bob.db.drive #to install DRIVE iterators - 3. For each dataset that you are planning to use, set the ``datadir`` to the root path where it is stored. E.g.: -- GitLab