diff --git a/pyproject.toml b/pyproject.toml index a418e59bd1e3f464ef4be58cd28431ae0efcb61e..549b347ebf05ef8554a725196a8576e8e809c2af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8" montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9" # shenzhen dataset (and cross-validation folds) shenzhen = "ptbench.configs.datasets.shenzhen.default" -shenzhen_rgb = "ptbench.data.shenzhen.rgb" +shenzhen_rgb = "ptbench.configs.datasets.shenzhen.rgb" shenzhen_f0 = "ptbench.data.shenzhen.fold_0" shenzhen_f1 = "ptbench.data.shenzhen.fold_1" shenzhen_f2 = "ptbench.data.shenzhen.fold_2" diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py index 6f5b31ff41c2b226059794d8ef1fdffd7828e8b4..11c285bcc93380291f273e191281035462eb0d6f 100644 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -55,6 +55,7 @@ class DefaultModule(BaseDataModule): fieldnames=("data", "label"), loader=samples_loader, ) + ( self.train_dataset, self.validation_dataset, diff --git a/src/ptbench/configs/datasets/shenzhen/rgb.py b/src/ptbench/configs/datasets/shenzhen/rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..0ceb952e28c0a1181772565b425043a9a08cb646 --- /dev/null +++ b/src/ptbench/configs/datasets/shenzhen/rgb.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Shenzhen dataset for TB detection (cross validation fold 0, RGB) + +* Split reference: first 80% of TB and healthy CXR for "train", rest for "test" +* This configuration resolution: 512 x 512 (default) +* See :py:mod:`ptbench.data.shenzhen` for dataset details +""" + +from clapper.logging import setup +from torchvision import transforms + +from ....data import return_subsets +from ....data.base_datamodule import BaseDataModule +from ....data.dataset import JSONDataset +from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class DefaultModule(BaseDataModule): + def __init__( + self, + train_batch_size=1, + predict_batch_size=1, + drop_incomplete_batch=False, + cache_samples=False, + multiproc_kwargs=None, + ): + super().__init__( + train_batch_size=train_batch_size, + predict_batch_size=predict_batch_size, + drop_incomplete_batch=drop_incomplete_batch, + multiproc_kwargs=multiproc_kwargs, + ) + + self.cache_samples = cache_samples + + self.post_transforms = [ + transforms.ToPILImage(), + transforms.Lambda(lambda x: x.convert("RGB")), + transforms.ToTensor(), + ] + + def setup(self, stage: str): + if self.cache_samples: + logger.info( + "Argument cache_samples set to True. Samples will be loaded in memory." + ) + samples_loader = _cached_loader + else: + logger.info( + "Argument cache_samples set to False. Samples will be loaded at runtime." + ) + samples_loader = _delayed_loader + + self.json_dataset = JSONDataset( + protocols=_protocols, + fieldnames=("data", "label"), + loader=samples_loader, + post_transforms=self.post_transforms, + ) + + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + self.predict_dataset, + ) = return_subsets(self.json_dataset, "default") + + +datamodule = DefaultModule diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index b1ffcadadff653340b39d4278ea9a6d1db795106..07f18acf8b1ed3c68be535c621d371f026c29ff7 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -75,7 +75,7 @@ class JSONDataset: * ``data``: which contains the data associated witht this sample """ - def __init__(self, protocols, fieldnames, loader): + def __init__(self, protocols, fieldnames, loader, post_transforms=[]): if isinstance(protocols, dict): self._protocols = protocols else: @@ -87,6 +87,7 @@ class JSONDataset: } self.fieldnames = fieldnames self._loader = loader + self.post_transforms = post_transforms def check(self, limit=0): """For each protocol, check if all data can be correctly accessed. @@ -176,6 +177,7 @@ class JSONDataset: self._loader( dict(protocol=protocol, subset=subset, order=n), dict(zip(self.fieldnames, k)), + self.post_transforms, ) for n, k in tqdm.tqdm(enumerate(samples)) ] diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py index 931c62912ec8beffa21c34d39d056a8bcac17506..a11aefee77becbbbb07e15a25d5404594c16999a 100644 --- a/src/ptbench/data/loader.py +++ b/src/ptbench/data/loader.py @@ -70,15 +70,15 @@ def load_pil_rgb(path): return load_pil(path).convert("RGB") -def make_cached(sample, loader, key=None): +def make_cached(sample, loader, additional_transforms=[], key=None): return Sample( - loader(sample), + loader(sample, additional_transforms), key=key or sample["data"], label=sample["label"], ) -def make_delayed(sample, loader, key=None): +def make_delayed(sample, loader, additional_transforms=[], key=None): """Returns a delayed-loading Sample object. Parameters @@ -105,7 +105,7 @@ def make_delayed(sample, loader, key=None): sample loading. """ return DelayedSample( - functools.partial(loader, sample), + functools.partial(loader, sample, additional_transforms), key=key or sample["data"], label=sample["label"], ) diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py index 9abf568964b3eab377ffcfcca98cf1e31a1d0cb3..d284b1b28905594e0c9c7fd20cb31f7c566468e3 100644 --- a/src/ptbench/data/shenzhen/__init__.py +++ b/src/ptbench/data/shenzhen/__init__.py @@ -51,29 +51,31 @@ _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) _resize_size = 512 _cc_size = 512 -_data_transforms = transforms.Compose( - [ - RemoveBlackBorders(), - transforms.Resize(_resize_size), - transforms.CenterCrop(_cc_size), - transforms.ToTensor(), - ] -) +_data_transforms = [ + RemoveBlackBorders(), + transforms.Resize(_resize_size), + transforms.CenterCrop(_cc_size), + transforms.ToTensor(), +] -def _raw_data_loader(sample): +def _raw_data_loader(sample, additional_transforms=[]): raw_data = load_pil_baw(os.path.join(_datadir, sample["data"])) + + base_transforms = transforms.Compose( + _data_transforms + additional_transforms + ) return dict( - data=_data_transforms(raw_data), + data=base_transforms(raw_data), label=sample["label"], ) -def _cached_loader(context, sample): - return make_cached(sample, _raw_data_loader) +def _cached_loader(context, sample, additional_transforms=[]): + return make_cached(sample, _raw_data_loader, additional_transforms) -def _delayed_loader(context, sample): +def _delayed_loader(context, sample, additional_transforms=[]): # "context" is ignored in this case - database is homogeneous # we returned delayed samples to avoid loading all images at once - return make_delayed(sample, _raw_data_loader) + return make_delayed(sample, _raw_data_loader, additional_transforms) diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py deleted file mode 100644 index 7bdb8fe3ce6826fb98d0d6356f2e1b429670a3d1..0000000000000000000000000000000000000000 --- a/src/ptbench/data/shenzhen/rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Shenzhen dataset for TB detection (default protocol, converted in RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.shenzhen` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("default", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule