From bec562fc92eec358b387590de0048fde13335496 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.anjos@idiap.ch>
Date: Thu, 9 Apr 2020 17:22:57 +0200
Subject: [PATCH] [bob.db.drive] Remove requirement

---
 bob/ip/binseg/configs/datasets/drive.py       |  4 +-
 bob/ip/binseg/configs/datasets/drive1024.py   | 28 +++---
 .../binseg/configs/datasets/drive1024test.py  | 26 +++---
 bob/ip/binseg/configs/datasets/drive1168.py   | 28 +++---
 bob/ip/binseg/configs/datasets/drive608.py    | 26 +++---
 bob/ip/binseg/configs/datasets/drive960.py    | 28 +++---
 bob/ip/binseg/configs/datasets/drivetest.py   |  4 +-
 .../starechasedb1iostarhrf544ssldrive.py      | 57 +++++-------
 bob/ip/binseg/data/binsegdataset.py           | 36 ++++----
 bob/ip/binseg/data/utils.py                   | 86 +++++++++++++++++--
 conda/meta.yaml                               |  1 -
 doc/setup.rst                                 | 15 +---
 12 files changed, 208 insertions(+), 131 deletions(-)

diff --git a/bob/ip/binseg/configs/datasets/drive.py b/bob/ip/binseg/configs/datasets/drive.py
index 3412e5be..5e5b4986 100644
--- a/bob/ip/binseg/configs/datasets/drive.py
+++ b/bob/ip/binseg/configs/datasets/drive.py
@@ -14,7 +14,7 @@ segmentation of blood vessels in retinal images.
 """
 
 from bob.ip.binseg.data.transforms import *
-transforms = Compose(
+_transforms = Compose(
     [
         CenterCrop((544, 544)),
         RandomHFlip(),
@@ -28,4 +28,4 @@ transforms = Compose(
 from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
 from bob.ip.binseg.data.drive import dataset as drive
 dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
-        transform=transforms)
+        transform=_transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive1024.py b/bob/ip/binseg/configs/datasets/drive1024.py
index ea99feb0..f27ac3b7 100644
--- a/bob/ip/binseg/configs/datasets/drive1024.py
+++ b/bob/ip/binseg/configs/datasets/drive1024.py
@@ -1,13 +1,20 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*-
+# coding=utf-8
 
-from bob.db.drive import Database as DRIVE
-from bob.ip.binseg.data.transforms import *
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
+"""DRIVE (training set) for Vessel Segmentation
+
+The DRIVE database has been established to enable comparative studies on
+segmentation of blood vessels in retinal images.
 
-#### Config ####
+* Reference: [DRIVE-2004]_
+* Original resolution (height x width): 584 x 565
+* This configuration resolution: 1024 x 1024 (center-crop)
+* Training samples: 20
+* Split reference: [DRIVE-2004]_
+"""
 
-transforms = Compose(
+from bob.ip.binseg.data.transforms import *
+_transforms = Compose(
     [
         RandomRotation(),
         CenterCrop((540, 540)),
@@ -19,8 +26,7 @@ transforms = Compose(
     ]
 )
 
-# bob.db.dataset init
-bobdb = DRIVE(protocol="default")
-
-# PyTorch dataset
-dataset = BinSegDataset(bobdb, split="train", transform=transforms)
+from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
+from bob.ip.binseg.data.drive import dataset as drive
+dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
+        transform=_transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive1024test.py b/bob/ip/binseg/configs/datasets/drive1024test.py
index c409dae5..54d1cf5c 100644
--- a/bob/ip/binseg/configs/datasets/drive1024test.py
+++ b/bob/ip/binseg/configs/datasets/drive1024test.py
@@ -1,16 +1,22 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*-
+# coding=utf-8
 
-from bob.db.drive import Database as DRIVE
-from bob.ip.binseg.data.transforms import *
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
+"""DRIVE (training set) for Vessel Segmentation
 
-#### Config ####
+The DRIVE database has been established to enable comparative studies on
+segmentation of blood vessels in retinal images.
 
-transforms = Compose([CenterCrop((540, 540)), Resize(1024), ToTensor()])
+* Reference: [DRIVE-2004]_
+* Original resolution (height x width): 584 x 565
+* This configuration resolution: 1024 x 1024 (center-crop)
+* Test samples: 20
+* Split reference: [DRIVE-2004]_
+"""
 
-# bob.db.dataset init
-bobdb = DRIVE(protocol="default")
+from bob.ip.binseg.data.transforms import *
+_transforms = Compose([CenterCrop((540, 540)), Resize(1024), ToTensor()])
 
-# PyTorch dataset
-dataset = BinSegDataset(bobdb, split="test", transform=transforms)
+from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
+from bob.ip.binseg.data.drive import dataset as drive
+dataset = DelayedSample2TorchDataset(drive.subsets("default")["test"],
+        transform=_transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive1168.py b/bob/ip/binseg/configs/datasets/drive1168.py
index f4a51d95..ce84af98 100644
--- a/bob/ip/binseg/configs/datasets/drive1168.py
+++ b/bob/ip/binseg/configs/datasets/drive1168.py
@@ -1,13 +1,20 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*-
+# coding=utf-8
 
-from bob.db.drive import Database as DRIVE
-from bob.ip.binseg.data.transforms import *
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
+"""DRIVE (training set) for Vessel Segmentation
+
+The DRIVE database has been established to enable comparative studies on
+segmentation of blood vessels in retinal images.
 
-#### Config ####
+* Reference: [DRIVE-2004]_
+* Original resolution (height x width): 584 x 565
+* This configuration resolution: 1168 x 1168 (center-crop)
+* Training samples: 20
+* Split reference: [DRIVE-2004]_
+"""
 
-transforms = Compose(
+from bob.ip.binseg.data.transforms import *
+_transforms = Compose(
     [
         RandomRotation(),
         Crop(75, 10, 416, 544),
@@ -20,8 +27,7 @@ transforms = Compose(
     ]
 )
 
-# bob.db.dataset init
-bobdb = DRIVE(protocol="default")
-
-# PyTorch dataset
-dataset = BinSegDataset(bobdb, split="train", transform=transforms)
+from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
+from bob.ip.binseg.data.drive import dataset as drive
+dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
+        transform=_transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive608.py b/bob/ip/binseg/configs/datasets/drive608.py
index e251930b..d5794205 100644
--- a/bob/ip/binseg/configs/datasets/drive608.py
+++ b/bob/ip/binseg/configs/datasets/drive608.py
@@ -1,13 +1,20 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-from bob.db.drive import Database as DRIVE
-from bob.ip.binseg.data.transforms import *
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
+"""DRIVE (training set) for Vessel Segmentation
+
+The DRIVE database has been established to enable comparative studies on
+segmentation of blood vessels in retinal images.
 
-#### Config ####
+* Reference: [DRIVE-2004]_
+* Original resolution (height x width): 584 x 565
+* This configuration resolution: 608 x 608 (center-crop)
+* Training samples: 20
+* Split reference: [DRIVE-2004]_
+"""
 
-transforms = Compose(
+from bob.ip.binseg.data.transforms import *
+_transforms = Compose(
     [
         RandomRotation(),
         CenterCrop((470, 544)),
@@ -20,8 +27,7 @@ transforms = Compose(
     ]
 )
 
-# bob.db.dataset init
-bobdb = DRIVE(protocol="default")
-
-# PyTorch dataset
-dataset = BinSegDataset(bobdb, split="train", transform=transforms)
+from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
+from bob.ip.binseg.data.drive import dataset as drive
+dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
+        transform=_transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive960.py b/bob/ip/binseg/configs/datasets/drive960.py
index 1a29eeee..a37f54b0 100644
--- a/bob/ip/binseg/configs/datasets/drive960.py
+++ b/bob/ip/binseg/configs/datasets/drive960.py
@@ -1,13 +1,20 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*-
+# coding=utf-8
 
-from bob.db.drive import Database as DRIVE
-from bob.ip.binseg.data.transforms import *
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
+"""DRIVE (training set) for Vessel Segmentation
+
+The DRIVE database has been established to enable comparative studies on
+segmentation of blood vessels in retinal images.
 
-#### Config ####
+* Reference: [DRIVE-2004]_
+* Original resolution (height x width): 584 x 565
+* This configuration resolution: 960 x 960 (center-crop)
+* Training samples: 20
+* Split reference: [DRIVE-2004]_
+"""
 
-transforms = Compose(
+from bob.ip.binseg.data.transforms import *
+_transforms = Compose(
     [
         RandomRotation(),
         CenterCrop((544, 544)),
@@ -19,8 +26,7 @@ transforms = Compose(
     ]
 )
 
-# bob.db.dataset init
-bobdb = DRIVE(protocol="default")
-
-# PyTorch dataset
-dataset = BinSegDataset(bobdb, split="train", transform=transforms)
+from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
+from bob.ip.binseg.data.drive import dataset as drive
+dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
+        transform=_transforms)
diff --git a/bob/ip/binseg/configs/datasets/drivetest.py b/bob/ip/binseg/configs/datasets/drivetest.py
index a92cd812..9315e59a 100644
--- a/bob/ip/binseg/configs/datasets/drivetest.py
+++ b/bob/ip/binseg/configs/datasets/drivetest.py
@@ -14,9 +14,9 @@ segmentation of blood vessels in retinal images.
 """
 
 from bob.ip.binseg.data.transforms import *
-transforms = Compose([CenterCrop((544, 544)), ToTensor()])
+_transforms = Compose([CenterCrop((544, 544)), ToTensor()])
 
 from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
 from bob.ip.binseg.data.drive import dataset as drive
 dataset = DelayedSample2TorchDataset(drive.subsets("default")["test"],
-        transform=transforms)
+        transform=_transforms)
diff --git a/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py b/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py
index f126871f..7e4eda99 100644
--- a/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py
+++ b/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py
@@ -1,42 +1,29 @@
-from bob.ip.binseg.configs.datasets.stare544 import dataset as stare
-from bob.ip.binseg.configs.datasets.chasedb1544 import dataset as chase
-from bob.ip.binseg.configs.datasets.iostarvessel544 import dataset as iostar
-from bob.ip.binseg.configs.datasets.hrf544 import dataset as hrf
-from bob.db.drive import Database as DRIVE
-from bob.ip.binseg.data.transforms import *
-import torch
-from bob.ip.binseg.data.binsegdataset import (
-    BinSegDataset,
-    SSLBinSegDataset,
-    UnLabeledBinSegDataset,
-)
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
 
+"""DRIVE (SSL training set) for Vessel Segmentation
 
-#### Config ####
+The DRIVE database has been established to enable comparative studies on
+segmentation of blood vessels in retinal images.
 
-# PyTorch dataset
-labeled_dataset = torch.utils.data.ConcatDataset([stare, chase, iostar, hrf])
+* Reference: [DRIVE-2004]_
+* This configuration resolution: 544 x 544 (center-crop)
+* Split reference: [DRIVE-2004]_
 
-#### Unlabeled STARE TRAIN ####
-unlabeled_transforms = Compose(
-    [
-        CenterCrop((544, 544)),
-        RandomHFlip(),
-        RandomVFlip(),
-        RandomRotation(),
-        ColorJitter(),
-        ToTensor(),
-    ]
-)
+The dataset available in this file is composed of STARE, CHASE-DB1, IOSTAR
+vessel and HRF (with annotated samples) and DRIVE without labels.
+"""
 
-# bob.db.dataset init
-drivebobdb = DRIVE(protocol="default")
+# Labelled bits
+import torch.utils.data
+from bob.ip.binseg.configs.datasets.stare544 import dataset as _stare
+from bob.ip.binseg.configs.datasets.chasedb1544 import dataset as _chase
+from bob.ip.binseg.configs.datasets.iostarvessel544 import dataset as _iostar
+from bob.ip.binseg.configs.datasets.hrf544 import dataset as _hrf
+_labelled = torch.utils.data.ConcatDataset([_stare, _chase, _iostar, _hrf])
 
-# PyTorch dataset
-unlabeled_dataset = UnLabeledBinSegDataset(
-    drivebobdb, split="train", transform=unlabeled_transforms
-)
+# Use DRIVE without labels in this setup
+from .drive import dataset as _unlabelled
 
-# SSL Dataset
-
-dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
+from bob.ip.binseg.data.utils import SSLDataset
+dataset = SSLDataset(_labelled, _unlabelled)
diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py
index e2977d37..9c487da3 100644
--- a/bob/ip/binseg/data/binsegdataset.py
+++ b/bob/ip/binseg/data/binsegdataset.py
@@ -5,15 +5,15 @@ import random
 
 
 class BinSegDataset(Dataset):
-    """PyTorch dataset wrapper around bob.db binary segmentation datasets. 
-    A transform object can be passed that will be applied to the image, ground truth and mask (if present). 
+    """PyTorch dataset wrapper around bob.db binary segmentation datasets.
+    A transform object can be passed that will be applied to the image, ground truth and mask (if present).
     It supports indexing such that dataset[i] can be used to get ith sample.
-    
+
     Parameters
-    ---------- 
+    ----------
     bobdb : :py:mod:`bob.db.base`
-        Binary segmentation bob database (e.g. bob.db.drive) 
-    split : str 
+        Binary segmentation bob database (e.g. bob.db.drive)
+    split : str
         ``'train'`` or ``'test'``. Defaults to ``'train'``
     transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
         A transform or composition of transfroms. Defaults to ``None``.
@@ -48,7 +48,7 @@ class BinSegDataset(Dataset):
         Parameters
         ----------
         index : int
-        
+
         Returns
         -------
         list
@@ -68,12 +68,12 @@ class BinSegDataset(Dataset):
 
 
 class SSLBinSegDataset(Dataset):
-    """PyTorch dataset wrapper around bob.db binary segmentation datasets. 
-    A transform object can be passed that will be applied to the image, ground truth and mask (if present). 
+    """PyTorch dataset wrapper around bob.db binary segmentation datasets.
+    A transform object can be passed that will be applied to the image, ground truth and mask (if present).
     It supports indexing such that dataset[i] can be used to get ith sample.
-    
+
     Parameters
-    ---------- 
+    ----------
     labeled_dataset : :py:class:`torch.utils.data.Dataset`
         BinSegDataset with labeled samples
     unlabeled_dataset : :py:class:`torch.utils.data.Dataset`
@@ -98,7 +98,7 @@ class SSLBinSegDataset(Dataset):
         Parameters
         ----------
         index : int
-        
+
         Returns
         -------
         list
@@ -112,15 +112,15 @@ class SSLBinSegDataset(Dataset):
 
 class UnLabeledBinSegDataset(Dataset):
     # TODO: if switch to handle case were not a bob.db object but a path to a directory is used
-    """PyTorch dataset wrapper around bob.db binary segmentation datasets. 
-    A transform object can be passed that will be applied to the image, ground truth and mask (if present). 
+    """PyTorch dataset wrapper around bob.db binary segmentation datasets.
+    A transform object can be passed that will be applied to the image, ground truth and mask (if present).
     It supports indexing such that dataset[i] can be used to get ith sample.
-    
+
     Parameters
-    ---------- 
+    ----------
     dv : :py:mod:`bob.db.base` or str
         Binary segmentation bob database (e.g. bob.db.drive) or path to folder containing unlabeled images
-    split : str 
+    split : str
         ``'train'`` or ``'test'``. Defaults to ``'train'``
     transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
         A transform or composition of transfroms. Defaults to ``None``.
@@ -148,7 +148,7 @@ class UnLabeledBinSegDataset(Dataset):
         Parameters
         ----------
         index : int
-        
+
         Returns
         -------
         list
diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py
index ecb90bca..6eca077a 100644
--- a/bob/ip/binseg/data/utils.py
+++ b/bob/ip/binseg/data/utils.py
@@ -6,8 +6,12 @@
 
 
 import functools
+
 import nose.plugins.skip
+
+import torch
 import torch.utils.data
+
 import bob.extension
 
 
@@ -43,12 +47,13 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset):
 
     transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
         A transform or composition of transfroms. Defaults to ``None``.
+
     """
 
     def __init__(self, samples, transform=None):
 
-        self.samples = samples
-        self.transform = transform
+        self._samples = samples
+        self._transform = transform
 
     def __len__(self):
         """
@@ -60,7 +65,7 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset):
             size of the dataset
 
         """
-        return len(self.samples)
+        return len(self._samples)
 
     def __getitem__(self, index):
         """
@@ -73,19 +78,86 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset):
         Returns
         -------
 
-        sample : tuple
+        sample : list
             The sample data: ``[key, image[, gt[, mask]]]``
 
         """
 
-        item = self.samples[index]
+        item = self._samples[index]
         data = item.data  # triggers data loading
 
         retval = [data["data"]]
         if "label" in data: retval.append(data["label"])
         if "mask" in data: retval.append(data["mask"])
 
-        if self.transform:
-            retval = self.transform(*retval)
+        if self._transform:
+            retval = self._transform(*retval)
 
         return [item.key] + retval
+
+
+class SSLDataset(torch.utils.data.Dataset):
+    """PyTorch dataset wrapper around labelled and unlabelled sample lists
+
+    Yields elements of the form:
+
+    .. code-block:: text
+
+       [key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image]
+
+    The size of the dataset is the same as the labelled dataset.
+
+    Indexing works by selecting the right element on the labelled dataset, and
+    randomly picking another one from the unlabelled dataset
+
+    Parameters
+    ----------
+
+    labelled : :py:class:`torch.utils.data.Dataset`
+        Labelled dataset (**must** have "mask" and "label" entries for every
+        sample)
+
+    unlabelled : :py:class:`torch.utils.data.Dataset`
+        Unlabelled dataset (**may** have "mask" and "label" entries for every
+        sample, but are ignored)
+
+    """
+
+    def __init__(self, labelled, unlabelled):
+        self.labelled = labelled
+        self.unlabelled = unlabelled
+
+    def __len__(self):
+        """
+
+        Returns
+        -------
+
+        size : int
+            size of the dataset
+
+        """
+
+        return len(self.labelled)
+
+    def __getitem__(self, index):
+        """
+
+        Parameters
+        ----------
+        index : int
+            The index for the element to pick
+
+        Returns
+        -------
+
+        sample : list
+            The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]``
+
+        """
+
+        retval = self.labelled[index]
+        # gets one an unlabelled sample randomly to follow the labelled sample
+        unlab = self.unlabelled[torch.randint(len(self.unlabelled))]
+        # only interested in key and data
+        return retval + unlab[:2]
diff --git a/conda/meta.yaml b/conda/meta.yaml
index a8906290..4f30610b 100644
--- a/conda/meta.yaml
+++ b/conda/meta.yaml
@@ -74,7 +74,6 @@ test:
     - sphinx
     - sphinx_rtd_theme
     - sphinxcontrib-programoutput
-    - bob.db.drive
     - bob.db.stare
     - bob.db.chasedb1
     - bob.db.hrf
diff --git a/doc/setup.rst b/doc/setup.rst
index 7a1d3c7b..5eaa6653 100644
--- a/doc/setup.rst
+++ b/doc/setup.rst
@@ -22,9 +22,8 @@ this:
 Datasets
 --------
 
-The package supports a range of retina fundus datasets, but does not install the
-`bob.db` iterator APIs by default, or includes the raw data itself, which you
-must procure.
+The package supports a range of retina fundus datasets, but does not include
+the raw data itself, which you must procure.
 
 To setup a dataset, do the following:
 
@@ -38,16 +37,6 @@ To setup a dataset, do the following:
       you unpack them in their **pristine** state.  Changing the location of
       files within a dataset distribution will likely cause execution errors.
 
-2. Install the corresponding ``bob.db`` package (package names are marked on
-   :ref:`bob.ip.binseg.datasets`) with the following command:
-
-   .. code-block:: sh
-
-      # replace "<package>" by the corresponding package name
-      (<myenv>) $ conda install <package>
-      # example:
-      (<myenv>) $ conda install bob.db.drive #to install DRIVE iterators
-
 3.  For each dataset that you are planning to use, set the ``datadir`` to the
     root path where it is stored.  E.g.:
 
-- 
GitLab