diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index 3bcd441a59a8e08c2e4d9aa33808c8d651026100..8377c66328241f549dbb2f10946f9cc973ef7a6f 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -115,6 +115,10 @@ class BaseDataModule(pl.LightningDataModule): return loaders_dict + def update_module_properties(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + def _compute_chunk_size(self, batch_size, chunk_count): batch_chunk_size = batch_size if batch_size % chunk_count != 0: diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index 8afac8469920dde1777b0efc6ae3919c1e381f13..b5fb23f59277f1f6511c6707929856cc2357be2f 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -11,95 +11,18 @@ """ from clapper.logging import setup -from torchvision import transforms -from ..base_datamodule import BaseDataModule -from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset -from ..shenzhen import _protocols, _raw_data_loader -from ..transforms import ElasticDeformation, RemoveBlackBorders +from ..transforms import ElasticDeformation +from .utils import ShenzhenDataModule logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +protocol_name = "default" -class DefaultModule(BaseDataModule): - def __init__( - self, - batch_size=1, - batch_chunk_count=1, - drop_incomplete_batch=False, - cache_samples=False, - parallel=-1, - ): - super().__init__( - batch_size=batch_size, - drop_incomplete_batch=drop_incomplete_batch, - batch_chunk_count=batch_chunk_count, - parallel=parallel, - ) +augmentation_transforms = [ElasticDeformation(p=0.8)] - self._cache_samples = cache_samples - self._has_setup_fit = False - self._has_setup_predict = False - self._protocol = "default" - - self.raw_data_transforms = [ - RemoveBlackBorders(), - transforms.Resize(512), - transforms.CenterCrop(512), - transforms.ToTensor(), - ] - - self.model_transforms = [] - - self.augmentation_transforms = [ElasticDeformation(p=0.8)] - - def setup(self, stage: str): - json_protocol = JSONProtocol( - protocols=_protocols, - fieldnames=("data", "label"), - ) - - if self._cache_samples: - dataset = CachedDataset - else: - dataset = RuntimeDataset - - if not self._has_setup_fit and stage == "fit": - self.train_dataset = dataset( - json_protocol, - self._protocol, - "train", - _raw_data_loader, - self._build_transforms(is_train=True), - ) - - self.validation_dataset = dataset( - json_protocol, - self._protocol, - "validation", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - - self._has_setup_fit = True - - if not self._has_setup_predict and stage == "predict": - self.train_dataset = dataset( - json_protocol, - self._protocol, - "train", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - self.validation_dataset = dataset( - json_protocol, - self._protocol, - "validation", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - - self._has_setup_predict = True - - -datamodule = DefaultModule +datamodule = ShenzhenDataModule( + protocol="default", + model_transforms=[], + augmentation_transforms=augmentation_transforms, +) diff --git a/src/ptbench/data/shenzhen/utils.py b/src/ptbench/data/shenzhen/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1521b674212feec942e73df081e7b20c19d89e29 --- /dev/null +++ b/src/ptbench/data/shenzhen/utils.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Shenzhen dataset for computer-aided diagnosis. + +The standard digital image database for Tuberculosis is created by the +National Library of Medicine, Maryland, USA in collaboration with Shenzhen +No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China. +The Chest X-rays are from out-patient clinics, and were captured as part of +the daily routine using Philips DR Digital Diagnose systems. + +* Reference: [MONTGOMERY-SHENZHEN-2014]_ +* Original resolution (height x width or width x height): 3000 x 3000 or less +* Split reference: none +* Protocol ``default``: + + * Training samples: 64% of TB and healthy CXR (including labels) + * Validation samples: 16% of TB and healthy CXR (including labels) + * Test samples: 20% of TB and healthy CXR (including labels) +""" +from clapper.logging import setup +from torchvision import transforms + +from ..base_datamodule import BaseDataModule +from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset +from ..shenzhen import _protocols, _raw_data_loader +from ..transforms import RemoveBlackBorders + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class ShenzhenDataModule(BaseDataModule): + def __init__( + self, + protocol="default", + model_transforms=[], + augmentation_transforms=[], + batch_size=1, + batch_chunk_count=1, + drop_incomplete_batch=False, + cache_samples=False, + parallel=-1, + ): + super().__init__( + batch_size=batch_size, + drop_incomplete_batch=drop_incomplete_batch, + batch_chunk_count=batch_chunk_count, + parallel=parallel, + ) + + self._cache_samples = cache_samples + self._has_setup_fit = False + self._has_setup_predict = False + self._protocol = protocol + + self.raw_data_transforms = [ + RemoveBlackBorders(), + transforms.Resize(512), + transforms.CenterCrop(512), + transforms.ToTensor(), + ] + + self.model_transforms = model_transforms + + self.augmentation_transforms = augmentation_transforms + + def setup(self, stage: str): + json_protocol = JSONProtocol( + protocols=_protocols, + fieldnames=("data", "label"), + ) + + if self._cache_samples: + dataset = CachedDataset + else: + dataset = RuntimeDataset + + if not self._has_setup_fit and stage == "fit": + self.train_dataset = dataset( + json_protocol, + self._protocol, + "train", + _raw_data_loader, + self._build_transforms(is_train=True), + ) + + self.validation_dataset = dataset( + json_protocol, + self._protocol, + "validation", + _raw_data_loader, + self._build_transforms(is_train=False), + ) + + self._has_setup_fit = True + + if not self._has_setup_predict and stage == "predict": + self.train_dataset = dataset( + json_protocol, + self._protocol, + "train", + _raw_data_loader, + self._build_transforms(is_train=False), + ) + self.validation_dataset = dataset( + json_protocol, + self._protocol, + "validation", + _raw_data_loader, + self._build_transforms(is_train=False), + ) + + self._has_setup_predict = True diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 6b37e1a109623290e10d6c79009cc69bb403ab05..4d2a226b5b0b479b3f84c748d24aebd43d8d6dea 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -270,7 +270,7 @@ def train( checkpoint_file = get_checkpoint(output_folder, resume_from) - datamodule = datamodule( + datamodule.update_module_properties( batch_size=batch_size, batch_chunk_count=batch_chunk_count, drop_incomplete_batch=drop_incomplete_batch,