Skip to content
Snippets Groups Projects
Commit f92fc663 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Created common DataModule for Shenzhen dataset

parent 2fcec25b
No related branches found
No related tags found
No related merge requests found
Pipeline #75318 failed
...@@ -115,6 +115,10 @@ class BaseDataModule(pl.LightningDataModule): ...@@ -115,6 +115,10 @@ class BaseDataModule(pl.LightningDataModule):
return loaders_dict return loaders_dict
def update_module_properties(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _compute_chunk_size(self, batch_size, chunk_count): def _compute_chunk_size(self, batch_size, chunk_count):
batch_chunk_size = batch_size batch_chunk_size = batch_size
if batch_size % chunk_count != 0: if batch_size % chunk_count != 0:
......
...@@ -11,95 +11,18 @@ ...@@ -11,95 +11,18 @@
""" """
from clapper.logging import setup from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule from ..transforms import ElasticDeformation
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset from .utils import ShenzhenDataModule
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import ElasticDeformation, RemoveBlackBorders
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
protocol_name = "default"
class DefaultModule(BaseDataModule): augmentation_transforms = [ElasticDeformation(p=0.8)]
def __init__(
self,
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
self._cache_samples = cache_samples datamodule = ShenzhenDataModule(
self._has_setup_fit = False protocol="default",
self._has_setup_predict = False model_transforms=[],
self._protocol = "default" augmentation_transforms=augmentation_transforms,
)
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = []
self.augmentation_transforms = [ElasticDeformation(p=0.8)]
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the
National Library of Medicine, Maryland, USA in collaboration with Shenzhen
No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
The Chest X-rays are from out-patient clinics, and were captured as part of
the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import RemoveBlackBorders
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class ShenzhenDataModule(BaseDataModule):
def __init__(
self,
protocol="default",
model_transforms=[],
augmentation_transforms=[],
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
self._cache_samples = cache_samples
self._has_setup_fit = False
self._has_setup_predict = False
self._protocol = protocol
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = model_transforms
self.augmentation_transforms = augmentation_transforms
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
...@@ -270,7 +270,7 @@ def train( ...@@ -270,7 +270,7 @@ def train(
checkpoint_file = get_checkpoint(output_folder, resume_from) checkpoint_file = get_checkpoint(output_folder, resume_from)
datamodule = datamodule( datamodule.update_module_properties(
batch_size=batch_size, batch_size=batch_size,
batch_chunk_count=batch_chunk_count, batch_chunk_count=batch_chunk_count,
drop_incomplete_batch=drop_incomplete_batch, drop_incomplete_batch=drop_incomplete_batch,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment