From bd79cd2f2d45238fcaf534641eceb8ad4d63497c Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 13 Jun 2023 15:00:37 +0200 Subject: [PATCH] Removed reliance on make_dataset and added method to cache samples for default shenzhen --- pyproject.toml | 2 +- src/ptbench/configs/datasets/__init__.py | 1 - .../configs/datasets/shenzhen/__init__.py | 3 + .../datasets}/shenzhen/default.py | 29 +++++++-- src/ptbench/configs/models/pasa.py | 8 ++- src/ptbench/data/__init__.py | 62 ++++++++++--------- src/ptbench/data/dataset.py | 5 +- src/ptbench/data/loader.py | 10 ++- src/ptbench/data/shenzhen/__init__.py | 58 +++++++---------- src/ptbench/data/transforms.py | 5 +- src/ptbench/models/pasa.py | 15 ++++- src/ptbench/scripts/train.py | 10 +++ 12 files changed, 131 insertions(+), 77 deletions(-) create mode 100644 src/ptbench/configs/datasets/shenzhen/__init__.py rename src/ptbench/{data => configs/datasets}/shenzhen/default.py (57%) diff --git a/pyproject.toml b/pyproject.toml index e43899d2..a418e59b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,7 +117,7 @@ montgomery_rs_f7 = "ptbench.configs.datasets.montgomery_RS.fold_7" montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8" montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9" # shenzhen dataset (and cross-validation folds) -shenzhen = "ptbench.data.shenzhen.default" +shenzhen = "ptbench.configs.datasets.shenzhen.default" shenzhen_rgb = "ptbench.data.shenzhen.rgb" shenzhen_f0 = "ptbench.data.shenzhen.fold_0" shenzhen_f1 = "ptbench.data.shenzhen.fold_1" diff --git a/src/ptbench/configs/datasets/__init__.py b/src/ptbench/configs/datasets/__init__.py index 400d5423..1e4d9131 100644 --- a/src/ptbench/configs/datasets/__init__.py +++ b/src/ptbench/configs/datasets/__init__.py @@ -267,7 +267,6 @@ def get_positive_weights(dataset): the positive weight of each class in the dataset given as input """ targets = [] - if isinstance(dataset, torch.utils.data.ConcatDataset): for ds in dataset.datasets: for s in ds._samples: diff --git a/src/ptbench/configs/datasets/shenzhen/__init__.py b/src/ptbench/configs/datasets/shenzhen/__init__.py new file mode 100644 index 00000000..84b9088e --- /dev/null +++ b/src/ptbench/configs/datasets/shenzhen/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py similarity index 57% rename from src/ptbench/data/shenzhen/default.py rename to src/ptbench/configs/datasets/shenzhen/default.py index bbeabcaf..6f5b31ff 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -12,9 +12,10 @@ from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +from ....data import return_subsets +from ....data.base_datamodule import BaseDataModule +from ....data.dataset import JSONDataset +from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -25,6 +26,7 @@ class DefaultModule(BaseDataModule): train_batch_size=1, predict_batch_size=1, drop_incomplete_batch=False, + cache_samples=False, multiproc_kwargs=None, ): super().__init__( @@ -34,14 +36,31 @@ class DefaultModule(BaseDataModule): multiproc_kwargs=multiproc_kwargs, ) + self.cache_samples = cache_samples + def setup(self, stage: str): - self.dataset = _maker("default") + if self.cache_samples: + logger.info( + "Argument cache_samples set to True. Samples will be loaded in memory." + ) + samples_loader = _cached_loader + else: + logger.info( + "Argument cache_samples set to False. Samples will be loaded at runtime." + ) + samples_loader = _delayed_loader + + self.json_dataset = JSONDataset( + protocols=_protocols, + fieldnames=("data", "label"), + loader=samples_loader, + ) ( self.train_dataset, self.validation_dataset, self.extra_validation_datasets, self.predict_dataset, - ) = return_subsets(self.dataset) + ) = return_subsets(self.json_dataset, "default") datamodule = DefaultModule diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index 3ee0b921..cda0540f 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -13,7 +13,9 @@ Reference: [PASA-2019]_ from torch import empty from torch.nn import BCEWithLogitsLoss +from torchvision import transforms +from ...data.transforms import ElasticDeformation from ...models.pasa import PASA # config @@ -26,5 +28,9 @@ optimizer = "Adam" criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +train_transforms = transforms.Compose([ElasticDeformation(p=0.8)]) + # model -model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) +model = PASA( + train_transforms, criterion, criterion_valid, optimizer, optimizer_configs +) diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py index 516af66b..8131b51b 100644 --- a/src/ptbench/data/__init__.py +++ b/src/ptbench/data/__init__.py @@ -6,6 +6,8 @@ import torch from clapper.logging import setup +from .utils import SampleListDataset + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -301,41 +303,45 @@ def get_positive_weights(dataset): return positive_weights -def return_subsets(dataset): +def return_subsets(dataset, protocol): train_dataset = None validation_dataset = None extra_validation_datasets = None predict_dataset = None - if "__train__" in dataset: - logger.info("Found (dedicated) '__train__' set for training") - train_dataset = dataset["__train__"] - else: - train_dataset = dataset["train"] - - if "__valid__" in dataset: - logger.info("Found (dedicated) '__valid__' set for validation") - validation_dataset = dataset["__valid__"] - - if "__extra_valid__" in dataset: - if not isinstance(dataset["__extra_valid__"], list): - raise RuntimeError( - f"If present, dataset['__extra_valid__'] must be a list, " - f"but you passed a {type(dataset['__extra_valid__'])}, " - f"which is invalid." + subsets = dataset.subsets(protocol) + if "train" in subsets.keys(): + train_dataset = SampleListDataset(subsets["train"], []) + + if "validation" in subsets.keys(): + validation_dataset = SampleListDataset(subsets["validation"], []) + else: + logger.warning( + "No validation dataset found, using training set instead." ) - logger.info( - f"Found {len(dataset['__extra_valid__'])} extra validation " - f"set(s) to be tracked during training" - ) - logger.info( - "Extra validation sets are NOT used for model checkpointing!" - ) - extra_validation_datasets = dataset["__extra_valid__"] - else: - extra_validation_datasets = None + validation_dataset = train_dataset + + if "__extra_valid__" in subsets.keys(): + if not isinstance(subsets["__extra_valid__"], list): + raise RuntimeError( + f"If present, dataset['__extra_valid__'] must be a list, " + f"but you passed a {type(subsets['__extra_valid__'])}, " + f"which is invalid." + ) + logger.info( + f"Found {len(subsets['__extra_valid__'])} extra validation " + f"set(s) to be tracked during training" + ) + logger.info( + "Extra validation sets are NOT used for model checkpointing!" + ) + extra_validation_datasets = SampleListDataset( + subsets["__extra_valid__"], [] + ) + else: + extra_validation_datasets = None - predict_dataset = dataset + predict_dataset = subsets return ( train_dataset, diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index 1c425562..b1ffcada 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -10,6 +10,7 @@ import pathlib import random import torch +import tqdm from torchvision.transforms import RandomRotation @@ -169,12 +170,14 @@ class JSONDataset: retval = {} for subset, samples in data.items(): + logger.info(f"Loading subset {subset} samples.") + retval[subset] = [ self._loader( dict(protocol=protocol, subset=subset, order=n), dict(zip(self.fieldnames, k)), ) - for n, k in enumerate(samples) + for n, k in tqdm.tqdm(enumerate(samples)) ] return retval diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py index 12a7517e..931c6291 100644 --- a/src/ptbench/data/loader.py +++ b/src/ptbench/data/loader.py @@ -10,7 +10,7 @@ import functools import PIL.Image -from .sample import DelayedSample +from .sample import DelayedSample, Sample def load_pil(path): @@ -70,6 +70,14 @@ def load_pil_rgb(path): return load_pil(path).convert("RGB") +def make_cached(sample, loader, key=None): + return Sample( + loader(sample), + key=key or sample["data"], + label=sample["label"], + ) + + def make_delayed(sample, loader, key=None): """Returns a delayed-loading Sample object. diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py index a854e559..9abf5689 100644 --- a/src/ptbench/data/shenzhen/__init__.py +++ b/src/ptbench/data/shenzhen/__init__.py @@ -19,16 +19,15 @@ the daily routine using Philips DR Digital Diagnose systems. * Validation samples: 16% of TB and healthy CXR (including labels) * Test samples: 20% of TB and healthy CXR (including labels) """ - import importlib.resources import os from clapper.logging import setup +from torchvision import transforms from ...utils.rc import load_rc -from .. import make_dataset -from ..dataset import JSONDataset -from ..loader import load_pil_baw, make_delayed +from ..loader import load_pil_baw, make_cached, make_delayed +from ..transforms import RemoveBlackBorders logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -49,45 +48,32 @@ _protocols = [ _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) +_resize_size = 512 +_cc_size = 512 + +_data_transforms = transforms.Compose( + [ + RemoveBlackBorders(), + transforms.Resize(_resize_size), + transforms.CenterCrop(_cc_size), + transforms.ToTensor(), + ] +) + def _raw_data_loader(sample): + raw_data = load_pil_baw(os.path.join(_datadir, sample["data"])) return dict( - data=load_pil_baw(os.path.join(_datadir, sample["data"])), + data=_data_transforms(raw_data), label=sample["label"], ) -def _loader(context, sample): +def _cached_loader(context, sample): + return make_cached(sample, _raw_data_loader) + + +def _delayed_loader(context, sample): # "context" is ignored in this case - database is homogeneous # we returned delayed samples to avoid loading all images at once return make_delayed(sample, _raw_data_loader) - - -json_dataset = JSONDataset( - protocols=_protocols, fieldnames=("data", "label"), loader=_loader -) -"""Shenzhen dataset object.""" - - -def _maker(protocol, resize_size=512, cc_size=512, RGB=False): - from torchvision import transforms - - from ..transforms import ElasticDeformation, RemoveBlackBorders - - post_transforms = [] - if RGB: - post_transforms = [ - transforms.Lambda(lambda x: x.convert("RGB")), - transforms.ToTensor(), - ] - - return make_dataset( - [json_dataset.subsets(protocol)], - [ - RemoveBlackBorders(), - transforms.Resize(resize_size), - transforms.CenterCrop(cc_size), - ], - [ElasticDeformation(p=0.8)], - post_transforms, - ) diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index c1f1f7d0..34c3a605 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -19,6 +19,7 @@ import numpy import PIL.Image from scipy.ndimage import gaussian_filter, map_coordinates +from torchvision import transforms class SingleAutoLevel16to8: @@ -76,6 +77,8 @@ class ElasticDeformation: self.random_state = random_state self.p = p + self.tensor_transform = transforms.Compose([transforms.ToTensor()]) + def __call__(self, img): if random.random() < self.p: img = numpy.asarray(img) @@ -114,6 +117,6 @@ class ElasticDeformation: result[:, :] = map_coordinates( img[:, :], indices, order=self.spline_order, mode=self.mode ).reshape(shape) - return PIL.Image.fromarray(result) + return self.tensor_transform(PIL.Image.fromarray(result)) else: return img diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index cae11375..c127239d 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -17,16 +17,23 @@ class PASA(pl.LightningModule): """ def __init__( - self, criterion, criterion_valid, optimizer, optimizer_configs + self, + train_transforms, + criterion, + criterion_valid, + optimizer, + optimizer_configs, ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=["train_transforms"]) self.name = "pasa" self.normalizer = TorchVisionNormalizer(nb_channels=1) + self.train_transforms = train_transforms + # First convolution block self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) @@ -126,6 +133,10 @@ class PASA(pl.LightningModule): def training_step(self, batch, batch_idx): images = batch[1] labels = batch[2] + for img in images: + img = torch.unsqueeze( + self.train_transforms(torch.squeeze(img, 0)), 0 + ) # Increase label dimension if too low # Allows single and multiclass usage diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 083d85d9..c59c81a3 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -165,6 +165,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") default="cpu", cls=ResourceOption, ) +@click.option( + "--cache-samples", + help="If set to True, loads the sample into memory, otherwise loads them at runtime.", + required=True, + show_default=True, + default=False, + cls=ResourceOption, +) @click.option( "--seed", "-s", @@ -235,6 +243,7 @@ def train( datamodule, checkpoint_period, accelerator, + cache_samples, seed, parallel, normalization, @@ -293,6 +302,7 @@ def train( train_batch_size=batch_chunk_size, drop_incomplete_batch=drop_incomplete_batch, multiproc_kwargs=multiproc_kwargs, + cache_samples=cache_samples, ) # Manually calling these as we need to access some values to reweight the criterion datamodule.prepare_data() -- GitLab