From 9c8674d8e2e06ffe7481606a952472ba046ad67f Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 27 Jun 2023 18:50:12 +0200 Subject: [PATCH] Created common DataModule for Shenzhen dataset --- src/ptbench/data/base_datamodule.py | 4 + src/ptbench/data/shenzhen/default.py | 95 +++------------------- src/ptbench/data/shenzhen/utils.py | 114 +++++++++++++++++++++++++++ src/ptbench/scripts/train.py | 2 +- 4 files changed, 128 insertions(+), 87 deletions(-) create mode 100644 src/ptbench/data/shenzhen/utils.py diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index 3bcd441a..8377c663 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -115,6 +115,10 @@ class BaseDataModule(pl.LightningDataModule): return loaders_dict + def update_module_properties(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + def _compute_chunk_size(self, batch_size, chunk_count): batch_chunk_size = batch_size if batch_size % chunk_count != 0: diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index 8afac846..b5fb23f5 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -11,95 +11,18 @@ """ from clapper.logging import setup -from torchvision import transforms -from ..base_datamodule import BaseDataModule -from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset -from ..shenzhen import _protocols, _raw_data_loader -from ..transforms import ElasticDeformation, RemoveBlackBorders +from ..transforms import ElasticDeformation +from .utils import ShenzhenDataModule logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +protocol_name = "default" -class DefaultModule(BaseDataModule): - def __init__( - self, - batch_size=1, - batch_chunk_count=1, - drop_incomplete_batch=False, - cache_samples=False, - parallel=-1, - ): - super().__init__( - batch_size=batch_size, - drop_incomplete_batch=drop_incomplete_batch, - batch_chunk_count=batch_chunk_count, - parallel=parallel, - ) +augmentation_transforms = [ElasticDeformation(p=0.8)] - self._cache_samples = cache_samples - self._has_setup_fit = False - self._has_setup_predict = False - self._protocol = "default" - - self.raw_data_transforms = [ - RemoveBlackBorders(), - transforms.Resize(512), - transforms.CenterCrop(512), - transforms.ToTensor(), - ] - - self.model_transforms = [] - - self.augmentation_transforms = [ElasticDeformation(p=0.8)] - - def setup(self, stage: str): - json_protocol = JSONProtocol( - protocols=_protocols, - fieldnames=("data", "label"), - ) - - if self._cache_samples: - dataset = CachedDataset - else: - dataset = RuntimeDataset - - if not self._has_setup_fit and stage == "fit": - self.train_dataset = dataset( - json_protocol, - self._protocol, - "train", - _raw_data_loader, - self._build_transforms(is_train=True), - ) - - self.validation_dataset = dataset( - json_protocol, - self._protocol, - "validation", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - - self._has_setup_fit = True - - if not self._has_setup_predict and stage == "predict": - self.train_dataset = dataset( - json_protocol, - self._protocol, - "train", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - self.validation_dataset = dataset( - json_protocol, - self._protocol, - "validation", - _raw_data_loader, - self._build_transforms(is_train=False), - ) - - self._has_setup_predict = True - - -datamodule = DefaultModule +datamodule = ShenzhenDataModule( + protocol="default", + model_transforms=[], + augmentation_transforms=augmentation_transforms, +) diff --git a/src/ptbench/data/shenzhen/utils.py b/src/ptbench/data/shenzhen/utils.py new file mode 100644 index 00000000..1521b674 --- /dev/null +++ b/src/ptbench/data/shenzhen/utils.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Shenzhen dataset for computer-aided diagnosis. + +The standard digital image database for Tuberculosis is created by the +National Library of Medicine, Maryland, USA in collaboration with Shenzhen +No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China. +The Chest X-rays are from out-patient clinics, and were captured as part of +the daily routine using Philips DR Digital Diagnose systems. + +* Reference: [MONTGOMERY-SHENZHEN-2014]_ +* Original resolution (height x width or width x height): 3000 x 3000 or less +* Split reference: none +* Protocol ``default``: + + * Training samples: 64% of TB and healthy CXR (including labels) + * Validation samples: 16% of TB and healthy CXR (including labels) + * Test samples: 20% of TB and healthy CXR (including labels) +""" +from clapper.logging import setup +from torchvision import transforms + +from ..base_datamodule import BaseDataModule +from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset +from ..shenzhen import _protocols, _raw_data_loader +from ..transforms import RemoveBlackBorders + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class ShenzhenDataModule(BaseDataModule): + def __init__( + self, + protocol="default", + model_transforms=[], + augmentation_transforms=[], + batch_size=1, + batch_chunk_count=1, + drop_incomplete_batch=False, + cache_samples=False, + parallel=-1, + ): + super().__init__( + batch_size=batch_size, + drop_incomplete_batch=drop_incomplete_batch, + batch_chunk_count=batch_chunk_count, + parallel=parallel, + ) + + self._cache_samples = cache_samples + self._has_setup_fit = False + self._has_setup_predict = False + self._protocol = protocol + + self.raw_data_transforms = [ + RemoveBlackBorders(), + transforms.Resize(512), + transforms.CenterCrop(512), + transforms.ToTensor(), + ] + + self.model_transforms = model_transforms + + self.augmentation_transforms = augmentation_transforms + + def setup(self, stage: str): + json_protocol = JSONProtocol( + protocols=_protocols, + fieldnames=("data", "label"), + ) + + if self._cache_samples: + dataset = CachedDataset + else: + dataset = RuntimeDataset + + if not self._has_setup_fit and stage == "fit": + self.train_dataset = dataset( + json_protocol, + self._protocol, + "train", + _raw_data_loader, + self._build_transforms(is_train=True), + ) + + self.validation_dataset = dataset( + json_protocol, + self._protocol, + "validation", + _raw_data_loader, + self._build_transforms(is_train=False), + ) + + self._has_setup_fit = True + + if not self._has_setup_predict and stage == "predict": + self.train_dataset = dataset( + json_protocol, + self._protocol, + "train", + _raw_data_loader, + self._build_transforms(is_train=False), + ) + self.validation_dataset = dataset( + json_protocol, + self._protocol, + "validation", + _raw_data_loader, + self._build_transforms(is_train=False), + ) + + self._has_setup_predict = True diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 6b37e1a1..4d2a226b 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -270,7 +270,7 @@ def train( checkpoint_file = get_checkpoint(output_folder, resume_from) - datamodule = datamodule( + datamodule.update_module_properties( batch_size=batch_size, batch_chunk_count=batch_chunk_count, drop_incomplete_batch=drop_incomplete_batch, -- GitLab