diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 5c94091877ff735e74942795203aa82c8931b61e..4c92476360b9895016560976706763f1098c74d3 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import collections +import functools import logging import multiprocessing import sys @@ -27,6 +28,39 @@ from .typing import ( logger = logging.getLogger(__name__) +def _sample_size_bytes(s: Sample) -> int: + """Recurse into the sample and figures out its total occupance in bytes. + + Parameters + ---------- + + s + The sample to be analyzed + + + Returns + ------- + + size + The size in bytes occupied by this sample + """ + + def _tensor_size_bytes(t: torch.Tensor) -> int: + """Returns a tensor size in bytes.""" + return int(t.element_size() * torch.prod(torch.tensor(t.shape))) + + size = int(s[0].element_size() * torch.prod(torch.tensor(s[0].shape))) + size += sys.getsizeof(s[1]) + + # check each element - if it is a tensor, then adds its total space in + # bytes + for v in s[1].values(): + if isinstance(v, torch.Tensor): + size += _tensor_size_bytes(v) + + return size + + class _DelayedLoadingDataset(Dataset): """A list that loads its samples on demand. @@ -59,6 +93,15 @@ class _DelayedLoadingDataset(Dataset): self.loader = loader self.transform = torchvision.transforms.Compose(transforms) + # Tests loading and output tensor size + first_sample = self[0] + logger.info( + f"Delayed loading dataset (first tensor): " + f"{list(first_sample[0].shape)}@{first_sample[0].dtype}" + ) + sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) + logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") + def labels(self) -> list[int]: """Returns the integer labels for all samples in the dataset.""" return [self.loader.label(k) for k in self.split] @@ -75,6 +118,39 @@ class _DelayedLoadingDataset(Dataset): yield self[x] +def _apply_loader_and_transforms( + info: typing.Any, + load: typing.Callable[[typing.Any], Sample], + model_transform: typing.Callable[[torch.Tensor], torch.Tensor], +) -> Sample: + """Local wrapper to apply raw-data loading and transformation in a single + step. + + Parameters + ---------- + + info + The sample information, as loaded from its split dictionary + + load + The raw-data loader function to use for loading the sample + + model_transform + A callable that will transform the loaded tensor into something + suitable for the model it will train. Typically, this will be a + composed transform. + + + Returns + ------- + + sample + The loaded and transformed sample. + """ + sample = load(info) + return model_transform(sample[0]), sample[1] + + class _CachedDataset(Dataset): """Basically, a list of preloaded samples. @@ -112,27 +188,41 @@ class _CachedDataset(Dataset): parallel: int = -1, transforms: typing.Sequence[Transform] = [], ): - self.transform = torchvision.transforms.Compose(transforms) + self.loader = functools.partial( + _apply_loader_and_transforms, + load=loader.sample, + model_transform=torchvision.transforms.Compose(transforms), + ) if parallel < 0: self.data = [ - loader.sample(k) for k in tqdm.tqdm(split, unit="sample") + self.loader(k) for k in tqdm.tqdm(split, unit="sample") ] else: instances = parallel or multiprocessing.cpu_count() logger.info(f"Caching dataset using {instances} processes...") with multiprocessing.Pool(instances) as p: self.data = list( - tqdm.tqdm(p.imap(loader.sample, split), total=len(split)) + tqdm.tqdm(p.imap(self.loader, split), total=len(split)) ) + # Estimates memory occupance + logger.info( + f"Cached dataset (first tensor): " + f"{list(self.data[0][0].shape)}@{self.data[0][0].dtype}" + ) + sample_size_mb = _sample_size_bytes(self.data[0]) / (1024.0 * 1024.0) + logger.info( + f"Estimated RAM occupance (sample / dataset): " + f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb" + ) + def labels(self) -> list[int]: """Returns the integer labels for all samples in the dataset.""" return [k[1]["label"] for k in self.data] def __getitem__(self, key: int) -> Sample: - tensor, metadata = self.data[key] - return self.transform(tensor), metadata + return self.data[key] def __len__(self): return len(self.data) @@ -338,14 +428,6 @@ class CachingDataModule(lightning.LightningDataModule): validation to balance sample picking probability, making sample across classes **and** datasets equitable. - model_transforms - A list of transforms (torch modules) that will be applied after - raw-data-loading, and just before data is fed into the model or - eventual data-augmentation transformations for all data loaders - produced by this data module. This part of the pipeline receives data - as output by the raw-data-loader, or model-related transforms (e.g. - resize adaptions), if any is specified. - batch_size Number of samples in every **training** batch (this parameter affects memory requirements for the network). If the number of samples in the @@ -382,6 +464,21 @@ class CachingDataModule(lightning.LightningDataModule): multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. + + + Attributes + ---------- + + model_transforms + A list of transforms (torch modules) that will be applied after + raw-data-loading, and just before data is fed into the model or + eventual data-augmentation transformations for all data loaders + produced by this data module. This part of the pipeline receives data + as output by the raw-data-loader, or model-related transforms (e.g. + resize adaptions), if any is specified. If data is cached, it is + cached **after** model-transforms are applied, as that is a potential + memory saver (e.g., if it contains a resizing operation to smaller + images). """ DatasetDictionary = dict[str, Dataset] @@ -392,7 +489,6 @@ class CachingDataModule(lightning.LightningDataModule): raw_data_loader: RawDataLoader, cache_samples: bool = False, balance_sampler_by_class: bool = False, - model_transforms: list[Transform] = [], batch_size: int = 1, batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, @@ -407,7 +503,7 @@ class CachingDataModule(lightning.LightningDataModule): self.cache_samples = cache_samples self._train_sampler = None self.balance_sampler_by_class = balance_sampler_by_class - self.model_transforms = model_transforms + self.model_transforms: list[Transform] | None = None self.drop_incomplete_batch = drop_incomplete_batch self.parallel = parallel # immutable, otherwise would need to call @@ -551,6 +647,13 @@ class CachingDataModule(lightning.LightningDataModule): Name of the dataset to setup. """ + if self.model_transforms is None: + raise RuntimeError( + "Parameter `model_transforms` has not yet been " + "set. If you do not have model transforms, then " + "set it to an empty list." + ) + if name in self._datasets: logger.info( f"Dataset `{name}` is already setup. " diff --git a/src/ptbench/data/image_utils.py b/src/ptbench/data/image_utils.py index ac31b9ce7fbce85fb688b394c99d591b83049f7f..ed284afc4ab63fd804a8110ff676c29917d968b4 100644 --- a/src/ptbench/data/image_utils.py +++ b/src/ptbench/data/image_utils.py @@ -31,6 +31,40 @@ class SingleAutoLevel16to8: ).convert("L") +def remove_black_borders(img: PIL.Image.Image, threshold: int=0) -> PIL.Image.Image: + """Remove black borders of CXR + + Parameters + ---------- + img + A PIL image + threshold + Threshold value from which borders are considered black. + Defaults to 0. + + Returns + ------- + A PIL image with black borders removed + """ + + img = numpy.asarray(img) + + if len(img.shape) == 2: # single channel + mask = numpy.asarray(img) > threshold + return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + + elif len(img.shape) == 3 and img.shape[2] == 3: + r_mask = img[:, :, 0] > threshold + g_mask = img[:, :, 1] > threshold + b_mask = img[:, :, 2] > threshold + + mask = r_mask | g_mask | b_mask + return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + + else: + raise NotImplementedError + + class RemoveBlackBorders: """Remove black borders of CXR.""" @@ -38,9 +72,7 @@ class RemoveBlackBorders: self.threshold = threshold def __call__(self, img): - img = numpy.asarray(img) - mask = numpy.asarray(img) > self.threshold - return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))]) + return remove_black_borders(img, self.threshold) def load_pil(path: str | pathlib.Path) -> PIL.Image.Image: diff --git a/src/ptbench/data/montgomery/__init__.py b/src/ptbench/data/montgomery/__init__.py index 65239cbf5d908075346675ad10e7c86569383f77..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/ptbench/data/montgomery/__init__.py +++ b/src/ptbench/data/montgomery/__init__.py @@ -1,88 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for computer-aided diagnosis. - -The Montgomery database has been established to foster research -in computer-aided diagnosis of pulmonary diseases with a special -focus on pulmonary tuberculosis (TB). - -* Reference: [MONTGOMERY-SHENZHEN-2014]_ -* Original resolution (height x width or width x height): 4020 x 4892 -* Split reference: none -* Protocol ``default``: - - * Training samples: 64% of TB and healthy CXR (including labels) - * Validation samples: 16% of TB and healthy CXR (including labels) - * Test samples: 20% of TB and healthy CXR (including labels) -""" - -import importlib.resources -import os - -from ...utils.rc import load_rc -from .. import make_dataset -from ..dataset import JSONDataset -from ..loader import load_pil_baw, make_delayed - -_protocols = [ - importlib.resources.files(__name__).joinpath("default.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_1.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_2.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_3.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_4.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_5.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_6.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_7.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), -] - -_datadir = load_rc().get("datadir.montgomery", os.path.realpath(os.curdir)) - - -def _raw_data_loader(sample): - return dict( - data=load_pil_baw(os.path.join(_datadir, sample["data"])), # type: ignore - label=sample["label"], - ) - - -def _loader(context, sample): - # "context" is ignored in this case - database is homogeneous - # we return delayed samples to avoid loading all images at once - return make_delayed(sample, _raw_data_loader) - - -json_dataset = JSONDataset( - protocols=_protocols, - fieldnames=("data", "label"), - loader=_loader, -) -"""Montgomery dataset object.""" - - -def _maker(protocol, resize_size=512, cc_size=512, RGB=False): - from torchvision import transforms - - from ..transforms import ElasticDeformation, RemoveBlackBorders - - post_transforms = [] - if RGB: - post_transforms = [ - transforms.Lambda(lambda x: x.convert("RGB")), - transforms.ToTensor(), - ] - - return make_dataset( - [json_dataset.subsets(protocol)], - [ - RemoveBlackBorders(), - transforms.Resize(resize_size), - transforms.CenterCrop(cc_size), - ], - [ElasticDeformation(p=0.8)], - post_transforms, - ) diff --git a/src/ptbench/data/montgomery/default.py b/src/ptbench/data/montgomery/default.py index 1f5c0809869be5f011880e808e160024b3c1c1b0..bb57b9a7e8d95f9af40d36ac5a57349c8f514846 100644 --- a/src/ptbench/data/montgomery/default.py +++ b/src/ptbench/data/montgomery/default.py @@ -2,46 +2,60 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (default protocol) +"""Montgomery datamodule for TB detection (``default`` protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" +The standard digital image database for Tuberculosis was created by the National +Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s +Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from -from clapper.logging import setup +* Database reference: [MONTGOMERY-SHENZHEN-2014]_ +* Original resolution (height x width or width x height): 4020x4892 px or 4892x4020 px +* This split: -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker + * Split reference: None + * Training samples: ?? of TB and healthy CXR + * Validation samples: ?? of TB and healthy CXR + * Test samples: ?? of TB and healthy CXR -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +Data specifications: +* Raw data input (on disk): -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) + * PNG images 8 bit grayscale + * resolution: fixed to one of the cases above + +* Output image: + + * Transforms: + + * Load raw PNG with :py:mod:`PIL` + * Remove black borders + * Torch center cropping to get square image + + * Final specifications - def setup(self, stage: str): - self.dataset = _maker("default") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) + * Grayscale, encoded as a single plane image, 8 bits + * Square (4020x4020 px) -datamodule = DefaultModule +Protocol ``default``: + + * Training samples: first 64% of TB and healthy CXR (including labels) + * Validation samples: 16% of TB and healthy CXR (including labels) + * Test samples: 20% of TB and healty CXR (including labels) +""" + +import importlib.resources + +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader + +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "default.json.bz2" + ) + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_0.py b/src/ptbench/data/montgomery/fold_0.py index c60791be50ccd5186ce8e4af263efb7d7513b07a..e50d2e302f1c6b529c862c529bb77cf20aef8a57 100644 --- a/src/ptbench/data/montgomery/fold_0.py +++ b/src/ptbench/data/montgomery/fold_0.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 0) +"""Montgomery datamodule for TB detection (``fold 0`` protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_0.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_0") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_0_rgb.py b/src/ptbench/data/montgomery/fold_0_rgb.py deleted file mode 100644 index 8e8b0c8914b6a63dd9ab854984ff2bc51cb4e255..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_0_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 0, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_0", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_1.py b/src/ptbench/data/montgomery/fold_1.py index d6627e673978bcf960b8fb5f72add7cb4a13a141..3698a9edfa614f980b9b2352d97c7329965d371d 100644 --- a/src/ptbench/data/montgomery/fold_1.py +++ b/src/ptbench/data/montgomery/fold_1.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 1) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_1.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_1") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_1_rgb.py b/src/ptbench/data/montgomery/fold_1_rgb.py deleted file mode 100644 index bc47a322c3fd779e3bc19924f6d7ac7c13e71847..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_1_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 1, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_1", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_2.py b/src/ptbench/data/montgomery/fold_2.py index 8c5f4a66fd2af0b9f26b67241f45c630f69bd06a..b2d7ac2cfd8def5627b56d5353740e9676e1d9cc 100644 --- a/src/ptbench/data/montgomery/fold_2.py +++ b/src/ptbench/data/montgomery/fold_2.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 2) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_2.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_2") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_2_rgb.py b/src/ptbench/data/montgomery/fold_2_rgb.py deleted file mode 100644 index b81a877b2bc7372a99812a27935e6daf42401568..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_2_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 2, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_2", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_3.py b/src/ptbench/data/montgomery/fold_3.py index 8e685d7e3baa3a23924c62a77ffc61bf51e12056..1c566e4f528e587cfd8a3bd882e2c73ea5a46aa6 100644 --- a/src/ptbench/data/montgomery/fold_3.py +++ b/src/ptbench/data/montgomery/fold_3.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 3) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_3.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_3") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_3_rgb.py b/src/ptbench/data/montgomery/fold_3_rgb.py deleted file mode 100644 index 7b600371c8d434d79049c6e6423b36e99f2a32cb..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_3_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 3, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_3", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_4.py b/src/ptbench/data/montgomery/fold_4.py index 9459cb938605df06823a86a96fbd1cf374fe9738..4b68bd538f71115a01bae0fce87742be6ab711a8 100644 --- a/src/ptbench/data/montgomery/fold_4.py +++ b/src/ptbench/data/montgomery/fold_4.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 4) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_4.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_4") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_4_rgb.py b/src/ptbench/data/montgomery/fold_4_rgb.py deleted file mode 100644 index 3eb136f654ab8d8d648468948e05dad774d85076..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_4_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 4, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_4", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_5.py b/src/ptbench/data/montgomery/fold_5.py index 147690f6d54f15d50b52f88288dbc8a41dfb7f33..59891e8e1b5531b94fc996bfe25ef140ff39a83a 100644 --- a/src/ptbench/data/montgomery/fold_5.py +++ b/src/ptbench/data/montgomery/fold_5.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 5) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_5.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_5") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_5_rgb.py b/src/ptbench/data/montgomery/fold_5_rgb.py deleted file mode 100644 index 3e7cb73f6957086b99147812b07f733dc51af9ec..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_5_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 5, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_5", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_6.py b/src/ptbench/data/montgomery/fold_6.py index 69f24390ac01271c3e961950d429d973e535c380..e6c1d31a69ff20bbfd3ec4e53ba4eab0f9beec7f 100644 --- a/src/ptbench/data/montgomery/fold_6.py +++ b/src/ptbench/data/montgomery/fold_6.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 6) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_6.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_6") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_6_rgb.py b/src/ptbench/data/montgomery/fold_6_rgb.py deleted file mode 100644 index ff3a8cdb0c00f511f4ebb7abcfabb10ae7853e99..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_6_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 6, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_6", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_7.py b/src/ptbench/data/montgomery/fold_7.py index 20ba9d3a7da5ffcb8673e685a0534d82fdb7ed2b..44dd80512be61c32616188968a418b9963b41aed 100644 --- a/src/ptbench/data/montgomery/fold_7.py +++ b/src/ptbench/data/montgomery/fold_7.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 7) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_7.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_7") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_7_rgb.py b/src/ptbench/data/montgomery/fold_7_rgb.py deleted file mode 100644 index 05664b06ab6393911a77b32418d6f2afb9d455fa..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_7_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 7, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_7", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_8.py b/src/ptbench/data/montgomery/fold_8.py index e92ff959a9b1028c174c95719867f5086831d6c9..fd7edde69259023fa36ff05027fe1f0ad19d6661 100644 --- a/src/ptbench/data/montgomery/fold_8.py +++ b/src/ptbench/data/montgomery/fold_8.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 8) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_8.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_8") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_8_rgb.py b/src/ptbench/data/montgomery/fold_8_rgb.py deleted file mode 100644 index b7d59359dcde32694affea0e3df88ad747f48e31..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_8_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 8, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_8", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/fold_9.py b/src/ptbench/data/montgomery/fold_9.py index 81bbf72e78826f7e9560189be149d51cb729064e..91228362f8c376d9ac9186f6675d80295e848f13 100644 --- a/src/ptbench/data/montgomery/fold_9.py +++ b/src/ptbench/data/montgomery/fold_9.py @@ -2,46 +2,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery dataset for TB detection (cross validation fold 9) +"""Montgomery datamodule for TB detection (default protocol) -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details +See :py:mod:`ptbench.data.montgomery.default` for input/output details. """ -from clapper.logging import setup +import importlib.resources -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .loader import RawDataLoader -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "fold_9.json.bz2" ) - - def setup(self, stage: str): - self.dataset = _maker("fold_9") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule + ), + raw_data_loader=RawDataLoader(), +) diff --git a/src/ptbench/data/montgomery/fold_9_rgb.py b/src/ptbench/data/montgomery/fold_9_rgb.py deleted file mode 100644 index e961e08ffe49a94001252c641ba8bee86758b44f..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/fold_9_rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (cross validation fold 9, RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold_9", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/montgomery/loader.py b/src/ptbench/data/montgomery/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ad856d5fcc45603015cf75c1c87885751f25bcd8 --- /dev/null +++ b/src/ptbench/data/montgomery/loader.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Specialized raw-data loaders for the Montgomery dataset.""" + +import os + +import PIL.Image + +from torchvision.transforms.functional import center_crop, to_tensor + +from ...utils.rc import load_rc +from ..image_utils import remove_black_borders +from ..typing import RawDataLoader as _BaseRawDataLoader +from ..typing import Sample + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the Montgomery dataset. + + Attributes + ---------- + + datadir + This variable contains the base directory where the database raw data + is stored. + """ + + datadir: str + + def __init__(self): + self.datadir = load_rc().get( + "datadir.montgomery", os.path.realpath(os.curdir) + ) + + def sample(self, sample: tuple[str, int]) -> Sample: + """Loads a single image sample from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + sample + The sample representation + """ + # N.B.: Montgomery images are encoded as grayscale PNGs, so no need to + # convert them again with Image.convert("L"). + image = PIL.Image.open(os.path.join(self.datadir, sample[0])) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() + + return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + + def label(self, sample: tuple[str, int]) -> int: + """Loads a single image sample label from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + label + The integer label associated with the sample + """ + return sample[1] diff --git a/src/ptbench/data/montgomery/rgb.py b/src/ptbench/data/montgomery/rgb.py deleted file mode 100644 index c162126648f0baae5a921fa7f009da171fb8ccc7..0000000000000000000000000000000000000000 --- a/src/ptbench/data/montgomery/rgb.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Montgomery dataset for TB detection (default protocol, converted in RGB) - -* Split reference: first 64% of TB and healthy CXR for "train" 16% for -* "validation", 20% for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.montgomery` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("default", RGB=True) - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index bfe93f44faaa9df235f357c1cc3a927412f4a011..a163b9bc6290f53e611d214bdfa03e0cf93eb492 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -2,33 +2,42 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen datamodule for computer-aided diagnosis (default protocol) +"""Shenzhen datamodule for computer-aided diagnosis (``default`` protocol) -See :py:mod:`ptbench.data.shenzhen` for more database details. +The standard digital image database for Tuberculosis was created by the National +Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s +Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from +out-patient clinics, and were captured as part of the daily routine using +Philips DR Digital Diagnose systems. -This configuration: +* Database reference: [MONTGOMERY-SHENZHEN-2014]_ +* Original resolution (height x width or width x height): 3000 x 3000 or less +* This split: -* Raw data input (on disk): + * Split reference: None + * Training samples: 64% of TB and healthy CXR (including labels) + * Validation samples: 16% of TB and healthy CXR (including labels) + * Test samples: 20% of TB and healthy CXR (including labels) + +Data specifications: - * PNG images (black and white, encoded as color images) - * Variable width and height: +* Raw data input (on disk): - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels + * PNG images (grayscale, encoded as RGB images with "inverted" grayscale scale) + * Variable width and height * Output image: - * Transforms: + * Transforms: - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) + * Load raw PNG with :py:mod:`PIL` + * Remove black borders + * Torch center cropping to get square image - * Final specifications: + * Final specifications: - * Fixed resolution: 512x512 pixels - * Color RGB encoding + * Grayscale, encoded as a single plane image, 8 bits + * Square, with varying resolutions, depending on the input image """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_0.py b/src/ptbench/data/shenzhen/fold_0.py index 888a0e60024480a3aaff65f6e3d819370fd22669..b505974491eea26e1da8931022eb168a42d57a0f 100644 --- a/src/ptbench/data/shenzhen/fold_0.py +++ b/src/ptbench/data/shenzhen/fold_0.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 0) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_1.py b/src/ptbench/data/shenzhen/fold_1.py index 62d7fbd55c83ed746754cbc99dcc65fe48efbc6a..1041c3e4ef6d14942dadd4c680dc10fee0cfd17c 100644 --- a/src/ptbench/data/shenzhen/fold_1.py +++ b/src/ptbench/data/shenzhen/fold_1.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 1) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_2.py b/src/ptbench/data/shenzhen/fold_2.py index b41284cd9d1c4a56c70eff715078f82213dabb3c..5026116a9cd75ac406f334682b38ce760104444d 100644 --- a/src/ptbench/data/shenzhen/fold_2.py +++ b/src/ptbench/data/shenzhen/fold_2.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 2) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_3.py b/src/ptbench/data/shenzhen/fold_3.py index cca555064e9923433ef39f591b3e342365cf7afc..16c00157c5fa9fda38afc16614b75f2e766c33d5 100644 --- a/src/ptbench/data/shenzhen/fold_3.py +++ b/src/ptbench/data/shenzhen/fold_3.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 3) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_4.py b/src/ptbench/data/shenzhen/fold_4.py index 897420076303e47406cc9efb3b6bf0d294ab3611..c0b0fdacdf90fdce168988057219923af73ad6a0 100644 --- a/src/ptbench/data/shenzhen/fold_4.py +++ b/src/ptbench/data/shenzhen/fold_4.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 4) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_5.py b/src/ptbench/data/shenzhen/fold_5.py index c520399d98ead9eeb1e3bdcfbe4dc48393adcebc..0397955e25d1077af68b825b5ecbf0d8974499db 100644 --- a/src/ptbench/data/shenzhen/fold_5.py +++ b/src/ptbench/data/shenzhen/fold_5.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 5) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_6.py b/src/ptbench/data/shenzhen/fold_6.py index a28f8fc5ca3e0ebd4b49fceaec99d3a2e94dd34c..145685ea96be63501a8afd771518b4b2f3f65c49 100644 --- a/src/ptbench/data/shenzhen/fold_6.py +++ b/src/ptbench/data/shenzhen/fold_6.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 6) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_7.py b/src/ptbench/data/shenzhen/fold_7.py index b0ea7b4324334980a2e55e4496ac4ab6af705d17..5b8d74034a18e2637a9a193557571521722e93bc 100644 --- a/src/ptbench/data/shenzhen/fold_7.py +++ b/src/ptbench/data/shenzhen/fold_7.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 7) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_8.py b/src/ptbench/data/shenzhen/fold_8.py index 9bbfbe84ab942cf5da5a8c5fc8318724908998f9..e9ce1a2f408543bc93d8f116a2f8834ab79c989f 100644 --- a/src/ptbench/data/shenzhen/fold_8.py +++ b/src/ptbench/data/shenzhen/fold_8.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 8) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/fold_9.py b/src/ptbench/data/shenzhen/fold_9.py index 87c2afb328f9b09f420a1ddce5f5d0ea54346c43..6da8dd3d7a4260e7b9a478baea4b2848383f8459 100644 --- a/src/ptbench/data/shenzhen/fold_9.py +++ b/src/ptbench/data/shenzhen/fold_9.py @@ -4,31 +4,7 @@ """Shenzhen datamodule for computer-aided diagnosis (fold 9) -See :py:mod:`ptbench.data.shenzhen` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images (black and white, encoded as color images) - * Variable width and height: - - * widths: from 1130 to 3001 pixels - * heights: from 948 to 3001 pixels - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch resizing(512px, 512px) - * Torch center cropping (512px, 512px) - - * Final specifications: - - * Fixed resolution: 512x512 pixels - * Color RGB encoding +See :py:mod:`ptbench.data.shenzhen.default` for input/output details. """ import importlib.resources diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py index 49ccf8bfb217e411004228b4acf7c924e3ffec66..3409fed2e1a552c44135888df6d6bc4a874b427c 100644 --- a/src/ptbench/data/shenzhen/loader.py +++ b/src/ptbench/data/shenzhen/loader.py @@ -2,30 +2,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen dataset for computer-aided diagnosis. - -The standard digital image database for Tuberculosis is created by the National -Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s -Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from -out-patient clinics, and were captured as part of the daily routine using -Philips DR Digital Diagnose systems. - -* Reference: [MONTGOMERY-SHENZHEN-2014]_ -* Original resolution (height x width or width x height): 3000 x 3000 or less -* Split reference: none -* Protocol ``default``: - - * Training samples: 64% of TB and healthy CXR (including labels) - * Validation samples: 16% of TB and healthy CXR (including labels) - * Test samples: 20% of TB and healthy CXR (including labels) -""" +"""Specialized raw-data loaders for the Shenzen dataset.""" import os -import torchvision.transforms +import PIL.Image + +from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import RemoveBlackBorders, load_pil_baw +from ..image_utils import remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -45,22 +31,12 @@ class RawDataLoader(_BaseRawDataLoader): """ datadir: str - transform: torchvision.transforms.Compose def __init__(self): self.datadir = load_rc().get( "datadir.shenzhen", os.path.realpath(os.curdir) ) - self.transform = torchvision.transforms.Compose( - [ - RemoveBlackBorders(), - torchvision.transforms.Resize(512), - torchvision.transforms.CenterCrop(512), - torchvision.transforms.ToTensor(), - ] - ) - def sample(self, sample: tuple[str, int]) -> Sample: """Loads a single image sample from the disk. @@ -79,9 +55,19 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = self.transform( - load_pil_baw(os.path.join(self.datadir, sample[0])) + # N.B.: Image.convert("L") is required to normalize grayscale back to + # normal (instead of inverted). + image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert( + "L" ) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index e2cb9b053a57c6c3a786c19c4cb47a99e8d5cddb..650aa7d4f60dfe5feda914deb8518415fda8876d 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -73,7 +73,8 @@ class Pasa(pl.LightningModule): self.name = "pasa" self.model_transforms = [ - torchvision.transforms.Resize(512), + torchvision.transforms.Grayscale(), + torchvision.transforms.Resize(512, antialias=True), ] self._train_loss = train_loss diff --git a/tests/test_ch.py b/tests/test_ch.py index 659e2c35ae092f3a90f7d072ba033786bb80bdf9..b28c81e93ef0765c59dd7252b93ecdecaeff6946 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -120,11 +120,6 @@ def test_loading(): from ptbench.data.datamodule import _DelayedLoadingDataset - def _check_size(shape): - if shape[0] == 1 and shape[1] == 512 and shape[2] == 512: - return True - return False - def _check_sample(s): assert len(s) == 2 @@ -132,10 +127,12 @@ def test_loading(): metadata = s[1] assert isinstance(data, torch.Tensor) - assert _check_size(data.shape) # Check size + + assert data.size(0) == 3 # check 3 channels + assert data.size(1) == data.size(2) # check square image assert ( - torchvision.transforms.ToPILImage()(data).mode == "L" + torchvision.transforms.ToPILImage()(data).mode == "RGB" ) # Check colors assert "label" in metadata diff --git a/tests/test_mc.py b/tests/test_mc.py index 1b2aa4fd5a0317b939816bf625907c534ead7910..2fcd14ac131f5d919e954bd4d9562effd4a19296 100644 --- a/tests/test_mc.py +++ b/tests/test_mc.py @@ -4,131 +4,188 @@ """Tests for Montgomery dataset.""" +import importlib + import pytest def test_protocol_consistency(): - from ptbench.data.montgomery import dataset # Default protocol - subset = dataset.subsets("default") + datamodule = importlib.import_module( + "ptbench.data.montgomery.default" + ).datamodule + subset = datamodule.database_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 88 for s in subset["train"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "validation" in subset assert len(subset["validation"]) == 22 for s in subset["validation"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "test" in subset assert len(subset["test"]) == 28 for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] # Cross-validation fold 0-7 for f in range(8): - subset = dataset.subsets("fold_" + str(f)) + datamodule = importlib.import_module( + f"ptbench.data.montgomery.fold_{str(f)}" + ).datamodule + subset = datamodule.database_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 99 for s in subset["train"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "validation" in subset assert len(subset["validation"]) == 25 for s in subset["validation"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "test" in subset assert len(subset["test"]) == 14 for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] # Cross-validation fold 8-9 for f in range(8, 10): - subset = dataset.subsets("fold_" + str(f)) + datamodule = importlib.import_module( + f"ptbench.data.montgomery.fold_{str(f)}" + ).datamodule + subset = datamodule.database_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 100 for s in subset["train"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "validation" in subset assert len(subset["validation"]) == 25 for s in subset["validation"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "test" in subset assert len(subset["test"]) == 13 for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_loading(): - from ptbench.data.montgomery import dataset + import torch + import torchvision.transforms + + from ptbench.data.datamodule import _DelayedLoadingDataset def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert data["data"].size in ( - (4020, 4892), # portrait - (4892, 4020), # landscape - (512, 512), # test database @ CI - ) - assert data["data"].mode == "L" # Check colors + data = s[0] + metadata = s[1] + + assert isinstance(data, torch.Tensor) + + assert data.size(0) == 1 # check single channel + assert data.size(1) == data.size(2) # check square image - assert "label" in data - assert data["label"] in [0, 1] # Check labels + assert ( + torchvision.transforms.ToPILImage()(data).mode == "L" + ) # Check colors + + assert "label" in metadata + assert metadata["label"] in [0, 1] # Check labels limit = 30 # use this to limit testing to first images only, else None - subset = dataset.subsets("default") - for s in subset["train"][:limit]: + datamodule = importlib.import_module( + "ptbench.data.montgomery.default" + ).datamodule + subset = datamodule.database_split.subsets + raw_data_loader = datamodule.raw_data_loader + + # Need to use private function so we can limit the number of samples to use + dataset = _DelayedLoadingDataset( + subset["train"][:limit], + raw_data_loader + ) + + for s in dataset: _check_sample(s) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_check(): - from ptbench.data.montgomery import dataset + from ptbench.data.split import check_database_split_loading + + limit = 30 # use this to limit testing to first images only, else 0 + + # Default protocol + datamodule = importlib.import_module( + "ptbench.data.montgomery.default" + ).datamodule + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + assert ( + check_database_split_loading( + database_split, raw_data_loader, limit=limit + ) + == 0 + ) + + # Folds + for f in range(10): + datamodule = importlib.import_module( + f"ptbench.data.montgomery.fold_{f}" + ).datamodule + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + assert ( + check_database_split_loading( + database_split, raw_data_loader, limit=limit + ) + == 0 + ) - assert dataset.check() == 0