Skip to content
Snippets Groups Projects
Commit bec562fc authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[bob.db.drive] Remove requirement

parent 4877170d
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -14,7 +14,7 @@ segmentation of blood vessels in retinal images. ...@@ -14,7 +14,7 @@ segmentation of blood vessels in retinal images.
""" """
from bob.ip.binseg.data.transforms import * from bob.ip.binseg.data.transforms import *
transforms = Compose( _transforms = Compose(
[ [
CenterCrop((544, 544)), CenterCrop((544, 544)),
RandomHFlip(), RandomHFlip(),
...@@ -28,4 +28,4 @@ transforms = Compose( ...@@ -28,4 +28,4 @@ transforms = Compose(
from bob.ip.binseg.data.utils import DelayedSample2TorchDataset from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
from bob.ip.binseg.data.drive import dataset as drive from bob.ip.binseg.data.drive import dataset as drive
dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"], dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
transform=transforms) transform=_transforms)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # coding=utf-8
from bob.db.drive import Database as DRIVE """DRIVE (training set) for Vessel Segmentation
from bob.ip.binseg.data.transforms import *
from bob.ip.binseg.data.binsegdataset import BinSegDataset The DRIVE database has been established to enable comparative studies on
segmentation of blood vessels in retinal images.
#### Config #### * Reference: [DRIVE-2004]_
* Original resolution (height x width): 584 x 565
* This configuration resolution: 1024 x 1024 (center-crop)
* Training samples: 20
* Split reference: [DRIVE-2004]_
"""
transforms = Compose( from bob.ip.binseg.data.transforms import *
_transforms = Compose(
[ [
RandomRotation(), RandomRotation(),
CenterCrop((540, 540)), CenterCrop((540, 540)),
...@@ -19,8 +26,7 @@ transforms = Compose( ...@@ -19,8 +26,7 @@ transforms = Compose(
] ]
) )
# bob.db.dataset init from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
bobdb = DRIVE(protocol="default") from bob.ip.binseg.data.drive import dataset as drive
dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
# PyTorch dataset transform=_transforms)
dataset = BinSegDataset(bobdb, split="train", transform=transforms)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # coding=utf-8
from bob.db.drive import Database as DRIVE """DRIVE (training set) for Vessel Segmentation
from bob.ip.binseg.data.transforms import *
from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config #### The DRIVE database has been established to enable comparative studies on
segmentation of blood vessels in retinal images.
transforms = Compose([CenterCrop((540, 540)), Resize(1024), ToTensor()]) * Reference: [DRIVE-2004]_
* Original resolution (height x width): 584 x 565
* This configuration resolution: 1024 x 1024 (center-crop)
* Test samples: 20
* Split reference: [DRIVE-2004]_
"""
# bob.db.dataset init from bob.ip.binseg.data.transforms import *
bobdb = DRIVE(protocol="default") _transforms = Compose([CenterCrop((540, 540)), Resize(1024), ToTensor()])
# PyTorch dataset from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
dataset = BinSegDataset(bobdb, split="test", transform=transforms) from bob.ip.binseg.data.drive import dataset as drive
dataset = DelayedSample2TorchDataset(drive.subsets("default")["test"],
transform=_transforms)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # coding=utf-8
from bob.db.drive import Database as DRIVE """DRIVE (training set) for Vessel Segmentation
from bob.ip.binseg.data.transforms import *
from bob.ip.binseg.data.binsegdataset import BinSegDataset The DRIVE database has been established to enable comparative studies on
segmentation of blood vessels in retinal images.
#### Config #### * Reference: [DRIVE-2004]_
* Original resolution (height x width): 584 x 565
* This configuration resolution: 1168 x 1168 (center-crop)
* Training samples: 20
* Split reference: [DRIVE-2004]_
"""
transforms = Compose( from bob.ip.binseg.data.transforms import *
_transforms = Compose(
[ [
RandomRotation(), RandomRotation(),
Crop(75, 10, 416, 544), Crop(75, 10, 416, 544),
...@@ -20,8 +27,7 @@ transforms = Compose( ...@@ -20,8 +27,7 @@ transforms = Compose(
] ]
) )
# bob.db.dataset init from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
bobdb = DRIVE(protocol="default") from bob.ip.binseg.data.drive import dataset as drive
dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
# PyTorch dataset transform=_transforms)
dataset = BinSegDataset(bobdb, split="train", transform=transforms)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from bob.db.drive import Database as DRIVE """DRIVE (training set) for Vessel Segmentation
from bob.ip.binseg.data.transforms import *
from bob.ip.binseg.data.binsegdataset import BinSegDataset The DRIVE database has been established to enable comparative studies on
segmentation of blood vessels in retinal images.
#### Config #### * Reference: [DRIVE-2004]_
* Original resolution (height x width): 584 x 565
* This configuration resolution: 608 x 608 (center-crop)
* Training samples: 20
* Split reference: [DRIVE-2004]_
"""
transforms = Compose( from bob.ip.binseg.data.transforms import *
_transforms = Compose(
[ [
RandomRotation(), RandomRotation(),
CenterCrop((470, 544)), CenterCrop((470, 544)),
...@@ -20,8 +27,7 @@ transforms = Compose( ...@@ -20,8 +27,7 @@ transforms = Compose(
] ]
) )
# bob.db.dataset init from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
bobdb = DRIVE(protocol="default") from bob.ip.binseg.data.drive import dataset as drive
dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
# PyTorch dataset transform=_transforms)
dataset = BinSegDataset(bobdb, split="train", transform=transforms)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # coding=utf-8
from bob.db.drive import Database as DRIVE """DRIVE (training set) for Vessel Segmentation
from bob.ip.binseg.data.transforms import *
from bob.ip.binseg.data.binsegdataset import BinSegDataset The DRIVE database has been established to enable comparative studies on
segmentation of blood vessels in retinal images.
#### Config #### * Reference: [DRIVE-2004]_
* Original resolution (height x width): 584 x 565
* This configuration resolution: 960 x 960 (center-crop)
* Training samples: 20
* Split reference: [DRIVE-2004]_
"""
transforms = Compose( from bob.ip.binseg.data.transforms import *
_transforms = Compose(
[ [
RandomRotation(), RandomRotation(),
CenterCrop((544, 544)), CenterCrop((544, 544)),
...@@ -19,8 +26,7 @@ transforms = Compose( ...@@ -19,8 +26,7 @@ transforms = Compose(
] ]
) )
# bob.db.dataset init from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
bobdb = DRIVE(protocol="default") from bob.ip.binseg.data.drive import dataset as drive
dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"],
# PyTorch dataset transform=_transforms)
dataset = BinSegDataset(bobdb, split="train", transform=transforms)
...@@ -14,9 +14,9 @@ segmentation of blood vessels in retinal images. ...@@ -14,9 +14,9 @@ segmentation of blood vessels in retinal images.
""" """
from bob.ip.binseg.data.transforms import * from bob.ip.binseg.data.transforms import *
transforms = Compose([CenterCrop((544, 544)), ToTensor()]) _transforms = Compose([CenterCrop((544, 544)), ToTensor()])
from bob.ip.binseg.data.utils import DelayedSample2TorchDataset from bob.ip.binseg.data.utils import DelayedSample2TorchDataset
from bob.ip.binseg.data.drive import dataset as drive from bob.ip.binseg.data.drive import dataset as drive
dataset = DelayedSample2TorchDataset(drive.subsets("default")["test"], dataset = DelayedSample2TorchDataset(drive.subsets("default")["test"],
transform=transforms) transform=_transforms)
from bob.ip.binseg.configs.datasets.stare544 import dataset as stare #!/usr/bin/env python
from bob.ip.binseg.configs.datasets.chasedb1544 import dataset as chase # -*- coding: utf-8 -*-
from bob.ip.binseg.configs.datasets.iostarvessel544 import dataset as iostar
from bob.ip.binseg.configs.datasets.hrf544 import dataset as hrf
from bob.db.drive import Database as DRIVE
from bob.ip.binseg.data.transforms import *
import torch
from bob.ip.binseg.data.binsegdataset import (
BinSegDataset,
SSLBinSegDataset,
UnLabeledBinSegDataset,
)
"""DRIVE (SSL training set) for Vessel Segmentation
#### Config #### The DRIVE database has been established to enable comparative studies on
segmentation of blood vessels in retinal images.
# PyTorch dataset * Reference: [DRIVE-2004]_
labeled_dataset = torch.utils.data.ConcatDataset([stare, chase, iostar, hrf]) * This configuration resolution: 544 x 544 (center-crop)
* Split reference: [DRIVE-2004]_
#### Unlabeled STARE TRAIN #### The dataset available in this file is composed of STARE, CHASE-DB1, IOSTAR
unlabeled_transforms = Compose( vessel and HRF (with annotated samples) and DRIVE without labels.
[ """
CenterCrop((544, 544)),
RandomHFlip(),
RandomVFlip(),
RandomRotation(),
ColorJitter(),
ToTensor(),
]
)
# bob.db.dataset init # Labelled bits
drivebobdb = DRIVE(protocol="default") import torch.utils.data
from bob.ip.binseg.configs.datasets.stare544 import dataset as _stare
from bob.ip.binseg.configs.datasets.chasedb1544 import dataset as _chase
from bob.ip.binseg.configs.datasets.iostarvessel544 import dataset as _iostar
from bob.ip.binseg.configs.datasets.hrf544 import dataset as _hrf
_labelled = torch.utils.data.ConcatDataset([_stare, _chase, _iostar, _hrf])
# PyTorch dataset # Use DRIVE without labels in this setup
unlabeled_dataset = UnLabeledBinSegDataset( from .drive import dataset as _unlabelled
drivebobdb, split="train", transform=unlabeled_transforms
)
# SSL Dataset from bob.ip.binseg.data.utils import SSLDataset
dataset = SSLDataset(_labelled, _unlabelled)
dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
...@@ -5,15 +5,15 @@ import random ...@@ -5,15 +5,15 @@ import random
class BinSegDataset(Dataset): class BinSegDataset(Dataset):
"""PyTorch dataset wrapper around bob.db binary segmentation datasets. """PyTorch dataset wrapper around bob.db binary segmentation datasets.
A transform object can be passed that will be applied to the image, ground truth and mask (if present). A transform object can be passed that will be applied to the image, ground truth and mask (if present).
It supports indexing such that dataset[i] can be used to get ith sample. It supports indexing such that dataset[i] can be used to get ith sample.
Parameters Parameters
---------- ----------
bobdb : :py:mod:`bob.db.base` bobdb : :py:mod:`bob.db.base`
Binary segmentation bob database (e.g. bob.db.drive) Binary segmentation bob database (e.g. bob.db.drive)
split : str split : str
``'train'`` or ``'test'``. Defaults to ``'train'`` ``'train'`` or ``'test'``. Defaults to ``'train'``
transform : :py:mod:`bob.ip.binseg.data.transforms`, optional transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
A transform or composition of transfroms. Defaults to ``None``. A transform or composition of transfroms. Defaults to ``None``.
...@@ -48,7 +48,7 @@ class BinSegDataset(Dataset): ...@@ -48,7 +48,7 @@ class BinSegDataset(Dataset):
Parameters Parameters
---------- ----------
index : int index : int
Returns Returns
------- -------
list list
...@@ -68,12 +68,12 @@ class BinSegDataset(Dataset): ...@@ -68,12 +68,12 @@ class BinSegDataset(Dataset):
class SSLBinSegDataset(Dataset): class SSLBinSegDataset(Dataset):
"""PyTorch dataset wrapper around bob.db binary segmentation datasets. """PyTorch dataset wrapper around bob.db binary segmentation datasets.
A transform object can be passed that will be applied to the image, ground truth and mask (if present). A transform object can be passed that will be applied to the image, ground truth and mask (if present).
It supports indexing such that dataset[i] can be used to get ith sample. It supports indexing such that dataset[i] can be used to get ith sample.
Parameters Parameters
---------- ----------
labeled_dataset : :py:class:`torch.utils.data.Dataset` labeled_dataset : :py:class:`torch.utils.data.Dataset`
BinSegDataset with labeled samples BinSegDataset with labeled samples
unlabeled_dataset : :py:class:`torch.utils.data.Dataset` unlabeled_dataset : :py:class:`torch.utils.data.Dataset`
...@@ -98,7 +98,7 @@ class SSLBinSegDataset(Dataset): ...@@ -98,7 +98,7 @@ class SSLBinSegDataset(Dataset):
Parameters Parameters
---------- ----------
index : int index : int
Returns Returns
------- -------
list list
...@@ -112,15 +112,15 @@ class SSLBinSegDataset(Dataset): ...@@ -112,15 +112,15 @@ class SSLBinSegDataset(Dataset):
class UnLabeledBinSegDataset(Dataset): class UnLabeledBinSegDataset(Dataset):
# TODO: if switch to handle case were not a bob.db object but a path to a directory is used # TODO: if switch to handle case were not a bob.db object but a path to a directory is used
"""PyTorch dataset wrapper around bob.db binary segmentation datasets. """PyTorch dataset wrapper around bob.db binary segmentation datasets.
A transform object can be passed that will be applied to the image, ground truth and mask (if present). A transform object can be passed that will be applied to the image, ground truth and mask (if present).
It supports indexing such that dataset[i] can be used to get ith sample. It supports indexing such that dataset[i] can be used to get ith sample.
Parameters Parameters
---------- ----------
dv : :py:mod:`bob.db.base` or str dv : :py:mod:`bob.db.base` or str
Binary segmentation bob database (e.g. bob.db.drive) or path to folder containing unlabeled images Binary segmentation bob database (e.g. bob.db.drive) or path to folder containing unlabeled images
split : str split : str
``'train'`` or ``'test'``. Defaults to ``'train'`` ``'train'`` or ``'test'``. Defaults to ``'train'``
transform : :py:mod:`bob.ip.binseg.data.transforms`, optional transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
A transform or composition of transfroms. Defaults to ``None``. A transform or composition of transfroms. Defaults to ``None``.
...@@ -148,7 +148,7 @@ class UnLabeledBinSegDataset(Dataset): ...@@ -148,7 +148,7 @@ class UnLabeledBinSegDataset(Dataset):
Parameters Parameters
---------- ----------
index : int index : int
Returns Returns
------- -------
list list
......
...@@ -6,8 +6,12 @@ ...@@ -6,8 +6,12 @@
import functools import functools
import nose.plugins.skip import nose.plugins.skip
import torch
import torch.utils.data import torch.utils.data
import bob.extension import bob.extension
...@@ -43,12 +47,13 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset): ...@@ -43,12 +47,13 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset):
transform : :py:mod:`bob.ip.binseg.data.transforms`, optional transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
A transform or composition of transfroms. Defaults to ``None``. A transform or composition of transfroms. Defaults to ``None``.
""" """
def __init__(self, samples, transform=None): def __init__(self, samples, transform=None):
self.samples = samples self._samples = samples
self.transform = transform self._transform = transform
def __len__(self): def __len__(self):
""" """
...@@ -60,7 +65,7 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset): ...@@ -60,7 +65,7 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset):
size of the dataset size of the dataset
""" """
return len(self.samples) return len(self._samples)
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -73,19 +78,86 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset): ...@@ -73,19 +78,86 @@ class DelayedSample2TorchDataset(torch.utils.data.Dataset):
Returns Returns
------- -------
sample : tuple sample : list
The sample data: ``[key, image[, gt[, mask]]]`` The sample data: ``[key, image[, gt[, mask]]]``
""" """
item = self.samples[index] item = self._samples[index]
data = item.data # triggers data loading data = item.data # triggers data loading
retval = [data["data"]] retval = [data["data"]]
if "label" in data: retval.append(data["label"]) if "label" in data: retval.append(data["label"])
if "mask" in data: retval.append(data["mask"]) if "mask" in data: retval.append(data["mask"])
if self.transform: if self._transform:
retval = self.transform(*retval) retval = self._transform(*retval)
return [item.key] + retval return [item.key] + retval
class SSLDataset(torch.utils.data.Dataset):
"""PyTorch dataset wrapper around labelled and unlabelled sample lists
Yields elements of the form:
.. code-block:: text
[key, image, ground-truth, [mask,] unlabelled-key, unlabelled-image]
The size of the dataset is the same as the labelled dataset.
Indexing works by selecting the right element on the labelled dataset, and
randomly picking another one from the unlabelled dataset
Parameters
----------
labelled : :py:class:`torch.utils.data.Dataset`
Labelled dataset (**must** have "mask" and "label" entries for every
sample)
unlabelled : :py:class:`torch.utils.data.Dataset`
Unlabelled dataset (**may** have "mask" and "label" entries for every
sample, but are ignored)
"""
def __init__(self, labelled, unlabelled):
self.labelled = labelled
self.unlabelled = unlabelled
def __len__(self):
"""
Returns
-------
size : int
size of the dataset
"""
return len(self.labelled)
def __getitem__(self, index):
"""
Parameters
----------
index : int
The index for the element to pick
Returns
-------
sample : list
The sample data: ``[key, image, gt, [mask, ]unlab-key, unlab-image]``
"""
retval = self.labelled[index]
# gets one an unlabelled sample randomly to follow the labelled sample
unlab = self.unlabelled[torch.randint(len(self.unlabelled))]
# only interested in key and data
return retval + unlab[:2]
...@@ -74,7 +74,6 @@ test: ...@@ -74,7 +74,6 @@ test:
- sphinx - sphinx
- sphinx_rtd_theme - sphinx_rtd_theme
- sphinxcontrib-programoutput - sphinxcontrib-programoutput
- bob.db.drive
- bob.db.stare - bob.db.stare
- bob.db.chasedb1 - bob.db.chasedb1
- bob.db.hrf - bob.db.hrf
......
...@@ -22,9 +22,8 @@ this: ...@@ -22,9 +22,8 @@ this:
Datasets Datasets
-------- --------
The package supports a range of retina fundus datasets, but does not install the The package supports a range of retina fundus datasets, but does not include
`bob.db` iterator APIs by default, or includes the raw data itself, which you the raw data itself, which you must procure.
must procure.
To setup a dataset, do the following: To setup a dataset, do the following:
...@@ -38,16 +37,6 @@ To setup a dataset, do the following: ...@@ -38,16 +37,6 @@ To setup a dataset, do the following:
you unpack them in their **pristine** state. Changing the location of you unpack them in their **pristine** state. Changing the location of
files within a dataset distribution will likely cause execution errors. files within a dataset distribution will likely cause execution errors.
2. Install the corresponding ``bob.db`` package (package names are marked on
:ref:`bob.ip.binseg.datasets`) with the following command:
.. code-block:: sh
# replace "<package>" by the corresponding package name
(<myenv>) $ conda install <package>
# example:
(<myenv>) $ conda install bob.db.drive #to install DRIVE iterators
3. For each dataset that you are planning to use, set the ``datadir`` to the 3. For each dataset that you are planning to use, set the ``datadir`` to the
root path where it is stored. E.g.: root path where it is stored. E.g.:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment