diff --git a/bob/ip/binseg/configs/datasets/drive.py b/bob/ip/binseg/configs/datasets/drive.py index 179e2e258aeba8d0dbbdeacb4a719252908ec4d6..3412e5be738e151afccc3b0506112b622cb2ee5a 100644 --- a/bob/ip/binseg/configs/datasets/drive.py +++ b/bob/ip/binseg/configs/datasets/drive.py @@ -13,12 +13,7 @@ segmentation of blood vessels in retinal images. * Split reference: [DRIVE-2004]_ """ -from bob.db.drive import Database as DRIVE from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset - -#### Config #### - transforms = Compose( [ CenterCrop((544, 544)), @@ -30,8 +25,7 @@ transforms = Compose( ] ) -# bob.db.dataset init -bobdb = DRIVE(protocol="default") - -# PyTorch dataset -dataset = BinSegDataset(bobdb, split="train", transform=transforms) +from bob.ip.binseg.data.utils import DelayedSample2TorchDataset +from bob.ip.binseg.data.drive import dataset as drive +dataset = DelayedSample2TorchDataset(drive.subsets("default")["train"], + transform=transforms) diff --git a/bob/ip/binseg/configs/datasets/drivetest.py b/bob/ip/binseg/configs/datasets/drivetest.py index 2f0aa772c8862f8b485c32e73c9a5ad965071663..a92cd812403b640e62178f9d018a8b7bf17588e5 100644 --- a/bob/ip/binseg/configs/datasets/drivetest.py +++ b/bob/ip/binseg/configs/datasets/drivetest.py @@ -13,16 +13,10 @@ segmentation of blood vessels in retinal images. * Split reference: [DRIVE-2004]_ """ -from bob.db.drive import Database as DRIVE from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.binsegdataset import BinSegDataset - -#### Config #### - transforms = Compose([CenterCrop((544, 544)), ToTensor()]) -# bob.db.dataset init -bobdb = DRIVE(protocol="default") - -# PyTorch dataset -dataset = BinSegDataset(bobdb, split="test", transform=transforms) +from bob.ip.binseg.data.utils import DelayedSample2TorchDataset +from bob.ip.binseg.data.drive import dataset as drive +dataset = DelayedSample2TorchDataset(drive.subsets("default")["test"], + transform=transforms) diff --git a/bob/ip/binseg/data/csvdataset.py b/bob/ip/binseg/data/csvdataset.py index 133a3d0aaf621797251038b70ba27ce2d0ea1ea5..9e65e3638e1c30403e2cbe9a9a4867bf37a9ed64 100644 --- a/bob/ip/binseg/data/csvdataset.py +++ b/bob/ip/binseg/data/csvdataset.py @@ -13,6 +13,7 @@ import torchvision.transforms.functional as VF import bob.io.base import logging + logger = logging.getLogger(__name__) @@ -87,7 +88,9 @@ class CSVDataset(Dataset): """ - def __init__(self, path, root_path=None, check_available=True, transform=None): + def __init__( + self, path, root_path=None, check_available=True, transform=None + ): self.root_path = root_path or os.path.dirname(path) self.transform = transform @@ -99,7 +102,7 @@ class CSVDataset(Dataset): retval.append(os.path.join(root, p)) return retval - with open(path, newline='') as f: + with open(path, newline="") as f: reader = csv.reader(f) self.data = [_make_abs_path(self.root_path, k) for k in reader] @@ -111,14 +114,16 @@ class CSVDataset(Dataset): if not os.path.exists(p): errors += 1 logger.error(f"Cannot find {p}") - assert errors == 0, f"There {errors} files which cannot be " \ - f"found on your filelist ({path}) dataset" + assert errors == 0, ( + f"There {errors} files which cannot be " + f"found on your filelist ({path}) dataset" + ) # check all data entries have the same size - assert all(len(k) == len(self.data[0]) for k in self.data), \ - f"There is an inconsistence on your dataset - not all " \ - f"entries have length=={len(self.data[0])}" - + assert all(len(k) == len(self.data[0]) for k in self.data), ( + f"There is an inconsistence on your dataset - not all " + f"entries have length=={len(self.data[0])}" + ) def __len__(self): """ @@ -175,6 +180,6 @@ class CSVDataset(Dataset): if stem.startswith(self.root_path): stem = os.path.relpath(stem, self.root_path) elif stem.startswith(os.pathsep): - stem = stem[len(os.pathsep):] + stem = stem[len(os.pathsep) :] return [stem] + sample diff --git a/bob/ip/binseg/data/drive/__init__.py b/bob/ip/binseg/data/drive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58d101adde6d1f2c8bdf0301d1f7e6abacecead0 --- /dev/null +++ b/bob/ip/binseg/data/drive/__init__.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# coding=utf-8 + +import os +import pkg_resources + +import bob.extension + +from ..jsondataset import JSONDataset +from ..loader import load_pil_rgb, load_pil_1 + + +_protocols = [ + pkg_resources.resource_filename(__name__, "default.json"), + pkg_resources.resource_filename(__name__, "second-annotation.json"), + ] + +_root_path = bob.extension.rc.get('bob.db.drive.datadir', + os.path.realpath(os.curdir)) + +def _loader(s): + return dict( + data=load_pil_rgb(s["data"]), + label=load_pil_1(s["label"]), + mask=load_pil_1(s["mask"]), + ) + +dataset = JSONDataset(protocols=_protocols, root_path=_root_path, loader=_loader) +"""DRIVE dataset for Vessel Segmentation + +The DRIVE database has been established to enable comparative studies on +segmentation of blood vessels in retinal images. + +* Reference: [DRIVE-2004]_ +* Original resolution (height x width): 584 x 565 +* Training samples: 20 (including labels and masks) +* Test samples: 20 (including labels from 2 annotators and masks) +* Split reference: [DRIVE-2004]_ +""" diff --git a/bob/ip/binseg/data/drive/default.json b/bob/ip/binseg/data/drive/default.json new file mode 100644 index 0000000000000000000000000000000000000000..6707e6edd93546915d7f9960f8ba85d3449a8d6f --- /dev/null +++ b/bob/ip/binseg/data/drive/default.json @@ -0,0 +1,206 @@ +{ + "train": [ + [ + "training/images/21_training.tif", + "training/1st_manual/21_manual1.gif", + "training/mask/21_training_mask.gif" + ], + [ + "training/images/22_training.tif", + "training/1st_manual/22_manual1.gif", + "training/mask/22_training_mask.gif" + ], + [ + "training/images/23_training.tif", + "training/1st_manual/23_manual1.gif", + "training/mask/23_training_mask.gif" + ], + [ + "training/images/24_training.tif", + "training/1st_manual/24_manual1.gif", + "training/mask/24_training_mask.gif" + ], + [ + "training/images/25_training.tif", + "training/1st_manual/25_manual1.gif", + "training/mask/25_training_mask.gif" + ], + [ + "training/images/26_training.tif", + "training/1st_manual/26_manual1.gif", + "training/mask/26_training_mask.gif" + ], + [ + "training/images/27_training.tif", + "training/1st_manual/27_manual1.gif", + "training/mask/27_training_mask.gif" + ], + [ + "training/images/28_training.tif", + "training/1st_manual/28_manual1.gif", + "training/mask/28_training_mask.gif" + ], + [ + "training/images/29_training.tif", + "training/1st_manual/29_manual1.gif", + "training/mask/29_training_mask.gif" + ], + [ + "training/images/30_training.tif", + "training/1st_manual/30_manual1.gif", + "training/mask/30_training_mask.gif" + ], + [ + "training/images/31_training.tif", + "training/1st_manual/31_manual1.gif", + "training/mask/31_training_mask.gif" + ], + [ + "training/images/32_training.tif", + "training/1st_manual/32_manual1.gif", + "training/mask/32_training_mask.gif" + ], + [ + "training/images/33_training.tif", + "training/1st_manual/33_manual1.gif", + "training/mask/33_training_mask.gif" + ], + [ + "training/images/34_training.tif", + "training/1st_manual/34_manual1.gif", + "training/mask/34_training_mask.gif" + ], + [ + "training/images/35_training.tif", + "training/1st_manual/35_manual1.gif", + "training/mask/35_training_mask.gif" + ], + [ + "training/images/36_training.tif", + "training/1st_manual/36_manual1.gif", + "training/mask/36_training_mask.gif" + ], + [ + "training/images/37_training.tif", + "training/1st_manual/37_manual1.gif", + "training/mask/37_training_mask.gif" + ], + [ + "training/images/38_training.tif", + "training/1st_manual/38_manual1.gif", + "training/mask/38_training_mask.gif" + ], + [ + "training/images/39_training.tif", + "training/1st_manual/39_manual1.gif", + "training/mask/39_training_mask.gif" + ], + [ + "training/images/40_training.tif", + "training/1st_manual/40_manual1.gif", + "training/mask/40_training_mask.gif" + ] + ], + "test": [ + [ + "test/images/01_test.tif", + "test/1st_manual/01_manual1.gif", + "test/mask/01_test_mask.gif" + ], + [ + "test/images/02_test.tif", + "test/1st_manual/02_manual1.gif", + "test/mask/02_test_mask.gif" + ], + [ + "test/images/03_test.tif", + "test/1st_manual/03_manual1.gif", + "test/mask/03_test_mask.gif" + ], + [ + "test/images/04_test.tif", + "test/1st_manual/04_manual1.gif", + "test/mask/04_test_mask.gif" + ], + [ + "test/images/05_test.tif", + "test/1st_manual/05_manual1.gif", + "test/mask/05_test_mask.gif" + ], + [ + "test/images/06_test.tif", + "test/1st_manual/06_manual1.gif", + "test/mask/06_test_mask.gif" + ], + [ + "test/images/07_test.tif", + "test/1st_manual/07_manual1.gif", + "test/mask/07_test_mask.gif" + ], + [ + "test/images/08_test.tif", + "test/1st_manual/08_manual1.gif", + "test/mask/08_test_mask.gif" + ], + [ + "test/images/09_test.tif", + "test/1st_manual/09_manual1.gif", + "test/mask/09_test_mask.gif" + ], + [ + "test/images/10_test.tif", + "test/1st_manual/10_manual1.gif", + "test/mask/10_test_mask.gif" + ], + [ + "test/images/11_test.tif", + "test/1st_manual/11_manual1.gif", + "test/mask/11_test_mask.gif" + ], + [ + "test/images/12_test.tif", + "test/1st_manual/12_manual1.gif", + "test/mask/12_test_mask.gif" + ], + [ + "test/images/13_test.tif", + "test/1st_manual/13_manual1.gif", + "test/mask/13_test_mask.gif" + ], + [ + "test/images/14_test.tif", + "test/1st_manual/14_manual1.gif", + "test/mask/14_test_mask.gif" + ], + [ + "test/images/15_test.tif", + "test/1st_manual/15_manual1.gif", + "test/mask/15_test_mask.gif" + ], + [ + "test/images/16_test.tif", + "test/1st_manual/16_manual1.gif", + "test/mask/16_test_mask.gif" + ], + [ + "test/images/17_test.tif", + "test/1st_manual/17_manual1.gif", + "test/mask/17_test_mask.gif" + ], + [ + "test/images/18_test.tif", + "test/1st_manual/18_manual1.gif", + "test/mask/18_test_mask.gif" + ], + [ + "test/images/19_test.tif", + "test/1st_manual/19_manual1.gif", + "test/mask/19_test_mask.gif" + ], + [ + "test/images/20_test.tif", + "test/1st_manual/20_manual1.gif", + "test/mask/20_test_mask.gif" + ] + ] +} diff --git a/bob/ip/binseg/data/drive/second-annotation.json b/bob/ip/binseg/data/drive/second-annotation.json new file mode 100644 index 0000000000000000000000000000000000000000..fee520debd55220ccfd82145df4c70e39b1fc6b5 --- /dev/null +++ b/bob/ip/binseg/data/drive/second-annotation.json @@ -0,0 +1,104 @@ +{ + "test": [ + [ + "test/images/01_test.tif", + "test/2nd_manual/01_manual2.gif", + "test/mask/01_test_mask.gif" + ], + [ + "test/images/02_test.tif", + "test/2nd_manual/02_manual2.gif", + "test/mask/02_test_mask.gif" + ], + [ + "test/images/03_test.tif", + "test/2nd_manual/03_manual2.gif", + "test/mask/03_test_mask.gif" + ], + [ + "test/images/04_test.tif", + "test/2nd_manual/04_manual2.gif", + "test/mask/04_test_mask.gif" + ], + [ + "test/images/05_test.tif", + "test/2nd_manual/05_manual2.gif", + "test/mask/05_test_mask.gif" + ], + [ + "test/images/06_test.tif", + "test/2nd_manual/06_manual2.gif", + "test/mask/06_test_mask.gif" + ], + [ + "test/images/07_test.tif", + "test/2nd_manual/07_manual2.gif", + "test/mask/07_test_mask.gif" + ], + [ + "test/images/08_test.tif", + "test/2nd_manual/08_manual2.gif", + "test/mask/08_test_mask.gif" + ], + [ + "test/images/09_test.tif", + "test/2nd_manual/09_manual2.gif", + "test/mask/09_test_mask.gif" + ], + [ + "test/images/10_test.tif", + "test/2nd_manual/10_manual2.gif", + "test/mask/10_test_mask.gif" + ], + [ + "test/images/11_test.tif", + "test/2nd_manual/11_manual2.gif", + "test/mask/11_test_mask.gif" + ], + [ + "test/images/12_test.tif", + "test/2nd_manual/12_manual2.gif", + "test/mask/12_test_mask.gif" + ], + [ + "test/images/13_test.tif", + "test/2nd_manual/13_manual2.gif", + "test/mask/13_test_mask.gif" + ], + [ + "test/images/14_test.tif", + "test/2nd_manual/14_manual2.gif", + "test/mask/14_test_mask.gif" + ], + [ + "test/images/15_test.tif", + "test/2nd_manual/15_manual2.gif", + "test/mask/15_test_mask.gif" + ], + [ + "test/images/16_test.tif", + "test/2nd_manual/16_manual2.gif", + "test/mask/16_test_mask.gif" + ], + [ + "test/images/17_test.tif", + "test/2nd_manual/17_manual2.gif", + "test/mask/17_test_mask.gif" + ], + [ + "test/images/18_test.tif", + "test/2nd_manual/18_manual2.gif", + "test/mask/18_test_mask.gif" + ], + [ + "test/images/19_test.tif", + "test/2nd_manual/19_manual2.gif", + "test/mask/19_test_mask.gif" + ], + [ + "test/images/20_test.tif", + "test/2nd_manual/20_manual2.gif", + "test/mask/20_test_mask.gif" + ] + ] +} diff --git a/bob/ip/binseg/data/drive/test.py b/bob/ip/binseg/data/drive/test.py new file mode 100644 index 0000000000000000000000000000000000000000..238f9e0f901e12b6f65509a9dfe8ba5295b6fc85 --- /dev/null +++ b/bob/ip/binseg/data/drive/test.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# coding=utf-8 + + +"""Tests for DRIVE""" + +import os +import nose.tools + +from ..utils import rc_variable_set, DelayedSample2TorchDataset +from ..transforms import Compose, CenterCrop +from . import dataset + + +def test_protocol_consitency(): + + subset = dataset.subsets("default") + nose.tools.eq_(len(subset), 2) + + assert "train" in subset + nose.tools.eq_(len(subset["train"]), 20) + for s in subset["train"]: + assert s.key.startswith(os.path.join("training", "images")) + + assert "test" in subset + nose.tools.eq_(len(subset["test"]), 20) + for s in subset["test"]: + assert s.key.startswith(os.path.join("test", "images")) + + subset = dataset.subsets("second-annotation") + nose.tools.eq_(len(subset), 1) + + assert "test" in subset + nose.tools.eq_(len(subset["test"]), 20) + for s in subset["test"]: + assert s.key.startswith(os.path.join("test", "images")) + + +@rc_variable_set('bob.db.drive.datadir') +def test_loading(): + + def _check_sample(s): + data = s.data + assert isinstance(data, dict) + nose.tools.eq_(len(data), 3) + assert "data" in data + nose.tools.eq_(data["data"].size, (565, 584)) + nose.tools.eq_(data["data"].mode, "RGB") + assert "label" in data + nose.tools.eq_(data["label"].size, (565, 584)) + nose.tools.eq_(data["label"].mode, "1") + assert "mask" in data + nose.tools.eq_(data["mask"].size, (565, 584)) + nose.tools.eq_(data["mask"].mode, "1") + + subset = dataset.subsets("default") + for s in subset["train"]: _check_sample(s) + for s in subset["test"]: _check_sample(s) + + subset = dataset.subsets("second-annotation") + for s in subset["test"]: _check_sample(s) + + +@rc_variable_set('bob.db.drive.datadir') +def test_check(): + nose.tools.eq_(dataset.check(), 0) + + +@rc_variable_set('bob.db.drive.datadir') +def test_torch_dataset(): + + def _check_sample(s): + nose.tools.eq_(len(s), 4) + assert isinstance(s[0], str) + nose.tools.eq_(s[1].size, (544, 544)) + nose.tools.eq_(s[1].mode, "RGB") + nose.tools.eq_(s[2].size, (544, 544)) + nose.tools.eq_(s[2].mode, "1") + nose.tools.eq_(s[3].size, (544, 544)) + nose.tools.eq_(s[3].mode, "1") + + + transforms = Compose([CenterCrop((544, 544))]) + + subset = dataset.subsets("default") + + torch_dataset = DelayedSample2TorchDataset(subset["train"], transforms) + nose.tools.eq_(len(torch_dataset), 20) + for s in torch_dataset: _check_sample(s) + + torch_dataset = DelayedSample2TorchDataset(subset["test"], transforms) + nose.tools.eq_(len(torch_dataset), 20) + for s in torch_dataset: _check_sample(s) diff --git a/bob/ip/binseg/data/jsondataset.py b/bob/ip/binseg/data/jsondataset.py new file mode 100644 index 0000000000000000000000000000000000000000..99c554bfec6fe6482719aecfcbc27e05d4a08a89 --- /dev/null +++ b/bob/ip/binseg/data/jsondataset.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python +# coding=utf-8 + +import os +import copy +import json +import functools + +import logging +logger = logging.getLogger(__name__) + +from .sample import DelayedSample + + +class JSONDataset: + """ + Generic multi-protocol filelist dataset + + To create a new dataset, you need to provide one or more JSON formatted + filelists (one per protocol) with the following contents: + + .. code-block:: json + + { + "subset1": [ + { + "data": "path/to/data", + "label": "path/to/optional/label", + "mask": "path/to/optional/mask" + } + ], + "subset2": [ + ] + } + + Optionally, you may also format your JSON file like this, where each sample + is described as a list of up to 3 elements: + + .. code-block:: json + + { + "subset1": [ + [ + "path/to/data", + "path/to/optional/label", + "path/to/optional/mask" + ] + ], + "subset2": [ + ] + } + + If your dataset does not have labels or masks, you may also represent it + like this: + + .. code-block:: json + + { + "subset1": [ + "path/to/data1", + "path/to/data2" + ], + "subset2": [ + ] + } + + Where: + + * ``data``: absolute or relative path leading to original image + * ``label``: (optional) absolute or relative path with manual segmentation + information + * ``mask``: (optional) absolute or relative path with a mask that indicates + valid regions in the image where automatic segmentation should occur + + Relative paths are interpreted with respect to the location where the JSON + file is or to an optional ``root_path`` parameter, that can be provided. + + There are no requirements concerning image or ground-truth homogenity. + Anything that can be loaded by our image and data loaders is OK. + + Notice that all rows must have the same number of entries. + + To generate a dataset without ground-truth (e.g. for prediction tasks), + then omit the ``label`` and ``mask`` entries. + + + Parameters + ---------- + + protocols : [str] + Paths to one or more JSON formatted files containing the various + protocols to be recognized by this dataset. + + root_path : str + Path to a common filesystem root where files with relative paths should + be sitting. If not set, then we use the current directory to resolve + relative paths. + + loader : object + A function that receives, as input, a dictionary with ``{key: path}`` + entries, and returns a dictionary with the loaded data + + """ + + def __init__(self, protocols, root_path, loader): + + self.protocols = dict( + (os.path.splitext(os.path.basename(k))[0], os.path.realpath(k)) + for k in protocols + ) + self.root_path = root_path + self.loader = loader + + def check(self): + """For each protocol, check all files are available on the filesystem + + Returns + ------- + + errors : int + Number of errors found + + """ + + errors = 0 + for proto in self.protocols: + logger.info(f"Checking protocol '{proto}'...") + for name, samples in self.subsets(proto).items(): + logger.info(f"Checking subset '{name}'...") + for sample in samples: + try: + sample.data # triggers loading + logger.info(f"{sample.key}: OK") + except Exception as e: + logger.error(f"{sample.key}: {e}") + errors += 1 + return errors + + def subsets(self, protocol): + """Returns all subsets in a protocol + + This method will load JSON information for a given protocol and return + all subsets of the given protocol after converting each entry into a + :py:class:`bob.ip.binseg.data.sample.DelayedSample`. + + Parameters + ---------- + + protocol : str + Name of the protocol data to load + + + Returns + ------- + + subsets : dict + A dictionary mapping subset names to lists of + :py:class:`bob.ip.binseg.data.sample.DelayedSample` objects, with + the proper loading implemented. Each delayed sample also carries a + ``key`` parameter, that contains the relative path of the sample, + without its extension. This parameter can be used for recording + sample transforms during check-pointing. + + """ + + with open(self.protocols[protocol], "r") as f: + data = json.load(f) + + # returns a fixed sample representations as a DelayedSamples + retval = {} + + for subset, samples in data.items(): + delayeds = [] + for k in samples: + + if isinstance(k, dict): + item = k + + elif isinstance(k, list): + item = {"data": k[0]} + if len(k) > 1: item["label"] = k[1] + if len(k) > 2: item["mask"] = k[2] + + elif isinstance(k, str): + item = {"data": k} + + key = os.path.splitext(item["data"])[0] + + # make paths absolute + abs_item = copy.deepcopy(item) + for k,v in item.items(): + if not os.path.isabs(v): + abs_item[k] = os.path.join(self.root_path, v) + + load = functools.partial(self.loader, abs_item) + delayeds.append(DelayedSample(load, key=key)) + + retval[subset] = delayeds + + return retval diff --git a/bob/ip/binseg/data/loader.py b/bob/ip/binseg/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a235ceb24a46d80bdc95237824a05cacb75639 --- /dev/null +++ b/bob/ip/binseg/data/loader.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# coding=utf-8 + + +"""Data loading code""" + + +import PIL.Image + + +def load_pil_rgb(path): + """Loads a sample data + + Parameters + ---------- + + path : str + The full path leading to the image to be loaded + + + Returns + ------- + + image : PIL.Image.Image + A PIL image in RGB mode + + """ + + return PIL.Image.open(path).convert("RGB") + + +def load_pil_1(path): + """Loads a sample binary label or mask + + Parameters + ---------- + + path : str + The full path leading to the image to be loaded + + + Returns + ------- + + image : PIL.Image.Image + A PIL image in mode "1" + + """ + + return PIL.Image.open(path).convert(mode="1", dither=None) diff --git a/bob/ip/binseg/data/sample.py b/bob/ip/binseg/data/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..2d85bf35b4d548d0ea25aeda194dc1bc33bb8a2f --- /dev/null +++ b/bob/ip/binseg/data/sample.py @@ -0,0 +1,110 @@ +from collections.abc import MutableSequence + +"""Base definition of sample + +.. todo:: + + Copied from bob/bob.pipelines **TEMPORARILY**! Remove this and use the + package directly! + +""" + + +def _copy_attributes(s, d): + """Copies attributes from a dictionary to self + """ + s.__dict__.update( + dict([k, v] for k, v in d.items() if k not in ("data", "load", "samples")) + ) + + +class DelayedSample: + """Representation of sample that can be loaded via a callable + + The optional ``**kwargs`` argument allows you to attach more attributes to + this sample instance. + + + Parameters + ---------- + + load : object + A python function that can be called parameterlessly, to load the + sample in question from whatever medium + + parent : :py:class:`DelayedSample`, :py:class:`Sample`, None + If passed, consider this as a parent of this sample, to copy + information + + kwargs : dict + Further attributes of this sample, to be stored and eventually + transmitted to transformed versions of the sample + + """ + + def __init__(self, load, parent=None, **kwargs): + self.load = load + if parent is not None: + _copy_attributes(self, parent.__dict__) + _copy_attributes(self, kwargs) + + @property + def data(self): + """Loads the data from the disk file""" + return self.load() + + +class Sample: + """Representation of sample that is sufficient for the blocks in this module + + Each sample must have the following attributes: + + * attribute ``data``: Contains the data for this sample + + + Parameters + ---------- + + data : object + Object representing the data to initialize this sample with. + + parent : object + A parent object from which to inherit all other attributes (except + ``data``) + + """ + + def __init__(self, data, parent=None, **kwargs): + self.data = data + if parent is not None: + _copy_attributes(self, parent.__dict__) + _copy_attributes(self, kwargs) + + + +class SampleSet(MutableSequence): + """A set of samples with extra attributes + https://docs.python.org/3/library/collections.abc.html#collections-abstract-base-classes + """ + + def __init__(self, samples, parent=None, **kwargs): + self.samples = samples + if parent is not None: + _copy_attributes(self, parent.__dict__) + _copy_attributes(self, kwargs) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, item): + return self.samples.__getitem__(item) + + def __setitem__(self, key, item): + return self.samples.__setitem__(key, item) + + def __delitem__(self, item): + return self.samples.__delitem__(item) + + def insert(self, index, item): + # if not item in self.samples: + self.samples.insert(index, item) diff --git a/bob/ip/binseg/data/utils.py b/bob/ip/binseg/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50db3da3b05b50064ce981d7b633ba45b7a32537 --- /dev/null +++ b/bob/ip/binseg/data/utils.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# coding=utf-8 + + +"""Common utilities""" + + +import functools +import nose.plugins.skip +import torch.utils.data +import bob.extension + + +def rc_variable_set(name): + """ + Decorator that checks if a given bobrc variable is set before running + """ + + def wrapped_function(test): + @functools.wraps(test) + def wrapper(*args, **kwargs): + if bob.extension.rc[name]: + return test(*args, **kwargs) + else: + raise nose.plugins.skip.SkipTest("Bob's RC variable '%s' is not set" % name) + + return wrapper + + return wrapped_function + + +class DelayedSample2TorchDataset(torch.utils.data.Dataset): + """PyTorch dataset wrapper around DelayedSample lists + + A transform object can be passed that will be applied to the image, ground + truth and mask (if present). + + It supports indexing such that dataset[i] can be used to get ith sample. + + Parameters + ---------- + samples : list + A list of :py:class:`bob.ip.binseg.data.sample.DelayedSample` objects + + transform : :py:mod:`bob.ip.binseg.data.transforms`, optional + A transform or composition of transfroms. Defaults to ``None``. + """ + + def __init__(self, samples, transform=None): + + self.samples = samples + self.transform = transform + + def __len__(self): + """ + + Returns + ------- + + size : int + size of the dataset + + """ + return len(self.samples) + + def __getitem__(self, index): + """ + + Parameters + ---------- + + index : int + + Returns + ------- + + sample : tuple + The sample data: ``[key, image[, gt[, mask]]]`` + + """ + + item = self.samples[index] + data = item.data # triggers data loading + + retval = [data["data"]] + if "label" in data: retval.append(data["label"]) + if "mask" in data: retval.append(data["mask"]) + + if self.transform: + retval = self.transform(*retval) + + return [item.key] + retval diff --git a/conda/meta.yaml b/conda/meta.yaml index 3a12db9ce5b2cff4a44017c1afe4740413fa89fb..a8906290fd911d639e066d14eefa39162d20c66a 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -29,6 +29,8 @@ requirements: - pytorch {{ pytorch }} # [linux] - numpy {{ numpy }} - bob.extension + - bob.core + - bob.io.base run: - python - setuptools @@ -36,11 +38,10 @@ requirements: - {{ pin_compatible('torchvision') }} # [linux] - {{ pin_compatible('numpy') }} - pandas + - pillow - matplotlib - tqdm - tabulate - - bob.core - - bob.io.base test: imports: diff --git a/doc/api.rst b/doc/api.rst index 500e6a6459d457155d908ceec91b8e3ac44c756b..11f26d99a2efaeb707c4d30cf2fba220b706add9 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -17,11 +17,24 @@ Data Manipulation bob.ip.binseg.data bob.ip.binseg.data.binsegdataset - bob.ip.binseg.data.csvdataset bob.ip.binseg.data.folderdataset + bob.ip.binseg.data.csvdataset + bob.ip.binseg.data.jsondataset + bob.ip.binseg.data.loader + bob.ip.binseg.data.sample + bob.ip.binseg.data.utils bob.ip.binseg.data.transforms +Datasets +-------- + +.. autosummary:: + :toctree: api/dataset + + bob.ip.binseg.data.drive + + Engines ------- @@ -179,16 +192,3 @@ Datasets bob.ip.binseg.configs.datasets.starechasedb1iostarhrf544 bob.ip.binseg.configs.datasets.starechasedb1iostarhrf544ssldrive bob.ip.binseg.configs.datasets.staretest - -Test Units ----------- - -.. autosummary:: - :toctree: api/tests - - bob.ip.binseg.test - bob.ip.binseg.test.test_basemetrics - bob.ip.binseg.test.test_batchmetrics - bob.ip.binseg.test.test_checkpointer - bob.ip.binseg.test.test_summary - bob.ip.binseg.test.test_transforms