Skip to content
Snippets Groups Projects
Commit 9c8674d8 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Created common DataModule for Shenzhen dataset

parent fc9cb8e4
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -115,6 +115,10 @@ class BaseDataModule(pl.LightningDataModule):
return loaders_dict
def update_module_properties(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _compute_chunk_size(self, batch_size, chunk_count):
batch_chunk_size = batch_size
if batch_size % chunk_count != 0:
......
......@@ -11,95 +11,18 @@
"""
from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import ElasticDeformation, RemoveBlackBorders
from ..transforms import ElasticDeformation
from .utils import ShenzhenDataModule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
protocol_name = "default"
class DefaultModule(BaseDataModule):
def __init__(
self,
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
augmentation_transforms = [ElasticDeformation(p=0.8)]
self._cache_samples = cache_samples
self._has_setup_fit = False
self._has_setup_predict = False
self._protocol = "default"
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = []
self.augmentation_transforms = [ElasticDeformation(p=0.8)]
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
datamodule = DefaultModule
datamodule = ShenzhenDataModule(
protocol="default",
model_transforms=[],
augmentation_transforms=augmentation_transforms,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the
National Library of Medicine, Maryland, USA in collaboration with Shenzhen
No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
The Chest X-rays are from out-patient clinics, and were captured as part of
the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import RemoveBlackBorders
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class ShenzhenDataModule(BaseDataModule):
def __init__(
self,
protocol="default",
model_transforms=[],
augmentation_transforms=[],
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
self._cache_samples = cache_samples
self._has_setup_fit = False
self._has_setup_predict = False
self._protocol = protocol
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = model_transforms
self.augmentation_transforms = augmentation_transforms
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
......@@ -270,7 +270,7 @@ def train(
checkpoint_file = get_checkpoint(output_folder, resume_from)
datamodule = datamodule(
datamodule.update_module_properties(
batch_size=batch_size,
batch_chunk_count=batch_chunk_count,
drop_incomplete_batch=drop_incomplete_batch,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment