From bd79cd2f2d45238fcaf534641eceb8ad4d63497c Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 13 Jun 2023 15:00:37 +0200
Subject: [PATCH] Removed reliance on make_dataset and added method to cache
 samples for default shenzhen

---
 pyproject.toml                                |  2 +-
 src/ptbench/configs/datasets/__init__.py      |  1 -
 .../configs/datasets/shenzhen/__init__.py     |  3 +
 .../datasets}/shenzhen/default.py             | 29 +++++++--
 src/ptbench/configs/models/pasa.py            |  8 ++-
 src/ptbench/data/__init__.py                  | 62 ++++++++++---------
 src/ptbench/data/dataset.py                   |  5 +-
 src/ptbench/data/loader.py                    | 10 ++-
 src/ptbench/data/shenzhen/__init__.py         | 58 +++++++----------
 src/ptbench/data/transforms.py                |  5 +-
 src/ptbench/models/pasa.py                    | 15 ++++-
 src/ptbench/scripts/train.py                  | 10 +++
 12 files changed, 131 insertions(+), 77 deletions(-)
 create mode 100644 src/ptbench/configs/datasets/shenzhen/__init__.py
 rename src/ptbench/{data => configs/datasets}/shenzhen/default.py (57%)

diff --git a/pyproject.toml b/pyproject.toml
index e43899d2..a418e59b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -117,7 +117,7 @@ montgomery_rs_f7 = "ptbench.configs.datasets.montgomery_RS.fold_7"
 montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8"
 montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9"
 # shenzhen dataset (and cross-validation folds)
-shenzhen = "ptbench.data.shenzhen.default"
+shenzhen = "ptbench.configs.datasets.shenzhen.default"
 shenzhen_rgb = "ptbench.data.shenzhen.rgb"
 shenzhen_f0 = "ptbench.data.shenzhen.fold_0"
 shenzhen_f1 = "ptbench.data.shenzhen.fold_1"
diff --git a/src/ptbench/configs/datasets/__init__.py b/src/ptbench/configs/datasets/__init__.py
index 400d5423..1e4d9131 100644
--- a/src/ptbench/configs/datasets/__init__.py
+++ b/src/ptbench/configs/datasets/__init__.py
@@ -267,7 +267,6 @@ def get_positive_weights(dataset):
         the positive weight of each class in the dataset given as input
     """
     targets = []
-
     if isinstance(dataset, torch.utils.data.ConcatDataset):
         for ds in dataset.datasets:
             for s in ds._samples:
diff --git a/src/ptbench/configs/datasets/shenzhen/__init__.py b/src/ptbench/configs/datasets/shenzhen/__init__.py
new file mode 100644
index 00000000..84b9088e
--- /dev/null
+++ b/src/ptbench/configs/datasets/shenzhen/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
similarity index 57%
rename from src/ptbench/data/shenzhen/default.py
rename to src/ptbench/configs/datasets/shenzhen/default.py
index bbeabcaf..6f5b31ff 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/configs/datasets/shenzhen/default.py
@@ -12,9 +12,10 @@
 
 from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+from ....data import return_subsets
+from ....data.base_datamodule import BaseDataModule
+from ....data.dataset import JSONDataset
+from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -25,6 +26,7 @@ class DefaultModule(BaseDataModule):
         train_batch_size=1,
         predict_batch_size=1,
         drop_incomplete_batch=False,
+        cache_samples=False,
         multiproc_kwargs=None,
     ):
         super().__init__(
@@ -34,14 +36,31 @@ class DefaultModule(BaseDataModule):
             multiproc_kwargs=multiproc_kwargs,
         )
 
+        self.cache_samples = cache_samples
+
     def setup(self, stage: str):
-        self.dataset = _maker("default")
+        if self.cache_samples:
+            logger.info(
+                "Argument cache_samples set to True. Samples will be loaded in memory."
+            )
+            samples_loader = _cached_loader
+        else:
+            logger.info(
+                "Argument cache_samples set to False. Samples will be loaded at runtime."
+            )
+            samples_loader = _delayed_loader
+
+        self.json_dataset = JSONDataset(
+            protocols=_protocols,
+            fieldnames=("data", "label"),
+            loader=samples_loader,
+        )
         (
             self.train_dataset,
             self.validation_dataset,
             self.extra_validation_datasets,
             self.predict_dataset,
-        ) = return_subsets(self.dataset)
+        ) = return_subsets(self.json_dataset, "default")
 
 
 datamodule = DefaultModule
diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index 3ee0b921..cda0540f 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -13,7 +13,9 @@ Reference: [PASA-2019]_
 
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
+from torchvision import transforms
 
+from ...data.transforms import ElasticDeformation
 from ...models.pasa import PASA
 
 # config
@@ -26,5 +28,9 @@ optimizer = "Adam"
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+train_transforms = transforms.Compose([ElasticDeformation(p=0.8)])
+
 # model
-model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
+model = PASA(
+    train_transforms, criterion, criterion_valid, optimizer, optimizer_configs
+)
diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py
index 516af66b..8131b51b 100644
--- a/src/ptbench/data/__init__.py
+++ b/src/ptbench/data/__init__.py
@@ -6,6 +6,8 @@ import torch
 
 from clapper.logging import setup
 
+from .utils import SampleListDataset
+
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
@@ -301,41 +303,45 @@ def get_positive_weights(dataset):
     return positive_weights
 
 
-def return_subsets(dataset):
+def return_subsets(dataset, protocol):
     train_dataset = None
     validation_dataset = None
     extra_validation_datasets = None
     predict_dataset = None
 
-    if "__train__" in dataset:
-        logger.info("Found (dedicated) '__train__' set for training")
-        train_dataset = dataset["__train__"]
-    else:
-        train_dataset = dataset["train"]
-
-    if "__valid__" in dataset:
-        logger.info("Found (dedicated) '__valid__' set for validation")
-        validation_dataset = dataset["__valid__"]
-
-    if "__extra_valid__" in dataset:
-        if not isinstance(dataset["__extra_valid__"], list):
-            raise RuntimeError(
-                f"If present, dataset['__extra_valid__'] must be a list, "
-                f"but you passed a {type(dataset['__extra_valid__'])}, "
-                f"which is invalid."
+    subsets = dataset.subsets(protocol)
+    if "train" in subsets.keys():
+        train_dataset = SampleListDataset(subsets["train"], [])
+
+        if "validation" in subsets.keys():
+            validation_dataset = SampleListDataset(subsets["validation"], [])
+        else:
+            logger.warning(
+                "No validation dataset found, using training set instead."
             )
-        logger.info(
-            f"Found {len(dataset['__extra_valid__'])} extra validation "
-            f"set(s) to be tracked during training"
-        )
-        logger.info(
-            "Extra validation sets are NOT used for model checkpointing!"
-        )
-        extra_validation_datasets = dataset["__extra_valid__"]
-    else:
-        extra_validation_datasets = None
+            validation_dataset = train_dataset
+
+        if "__extra_valid__" in subsets.keys():
+            if not isinstance(subsets["__extra_valid__"], list):
+                raise RuntimeError(
+                    f"If present, dataset['__extra_valid__'] must be a list, "
+                    f"but you passed a {type(subsets['__extra_valid__'])}, "
+                    f"which is invalid."
+                )
+            logger.info(
+                f"Found {len(subsets['__extra_valid__'])} extra validation "
+                f"set(s) to be tracked during training"
+            )
+            logger.info(
+                "Extra validation sets are NOT used for model checkpointing!"
+            )
+            extra_validation_datasets = SampleListDataset(
+                subsets["__extra_valid__"], []
+            )
+        else:
+            extra_validation_datasets = None
 
-    predict_dataset = dataset
+        predict_dataset = subsets
 
     return (
         train_dataset,
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index 1c425562..b1ffcada 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -10,6 +10,7 @@ import pathlib
 import random
 
 import torch
+import tqdm
 
 from torchvision.transforms import RandomRotation
 
@@ -169,12 +170,14 @@ class JSONDataset:
 
         retval = {}
         for subset, samples in data.items():
+            logger.info(f"Loading subset {subset} samples.")
+
             retval[subset] = [
                 self._loader(
                     dict(protocol=protocol, subset=subset, order=n),
                     dict(zip(self.fieldnames, k)),
                 )
-                for n, k in enumerate(samples)
+                for n, k in tqdm.tqdm(enumerate(samples))
             ]
 
         return retval
diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py
index 12a7517e..931c6291 100644
--- a/src/ptbench/data/loader.py
+++ b/src/ptbench/data/loader.py
@@ -10,7 +10,7 @@ import functools
 
 import PIL.Image
 
-from .sample import DelayedSample
+from .sample import DelayedSample, Sample
 
 
 def load_pil(path):
@@ -70,6 +70,14 @@ def load_pil_rgb(path):
     return load_pil(path).convert("RGB")
 
 
+def make_cached(sample, loader, key=None):
+    return Sample(
+        loader(sample),
+        key=key or sample["data"],
+        label=sample["label"],
+    )
+
+
 def make_delayed(sample, loader, key=None):
     """Returns a delayed-loading Sample object.
 
diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py
index a854e559..9abf5689 100644
--- a/src/ptbench/data/shenzhen/__init__.py
+++ b/src/ptbench/data/shenzhen/__init__.py
@@ -19,16 +19,15 @@ the daily routine using Philips DR Digital Diagnose systems.
   * Validation samples: 16% of TB and healthy CXR (including labels)
   * Test samples: 20% of TB and healthy CXR (including labels)
 """
-
 import importlib.resources
 import os
 
 from clapper.logging import setup
+from torchvision import transforms
 
 from ...utils.rc import load_rc
-from .. import make_dataset
-from ..dataset import JSONDataset
-from ..loader import load_pil_baw, make_delayed
+from ..loader import load_pil_baw, make_cached, make_delayed
+from ..transforms import RemoveBlackBorders
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -49,45 +48,32 @@ _protocols = [
 
 _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
 
+_resize_size = 512
+_cc_size = 512
+
+_data_transforms = transforms.Compose(
+    [
+        RemoveBlackBorders(),
+        transforms.Resize(_resize_size),
+        transforms.CenterCrop(_cc_size),
+        transforms.ToTensor(),
+    ]
+)
+
 
 def _raw_data_loader(sample):
+    raw_data = load_pil_baw(os.path.join(_datadir, sample["data"]))
     return dict(
-        data=load_pil_baw(os.path.join(_datadir, sample["data"])),
+        data=_data_transforms(raw_data),
         label=sample["label"],
     )
 
 
-def _loader(context, sample):
+def _cached_loader(context, sample):
+    return make_cached(sample, _raw_data_loader)
+
+
+def _delayed_loader(context, sample):
     # "context" is ignored in this case - database is homogeneous
     # we returned delayed samples to avoid loading all images at once
     return make_delayed(sample, _raw_data_loader)
-
-
-json_dataset = JSONDataset(
-    protocols=_protocols, fieldnames=("data", "label"), loader=_loader
-)
-"""Shenzhen dataset object."""
-
-
-def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
-    from torchvision import transforms
-
-    from ..transforms import ElasticDeformation, RemoveBlackBorders
-
-    post_transforms = []
-    if RGB:
-        post_transforms = [
-            transforms.Lambda(lambda x: x.convert("RGB")),
-            transforms.ToTensor(),
-        ]
-
-    return make_dataset(
-        [json_dataset.subsets(protocol)],
-        [
-            RemoveBlackBorders(),
-            transforms.Resize(resize_size),
-            transforms.CenterCrop(cc_size),
-        ],
-        [ElasticDeformation(p=0.8)],
-        post_transforms,
-    )
diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index c1f1f7d0..34c3a605 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -19,6 +19,7 @@ import numpy
 import PIL.Image
 
 from scipy.ndimage import gaussian_filter, map_coordinates
+from torchvision import transforms
 
 
 class SingleAutoLevel16to8:
@@ -76,6 +77,8 @@ class ElasticDeformation:
         self.random_state = random_state
         self.p = p
 
+        self.tensor_transform = transforms.Compose([transforms.ToTensor()])
+
     def __call__(self, img):
         if random.random() < self.p:
             img = numpy.asarray(img)
@@ -114,6 +117,6 @@ class ElasticDeformation:
             result[:, :] = map_coordinates(
                 img[:, :], indices, order=self.spline_order, mode=self.mode
             ).reshape(shape)
-            return PIL.Image.fromarray(result)
+            return self.tensor_transform(PIL.Image.fromarray(result))
         else:
             return img
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index cae11375..c127239d 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -17,16 +17,23 @@ class PASA(pl.LightningModule):
     """
 
     def __init__(
-        self, criterion, criterion_valid, optimizer, optimizer_configs
+        self,
+        train_transforms,
+        criterion,
+        criterion_valid,
+        optimizer,
+        optimizer_configs,
     ):
         super().__init__()
 
-        self.save_hyperparameters()
+        self.save_hyperparameters(ignore=["train_transforms"])
 
         self.name = "pasa"
 
         self.normalizer = TorchVisionNormalizer(nb_channels=1)
 
+        self.train_transforms = train_transforms
+
         # First convolution block
         self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
         self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
@@ -126,6 +133,10 @@ class PASA(pl.LightningModule):
     def training_step(self, batch, batch_idx):
         images = batch[1]
         labels = batch[2]
+        for img in images:
+            img = torch.unsqueeze(
+                self.train_transforms(torch.squeeze(img, 0)), 0
+            )
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 083d85d9..c59c81a3 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -165,6 +165,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     default="cpu",
     cls=ResourceOption,
 )
+@click.option(
+    "--cache-samples",
+    help="If set to True, loads the sample into memory, otherwise loads them at runtime.",
+    required=True,
+    show_default=True,
+    default=False,
+    cls=ResourceOption,
+)
 @click.option(
     "--seed",
     "-s",
@@ -235,6 +243,7 @@ def train(
     datamodule,
     checkpoint_period,
     accelerator,
+    cache_samples,
     seed,
     parallel,
     normalization,
@@ -293,6 +302,7 @@ def train(
         train_batch_size=batch_chunk_size,
         drop_incomplete_batch=drop_incomplete_batch,
         multiproc_kwargs=multiproc_kwargs,
+        cache_samples=cache_samples,
     )
     # Manually calling these as we need to access some values to reweight the criterion
     datamodule.prepare_data()
-- 
GitLab