Skip to content
Snippets Groups Projects
Commit eaaff3ed authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Added support for post transforms

parent bd79cd2f
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -118,7 +118,7 @@ montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8"
montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9"
# shenzhen dataset (and cross-validation folds)
shenzhen = "ptbench.configs.datasets.shenzhen.default"
shenzhen_rgb = "ptbench.data.shenzhen.rgb"
shenzhen_rgb = "ptbench.configs.datasets.shenzhen.rgb"
shenzhen_f0 = "ptbench.data.shenzhen.fold_0"
shenzhen_f1 = "ptbench.data.shenzhen.fold_1"
shenzhen_f2 = "ptbench.data.shenzhen.fold_2"
......
......@@ -55,6 +55,7 @@ class DefaultModule(BaseDataModule):
fieldnames=("data", "label"),
loader=samples_loader,
)
(
self.train_dataset,
self.validation_dataset,
......
......@@ -2,19 +2,20 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (default protocol, converted in RGB)
"""Shenzhen dataset for TB detection (cross validation fold 0, RGB)
* Split reference: first 64% of TB and healthy CXR for "train" 16% for
* "validation", 20% for "test"
* Split reference: first 80% of TB and healthy CXR for "train", rest for "test"
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from torchvision import transforms
from .. import return_subsets
from ..base_datamodule import BaseDataModule
from . import _maker
from ....data import return_subsets
from ....data.base_datamodule import BaseDataModule
from ....data.dataset import JSONDataset
from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -25,6 +26,7 @@ class DefaultModule(BaseDataModule):
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
cache_samples=False,
multiproc_kwargs=None,
):
super().__init__(
......@@ -34,14 +36,39 @@ class DefaultModule(BaseDataModule):
multiproc_kwargs=multiproc_kwargs,
)
self.cache_samples = cache_samples
self.post_transforms = [
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]
def setup(self, stage: str):
self.dataset = _maker("default", RGB=True)
if self.cache_samples:
logger.info(
"Argument cache_samples set to True. Samples will be loaded in memory."
)
samples_loader = _cached_loader
else:
logger.info(
"Argument cache_samples set to False. Samples will be loaded at runtime."
)
samples_loader = _delayed_loader
self.json_dataset = JSONDataset(
protocols=_protocols,
fieldnames=("data", "label"),
loader=samples_loader,
post_transforms=self.post_transforms,
)
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
) = return_subsets(self.json_dataset, "default")
datamodule = DefaultModule
......@@ -75,7 +75,7 @@ class JSONDataset:
* ``data``: which contains the data associated witht this sample
"""
def __init__(self, protocols, fieldnames, loader):
def __init__(self, protocols, fieldnames, loader, post_transforms=[]):
if isinstance(protocols, dict):
self._protocols = protocols
else:
......@@ -87,6 +87,7 @@ class JSONDataset:
}
self.fieldnames = fieldnames
self._loader = loader
self.post_transforms = post_transforms
def check(self, limit=0):
"""For each protocol, check if all data can be correctly accessed.
......@@ -176,6 +177,7 @@ class JSONDataset:
self._loader(
dict(protocol=protocol, subset=subset, order=n),
dict(zip(self.fieldnames, k)),
self.post_transforms,
)
for n, k in tqdm.tqdm(enumerate(samples))
]
......
......@@ -70,15 +70,15 @@ def load_pil_rgb(path):
return load_pil(path).convert("RGB")
def make_cached(sample, loader, key=None):
def make_cached(sample, loader, additional_transforms=[], key=None):
return Sample(
loader(sample),
loader(sample, additional_transforms),
key=key or sample["data"],
label=sample["label"],
)
def make_delayed(sample, loader, key=None):
def make_delayed(sample, loader, additional_transforms=[], key=None):
"""Returns a delayed-loading Sample object.
Parameters
......@@ -105,7 +105,7 @@ def make_delayed(sample, loader, key=None):
sample loading.
"""
return DelayedSample(
functools.partial(loader, sample),
functools.partial(loader, sample, additional_transforms),
key=key or sample["data"],
label=sample["label"],
)
......
......@@ -51,29 +51,31 @@ _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
_resize_size = 512
_cc_size = 512
_data_transforms = transforms.Compose(
[
RemoveBlackBorders(),
transforms.Resize(_resize_size),
transforms.CenterCrop(_cc_size),
transforms.ToTensor(),
]
)
_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(_resize_size),
transforms.CenterCrop(_cc_size),
transforms.ToTensor(),
]
def _raw_data_loader(sample):
def _raw_data_loader(sample, additional_transforms=[]):
raw_data = load_pil_baw(os.path.join(_datadir, sample["data"]))
base_transforms = transforms.Compose(
_data_transforms + additional_transforms
)
return dict(
data=_data_transforms(raw_data),
data=base_transforms(raw_data),
label=sample["label"],
)
def _cached_loader(context, sample):
return make_cached(sample, _raw_data_loader)
def _cached_loader(context, sample, additional_transforms=[]):
return make_cached(sample, _raw_data_loader, additional_transforms)
def _delayed_loader(context, sample):
def _delayed_loader(context, sample, additional_transforms=[]):
# "context" is ignored in this case - database is homogeneous
# we returned delayed samples to avoid loading all images at once
return make_delayed(sample, _raw_data_loader)
return make_delayed(sample, _raw_data_loader, additional_transforms)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment